├── .gitignore ├── READING_LIST.md ├── README.md ├── proposal.pdf ├── requirements.txt └── src ├── README.md ├── config ├── 5shot.yaml ├── __init__.py ├── lstm_baseline.yaml ├── lstm_baseline_test_seed.yaml ├── lyrics.yaml ├── midi.yaml └── unigram.yaml ├── data ├── __init__.py ├── base_loader.py ├── dataset.py ├── episode.py ├── lyrics_loader.py └── midi_loader.py ├── evaluation └── __init__.py ├── models ├── __init__.py ├── base_model.py ├── lstm_baseline.py ├── tf_model.py └── unigram_model.py └── train ├── __init__.py ├── test_seed.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .*.swp 3 | -------------------------------------------------------------------------------- /READING_LIST.md: -------------------------------------------------------------------------------- 1 | # Reading List 2 | 3 | ## On few-shot learning based on learning an initialization and fine-tuning procedure 4 | 5 | * [Optimization as a model for few-shot learning](https://openreview.net/forum?id=rJY0-Kcll) 6 | * [Model-agnostic meta-learning for fast adaptation of deep networks](https://arxiv.org/abs/1703.03400) 7 | * [One-Shot Visual Imitation Learning via Meta-Learning](https://arxiv.org/abs/1709.04905) 8 | 9 | ## On few-shot learning based on a generic neural network architecture 10 | 11 | * [One-Shot Learning with Memory-Augmented Neural Networks](https://arxiv.org/abs/1605.06065) 12 | * [A Simple Neural Attentive Meta-Learner](https://openreview.net/forum?id=B1DmUzWAW) 13 | * [One-Shot Imitation Learning](https://arxiv.org/abs/1703.07326) 14 | 15 | ## On adaptive language models 16 | 17 | * [Improving neural language models with a continuous cache](https://arxiv.org/abs/1612.04426) 18 | * [Dynamic evaluation of neural sequence models](https://arxiv.org/abs/1709.07432) 19 | 20 | ## On conditional language models using attention 21 | 22 | * [Attention is all you need](https://arxiv.org/abs/1706.03762) 23 | 24 | ## On meta-learning for distribution learning: 25 | 26 | * [One-Shot Generalization in Deep Generative Models](https://arxiv.org/abs/1603.05106) 27 | * [Few-shot Autoregressive Density Estimation: Towards Learning to Learn Distributions](https://arxiv.org/abs/1710.10304) 28 | 29 | ## On models for music generation 30 | 31 | * [Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription](http://www-etud.iro.umontreal.ca/~boulanni/ICML2012.pdf) 32 | * [Sequence Tutor: Conservative Fine-Tuning of Sequence Generation Models with KL-control](https://arxiv.org/abs/1611.02796) 33 | * [Counterpoint by Convolution](https://openreview.net/forum?id=r1Usiwcex) 34 | * [Hierarchical Variational Autoencoders for Music](https://nips2017creativity.github.io/doc/Hierarchical_Variational_Autoencoders_for_Music.pdf) 35 | 36 | ## On MIDI files 37 | 38 | * [ISMIR 2017 Tutorial: Leveraging MIDI Files for Music Information Retrieval](https://youtu.be/iZt7tpYR6MI) 39 | [[Slides](http://colinraffel.com/talks/ismir2017leveraging.pdf)] 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Few-Shot Music Generation 2 | 3 | ## Few-Shot Distribution Learning for Music Generation 4 | 5 | * Tagline: Learning a generative model for music data using a small amount of examples. 6 | * Date: December 2017 7 | * Category: Fundamental Research 8 | * Author(s): [Hugo Larochelle](https://github.com/larocheh), [Chelsea Finn](https://github.com/cbfinn), [Sachin Ravi](https://github.com/sachinravi14) 9 | 10 | ## Project Status 11 | 12 | * ~~Brainstorming for datasets phase: Currently collecting ideas for dataset collection for lyrics and MIDI data. See the Issues page for details.~~ 13 | * ~~Collecting actual data for lyrics and MIDI.~~ 14 | * ~~Decide and implement pre-processing scheme for data (specifically for MIDI).~~ 15 | * ~~Release training script and model API code.~~ 16 | * Experiment with new models on both datasets. 17 | 18 | ## Community Links 19 | 20 | * [Project Slack](https://few-shot-music-gen.slack.com/join/shared_invite/enQtMjgwMTA0NTA3MzQ3LTA3MTc3M2E4MjEyNDlhZDNlMTU2ZmUyMmNmMDlhYmQ2ZmFkMDRiZTAzZDJmYmYwYmE0NjRmZGMyMmYxOWEzMWU) 21 | * [Project Mailing List](https://groups.google.com/forum/#!forum/few-shot-music-generation) 22 | 23 | ## Problem description: 24 | 25 | See Introduction section of the proposal. 26 | 27 | ## Why this problem matters: 28 | 29 | See Introduction section of the proposal. 30 | 31 | ## How to measure success: 32 | 33 | See Experiments section of the proposal. 34 | 35 | ## Datasets: 36 | 37 | See Datasets subsection of the proposal. 38 | 39 | ## Relevant Work: 40 | 41 | See Related Work section of the proposal. 42 | 43 | ## Contribute 44 | 45 | * Please begin by reading papers from the [Reading List](https://github.com/AI-ON/Few-Shot-Music-Generation/blob/master/READING_LIST.md) to familiarize yourself with work in this area. 46 | 47 | ### Data 48 | 49 | Both the lyrics and freemidi data can be downloaded [here](https://drive.google.com/drive/u/1/folders/1sI1K3CjzpN81QjjpaEDVKW79c7AOUdyQ). Place the `raw-data` directory in the home folder of the repository and make sure to unzip both `.zip` files in the data sub-directories. 50 | 51 | For example, for the lyrics data, make sure in the given path the following files and directories exist: 52 | ``` 53 | ls Few-Shot-Music-Generation/raw-data/lyrics/ 54 | >> lyrics_data test.csv train.csv val.csv 55 | ``` 56 | 57 | ### Training Models 58 | Sample run (check the different yaml files for different ways to run): 59 | ``` 60 | $ CONFIG=lyrics.yaml 61 | $ MODEL=lstm_baseline.yaml 62 | $ TASK=5shot.yaml 63 | python -um train.train --data=config/${CONFIG} --model=config/${MODEL} --task=config/${TASK} --checkpt_dir=/tmp/fewshot/lstm_baseline 64 | ``` 65 | 66 | To view the tensorboard (only works for `lstm_baseline.yaml` model): 67 | ``` 68 | $ tensorboard --logdir=/tmp/fewshot 69 | ``` 70 | 71 | If you have any trouble running the code, please create an issue describing your problem. 72 | 73 | 74 | ### Logging results 75 | Please log all your results in [this spreadsheet](https://docs.google.com/spreadsheets/d/18Wb2ct78WnHX2Z9TUgd1mHaJ1zDXznapU5MO8f-9ou0/edit#gid=0). 76 | -------------------------------------------------------------------------------- /proposal.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-ON/Few-Shot-Music-Generation/dcc3709db41113761614508c681a0196ea6011c7/proposal.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Library dependencies for the python code. You need to install these with 2 | # `pip install -r requirements.txt` before you can run this. 3 | 4 | # General 5 | yaml 6 | numpy 7 | 8 | # For data loading 9 | ## Need to also run python -m nltk.downloader punkt 10 | nltk 11 | pretty_midi 12 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Few Shot Music Generation 2 | TODO(all): Add more documentation. 3 | 4 | ## Instructions 5 | Download the zip and csv files from 6 | https://drive.google.com/corp/drive/u/0/folders/1sI1K3CjzpN81QjjpaEDVKW79c7AOUdyQ 7 | and store them in `../raw-data/lyrics` and `../raw-data/freemidi`, respectively. 8 | 9 | Unzip the zip files. 10 | 11 | Sample run (check the different yaml files for different ways to run): 12 | ``` 13 | $ CONFIG=lyrics.yaml 14 | $ MODEL=lstm_baseline.yaml 15 | $ TASK=5shot.yaml 16 | python -um train.train --data=config/${CONFIG} --model=config/${MODEL} --task=config/${TASK} --checkpt_dir=/tmp/fewshot/lstm_baseline 17 | ``` 18 | 19 | To view the tensorboard (only works for `lstm_baseline.yaml` model): 20 | ``` 21 | $ tensorboard --logdir=/tmp/fewshot 22 | ``` 23 | -------------------------------------------------------------------------------- /src/config/5shot.yaml: -------------------------------------------------------------------------------- 1 | query_size: 4 2 | support_size: 5 3 | seed: !!int 1234 4 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-ON/Few-Shot-Music-Generation/dcc3709db41113761614508c681a0196ea6011c7/src/config/__init__.py -------------------------------------------------------------------------------- /src/config/lstm_baseline.yaml: -------------------------------------------------------------------------------- 1 | name: 'lstm_baseline' 2 | model_module_name: 'models.lstm_baseline' 3 | model_class_name: 'LSTMBaseline' 4 | 5 | n_train: !!int 60000 6 | n_decay: !!int 10000 7 | print_every_n: !!int 1000 8 | val_every_n: !!float 5000 9 | n_val: !!int 600 10 | n_test: !!int 600 11 | n_samples: !!int 20 12 | 13 | lr: !!float 5e-3 14 | max_grad_norm: !!int 5 15 | batch_size: !!int 5 16 | embedding_size: !!int 250 17 | n_layers: !!int 1 18 | hidden_size: !!int 200 19 | -------------------------------------------------------------------------------- /src/config/lstm_baseline_test_seed.yaml: -------------------------------------------------------------------------------- 1 | name: 'lstm_baseline' 2 | model_module_name: 'models.lstm_baseline' 3 | model_class_name: 'LSTMBaseline' 4 | 5 | n_train: !!int 60000 6 | n_decay: !!int 10000 7 | print_every_n: !!int 1000 8 | val_every_n: !!float 5000 9 | n_val: !!int 600 10 | n_test: !!int 600 11 | n_samples: !!int 20 12 | 13 | lr: !!float 5e-3 14 | max_grad_norm: !!int 5 15 | batch_size: !!int 5 16 | embedding_size: !!int 250 17 | n_layers: !!int 1 18 | hidden_size: !!int 200 19 | -------------------------------------------------------------------------------- /src/config/lyrics.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'lyrics' 2 | dataset_path: '../raw-data/lyrics/lyrics_data/' 3 | splits: ['train', 'val', 'test'] 4 | max_len: 50 5 | -------------------------------------------------------------------------------- /src/config/midi.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'midi' 2 | dataset_path: '../raw-data/freemidi/freemidi_data/' 3 | splits: ['train', 'val', 'test'] 4 | max_len: 50 5 | -------------------------------------------------------------------------------- /src/config/unigram.yaml: -------------------------------------------------------------------------------- 1 | name: 'unigram_model' 2 | model_module_name: 'models.unigram_model' 3 | model_class_name: 'UnigramModel' 4 | 5 | n_train: !!int 60000 6 | print_every_n: !!int 1000 7 | val_every_n: !!float 5000 8 | n_val: !!int 600 9 | n_test: !!int 600 10 | n_samples: !!int 20 11 | 12 | batch_size: !!int 1 13 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['episode', 'dataset', 'loaders'] 2 | -------------------------------------------------------------------------------- /src/data/base_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | """A module for the parent loader class. 3 | """ 4 | import os 5 | import numpy as np 6 | import logging 7 | 8 | 9 | log = logging.getLogger("few-shot") 10 | 11 | 12 | class Loader(object): 13 | """A class for turning data into a sequence of tokens. 14 | """ 15 | def __init__(self, max_len, dtype=np.int32, persist=True): 16 | self.max_len = max_len 17 | self.dtype = dtype 18 | self.persist = persist 19 | 20 | def is_song(self, filepath): 21 | raise NotImplementedError 22 | 23 | def read(self, filepath): 24 | raise NotImplementedError 25 | 26 | def tokenize(self, data): 27 | raise NotImplementedError 28 | 29 | def detokenize(self, numpy_data): 30 | raise NotImplementedError 31 | 32 | def get_num_tokens(self): 33 | raise NotImplementedError 34 | 35 | def validate(self, filepath): 36 | try: 37 | self.load(filepath) 38 | return True 39 | except OSError: 40 | return False 41 | except KeyError: 42 | return False 43 | except EOFError: 44 | return False 45 | except IndexError: 46 | return False 47 | except ValueError: 48 | return False 49 | except IOError: 50 | return False 51 | 52 | def load(self, filepath): 53 | npfile = '%s.%s.npy' % (filepath, self.max_len) 54 | if self.persist and os.path.isfile(npfile): 55 | return np.load(npfile).astype(self.dtype) 56 | else: 57 | data = self.read(filepath) 58 | tokens = self.tokenize(data) 59 | numpy_tokens = np.zeros(self.max_len, dtype=self.dtype) 60 | for token_index in range(min(self.max_len, len(tokens))): 61 | numpy_tokens[token_index] = tokens[token_index] 62 | if self.persist: 63 | np.save(npfile, numpy_tokens) 64 | return numpy_tokens 65 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | """A dataset class lyrics and MIDI data for the few-shot-music-gen project 3 | """ 4 | import os 5 | import logging 6 | import time 7 | import multiprocessing 8 | import itertools 9 | try: 10 | from urllib import quote, unquote # python 2 11 | except ImportError: 12 | from urllib.parse import quote, unquote # python 3 13 | 14 | import numpy as np 15 | 16 | 17 | log = logging.getLogger('few-shot') 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | 21 | class Metadata(object): 22 | """An object for tracking the metadata associated with a configuration of 23 | the sampler. 24 | """ 25 | def __init__(self, root, name): 26 | self.dir = os.path.join(root, name) 27 | self.open_files = {} 28 | if not os.path.exists(self.dir): 29 | os.makedirs(self.dir) 30 | 31 | def exists(self, filename): 32 | return os.path.exists(os.path.join(self.dir, filename)) 33 | 34 | def lines(self, filename): 35 | if self.exists(os.path.join(self.dir, filename)): 36 | for line in open(os.path.join(self.dir, filename), 'r'): 37 | yield line 38 | 39 | def write(self, filename, line): 40 | if filename not in self.open_files: 41 | self.open_files[filename] = open(os.path.join(self.dir, filename), 'a') 42 | self.open_files[filename].write(line) 43 | 44 | def close(self): 45 | for filename in self.open_files: 46 | self.open_files[filename].close() 47 | 48 | 49 | class Dataset(object): 50 | """A class for train/val/test sets. 51 | 52 | NOTE/TODO: This object only creates disjoint subsets of the full dataset 53 | when persistence is enabled. It might be good to support disjoint subsets 54 | when persistence is disabled. 55 | 56 | This class is initialized with the following arguments: 57 | Arguments: 58 | root (str): the root directory of the dataset 59 | split ("train", "val", or "test"): the split of the dataset which 60 | this object represents. 61 | loader (Loader): the object used for reading and parsing 62 | metadata (Metadata): the object used for persisting metadata 63 | split_proportions (tuple of three numbers): the unnormalized 64 | (train, val, test) split. Specifying the split for all three subsets 65 | (despite the Dataset object only being initialized for a single 66 | subset) is necessary because the dataset will be split and persisted 67 | on the first instantiation of the Dataset object. 68 | persist (bool): persists the train/val/test split information in csv 69 | files in `root`, so future runs will use the same splits. If those 70 | csvs already exist, the sampler uses the splits from those files. 71 | cache (bool): if true, caches the loaded/parsed songs in memory. 72 | Otherwise it loads and parses songs on every episode. 73 | validate (bool): if true, validates every song at initialization. If 74 | the song doesn't pass validation, it is removed from the dataset. 75 | If persist is also set to true, the validation info will be 76 | persisted. 77 | min_songs (int): the minimum number of songs which an artist must have. 78 | If they don't have `min_songs` songs, they will not be present in 79 | the dataset. 80 | valid_songs_file (str): the name file which contains persists a list 81 | of valid songs. 82 | seed (int or None): the random seed which is used for shuffling the 83 | artists. 84 | """ 85 | def __init__(self, root, split, loader, metadata, split_proportions=(8,1,1), 86 | persist=True, cache=True, validate=True, min_songs=0, parallel=False, 87 | valid_songs_file='valid_songs.csv', seed=None): 88 | self.root = root 89 | self.cache = cache 90 | self.cache_data = {} 91 | self.loader = loader 92 | self.metadata = metadata 93 | self.artists = [] 94 | self.valid_songs_file = valid_songs_file 95 | valid_songs = {} 96 | artist_in_split = [] 97 | 98 | # If we're both validating and using persistence, load any validation 99 | # data from disk. The format of the validation file is just a CSV 100 | # with two entries: artist and song. The artist is the name of the 101 | # artist (i.e. the directory (e.g. 'K_s Choice')) and the song is 102 | # the song file name (e.g. 'ironflowers.mid'). 103 | if validate and persist: 104 | for line in self.metadata.lines(valid_songs_file): 105 | artist, song = line.rstrip('\n').split(',', 1) 106 | artist = unquote(artist) 107 | song = unquote(song) 108 | if artist not in valid_songs: 109 | valid_songs[artist] = set() 110 | valid_songs[artist].add(song) 111 | 112 | if persist and self.metadata.exists('%s.csv' % split): 113 | artists_in_split = [] 114 | for line in self.metadata.lines('%s.csv' % split): 115 | artists_in_split.append(line.rstrip('\n')) 116 | else: 117 | dirs = [] 118 | all_artists = [] 119 | skipped_count = 0 120 | pool = multiprocessing.Pool(multiprocessing.cpu_count()) 121 | 122 | for artist in os.listdir(root): 123 | if os.path.isdir(os.path.join(root, artist)): 124 | songs = os.listdir(os.path.join(root, artist)) 125 | songs = [song for song in songs if loader.is_song(song)] 126 | if len(songs) > 0: 127 | dirs.append(artist) 128 | 129 | num_dirs = len(dirs) 130 | progress_logger = ProgressLogger(num_dirs) 131 | 132 | for artist_index, artist in enumerate(dirs): 133 | songs = os.listdir(os.path.join(root, artist)) 134 | # We only want .txt and .mid files. Filter all others. 135 | songs = [song for song in songs if loader.is_song(song)] 136 | # populate `valid_songs[artist]` 137 | if validate: 138 | progress_logger.maybe_log(artist_index) 139 | if artist not in valid_songs: 140 | valid_songs[artist] = set() 141 | songs_to_validate = [song for song in songs if song not in valid_songs[artist]] 142 | song_files = [os.path.join(root, artist, song) for song in songs_to_validate] 143 | if parallel: 144 | mapped = pool.map(loader.validate, song_files) 145 | else: 146 | mapped = map(loader.validate, song_files) 147 | validated = itertools.compress(songs_to_validate, mapped) 148 | for song in validated: 149 | song_file = os.path.join(root, artist, song) 150 | if persist: 151 | line = '%s,%s\n' % (quote(artist), quote(song)) 152 | self.metadata.write(self.valid_songs_file, line) 153 | valid_songs[artist].add(song) 154 | else: 155 | valid_songs[artist] = set(songs) 156 | 157 | if len(valid_songs[artist]) >= min_songs: 158 | all_artists.append(artist) 159 | else: 160 | skipped_count += 1 161 | pool.close() 162 | pool.join() 163 | if skipped_count > 0: 164 | log.info("%s artists don't have K+K'=%s songs. Using %s artists" % ( 165 | skipped_count, min_songs, len(all_artists))) 166 | train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(all_artists)) 167 | val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(all_artists)) 168 | # Use RandomState(seed) so that shuffles with the same set of 169 | # artists will result in the same shuffle on different computers. 170 | np.random.RandomState(seed).shuffle(all_artists) 171 | if persist: 172 | self.metadata.write('train.csv', '\n'.join(all_artists[:train_count])) 173 | self.metadata.write('val.csv', '\n'.join(all_artists[train_count:train_count+val_count])) 174 | self.metadata.write('test.csv', '\n'.join(all_artists[train_count+val_count:])) 175 | if split == 'train': 176 | artists_in_split = all_artists[:train_count] 177 | elif split == 'val': 178 | artists_in_split = all_artists[train_count:train_count+val_count] 179 | else: 180 | artists_in_split = all_artists[train_count+val_count:] 181 | 182 | self.metadata.close() 183 | 184 | for artist in artists_in_split: 185 | self.artists.append(ArtistDataset(artist, list(valid_songs[artist]))) 186 | 187 | def load(self, artist, song): 188 | """Read and parse `song` by `artist`. 189 | 190 | Arguments: 191 | song (str): the name of the song file. e.g. `"lateralus.txt"` 192 | artist (str): the name of the artist directory. e.g. `"tool"` 193 | """ 194 | if self.cache and (artist, song) in self.cache_data: 195 | return self.cache_data[(artist, song)] 196 | else: 197 | data = self.loader.load(os.path.join(self.root, artist, song)) 198 | self.cache_data[(artist, song)] = data 199 | return data 200 | 201 | def __len__(self): 202 | return len(self.artists) 203 | 204 | def __getitem__(self, index): 205 | return self.artists[index] 206 | 207 | 208 | class ArtistDataset(object): 209 | def __init__(self, artist, songs): 210 | self.name = artist 211 | self.songs = songs 212 | 213 | def __len__(self): 214 | return len(self.songs) 215 | 216 | def __getitem__(self, index): 217 | return self.songs[index] 218 | 219 | 220 | class ProgressLogger(object): 221 | def __init__(self, num_dirs): 222 | self.last_log = 0 223 | self.last_log_percent = None 224 | self.num_dirs = num_dirs 225 | 226 | def maybe_log(self, index): 227 | # log progress at most every second 228 | if time.time() - self.last_log >= 1: 229 | if self.last_log_percent != '%.2f' % (100*index/self.num_dirs): 230 | self.last_log_percent = '%.2f' % (100*index/self.num_dirs) 231 | log.info("Preprocessing data. %s%%" % self.last_log_percent) 232 | self.last_log = time.time() 233 | -------------------------------------------------------------------------------- /src/data/episode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import time 4 | import logging 5 | import yaml 6 | 7 | import numpy as np 8 | from numpy.random import RandomState 9 | 10 | from data.midi_loader import MIDILoader 11 | from data.lyrics_loader import LyricsLoader 12 | from data.dataset import Dataset, Metadata 13 | 14 | 15 | class Episode(object): 16 | def __init__(self, support, query): 17 | self.support = support 18 | self.query = query 19 | 20 | 21 | class SQSampler(object): 22 | """A sampler for randomly sampling support/query sets. 23 | 24 | Arguments: 25 | support_size (int): number of songs in the support set 26 | query_size (int): number of songs in the query set 27 | random (RandomState): random generator to use 28 | """ 29 | def __init__(self, support_size, query_size, random): 30 | self.support_size = support_size 31 | self.query_size = query_size 32 | self.random = random 33 | 34 | def sample(self, artist): 35 | sample = self.random.choice( 36 | artist, 37 | size=self.support_size+self.query_size, 38 | replace=False) 39 | query = sample[:self.query_size] 40 | support = sample[self.query_size:] 41 | return query, support 42 | 43 | 44 | class EpisodeSampler(object): 45 | def __init__(self, dataset, batch_size, support_size, query_size, max_len, 46 | dtype=np.int32, seed=None): 47 | self.dataset = dataset 48 | self.batch_size = batch_size 49 | self.support_size = support_size 50 | self.query_size = query_size 51 | self.max_len = max_len 52 | self.dtype = dtype 53 | self.random = get_random(seed) 54 | self.sq_sampler = SQSampler(support_size, query_size, self.random) 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def __repr__(self): 60 | return 'EpisodeSampler("%s", "%s")' % (self.root, self.split) 61 | 62 | def get_episode(self): 63 | support = np.zeros((self.batch_size, self.support_size, self.max_len), dtype=self.dtype) 64 | query = np.zeros((self.batch_size, self.query_size, self.max_len), dtype=self.dtype) 65 | artists = self.random.choice(self.dataset, size=self.batch_size, replace=False) 66 | for batch_index, artist in enumerate(artists): 67 | query_songs, support_songs = self.sq_sampler.sample(artist) 68 | for support_index, song in enumerate(support_songs): 69 | parsed_song = self.dataset.load(artist.name, song) 70 | support[batch_index,support_index,:] = parsed_song 71 | for query_index, song in enumerate(query_songs): 72 | parsed_song = self.dataset.load(artist.name, song) 73 | query[batch_index,query_index,:] = parsed_song 74 | return Episode(support, query) 75 | 76 | def get_num_unique_words(self): 77 | return self.dataset.loader.get_num_tokens() 78 | 79 | def detokenize(self, numpy_data): 80 | return self.dataset.loader.detokenize(numpy_data) 81 | 82 | def load_sampler_from_config(config): 83 | """Create an EpisodeSampler from a yaml config.""" 84 | if isinstance(config, str): 85 | config = yaml.load(open(config, 'r')) 86 | elif isinstance(config, dict): 87 | config = config 88 | else: 89 | config = yaml.load(config) 90 | required_keys = [ 91 | 'dataset_path', 92 | 'query_size', 93 | 'support_size', 94 | 'batch_size', 95 | 'max_len', 96 | 'dataset', 97 | 'split' 98 | ] 99 | optional_keys = [ 100 | 'train_proportion', 101 | 'val_proportion', 102 | 'test_proportion', 103 | 'persist', 104 | 'cache', 105 | 'seed', 106 | 'dataset_seed' 107 | ] 108 | for key in required_keys: 109 | if key not in config: 110 | raise RuntimeError('required config key "%s" not found' % key) 111 | props = ( 112 | config.get('train_proportion', 8), 113 | config.get('val_proportion', 1), 114 | config.get('test_proportion', 1) 115 | ) 116 | root = config['dataset_path'] 117 | if not os.path.isdir(root): 118 | raise RuntimeError('required data directory %s does not exist' % root) 119 | 120 | metadata_dir = 'few_shot_metadata_%s_%s' % (config['dataset'], config['max_len']) 121 | metadata = Metadata(root, metadata_dir) 122 | if config['dataset'] == 'lyrics': 123 | loader = LyricsLoader(config['max_len'], metadata=metadata) 124 | parallel = False 125 | elif config['dataset'] == 'midi': 126 | loader = MIDILoader(config['max_len']) 127 | parallel = False 128 | else: 129 | raise RuntimeError('unknown dataset "%s"' % config['dataset']) 130 | dataset = Dataset( 131 | root, 132 | config['split'], 133 | loader, 134 | metadata, 135 | split_proportions=props, 136 | cache=config.get('cache', True), 137 | persist=config.get('persist', True), 138 | validate=config.get('validate', True), 139 | min_songs=config['support_size']+config['query_size'], 140 | parallel=parallel, 141 | seed=config.get('dataset_seed', 0) 142 | ) 143 | return EpisodeSampler( 144 | dataset, 145 | config['batch_size'], 146 | config['support_size'], 147 | config['query_size'], 148 | config['max_len'], 149 | seed=config.get('seed', None)) 150 | 151 | 152 | def get_random(seed): 153 | if seed is not None: 154 | return RandomState(seed) 155 | else: 156 | return np.random 157 | -------------------------------------------------------------------------------- /src/data/lyrics_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | """A module for lyrics dataset loader. 3 | """ 4 | import logging 5 | import string 6 | import codecs 7 | 8 | import nltk 9 | import numpy as np 10 | 11 | from data.base_loader import Loader 12 | 13 | 14 | log = logging.getLogger("few-shot") 15 | 16 | 17 | class LyricsLoader(Loader): 18 | """Objects of this class parse lyrics files and persist word IDs. 19 | 20 | Arguments: 21 | max_len (int): maximum length of sequence of words 22 | metadata (Metadata): a Metadata object 23 | tokenizer (callable): a callable which takes a file name and returns a 24 | list of words. Defaults to nltk's `word_tokenize`, which requires 25 | the punkt tokenizer models. You can download the models with 26 | `nltk.download('punkt')` 27 | persist (bool): if true, the tokenizer will persist the IDs of each word 28 | to a file. If the file already exists, the tokenizer will bootstrap 29 | from the file. 30 | """ 31 | def __init__(self, max_len, metadata, tokenizer=nltk.word_tokenize, 32 | persist=True, dtype=np.int32): 33 | super(LyricsLoader, self).__init__(max_len, dtype=dtype) 34 | self.tokenizer = tokenizer 35 | self.metadata = metadata 36 | self.word_to_id = {} 37 | self.id_to_word = {} 38 | self.highest_word_id = -1 39 | # read persisted word ids 40 | if persist: 41 | log.info('Loading lyrics metadata...') 42 | for line in self.metadata.lines('word_ids.csv'): 43 | row = line.rstrip('\n').split(',', 1) 44 | word_id = int(row[0]) 45 | self.word_to_id[row[1]] = word_id 46 | self.id_to_word[word_id] = row[1] 47 | if word_id > self.highest_word_id: 48 | self.highest_word_id = word_id 49 | 50 | def is_song(self, filepath): 51 | return filepath.endswith('.txt') 52 | 53 | def read(self, filepath): 54 | """Read a file. 55 | 56 | Arguments: 57 | filepath (str): path to the lyrics file. e.g. 58 | "/home/user/lyrics_data/tool/lateralus.txt" 59 | """ 60 | return ''.join(codecs.open(filepath, 'r', errors='ignore').readlines()) 61 | 62 | def get_num_tokens(self): 63 | return self.highest_word_id + 1 64 | 65 | def tokenize(self, raw_lyrics): 66 | """Turns a string of lyrics data into a numpy array of int "word" IDs. 67 | 68 | Arguments: 69 | raw_lyrics (str): Stringified lyrics data 70 | """ 71 | tokens = [] 72 | for token in self.tokenizer(raw_lyrics): 73 | if token not in self.word_to_id: 74 | self.highest_word_id += 1 75 | self.word_to_id[token] = self.highest_word_id 76 | self.id_to_word[self.highest_word_id] = token 77 | if self.persist: 78 | self.metadata.write( 79 | 'word_ids.csv', 80 | '%s,%s\n' % (self.highest_word_id, token) 81 | ) 82 | tokens.append(self.word_to_id[token]) 83 | return tokens 84 | 85 | def detokenize(self, numpy_data): 86 | ret = '' 87 | for token in numpy_data: 88 | word = self.id_to_word[token] 89 | if word == "n't": 90 | ret += word 91 | elif word not in string.punctuation and not word.startswith("'"): 92 | ret += " " + word 93 | else: 94 | ret += word 95 | return "".join(ret).strip() 96 | -------------------------------------------------------------------------------- /src/data/midi_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | """A module for the MIDI dataset loader. 3 | """ 4 | import logging 5 | import string 6 | import collections 7 | import math 8 | from operator import itemgetter 9 | 10 | import pretty_midi 11 | 12 | from data.base_loader import Loader 13 | 14 | _SUSTAIN_ON = 0 15 | _SUSTAIN_OFF = 1 16 | _NOTE_ON = 2 17 | _NOTE_OFF = 3 18 | 19 | NOTE_ON = 1 20 | NOTE_OFF = 2 21 | TIME_SHIFT = 3 22 | VELOCITY = 4 23 | 24 | MAX_SHIFT_STEPS = 100 25 | 26 | MIN_MIDI_VELOCITY = 1 27 | MAX_MIDI_VELOCITY = 127 28 | 29 | MIN_MIDI_PROGRAM = 0 30 | MAX_MIDI_PROGRAM = 127 31 | PROGRAMS_PER_FAMILY = 8 32 | 33 | 34 | log = logging.getLogger("few-shot") 35 | 36 | 37 | class MIDILoader(Loader): 38 | """Objects of this class parse MIDI files into a sequence of note IDs 39 | """ 40 | def read(self, filepath): 41 | """Reads a MIDI file. 42 | 43 | Arguments: 44 | filepath (str): path to the lyrics file. e.g. 45 | "/home/user/freemidi_data/Tool/lateralus.mid" 46 | 47 | """ 48 | return pretty_midi.PrettyMIDI(filepath) 49 | 50 | def is_song(self, filepath): 51 | return filepath.endswith('.mid') 52 | 53 | def get_num_tokens(self): 54 | """Get total number of possible MIDI tokens. 55 | 56 | These are: 128 on/off notes for each of 16 instruments, 57 | 32 velocity buckets for each of 16 instruments, 58 | and 100 for different time-shifts. 59 | """ 60 | return 16 * 128 * 2 + 32 * 16 + 100 61 | 62 | def tokenize(self, midi): 63 | """Turns a MIDI file into a list of event IDs. 64 | 65 | Arguments: 66 | filepath (str): path to the lyrics file. e.g. 67 | "/home/user/freemidi_data/Tool/lateralus.mid" 68 | """ 69 | tokens = [] 70 | midi_notes = get_notes(midi) 71 | midi_control_changes = get_control_changes(midi) 72 | midi_notes = apply_sustain_control_changes(midi_notes, midi_control_changes) 73 | midi_notes = quantize_notes(midi_notes) 74 | no_drum_notes = remove_drums(midi_notes) 75 | no_clash_notes = resolve_pitch_clashes(no_drum_notes) 76 | events = get_event_list(no_clash_notes) 77 | for event_type, event_value, family in events: 78 | if event_type == NOTE_ON: 79 | token = family * 128 + event_value 80 | elif event_type == NOTE_OFF: 81 | token = 16 * 128 + family * 128 + event_value 82 | elif event_type == VELOCITY: 83 | token = 16 * 128 * 2 + 32 * family + event_value 84 | elif event_type == TIME_SHIFT: 85 | # subtract one, because TIME_SHIFT event values are 1-indexed 86 | token = 16 * 128 * 2 + 32 * 16 + event_value - 1 87 | tokens.append(token) 88 | return tokens 89 | 90 | def detokenize(self, numpy_data): 91 | current_time = 0 92 | current_velocity = [16 for _ in range(16)] 93 | unsorted_notes = [[] for _ in range(16)] 94 | active_notes = [[None for _ in range(128)] for _ in range(16)] 95 | for token in numpy_data: 96 | if token < 16 * 128: 97 | instr_class = token // 128 98 | note_number = token % 128 99 | active_notes[instr_class][note_number] = (current_velocity[instr_class], current_time) 100 | elif token < 16 * 128 * 2: 101 | instr_class = (token-16*128) // 128 102 | pitch = (token-16*128) % 128 103 | if active_notes[instr_class][pitch] is not None: 104 | (velocity, start_time) = active_notes[instr_class][pitch] 105 | unsorted_notes[instr_class].append((start_time, current_time, pitch, velocity)) 106 | active_notes[instr_class][pitch] = None 107 | elif token < 16 * 128 * 2 + 32 * 16: 108 | instr_class = (token-16*128*2) // 32 109 | velocity = (token-16*128*2) % 32 110 | current_velocity[instr_class] = velocity 111 | else: 112 | current_time += (token-16*128*2-32*16+1) 113 | 114 | midi = pretty_midi.PrettyMIDI() 115 | for instr_class, instr_notes in enumerate(unsorted_notes): 116 | instr_notes.sort() 117 | if instr_notes != []: 118 | instr = pretty_midi.Instrument(program=(instr_class*8)) 119 | for (start_time, end_time, pitch, velocity) in instr_notes: 120 | note = pretty_midi.Note( 121 | start=0.01*start_time, 122 | end=0.01*end_time, 123 | pitch=pitch, 124 | velocity=velocity*4 125 | ) 126 | instr.notes.append(note) 127 | midi.instruments.append(instr) 128 | return midi 129 | 130 | 131 | def resolve_pitch_clashes(midi_notes): 132 | """This function resolve note conflicts resulting from merging instruments 133 | of the same class. 134 | 135 | MIDI specifies 16 instrument classes, with 8 instruments per class. For 136 | this project, we merge together all instruments for a class into a single 137 | instrument. This can create issues if you have multiple instruments of the 138 | same class in the same song (e.g. two electric guitars, or a viola and a 139 | violin). The conflict occurs when you have two instruments play the same 140 | note at the same time. For example, if you have guitar 1 begin note 55 at 141 | time-step 100, and then guitar 2 begin note 55 at time-step 101, and then 142 | guitar 1 end note 55 at time-step 102, it becomes unclear how to represent 143 | that as a single instrument. Does note 55 end at the guitar 1 note end 144 | event? Or does it wait until guitar 2 ends? Does it play the overlapping 145 | notes as a single note, or does it try to split it into two notes somehow? 146 | 147 | This code solves that problem by allowing the first note to finish. If the 148 | duration of the second note extends beyond the duration of the first, the 149 | remaining duration will be played after the first note ends. 150 | 151 | Arguments: 152 | midi_notes ([(int, int, int, pretty_midi.Note)]): a tuple list of 153 | information on the MIDI notes of all instruments in a song. The 154 | first element is the quantized start time of the note. The second 155 | element is the quantized end time of the note. The third element is 156 | the instrument number (refer to the General MIDI spec for more 157 | info). 158 | """ 159 | num_program_families = int((MAX_MIDI_PROGRAM - MIN_MIDI_PROGRAM + 1) / \ 160 | PROGRAMS_PER_FAMILY) 161 | no_clash_notes = [] 162 | active_pitch_notes = {} 163 | 164 | sorted_notes = sorted(midi_notes, key=lambda element: element[0:3]) 165 | for program_family in range(num_program_families): 166 | active_pitch_notes[program_family + 1] = [] 167 | 168 | for quantized_start, quantized_end, program, midi_note in sorted_notes: 169 | program_family = (program - MIN_MIDI_PROGRAM) // PROGRAMS_PER_FAMILY + 1 170 | new_active_pitch_notes = [(pitch, end) for pitch, end 171 | in active_pitch_notes[program_family] if end > quantized_start] 172 | active_pitch_notes[program_family] = new_active_pitch_notes 173 | note_pitch = midi_note.pitch 174 | max_end = 0 175 | for pitch, end in active_pitch_notes[program_family]: 176 | if pitch == note_pitch and end > max_end: 177 | max_end = end 178 | if max_end >= quantized_end: 179 | continue 180 | quantized_start = max(quantized_start, max_end) 181 | active_pitch_notes[program_family].append((note_pitch, quantized_end)) 182 | no_clash_notes.append((quantized_start, quantized_end, program, midi_note)) 183 | 184 | return no_clash_notes 185 | 186 | def remove_drums(midi_notes): 187 | """Removes all drum notes from a sequence of MIDI notes. 188 | 189 | Argument: 190 | midi_notes ([(int, int, bool, int, int, pretty_midi.Note)]): a list 191 | containing tuples of info on each MIDI note. The first and second 192 | elements are discarded. The third element is a boolean representing 193 | if the note is a drum or not. The fourth and fifth are the start 194 | and end time respectively. The last is the note. 195 | """ 196 | return [(start, end, program, midi_note) 197 | for program, instrument, is_drum, start, end, midi_note 198 | in midi_notes if not is_drum] 199 | 200 | def get_event_list(midi_notes, num_velocity_bins=32): 201 | """Transforms a sequence of MIDI notes into a sequence of events. 202 | 203 | Arguments: 204 | midi_notes ([(int, int, int, pretty_midi.Note)]): A list containing 205 | info on each MIDI note. 206 | num_velocity_bins (int): the number of bins to split the velocity 207 | into. The MIDI standardizes on 128 possible values (0-127) but 208 | we bucket subranges together to reduce the dimensionality. Must 209 | evenly divide into 128. 210 | """ 211 | note_on_set = [] 212 | note_off_set = [] 213 | for index, element in enumerate(midi_notes): 214 | quantized_start, quantized_end, program, midi_note = element 215 | note_on_set.append((quantized_start, index, program, False)) 216 | note_off_set.append((quantized_end, index, program, True)) 217 | note_events = sorted(note_on_set + note_off_set) 218 | 219 | velocity_bin_size = int(math.ceil( 220 | (MAX_MIDI_VELOCITY - MIN_MIDI_VELOCITY + 1) / num_velocity_bins)) 221 | num_program_families = int((MAX_MIDI_PROGRAM - MIN_MIDI_PROGRAM + 1) / \ 222 | PROGRAMS_PER_FAMILY) 223 | 224 | current_step = 0 225 | current_velocity_bin = {} 226 | for program_family in range(num_program_families): 227 | current_velocity_bin[program_family + 1] = 0 228 | events = [] 229 | 230 | for step, index, program, is_off in note_events: 231 | if step > current_step: 232 | while step > current_step + MAX_SHIFT_STEPS: 233 | events.append((TIME_SHIFT, MAX_SHIFT_STEPS, 0)) 234 | current_step += MAX_SHIFT_STEPS 235 | events.append((TIME_SHIFT, step-current_step, 0)) 236 | current_step = step 237 | 238 | note_velocity = midi_notes[index][3].velocity 239 | note_pitch = midi_notes[index][3].pitch 240 | velocity_bin = (note_velocity - MIN_MIDI_VELOCITY) // velocity_bin_size + 1 241 | program_family = (program - MIN_MIDI_PROGRAM) // PROGRAMS_PER_FAMILY + 1 242 | if not is_off and velocity_bin != current_velocity_bin[program_family]: 243 | current_velocity_bin[program_family] = velocity_bin 244 | # NOTE: velocity is set per-program-family, but that's not strictly 245 | # necessary. We could have set-velocity events set the velocity 246 | # for all instruments. This would change the required number of 247 | # set-velocity events, but it would reduce the dimensionality of 248 | # the encoding. 249 | events.append((VELOCITY, velocity_bin, program_family)) 250 | if not is_off: 251 | events.append((NOTE_ON, note_pitch, program_family)) 252 | if is_off: 253 | events.append((NOTE_OFF, note_pitch, program_family)) 254 | 255 | return events 256 | 257 | 258 | def quantize_notes(midi_notes, steps_per_second=100): 259 | """Quantize MIDI notes into integers. The unit represents a unit of time, 260 | determined by `steps_per_second`. 261 | 262 | midi_notes ([(int, int, bool, pretty_midi.Note)]): A list containing tuples 263 | of info describing individual MIDI notes in a song. The first element 264 | is the MIDI instrument number of the note. The second element is an 265 | identifier of the MIDI instrument, unique to all instruments within 266 | the song. The third element is a flag indicating whether the instrument 267 | is a drum or not. The last element is the note object. 268 | steps_per_second (int): The number of steps per second. Which each note 269 | gets rounded toward. 270 | """ 271 | new_midi_notes = [] 272 | 273 | for program, instrument, is_drum, midi_note in midi_notes: 274 | quantized_start = int(midi_note.start*steps_per_second + 0.5) 275 | quantized_end = int(midi_note.end*steps_per_second + 0.5) 276 | if quantized_start == quantized_end: 277 | quantized_end = quantized_end + 1 278 | new_midi_notes.append((program, instrument, is_drum, quantized_start, 279 | quantized_end, midi_note)) 280 | 281 | return new_midi_notes 282 | 283 | 284 | def apply_sustain_control_changes(midi_notes, midi_control_changes, 285 | sustain_control_number=64): 286 | """Applies sustain to the MIDI notes by modifying the notes in-place. 287 | 288 | Normally, MIDI note start/end times simply describe e.g. when a piano key 289 | is pressed. It's possible that the sound from the note continues beyond 290 | the pressing of the note if a sustain on the instrument is active. The 291 | activity of sustain on MIDI instruments is determined by certain control 292 | events. This function alters the start/end time of MIDI notes with respect 293 | to the sustain control messages to mimic sustain. 294 | 295 | Arguments: 296 | midi_notes ([(int, int, bool, pretty_midi.Note)]): A list of tuples of 297 | info on each MIDI note. 298 | midi_control_changes ([(int, int, bool, pretty_midi.ControlChange)]): 299 | A list of tuples on each control change event. 300 | """ 301 | events = [] 302 | events.extend([(midi_note.start, _NOTE_ON, instrument, midi_note) for 303 | _1, instrument, _2, midi_note in midi_notes]) 304 | events.extend([(midi_note.end, _NOTE_OFF, instrument, midi_note) for 305 | _1, instrument, _2, midi_note in midi_notes]) 306 | 307 | for _1, instrument, _2, control_change in midi_control_changes: 308 | if control_change.number != sustain_control_number: 309 | continue 310 | value = control_change.value 311 | # MIDI spec specifies that >= 64 means ON and < 64 means OFF. 312 | if value >= 64: 313 | events.append((control_change.time, _SUSTAIN_ON, instrument, 314 | control_change)) 315 | if value < 64: 316 | events.append((control_change.time, _SUSTAIN_OFF, instrument, 317 | control_change)) 318 | 319 | events.sort(key=itemgetter(0)) 320 | 321 | active_notes = collections.defaultdict(list) 322 | sus_active = collections.defaultdict(lambda: False) 323 | 324 | time = 0 325 | for time, event_type, instrument, event in events: 326 | if event_type == _SUSTAIN_ON: 327 | sus_active[instrument] = True 328 | elif event_type == _SUSTAIN_OFF: 329 | sus_active[instrument] = False 330 | new_active_notes = [] 331 | for note in active_notes[instrument]: 332 | if note.end < time: 333 | note.end = time 334 | else: 335 | new_active_notes.append(note) 336 | active_notes[instrument] = new_active_notes 337 | elif event_type == _NOTE_ON: 338 | if sus_active[instrument]: 339 | new_active_notes = [] 340 | for note in active_notes[instrument]: 341 | if note.pitch == event.pitch: 342 | note.end = time 343 | if note.start == note.end: 344 | try: 345 | midi_notes.remove(note) 346 | except ValueError: 347 | continue 348 | else: 349 | new_active_notes.append(note) 350 | active_notes[instrument] = new_active_notes 351 | active_notes[instrument].append(event) 352 | elif event_type == _NOTE_OFF: 353 | if sus_active[instrument]: 354 | pass 355 | else: 356 | if event in active_notes[instrument]: 357 | active_notes[instrument].remove(event) 358 | 359 | for instrument in active_notes.values(): 360 | for note in instrument: 361 | note.end = time 362 | 363 | return midi_notes 364 | 365 | 366 | def get_control_changes(midi): 367 | """Retrieves a list of control change events from a given MIDI song. 368 | 369 | Arguments: 370 | midi (PrettyMIDI): The MIDI song. 371 | """ 372 | midi_control_changes = [] 373 | for num_instrument, midi_instrument in enumerate(midi.instruments): 374 | for midi_control_change in midi_instrument.control_changes: 375 | midi_control_changes.append(( 376 | midi_instrument.program, 377 | num_instrument, 378 | midi_instrument.is_drum, 379 | midi_control_change 380 | )) 381 | return midi_control_changes 382 | 383 | 384 | def get_notes(midi): 385 | """Retrieves a list of MIDI notes (for all instruments) given a MIDI song. 386 | 387 | Arguments: 388 | midi (PrettyMIDI): The MIDI song. 389 | """ 390 | midi_notes = [] 391 | for num_instrument, midi_instrument in enumerate(midi.instruments): 392 | for midi_note in midi_instrument.notes: 393 | midi_notes.append(( 394 | midi_instrument.program, 395 | num_instrument, 396 | midi_instrument.is_drum, 397 | midi_note 398 | )) 399 | return midi_notes 400 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-ON/Few-Shot-Music-Generation/dcc3709db41113761614508c681a0196ea6011c7/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-ON/Few-Shot-Music-Generation/dcc3709db41113761614508c681a0196ea6011c7/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BaseModel(object): 5 | 6 | def __init__(self, config): 7 | self._config = config 8 | 9 | @property 10 | def name(self): 11 | return self._config['name'] 12 | 13 | def train(self, episode): 14 | """Train model on episode. 15 | 16 | Args: 17 | episode: Episode object containing support and query set. 18 | """ 19 | raise NotImplementedError() 20 | 21 | def eval(self, episode): 22 | """Evaluate model on episode. 23 | 24 | Args: 25 | episode: Episode object containing support and query set. 26 | """ 27 | raise NotImplementedError() 28 | 29 | def sample(self, support_set, num): 30 | """Sample a sequence of size num conditioned on support_set. 31 | 32 | Args: 33 | support_set (numpy array): support set to condition the sample. 34 | num: size of sequence to sample. 35 | """ 36 | raise NotImplementedError() 37 | 38 | def save(self, checkpt_path): 39 | """Save model's current parameters at checkpt_path. 40 | 41 | Args: 42 | checkpt_path (string): path where to save parameters. 43 | """ 44 | raise NotImplementedError() 45 | 46 | def recover_or_init(self, init_path): 47 | """Recover or initialize model based on init_path. 48 | 49 | If init_path has appropriate model parameters, load them; otherwise, 50 | initialize parameters randomly. 51 | Args: 52 | init_path (string): path from where to load parameters. 53 | """ 54 | raise NotImplementedError() 55 | 56 | 57 | def flatten_first_two_dims(token_array): 58 | """Convert shape from [B,S,N] => [BxS,N].""" 59 | shape = token_array.shape 60 | return np.reshape(token_array, (shape[0] * shape[1], shape[2])) 61 | 62 | 63 | def convert_tokens_to_input_and_target(token_array, start_word=None): 64 | """Convert token_array to input and target to use for model for 65 | sequence generation. 66 | 67 | If start_word is given, add to start of each sequence of tokens. 68 | Input is token_array without last item; Target is token_array without first item. 69 | 70 | Arguments: 71 | token_array (numpy int array): tokens array of size [B,S,N] where 72 | B is batch_size, S is number of songs, N is size of the song 73 | start_word (int): token to use for start word 74 | """ 75 | X = flatten_first_two_dims(token_array) 76 | 77 | if start_word is None: 78 | Y = np.copy(X[:, 1:]) 79 | X_new = X[:, :-1] 80 | else: 81 | Y = np.copy(X) 82 | start_word_column = np.full( 83 | shape=[np.shape(X)[0], 1], fill_value=start_word) 84 | X_new = np.concatenate([start_word_column, X[:, :-1]], axis=1) 85 | 86 | return X_new, Y 87 | -------------------------------------------------------------------------------- /src/models/lstm_baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from models.tf_model import TFModel 5 | from models.base_model import convert_tokens_to_input_and_target 6 | 7 | 8 | class LSTMBaseline(TFModel): 9 | """LSTM language model 10 | 11 | Trained on songs from the meta-training set. During evaluation, 12 | ignore each episode's support set and evaluate only on query set. 13 | """ 14 | 15 | def __init__(self, config): 16 | super(LSTMBaseline, self).__init__(config) 17 | 18 | def _define_placedholders(self): 19 | # Add start word that starts every song 20 | # Adding start word increases the size of vocabulary by 1 21 | self._start_word = self._config['input_size'] 22 | self._input_size = self._config['input_size'] + 1 23 | 24 | self._time_steps = self._config['max_len'] 25 | self._embd_size = self._config['embedding_size'] 26 | self._hidden_size = self._config['hidden_size'] 27 | self._n_layers = self._config['n_layers'] 28 | self._lr = self._config['lr'] 29 | self._max_grad_norm = self._config['max_grad_norm'] 30 | 31 | self._batch_size = tf.placeholder(tf.int32, shape=()) 32 | self._seq_length = tf.placeholder(tf.int32, [None]) 33 | self._words = tf.placeholder( 34 | tf.int32, [None, self._time_steps]) 35 | self._target = tf.placeholder( 36 | tf.int32, [None, self._time_steps]) 37 | 38 | def _build_graph(self): 39 | embedding = tf.get_variable( 40 | 'embedding', [self._input_size, self._embd_size]) 41 | inputs = tf.nn.embedding_lookup(embedding, self._words) 42 | inputs = tf.unstack(inputs, axis=1) 43 | 44 | def make_cell(): 45 | return tf.contrib.rnn.BasicLSTMCell( 46 | self._hidden_size, forget_bias=1., state_is_tuple=True) 47 | 48 | self._cell = tf.contrib.rnn.MultiRNNCell( 49 | [make_cell() for _ in range(self._n_layers)]) 50 | self._initial_state = self._cell.zero_state( 51 | self._batch_size, dtype=tf.float32) 52 | outputs, state = tf.nn.static_rnn( 53 | self._cell, inputs, initial_state=self._initial_state, 54 | sequence_length=self._seq_length) 55 | self._state = state 56 | 57 | output = tf.concat(outputs, 1) 58 | self._output = tf.reshape(output, [-1, self._hidden_size]) 59 | 60 | softmax_w = tf.get_variable( 61 | 'softmax_w', [self._hidden_size, self._input_size]) 62 | softmax_b = tf.get_variable('softmax_b', [self._input_size]) 63 | # Reshape logits to be a 3-D tensor for sequence loss 64 | logits = tf.nn.xw_plus_b(self._output, softmax_w, softmax_b) 65 | logits = tf.reshape( 66 | logits, [self._batch_size, self._time_steps, self._input_size]) 67 | self._logits = logits 68 | self._prob = tf.nn.softmax(self._logits) 69 | 70 | self._avg_neg_log = tf.contrib.seq2seq.sequence_loss( 71 | logits, 72 | self._target, 73 | tf.ones([self._batch_size, self._time_steps], dtype=tf.float32), 74 | average_across_timesteps=True, 75 | average_across_batch=True) 76 | 77 | lr = tf.train.exponential_decay( 78 | self._lr, 79 | self._global_step, 80 | self._config['n_decay'], 0.5, staircase=False 81 | ) 82 | optimizer = tf.train.AdamOptimizer(lr) 83 | grads, _ = tf.clip_by_global_norm(tf.gradients(self._avg_neg_log, 84 | self.get_vars()), 85 | self._max_grad_norm) 86 | self._train_op = optimizer.apply_gradients(zip(grads, self.get_vars()), 87 | self._global_step) 88 | 89 | def train(self, episode): 90 | """Concatenate query and support sets to train.""" 91 | X, Y = convert_tokens_to_input_and_target( 92 | episode.support, self._start_word) 93 | X2, Y2 = convert_tokens_to_input_and_target( 94 | episode.query, self._start_word) 95 | X = np.concatenate([X, X2]) 96 | Y = np.concatenate([Y, Y2]) 97 | 98 | feed_dict = {} 99 | feed_dict[self._words] = X 100 | feed_dict[self._target] = Y 101 | feed_dict[self._batch_size] = np.shape(X)[0] 102 | feed_dict[self._seq_length] = [np.shape(X)[1]] * np.shape(X)[0] 103 | 104 | _, loss = self._sess.run([self._train_op, self._avg_neg_log], 105 | feed_dict=feed_dict) 106 | if self._summary_writer: 107 | summary = tf.Summary(value=[ 108 | tf.Summary.Value(tag='Train/loss', 109 | simple_value=loss)]) 110 | self._summary_writer.add_summary(summary, self._train_calls) 111 | self._train_calls += 1 112 | 113 | return loss 114 | 115 | def eval(self, episode): 116 | """Ignore support set and evaluate only on query set.""" 117 | X, Y = convert_tokens_to_input_and_target( 118 | episode.query, self._start_word) 119 | 120 | feed_dict = {} 121 | feed_dict[self._words] = X 122 | feed_dict[self._target] = Y 123 | feed_dict[self._batch_size] = np.shape(X)[0] 124 | feed_dict[self._seq_length] = [np.shape(X)[1]] * np.shape(X)[0] 125 | avg_neg_log = self._sess.run(self._avg_neg_log, feed_dict=feed_dict) 126 | if self._summary_writer is not None: 127 | summary = tf.Summary(value=[ 128 | tf.Summary.Value(tag='Eval/Avg_NLL', 129 | simple_value=avg_neg_log)]) 130 | self._summary_writer.add_summary(summary, self._eval_calls) 131 | self._eval_calls += 1 132 | 133 | return avg_neg_log 134 | 135 | def sample(self, support_set, num): 136 | """Ignore support set for sampling.""" 137 | pred_words = [] 138 | word = self._start_word 139 | 140 | state = self._sess.run(self._cell.zero_state(1, tf.float32)) 141 | x = np.zeros((1, self._time_steps)) 142 | for i in range(num): 143 | x[0, 0] = word 144 | feed_dict = {} 145 | feed_dict[self._words] = x 146 | feed_dict[self._batch_size] = 1 147 | feed_dict[self._seq_length] = [1] 148 | feed_dict[self._initial_state] = state 149 | 150 | probs, state = self._sess.run([self._prob, self._state], 151 | feed_dict=feed_dict) 152 | p = probs[0][0] 153 | word = np.argmax(p) 154 | pred_words.append(word) 155 | 156 | return pred_words 157 | -------------------------------------------------------------------------------- /src/models/tf_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import pprint 4 | 5 | from models.base_model import BaseModel 6 | 7 | PP = pprint.PrettyPrinter(depth=6) 8 | 9 | 10 | def start_session(): 11 | sess_config = tf.ConfigProto(allow_soft_placement=True) 12 | sess_config.gpu_options.allow_growth = True 13 | return tf.Session(config=sess_config) 14 | 15 | 16 | def init_vars_op(sess): 17 | variables = tf.global_variables() 18 | init_flag = sess.run( 19 | tf.stack([tf.is_variable_initialized(v) for v in variables])) 20 | uninit_variables = [v for v, f in zip(variables, init_flag) if not f] 21 | 22 | print('Initializing vars:') 23 | print(PP.pformat([v.name for v in uninit_variables])) 24 | 25 | return tf.variables_initializer(uninit_variables) 26 | 27 | 28 | def optimistic_restore(session, save_file, only_load_trainable_vars=False, 29 | stat_name_prefix='moving_'): 30 | """Restore variables of model from save_file. 31 | 32 | Argument trainable_vars is used to determine whether to fetch model 33 | variables & batch-norm statistics OR whether to fetch ALL variables 34 | (includes model variables, batch-norm statistics, training variables ~ 35 | such as global step and learning rate decay). 36 | 37 | Args: 38 | session: tf.Session to use in recovery 39 | save_file: file to load variables from 40 | trainable_vars: to recover only variables that are trained (model 41 | variables and running batch-norm statistics) or not (all variables 42 | including global step) 43 | """ 44 | reader = tf.train.NewCheckpointReader(save_file) 45 | saved_shapes = reader.get_variable_to_shape_map() 46 | if only_load_trainable_vars: 47 | var_names = sorted([(var.name, var.name.split(':')[0]) 48 | for var in tf.trainable_variables() 49 | if var.name.split(':')[0] in saved_shapes]) 50 | running_stat_names = sorted([(var.name, var.name.split(':')[0]) 51 | for var in tf.global_variables() 52 | if (var.name.split(':')[0] 53 | in saved_shapes and 54 | stat_name_prefix in var.name)]) 55 | var_names += running_stat_names 56 | else: 57 | var_names = sorted([(var.name, var.name.split(':')[0]) 58 | for var in tf.global_variables() 59 | if var.name.split(':')[0] in saved_shapes]) 60 | 61 | print('Loading vars:') 62 | print(PP.pformat(var_names)) 63 | 64 | restore_vars = [] 65 | name2var = dict(zip(map(lambda x: x.name.split( 66 | ':')[0], tf.global_variables()), tf.global_variables())) 67 | 68 | with tf.variable_scope('', reuse=True): 69 | for var_name, saved_var_name in var_names: 70 | curr_var = name2var[saved_var_name] 71 | var_shape = curr_var.get_shape().as_list() 72 | if var_shape == saved_shapes[saved_var_name]: 73 | restore_vars.append(curr_var) 74 | saver = tf.train.Saver(restore_vars) 75 | saver.restore(session, save_file) 76 | 77 | 78 | class TFModel(BaseModel): 79 | 80 | def __init__(self, config): 81 | tf.set_random_seed(config['seed']) 82 | 83 | super(TFModel, self).__init__(config) 84 | self._summary_writer = None 85 | if 'checkpt_dir' in config: 86 | self._summary_writer = tf.summary.FileWriter(config['checkpt_dir']) 87 | self._train_calls = 0 88 | self._eval_calls = 0 89 | self._sess = start_session() 90 | 91 | with tf.variable_scope(self.name): 92 | self._global_step = tf.Variable(0, trainable=False) 93 | self._define_placedholders() 94 | self._build_graph() 95 | 96 | self._saver = tf.train.Saver(self.get_vars(only_trainable=False), 97 | max_to_keep=10) 98 | 99 | def get_vars(self, name=None, only_trainable=True): 100 | name = name or self.name 101 | if only_trainable: 102 | return [v for v in tf.trainable_variables() if name in v.name] 103 | else: 104 | return [v for v in tf.global_variables() if name in v.name] 105 | 106 | def _get_checkpt_prefix(self, checkpt_path): 107 | directory = os.path.join(checkpt_path, self.name) 108 | if not os.path.exists(directory): 109 | os.makedirs(directory) 110 | return os.path.join(directory, self.name) 111 | 112 | def save(self, checkpt_path): 113 | self._saver.save(self._sess, self._get_checkpt_prefix(checkpt_path), 114 | global_step=self._global_step) 115 | 116 | def _recover(self, checkpt_path, only_load_trainable_vars): 117 | latest_checkpt = tf.train.latest_checkpoint( 118 | os.path.join(checkpt_path, self.name) 119 | ) 120 | if latest_checkpt is None: 121 | return False 122 | 123 | print('recovering %s from %s' % (self.name, latest_checkpt)) 124 | optimistic_restore(self._sess, latest_checkpt, only_load_trainable_vars) 125 | return True 126 | 127 | def recover_or_init(self, checkpt_path, only_load_trainable_vars=False): 128 | self._recover(checkpt_path, only_load_trainable_vars) 129 | self._sess.run(init_vars_op(self._sess)) 130 | -------------------------------------------------------------------------------- /src/models/unigram_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from models.tf_model import TFModel 5 | from models.base_model import convert_tokens_to_input_and_target 6 | 7 | 8 | class UnigramModel(TFModel): 9 | """Unigram model that uses word frequencies to compute word probabilities. 10 | 11 | Use meta-training set to approximate word frequencies. During 12 | evaluation, ignore each episode's support set and evaluate only on query set. 13 | """ 14 | 15 | def __init__(self, config): 16 | super(UnigramModel, self).__init__(config) 17 | 18 | def _define_placedholders(self): 19 | self._input_size = self._config['input_size'] 20 | self._time_steps = self._config['max_len'] 21 | 22 | self._words = tf.placeholder( 23 | tf.int32, [None, self._time_steps - 1]) 24 | self._alpha = 1 25 | 26 | def _build_graph(self): 27 | word_count = tf.get_variable( 28 | 'word_count', [self._input_size], 29 | initializer=tf.constant_initializer(self._alpha), 30 | trainable=False) 31 | flatten_words = tf.reshape(self._words, [-1]) 32 | ones = tf.ones_like(flatten_words, dtype=tf.float32) 33 | self._train_op = tf.scatter_add(word_count, flatten_words, ones) 34 | 35 | sum_ = tf.reduce_sum(word_count) 36 | self._prob = tf.gather(word_count, flatten_words) / sum_ 37 | self._avg_neg_log = -tf.reduce_mean(tf.log(self._prob)) 38 | 39 | self._prob_all = word_count / sum_ 40 | 41 | def train(self, episode): 42 | """Concatenate query and support sets to train.""" 43 | X, Y = convert_tokens_to_input_and_target( 44 | episode.support) 45 | X2, Y2 = convert_tokens_to_input_and_target( 46 | episode.query) 47 | X = np.concatenate([X, X2]) 48 | 49 | feed_dict = {} 50 | feed_dict[self._words] = X 51 | 52 | _, loss = self._sess.run([self._train_op, self._avg_neg_log], 53 | feed_dict=feed_dict) 54 | 55 | return loss 56 | 57 | def eval(self, episode): 58 | """Ignore support set and evaluate only on query set.""" 59 | query_set = episode.query 60 | X, Y = convert_tokens_to_input_and_target(query_set) 61 | 62 | feed_dict = {} 63 | feed_dict[self._words] = Y 64 | avg_neg_log = self._sess.run(self._avg_neg_log, 65 | feed_dict=feed_dict) 66 | 67 | return avg_neg_log 68 | 69 | def sample(self, support_set, num): 70 | """Ignore support set for sampling.""" 71 | pred_words = [] 72 | 73 | for i in range(num): 74 | prob = self._sess.run(self._prob_all) 75 | word = np.argmax(prob) 76 | pred_words.append(word) 77 | 78 | return pred_words 79 | -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-ON/Few-Shot-Music-Generation/dcc3709db41113761614508c681a0196ea6011c7/src/train/__init__.py -------------------------------------------------------------------------------- /src/train/test_seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import argparse 4 | import yaml 5 | 6 | from data.episode import load_sampler_from_config 7 | from train import load_model_from_config 8 | 9 | PP = pprint.PrettyPrinter(depth=6) 10 | 11 | parser = argparse.ArgumentParser(description='Train a model.') 12 | parser.add_argument('--data', dest='data', default='') 13 | parser.add_argument('--task', dest='task', default='') 14 | args = parser.parse_args() 15 | 16 | 17 | def main(): 18 | print('Args:') 19 | print(PP.pformat(vars(args))) 20 | 21 | config = yaml.load(open(args.data, 'r')) 22 | config.update(yaml.load(open(args.task, 'r'))) 23 | config.update(yaml.load(open('config/lstm_baseline_test_seed.yaml', 'r'))) 24 | config['dataset_path'] = os.path.abspath(config['dataset_path']) 25 | print('Config:') 26 | print(PP.pformat(config)) 27 | 28 | episode_sampler = {} 29 | config['split'] = 'train' 30 | episode_sampler['train'] = load_sampler_from_config(config) 31 | 32 | config['input_size'] = episode_sampler['train'].get_num_unique_words() 33 | if not config['input_size'] > 0: 34 | raise RuntimeError( 35 | 'error reading data: %d unique tokens processed' % config['input_size']) 36 | print('Num unique words: %d' % config['input_size']) 37 | 38 | model = load_model_from_config(config) 39 | model.recover_or_init('') 40 | 41 | ############################################################################ 42 | # Run test to compare loss on training set after n updates to what we expect 43 | ############################################################################ 44 | 45 | EXP_LOSS = { 46 | 'config/midi.yaml' + 'config/5shot.yaml': 6.1458707, 47 | 'config/lyrics.yaml' + 'config/5shot.yaml': 7.1389594 48 | } 49 | EPSILON = 0.001 50 | N_UPDATES = 10 51 | ERROR_MSG = """ Test failed: there is an issue with the seeding as model 52 | loss is different from what we expect """ 53 | 54 | loss = 0. 55 | for i in range(0, N_UPDATES): 56 | episode = episode_sampler['train'].get_episode() 57 | loss = model.train(episode) 58 | 59 | k = args.data + args.task 60 | if k not in EXP_LOSS: 61 | raise RuntimeError( 62 | 'No test for data: %s and task: %s' % (args.data, args.task)) 63 | 64 | expected_loss = EXP_LOSS[k] 65 | assert abs(loss - expected_loss) <= EPSILON, ERROR_MSG 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /src/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import argparse 4 | import yaml 5 | from importlib import import_module 6 | 7 | from data.episode import load_sampler_from_config 8 | 9 | PP = pprint.PrettyPrinter(depth=6) 10 | 11 | 12 | def load_model_from_config(config): 13 | Model = getattr(import_module(config['model_module_name']), 14 | config['model_class_name']) 15 | return Model(config) 16 | 17 | 18 | def write_seq(seq, dir, name): 19 | if isinstance(seq, str): 20 | text_file = open(os.path.join(dir, name + '.txt'), "w") 21 | text_file.write(seq) 22 | text_file.close() 23 | else: 24 | seq.write(os.path.join(dir, name + '.mid')) 25 | 26 | 27 | def evaluate(model, episode_sampler, n_episodes): 28 | avg_nll = 0. 29 | for i in range(n_episodes): 30 | episode = episode_sampler.get_episode() 31 | avg_nll += model.eval(episode) 32 | 33 | return avg_nll / n_episodes 34 | 35 | 36 | parser = argparse.ArgumentParser(description='Train a model.') 37 | parser.add_argument('--data', dest='data', default='') 38 | parser.add_argument('--model', dest='model', default='') 39 | parser.add_argument('--task', dest='task', default='') 40 | parser.add_argument('--checkpt_dir', dest='checkpt_dir', default='') 41 | parser.add_argument('--init_dir', dest='init_dir', default='') 42 | args = parser.parse_args() 43 | 44 | 45 | def main(): 46 | print('Args:') 47 | print(PP.pformat(vars(args))) 48 | 49 | config = yaml.load(open(args.data, 'r')) 50 | config.update(yaml.load(open(args.task, 'r'))) 51 | config.update(yaml.load(open(args.model, 'r'))) 52 | config['dataset_path'] = os.path.abspath(config['dataset_path']) 53 | config['checkpt_dir'] = args.checkpt_dir 54 | print('Config:') 55 | print(PP.pformat(config)) 56 | 57 | episode_sampler = {} 58 | for split in config['splits']: 59 | config['split'] = split 60 | episode_sampler[split] = load_sampler_from_config(config) 61 | 62 | config['input_size'] = episode_sampler['train'].get_num_unique_words() 63 | if not config['input_size'] > 0: 64 | raise RuntimeError( 65 | 'error reading data: %d unique tokens processed' % config['input_size']) 66 | print('Num unique words: %d' % config['input_size']) 67 | 68 | n_train = config['n_train'] 69 | print_every_n = config['print_every_n'] 70 | val_every_n = config['val_every_n'] 71 | n_val = config['n_val'] 72 | n_test = config['n_test'] 73 | n_samples = config['n_samples'] 74 | max_len = config['max_len'] 75 | 76 | model = load_model_from_config(config) 77 | model.recover_or_init(args.init_dir) 78 | 79 | # Train model and evaluate 80 | avg_nll = evaluate(model, episode_sampler['val'], n_val) 81 | print("Iter: %d, val-nll: %.3e" % (0, avg_nll)) 82 | 83 | avg_loss = 0. 84 | for i in range(1, n_train + 1): 85 | episode = episode_sampler['train'].get_episode() 86 | loss = model.train(episode) 87 | avg_loss += loss 88 | 89 | if i % val_every_n == 0: 90 | avg_nll = evaluate(model, episode_sampler['val'], n_val) 91 | print("Iter: %d, val-nll: %.3e" % (i, avg_nll)) 92 | 93 | if args.checkpt_dir != '': 94 | model.save(args.checkpt_dir) 95 | 96 | if i % print_every_n == 0: 97 | print("Iter: %d, loss: %.3e" % (i, avg_loss / print_every_n)) 98 | avg_loss = 0. 99 | 100 | # Evaluate model after training on training, validation, and test sets 101 | avg_nll = evaluate(model, episode_sampler['train'], n_test) 102 | print("Train Avg NLL: %.3e" % (avg_nll)) 103 | avg_nll = evaluate(model, episode_sampler['val'], n_test) 104 | print("Validation Avg NLL: %.3e" % (avg_nll)) 105 | avg_nll = evaluate(model, episode_sampler['test'], n_test) 106 | print("Test Avg NLL: %.3e" % (avg_nll)) 107 | 108 | # Generate samples from trained model for test episodes 109 | samples_dir = os.path.join(args.checkpt_dir, 'samples') 110 | if not os.path.exists(samples_dir): 111 | os.makedirs(samples_dir) 112 | 113 | for i in range(n_samples): 114 | curr_sample_dir = os.path.join(samples_dir, 'sample_%d' % i) 115 | os.makedirs(curr_sample_dir) 116 | 117 | episode = episode_sampler['test'].get_episode() 118 | support_set = episode.support[0] 119 | sample = model.sample(support_set, max_len) 120 | 121 | for j in range(support_set.shape[0]): 122 | write_seq(episode_sampler['test'].detokenize(support_set[j]), 123 | curr_sample_dir, 'support_%d' % j) 124 | 125 | write_seq(episode_sampler['test'].detokenize(sample), curr_sample_dir, 126 | 'model_sample') 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | --------------------------------------------------------------------------------