├── .gitignore ├── README.md ├── pytorchtextvae ├── __init__.py ├── datasets.py ├── generate.py ├── helpers.py ├── interpolate.py ├── model.py └── train.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.swo 4 | *.pt 5 | __pycache__ 6 | *.egg-info/ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A partial reimplementation of "Generating Sentences From a Continuous Space", Bowman, Vilnis, Vinyals, Dai, Jozefowicz, Bengio (). 2 | 3 | Based on code from Kyle Kastner (`@kastnerkyle`) , adapted to support the [`deephypebot`](https://github.com/iconix/deephypebot) project. 4 | 5 | --- 6 | 7 | Changes in this [detached fork](https://github.com/kastnerkyle/pytorch-text-vae/): 8 | - Update compatibility to Python 3 and PyTorch 0.4 9 | - Add `generate.py` for sampling 10 | - Add special support for JSON reading and thought vector conditioning 11 | - Some code cleanup 12 | - Add `setup.py` for package support as `pytorchtextvae` 13 | - Train/test data split support 14 | -------------------------------------------------------------------------------- /pytorchtextvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iconix/pytorch-text-vae/59fbd356717df7417add3902b7db9f9f7e8b413a/pytorchtextvae/__init__.py -------------------------------------------------------------------------------- /pytorchtextvae/datasets.py: -------------------------------------------------------------------------------- 1 | # Author: Nadja Rhodes 2 | # License: BSD 3-Clause 3 | # Modified from Kyle Kastner's example here: 4 | # https://github.com/kastnerkyle/pytorch-text-vae 5 | import time 6 | import os 7 | try: 8 | import Queue 9 | except ImportError: 10 | import queue as Queue 11 | import multiprocessing as mp 12 | import dill as pickle 13 | from enum import Enum 14 | 15 | import numpy as np 16 | import re 17 | import sys 18 | import unidecode 19 | import unicodedata 20 | import collections 21 | import pandas as pd 22 | 23 | SOS_token = 0 24 | EOS_token = 1 25 | UNK_token = 2 26 | N_CORE = 24 27 | 28 | class Condition(Enum): 29 | NONE = 0 30 | GENRE = 1 31 | AF = 2 # audio features 32 | 33 | class DataSplit: 34 | def __init__(self, filename, data_type): 35 | self.filename = filename 36 | self.data_type = data_type 37 | 38 | self.df = pd.read_json(self.filename) 39 | self.n_conditions = -1 40 | 41 | def __iter__(self): 42 | if self.data_type == Dataset.DataType.JSON: 43 | return self.read_json_gen() 44 | else: 45 | return self.read_file_line_gen() 46 | 47 | def read_file_line_gen(self): 48 | with open(self.filename) as f: 49 | for line in f: 50 | yield unidecode.unidecode(line) 51 | 52 | def encode_conditions(self, conditions): 53 | raise NotImplementedError 54 | 55 | def decode_conditions(self, tensor): 56 | raise NotImplementedError 57 | 58 | def read_json_gen(self): 59 | for i, row in self.df.iterrows(): 60 | for sent in row.content_sentences: 61 | yield sent 62 | 63 | class GenreDataSplit(DataSplit): 64 | def __init__(self, filename, data_type, condition_set=None): 65 | super(GenreDataSplit, self).__init__(filename, data_type) 66 | 67 | if condition_set: 68 | self.condition_set = condition_set 69 | else: 70 | self.condition_set = set([g for gg in self.df.spotify_genres for g in gg]) 71 | 72 | self.genre_to_idx = {unique_g: i for i, unique_g in enumerate(sorted(self.condition_set))} 73 | self.idx_to_genre = {i: unique_g for i, unique_g in enumerate(sorted(self.condition_set))} 74 | 75 | self.n_conditions = len(self.condition_set) + 1 76 | 77 | def encode_conditions(self, conditions): 78 | e = np.zeros(self.n_conditions) 79 | for g in conditions: 80 | if g in self.genre_to_idx: 81 | e[self.genre_to_idx[g]] = 1 82 | else: 83 | # for unknown genres 84 | e[len(e) - 1] = 1 85 | return e 86 | 87 | def decode_conditions(self, tensor): 88 | genres = [] 89 | for i, x in enumerate(tensor.squeeze()): 90 | if x.item() == 1: 91 | if i in self.idx_to_genre: 92 | genres.append(self.idx_to_genre[i]) 93 | else: 94 | genres.append('UNK') 95 | return genres 96 | 97 | def read_json_gen(self): 98 | for i, row in self.df.iterrows(): 99 | gs = self.encode_conditions(row.spotify_genres) 100 | for sent in row.content_sentences: 101 | yield sent, gs 102 | 103 | class AFDataSplit(DataSplit): 104 | def __init__(self, filename, data_type): 105 | super(AFDataSplit, self).__init__(filename, data_type) 106 | 107 | import json 108 | # all rows should have the same condition keys 109 | self.ignore_keys = ['analysis_url', 'duration_ms', 'id', 'track_href', 'type', 'uri'] 110 | condition_list = [k for (k, v) in sorted(json.loads(df.audio_features[0].replace("'", "\"")).items()) if k not in self.ignore_keys] 111 | 112 | self.n_conditions = len(condition_list) 113 | self.idx_to_af = {i: c for i, c in enumerate(condition_list)} 114 | 115 | def encode_conditions(self, conditions): 116 | return np.array([v for (k, v) in sorted(conditions.items()) if k not in self.ignore_keys]) 117 | 118 | def decode_conditions(self, tensor): 119 | afs = {} 120 | for i, x in enumerate(tensor.squeeze()): 121 | afs[self.idx_to_af[i]] = x.item() 122 | return afs 123 | 124 | def read_json_gen(self): 125 | for i, row in self.df.iterrows(): 126 | try: 127 | fs = self.encode_conditions(json.loads(row.audio_features.replace("'", "\""))) 128 | except json.decoder.JSONDecodeError: 129 | # TODO: why audio_features = None ever? 130 | fs = np.zeros(self.n_conditions) 131 | for sent in row.content_sentences: 132 | yield sent, fs 133 | 134 | class Dataset: 135 | class DataType(Enum): 136 | DEFAULT = 0 137 | JSON = 1 138 | 139 | def __init__(self, trn_path, test_path=None): 140 | if trn_path.endswith('.json'): 141 | self.data_type = Dataset.DataType.JSON 142 | else: 143 | self.data_type = Dataset.DataType.DEFAULT 144 | 145 | self.trn_split = DataSplit(trn_path, self.data_type) 146 | if test_path: 147 | self.test_split = DataSplit(test_path, self.data_type) 148 | else: 149 | self.test_split = None 150 | 151 | class GenreDataset(Dataset): 152 | def __init__(self, trn_path, test_path=None): 153 | super(GenreDataset, self).__init__(trn_path, test_path) 154 | 155 | self.trn_split = GenreDataSplit(trn_path, self.data_type) 156 | if test_path: 157 | self.test_split = GenreDataSplit(test_path, self.data_type, self.trn_split.condition_set) 158 | else: 159 | self.test_split = None 160 | 161 | class AFDataset(Dataset): 162 | def __init__(self, trn_path, test_path=None): 163 | super(AFDataset, self).__init__(trn_path, test_path) 164 | 165 | self.trn_split = AFDataSplit(trn_path, self.data_type) 166 | if test_path: 167 | self.test_split = AFDataSplit(test_path, self.data_type) 168 | else: 169 | self.test_split = None 170 | 171 | def get_mean_condition(self, pairs): 172 | if not hasattr(self, 'mean_condition'): 173 | conditions = np.array([p[2] for p in pairs]) 174 | self.mean_condition = np.mean(conditions, axis=0) 175 | 176 | return self.mean_condition 177 | 178 | 179 | norvig_list = None 180 | # http://norvig.com/ngrams/count_1w.txt 181 | # TODO: replace with spacy tokenization? or is it better to stick to common words? 182 | '''Things turned to UNK: 183 | - numbers 184 | ''' 185 | def get_vocabulary(tmp_path): 186 | global norvig_list 187 | global reverse_norvig_list 188 | if norvig_list == None: 189 | with open(os.path.join(tmp_path, "count_1w.txt")) as f: 190 | r = f.readlines() 191 | norvig_list = [tuple(ri.strip().split("\t")) for ri in r] 192 | return norvig_list 193 | 194 | 195 | # Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427 196 | def unicode_to_ascii(s): 197 | return ''.join( 198 | c for c in unicodedata.normalize(u'NFD', s) 199 | if unicodedata.category(c) != u'Mn' 200 | ) 201 | 202 | 203 | # Lowercase, trim, and remove non-letter characters 204 | def normalize_string(s): 205 | s = unicode_to_ascii(s.lower().strip()) 206 | s = re.sub(r"'", r"", s) 207 | s = re.sub(r"([.!?])", r" \1", s) 208 | #s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 209 | s = re.sub(r"[^\w]", r" ", s) 210 | s = re.sub(r"\s+", r" ", s).strip().lstrip().rstrip() 211 | return s 212 | 213 | 214 | class Lang: 215 | def __init__(self, name, tmp_path, vocabulary_size=-1, reverse=False): 216 | self.name = name 217 | if reverse: 218 | self.vocabulary = [w[::-1] for w in ["SOS", "EOS", "UNK"]] + [w[0][::-1] for w in get_vocabulary(tmp_path)] 219 | else: 220 | self.vocabulary = ["SOS", "EOS", "UNK"] + [w[0] for w in get_vocabulary(tmp_path)] 221 | 222 | if vocabulary_size < 0: 223 | vocabulary_size = len(self.vocabulary) 224 | 225 | self.reverse = reverse 226 | self.vocabulary_size = vocabulary_size 227 | if vocabulary_size < len(self.vocabulary): 228 | print(f"Trimming vocabulary size from {len(self.vocabulary)} to {vocabulary_size}") 229 | else: 230 | print(f"Vocabulary size: {vocabulary_size}") 231 | self.vocabulary = self.vocabulary[:vocabulary_size] 232 | self.word2index = {v: k for k, v in enumerate(self.vocabulary)} 233 | self.index2word = {v: k for k, v in self.word2index.items()} 234 | self.n_words = len(self.vocabulary) # Count SOS, EOS, UNK 235 | # dict.keys() do not pickle in Python 3.x - convert to list 236 | # https://groups.google.com/d/msg/pyomo-forum/XOf6zwvEbt4/ZfkbHzvDBgAJ 237 | self.words = list(self.word2index.keys()) 238 | self.indices = list(self.index2word.keys()) 239 | 240 | def index_to_word(self, index): 241 | try: 242 | return self.index2word[index.item()] 243 | except KeyError: 244 | return self.index2word[self.word2index[self.vocabulary[UNK_token]]] 245 | 246 | def word_to_index(self, word): 247 | try: 248 | return self.word2index[word.lower()] 249 | except KeyError: 250 | #print(f"[WARNING] {word.lower()}") 251 | return self.word2index[self.vocabulary[UNK_token]] 252 | 253 | def word_check(self, word): 254 | if word in self.word2index.keys(): 255 | return word 256 | else: 257 | return self.word2index[self.vocabulary[UNK_token]] 258 | 259 | def process_sentence(self, sentence, normalize=True): 260 | if normalize: 261 | s = normalize_string(sentence) 262 | else: 263 | s = sentence 264 | return " ".join([w if w in self.words else self.word2index[self.vocabulary[UNK_token]] for w in s.split(" ")]) 265 | 266 | def filter_pair(p): 267 | return MIN_LENGTH < len(p[0].split(' ')) < MAX_LENGTH and MIN_LENGTH < len(p[1].split(' ')) < MAX_LENGTH 268 | 269 | 270 | def process_input_side(s): 271 | return " ".join([WORDS[w] for w in s.split(" ")]) 272 | 273 | 274 | def process_output_side(s): 275 | return " ".join([REVERSE_WORDS[w] for w in s.split(" ")]) 276 | 277 | 278 | WORDS = None 279 | REVERSE_WORDS = None 280 | 281 | def unk_func(): 282 | return "UNK" 283 | 284 | def _get_line(data_type, elem): 285 | # JSON data can come with extra conditional info 286 | if data_type == Dataset.DataType.JSON and not isinstance(elem, str): 287 | line = elem[0] 288 | else: 289 | line = elem 290 | 291 | return line 292 | 293 | def _setup_vocab(trn_path, vocabulary_size, condition_on): 294 | global WORDS 295 | global REVERSE_WORDS 296 | wc = collections.Counter() 297 | if condition_on == Condition.GENRE: 298 | dataset = GenreDataset(trn_path) 299 | elif condition_on == Condition.AF: 300 | dataset = AFDataset(trn_path) 301 | else: 302 | dataset = Dataset(trn_path) 303 | for n, elem in enumerate(iter(dataset.trn_split)): 304 | if n % 100000 == 0: 305 | print("Fetching vocabulary from line {}".format(n)) 306 | print("Current word count {}".format(len(wc.keys()))) 307 | 308 | line = _get_line(dataset.data_type, elem) 309 | 310 | l = line.strip().lstrip().rstrip() 311 | if MIN_LENGTH < len(l.split(' ')) < MAX_LENGTH: 312 | l = normalize_string(l) 313 | WORDS = l.split(" ") 314 | wc.update(WORDS) 315 | else: 316 | continue 317 | 318 | the_words = ["SOS", "EOS", "UNK"] 319 | the_reverse_words = [w[::-1] for w in the_words] 320 | the_words += [wi[0] for wi in wc.most_common()[:vocabulary_size - 3]] 321 | the_reverse_words += [wi[0][::-1] for wi in wc.most_common()[:vocabulary_size - 3]] 322 | 323 | WORDS = collections.defaultdict(unk_func) 324 | REVERSE_WORDS = collections.defaultdict(unk_func) 325 | for k in range(len(the_words)): 326 | WORDS[the_words[k]] = the_words[k] 327 | REVERSE_WORDS[the_reverse_words[k]] = the_reverse_words[k] 328 | 329 | 330 | def proc_line(line, reverse): 331 | if len(line.strip()) == 0: 332 | return None 333 | else: 334 | l = line.strip().lstrip().rstrip() 335 | # try to bail as early as possible to minimize processing 336 | if MIN_LENGTH < len(l.split(' ')) < MAX_LENGTH: 337 | l = normalize_string(l) 338 | l2 = l 339 | pair = (l, l2) 340 | 341 | if filter_pair(pair): 342 | if reverse: 343 | pair = (l, "".join(list(reversed(l2)))) 344 | p0 = process_input_side(pair[0]) 345 | p1 = process_output_side(pair[1]) 346 | return (p0, p1) 347 | else: 348 | return None 349 | else: 350 | return None 351 | 352 | 353 | def process(q, oq, iolock): 354 | while True: 355 | stuff = q.get() 356 | if stuff is None: 357 | break 358 | r = [(proc_line(s[0], True), s[1]) if isinstance(s, tuple) else proc_line(s, True) for s in stuff] 359 | r = [ri for ri in r if ri != None and ri[0] != None] 360 | # flatten any tuples 361 | r = [ri[0] + (ri[1], ) if isinstance(ri[0], tuple) else ri for ri in r] 362 | if len(r) > 0: 363 | oq.put(r) 364 | 365 | def _setup_pairs(datasplit): 366 | print("Setting up queues") 367 | # some nasty multiprocessing 368 | # ~ 40 per second was the single core number 369 | q = mp.Queue(maxsize=1000000 * N_CORE) 370 | oq = mp.Queue(maxsize=1000000 * N_CORE) 371 | print("Queue setup complete") 372 | print("Getting lock") 373 | iolock = mp.Lock() 374 | print("Setting up pool") 375 | pool = mp.Pool(N_CORE, initializer=process, initargs=(q, oq, iolock)) 376 | print("Pool setup complete") 377 | 378 | start_time = time.time() 379 | pairs = [] 380 | last_empty = time.time() 381 | 382 | curr_block = [] 383 | block_size = 1000 384 | last_send = 0 385 | # takes ~ 30s to get a block done 386 | empty_wait = 2 387 | avg_time_per_block = 30 388 | status_every = 100000 389 | print("Starting block processing") 390 | 391 | for n, elem in enumerate(iter(datasplit)): 392 | curr_block.append(elem) 393 | if len(curr_block) > block_size: 394 | # this could block, oy 395 | q.put(curr_block) 396 | curr_block = [] 397 | 398 | if last_empty < time.time() - empty_wait: 399 | try: 400 | while True: 401 | with iolock: 402 | r = oq.get(block=True, timeout=.0001) 403 | pairs.extend(r) 404 | except: 405 | last_empty = time.time() 406 | if n % status_every == 0: 407 | with iolock: 408 | print("Queued line {}".format(n)) 409 | tt = time.time() - start_time 410 | print("Elapsed time {}".format(tt)) 411 | tl = len(pairs) 412 | print("Total lines {}".format(tl)) 413 | avg_time_per_block = max(30, block_size * (tt / (tl + 1))) 414 | print("Approximate lines / s {}".format(tl / tt)) 415 | # finish the queue 416 | q.put(curr_block) 417 | print("Finalizing line processing") 418 | for _ in range(N_CORE): # tell workers we're done 419 | q.put(None) 420 | empty_checks = 0 421 | prev_len = len(pairs) 422 | last_status = time.time() 423 | print("Total lines {}".format(len(pairs))) 424 | while True: 425 | if empty_checks > 10: 426 | break 427 | if status_every < (len(pairs) - prev_len) or last_status < time.time() - empty_wait: 428 | print("Total lines {}".format(len(pairs))) 429 | prev_len = len(pairs) 430 | last_status = time.time() 431 | if not oq.empty(): 432 | try: 433 | while True: 434 | with iolock: 435 | r = oq.get(block=True, timeout=.0001) 436 | pairs.extend(r) 437 | empty_checks = 0 438 | except: 439 | # Queue.Empty 440 | pass 441 | elif oq.empty(): 442 | empty_checks += 1 443 | time.sleep(empty_wait) 444 | print("Line processing complete") 445 | print("Final line count {}".format(len(pairs))) 446 | pool.close() 447 | pool.join() 448 | 449 | return pairs 450 | 451 | # https://stackoverflow.com/questions/43078980/python-multiprocessing-with-generator 452 | def prepare_pair_data(path, vocabulary_size, tmp_path, min_length, max_length, condition_on, reverse=False): 453 | global MIN_LENGTH 454 | global MAX_LENGTH 455 | MIN_LENGTH, MAX_LENGTH = min_length, max_length 456 | 457 | print("Reading lines...") 458 | print(f'MIN_LENGTH: {MIN_LENGTH}; MAX_LENGTH: {MAX_LENGTH}') 459 | 460 | if os.path.isdir(path): 461 | # assume folder contains separate train.json and test.json 462 | # TODO: would be cool not to assume .json format 463 | trn_path = os.path.join(path, 'train.json') 464 | test_path = os.path.join(path, 'test.json') 465 | else: 466 | trn_path = path 467 | test_path = None 468 | 469 | pkl_path = trn_path.split(os.sep)[-1].split(".")[0] + "_vocabulary.pkl" 470 | vocab_cache_path = os.path.join(tmp_path, pkl_path) 471 | global WORDS 472 | global REVERSE_WORDS 473 | if not os.path.exists(vocab_cache_path): 474 | print("Vocabulary cache {} not found".format(vocab_cache_path)) 475 | print("Prepping vocabulary") 476 | _setup_vocab(trn_path, vocabulary_size, condition_on) 477 | with open(vocab_cache_path, "wb") as f: 478 | pickle.dump((WORDS, REVERSE_WORDS), f) 479 | else: 480 | print("Vocabulary cache {} found".format(vocab_cache_path)) 481 | print("Loading...".format(vocab_cache_path)) 482 | with open(vocab_cache_path, "rb") as f: 483 | r = pickle.load(f) 484 | WORDS = r[0] 485 | REVERSE_WORDS = r[1] 486 | print("Vocabulary prep complete") 487 | 488 | if condition_on == Condition.GENRE: 489 | dataset = GenreDataset(trn_path, test_path) 490 | elif condition_on == Condition.AF: 491 | dataset = AFDataset(trn_path, test_path) 492 | else: 493 | dataset = Dataset(trn_path, test_path) 494 | 495 | # don't use these for processing, but pass for ease of use later on 496 | dataset.input_side = Lang("in", tmp_path, vocabulary_size) 497 | dataset.output_side = Lang("out", tmp_path, vocabulary_size, reverse) 498 | 499 | print("Pair preparation for train split") 500 | dataset.trn_pairs = _setup_pairs(dataset.trn_split) 501 | 502 | if dataset.test_split: 503 | print("Pair preparation for test split") 504 | dataset.test_pairs = _setup_pairs(dataset.test_split) 505 | else: 506 | dataset.test_pairs = None 507 | 508 | print("Pair preparation complete") 509 | return dataset 510 | -------------------------------------------------------------------------------- /pytorchtextvae/generate.py: -------------------------------------------------------------------------------- 1 | import dill as pickle 2 | import numpy as np 3 | import os 4 | from pathlib import Path 5 | import time 6 | import torch 7 | 8 | if __package__ is None or __package__ == '': 9 | # uses current directory visibility 10 | import model 11 | from datasets import EOS_token 12 | else: 13 | # uses current package visibility 14 | import pytorchtextvae.model as model 15 | from pytorchtextvae.datasets import EOS_token 16 | 17 | def load_model(saved_vae, stored_info, device, cache_path=str(Path('../tmp')), seed=None): 18 | stored_info = stored_info.split(os.sep)[-1] 19 | cache_file = os.path.join(cache_path, stored_info) 20 | 21 | start_load = time.time() 22 | print(f"Fetching cached info at {cache_file}") 23 | with open(cache_file, "rb") as f: 24 | dataset, z_size, condition_size, condition_on, decoder_hidden_size, encoder_hidden_size, n_encoder_layers = pickle.load(f) 25 | end_load = time.time() 26 | print(f"Cache {cache_file} loaded (load time: {end_load - start_load:.2f}s)") 27 | 28 | if os.path.exists(saved_vae): 29 | print(f"Found saved model {saved_vae}") 30 | start_load_model = time.time() 31 | 32 | e = model.EncoderRNN(dataset.input_side.n_words, encoder_hidden_size, z_size, n_encoder_layers, bidirectional=True) 33 | d = model.DecoderRNN(z_size, dataset.trn_split.n_conditions, condition_size, decoder_hidden_size, dataset.input_side.n_words, 1, word_dropout=0) 34 | vae = model.VAE(e, d).to(device) 35 | vae.load_state_dict(torch.load(saved_vae, map_location=lambda storage, loc: storage)) 36 | vae.eval() 37 | print(f"Trained for {vae.steps_seen} steps (load time: {time.time() - start_load_model:.2f}s)") 38 | 39 | print("Setting new random seed") 40 | if seed is None: 41 | # TODO: torch.manual_seed(1999) in model.py is affecting this 42 | new_seed = int(time.time()) 43 | new_seed = abs(new_seed) % 4294967295 # must be between 0 and 4294967295 44 | else: 45 | new_seed = seed 46 | torch.manual_seed(new_seed) 47 | 48 | random_state = np.random.RandomState(new_seed) 49 | #random_state.shuffle(dataset.trn_pairs) 50 | 51 | return vae, dataset, z_size, random_state 52 | 53 | def generate(vae, dataset, z_size, random_state, device, condition_inputs=None, max_length=50, num_sample=10, temp=0.75, print_z=False, clean_gen=False): 54 | gens = [] 55 | zs = [] 56 | conditions = [] 57 | 58 | if dataset.trn_split.n_conditions > -1 and condition_inputs is not None and not isinstance(condition_inputs, list): 59 | print(f'[WARNING] condition_inputs provided is of type "{type(condition_inputs).__name__}" but should be of type "list". Continuing with random condition_inputs...') 60 | 61 | for i in range(num_sample): 62 | z = torch.randn(z_size).unsqueeze(0).to(device) 63 | 64 | if dataset.trn_split.n_conditions > -1: 65 | if isinstance(condition_inputs, list): 66 | condition = torch.tensor(dataset.trn_split.encode_conditions(condition_inputs), dtype=torch.float).unsqueeze(0).to(device) 67 | else: 68 | condition = model.random_training_set(dataset, random_state, device)[2] 69 | else: 70 | condition = None 71 | 72 | generated = vae.decoder.generate(z, condition, max_length, temp, device) 73 | generated_str = model.float_word_tensor_to_string(dataset.output_side, generated) 74 | 75 | EOS_str = f' {dataset.output_side.index_to_word(torch.LongTensor([EOS_token]))} ' 76 | 77 | if generated_str.endswith(EOS_str): 78 | generated_str = generated_str[:-5] 79 | 80 | # flip it back 81 | generated_str = generated_str[::-1] 82 | 83 | if clean_gen: 84 | # remove 1) UNKs, 2) consecutive duplicated words 85 | gen_list = generated_str.replace('UNK', '').split() 86 | generated_str = ' '.join([v for i, v in enumerate(gen_list) if i == 0 or v != gen_list[i-1]]) 87 | 88 | print('---') 89 | if dataset.trn_split.n_conditions > -1: 90 | print(dataset.trn_split.decode_conditions(condition)) 91 | print(generated_str) 92 | gens.append(generated_str) 93 | zs.append(z) 94 | if dataset.trn_split.n_conditions > -1: 95 | conditions.append(condition) 96 | if print_z: 97 | print(z) 98 | 99 | return gens, zs, conditions 100 | 101 | def run(saved_vae, stored_info, cache_path=str(Path(f'..{os.sep}tmp')), condition_inputs=None, max_length=50, num_sample=10, seed=None, temp=0.75, 102 | use_cuda=True, print_z=False, clean_gen=False): 103 | 104 | args_passed = locals() 105 | print(args_passed) 106 | 107 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() and use_cuda else 'cpu') 108 | 109 | with torch.no_grad(): 110 | vae, dataset, z_size, random_state = load_model(saved_vae, stored_info, DEVICE, cache_path, seed) 111 | gens, zs, conditions = generate(vae, dataset, z_size, random_state, DEVICE, condition_inputs, max_length, num_sample, temp, print_z, clean_gen) 112 | 113 | if __name__ == "__main__": 114 | import fire; fire.Fire(run) 115 | -------------------------------------------------------------------------------- /pytorchtextvae/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/spro/char-rnn.pytorch 2 | 3 | import unidecode 4 | import string 5 | import random 6 | import time 7 | import math 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | USE_CUDA = True 12 | 13 | # Reading and un-unicode-encoding data 14 | SOS_token = 0 15 | EOS_token = 1 16 | UNK = 2 17 | 18 | def read_file(filename): 19 | file = unidecode.unidecode(open(filename).read()) 20 | return file, len(file) 21 | 22 | # Turning a string into a tensor 23 | 24 | def char_tensor(string): 25 | size = len(string) + 1 26 | tensor = torch.zeros(size).long() 27 | for c in range(len(string)): 28 | tensor[c] = all_characters.index(string[c]) 29 | tensor[-1] = EOS 30 | tensor = Variable(tensor) 31 | if USE_CUDA: 32 | tensor = tensor.cuda() 33 | return tensor 34 | 35 | # Turn a tensor into a string 36 | 37 | def index_to_char(top_i): 38 | if top_i == EOS: 39 | return '$' 40 | elif top_i == SOS: 41 | return '^' 42 | elif top_i == UNK: 43 | return '*' 44 | else: 45 | return all_characters[top_i] 46 | 47 | def tensor_to_string(t): 48 | s = '' 49 | for i in range(t.size(0)): 50 | ti = t[i] 51 | top_k = ti.data.topk(1) 52 | top_i = top_k[1][0] 53 | s += index_to_char(top_i) 54 | return s.split(index_to_char(EOS))[0] 55 | #return s 56 | 57 | def longtensor_to_string(t): 58 | s = '' 59 | for i in range(t.size(0)): 60 | top_i = t.data[i] 61 | s += index_to_char(top_i) 62 | return s 63 | 64 | # Readable time elapsed 65 | 66 | def time_since(since): 67 | s = time.time() - since 68 | m = math.floor(s / 60) 69 | s -= m * 60 70 | return '%dm %ds' % (m, s) 71 | 72 | -------------------------------------------------------------------------------- /pytorchtextvae/interpolate.py: -------------------------------------------------------------------------------- 1 | # slerp, lerp and associated from Tom White in plat (https://github.com/dribnet/plat) 2 | from model import * 3 | import numpy as np 4 | import sys 5 | from scipy.stats import norm 6 | import argparse 7 | 8 | default_data_path = "books_large_all_stored_info.pkl" 9 | default_vae_path = "vae.pt" 10 | default_temperature = 1. 11 | default_n_samples = 10 12 | default_length = 5 13 | default_path = "slerp" 14 | default_seed = 1999 15 | default_s1 = None 16 | default_s2 = None 17 | 18 | parser = argparse.ArgumentParser(description="Interpolation tests for trained RNN-VAE", 19 | # epilog="Simple usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10\nFull usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10 -d 0 -s 'HOLOFERNES' -e 'crew?\\n' -r 2177", 20 | formatter_class=argparse.RawTextHelpFormatter) 21 | parser.add_argument("-f", "--filepath", help="Path to pickled dataset info\nDefault: {}".format(default_data_path), default=default_data_path) 22 | parser.add_argument("-s", "--saved", help="Path to saved vae.pt file\nDefault: {}".format(default_vae_path), default=default_vae_path) 23 | parser.add_argument("-l", "--length", help="Length of sample path\nDefault: {}".format(default_length), default=default_length) 24 | parser.add_argument("-p", "--path", help="Path to use for sampling\nDefault: {}".format(default_path), default=default_path) 25 | parser.add_argument("-r", "--seed", help="Random seed to use\nDefault: {}".format(default_seed), default=default_seed) 26 | parser.add_argument("-n", "--nsamples", help="Number of samples\nDefault: {}".format(default_n_samples), default=default_n_samples) 27 | parser.add_argument("-t", "--temperature", help="Temperature to use when sampling\nDefault: {}".format(default_temperature), default=default_temperature) 28 | parser.add_argument("-1", "--s1", help="First sentence on the path, None for random\nDefault: {}".format(default_s1), default=default_s1) 29 | parser.add_argument("-2", "--s2", help="Second sentence on the path, None for random\nDefault: {}".format(default_s2), default=default_s2) 30 | 31 | args = parser.parse_args() 32 | filepath = args.filepath 33 | saved = args.saved 34 | length = int(args.length) 35 | path = args.path 36 | seed = int(args.seed) 37 | n_samples = int(args.nsamples) 38 | temperature = float(args.temperature) 39 | s1 = args.s1 40 | s2 = args.s2 41 | 42 | # Don't need it for sampling 43 | USE_CUDA = True 44 | 45 | vae = torch.load(saved) 46 | vae.train(False) 47 | 48 | torch.manual_seed(seed) 49 | random_state = np.random.RandomState(seed) 50 | 51 | reverse = True 52 | csv = False 53 | 54 | if filepath.endswith(".pkl"): 55 | cache_path = filepath 56 | lang_cache_path = filepath.split(os.sep)[-1].split(".")[0] + "_stored_lang.pkl" 57 | else: 58 | raise ValueError("Must be a pkl file") 59 | 60 | if not os.path.exists(cache_path): 61 | raise ValueError("Must have stored info already!") 62 | else: 63 | if os.path.exists(lang_cache_path): 64 | start_load = time.time() 65 | print("Fetching cached language info at {}".format(lang_cache_path)) 66 | with open(lang_cache_path, "rb") as f: 67 | input_side, output_side = pickle.load(f) 68 | end_load = time.time() 69 | print("Language only cache {} loaded, total load time {}".format(lang_cache_path, end_load - start_load)) 70 | else: 71 | start_load = time.time() 72 | print("Fetching cached info at {}".format(cache_path)) 73 | with open(cache_path, "rb") as f: 74 | input_side, output_side, pairs = pickle.load(f) 75 | end_load = time.time() 76 | print("Cache {} loaded, total load time {}".format(cache_path, end_load - start_load)) 77 | 78 | with open(lang_cache_path, "wb") as f: 79 | pickle.dump((input_side, output_side), f) 80 | 81 | 82 | def encode_sample(encode_sentence=None, stochastic=True): 83 | size = vae.encoder.output_size 84 | if encode_sentence is None: 85 | rm = Variable(torch.FloatTensor(1, size).normal_()) 86 | rl = Variable(torch.FloatTensor(1, size).normal_()) 87 | else: 88 | inp = word_tensor(input_side, encode_sentence) 89 | # temporary 90 | try: 91 | m, l, z = vae.encode(inp) 92 | except AttributeError: 93 | m, l, z = vae.encoder(inp) 94 | rm = m 95 | rl = l 96 | 97 | if USE_CUDA: 98 | rm = rm.cuda() 99 | rl = rl.cuda() 100 | 101 | if stochastic: 102 | z = vae.encoder.sample(rm, rl) 103 | return z 104 | 105 | 106 | 107 | def lerp(val, low, high): 108 | """Linear interpolation""" 109 | return low + (high - low) * val 110 | 111 | 112 | def lerp_gaussian(val, low, high): 113 | """Linear interpolation with gaussian CDF""" 114 | low_gau = norm.cdf(low) 115 | high_gau = norm.cdf(high) 116 | lerped_gau = lerp(val, low_gau, high_gau) 117 | return norm.ppf(lerped_gau) 118 | 119 | 120 | def slerp(val, low, high): 121 | """Spherical interpolation. val has a range of 0 to 1.""" 122 | if val <= 0: 123 | return low 124 | elif val >= 1: 125 | return high 126 | elif np.allclose(low, high): 127 | return low 128 | omega = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high))) 129 | so = np.sin(omega) 130 | return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high 131 | 132 | 133 | def slerp_gaussian(val, low, high): 134 | """Spherical interpolation with gaussian CDF (generally not useful)""" 135 | offset = norm.cdf(np.zeros_like(low)) # offset is just [0.5, 0.5, ...] 136 | low_gau_shifted = norm.cdf(low) - offset 137 | high_gau_shifted = norm.cdf(high) - offset 138 | circle_lerped_gau = slerp(val, low_gau_shifted, high_gau_shifted) 139 | epsilon = 0.001 140 | clipped_sum = np.clip(circle_lerped_gau + offset, epsilon, 1.0 - epsilon) 141 | result = norm.ppf(clipped_sum) 142 | return result 143 | 144 | 145 | for s in range(1, n_samples): 146 | if s1 is None: 147 | sent0 = None 148 | z0 = encode_sample() 149 | else: 150 | sent0 = input_side.process_sentence(str(s1)) 151 | z0 = encode_sample(sent0, False) 152 | 153 | if s2 is None: 154 | sent1 = None 155 | z1 = encode_sample() 156 | else: 157 | sent1 = input_side.process_sentence(str(s2)) 158 | z1 = encode_sample(sent1, False) 159 | 160 | z0_np = z0.cpu().data.numpy().ravel() 161 | z1_np = z1.cpu().data.numpy().ravel() 162 | last_s = '' 163 | 164 | generated_str = float_word_tensor_to_string(output_side, vae.decoder.generate(z0, MAX_LENGTH, temperature)) 165 | if generated_str.endswith("EOS "): 166 | generated_str = generated_str[:-4] 167 | generated_str = generated_str[::-1] 168 | 169 | end_str = float_word_tensor_to_string(output_side, vae.decoder.generate(z1, MAX_LENGTH, temperature)) 170 | if end_str.endswith("EOS "): 171 | end_str = end_str[:-4] 172 | end_str = end_str[::-1] 173 | 174 | if sent0 is not None: 175 | print('(s0)', sent0) 176 | print('(z0)', generated_str) 177 | 178 | last_s = generated_str 179 | 180 | for i in range(1, length): 181 | t = i * 1.0 / length 182 | 183 | #sph_z = slerp(t, z0_np, z1_np) 184 | #sph_z = slerp_gaussian(t, z0_np, z1_np) 185 | sph_z = lerp(t, z0_np, z1_np) 186 | interp_z = Variable(torch.FloatTensor(sph_z[None])) 187 | if USE_CUDA: 188 | interp_z = interp_z.cuda() 189 | s = float_word_tensor_to_string(output_side, vae.decoder.generate(interp_z, MAX_LENGTH, temperature)) 190 | generated_str = s 191 | if generated_str.endswith("EOS "): 192 | generated_str = generated_str[:-4] 193 | generated_str = generated_str[::-1] 194 | 195 | if generated_str != last_s and generated_str != end_str: 196 | print(' .)', generated_str) 197 | 198 | last_s = generated_str 199 | 200 | print('(z1)', end_str) 201 | if sent1 is not None: 202 | print('(s1)', sent1) 203 | print('\n') 204 | -------------------------------------------------------------------------------- /pytorchtextvae/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from torch.nn import Parameter 6 | from functools import wraps 7 | 8 | if __package__ is None or __package__ == '': 9 | from datasets import * 10 | else: 11 | from pytorchtextvae.datasets import * 12 | 13 | MAX_SAMPLE = False 14 | TRUNCATED_SAMPLE = True 15 | model_random_state = np.random.RandomState(1988) 16 | torch.manual_seed(1999) 17 | 18 | 19 | def _decorate(forward, module, name, name_g, name_v): 20 | @wraps(forward) 21 | def decorated_forward(*args, **kwargs): 22 | g = module.__getattr__(name_g) 23 | v = module.__getattr__(name_v) 24 | w = v*(g/torch.norm(v)).expand_as(v) 25 | module.__setattr__(name, w) 26 | return forward(*args, **kwargs) 27 | return decorated_forward 28 | 29 | 30 | def weight_norm(module, name): 31 | param = module.__getattr__(name) 32 | 33 | # construct g,v such that w = g/||v|| * v 34 | g = torch.norm(param) 35 | v = param/g.expand_as(param) 36 | g = Parameter(g.data) 37 | v = Parameter(v.data) 38 | name_g = name + '_g' 39 | name_v = name + '_v' 40 | 41 | # remove w from parameter list 42 | del module._parameters[name] 43 | 44 | # add g and v as new parameters 45 | module.register_parameter(name_g, g) 46 | module.register_parameter(name_v, v) 47 | 48 | # construct w every time before forward is called 49 | module.forward = _decorate(module.forward, module, name, name_g, name_v) 50 | return module 51 | 52 | 53 | def word_tensor(lang, string): 54 | split_string = string.split(" ") 55 | size = len(split_string) + 1 56 | tensor = torch.zeros(size).long() 57 | for c in range(len(split_string)): 58 | tensor[c] = lang.word_to_index(split_string[c]) 59 | tensor[-1] = EOS_token 60 | tensor = Variable(tensor) 61 | return tensor 62 | 63 | def _pair_to_tensors(input_side, output_side, pair, device): 64 | inp = word_tensor(input_side, pair[0]).to(device) 65 | target = word_tensor(output_side, pair[1]).to(device) 66 | condition = torch.tensor(pair[2], dtype=torch.float).unsqueeze(0).to(device) if len(pair) == 3 else None 67 | 68 | return inp, target, condition 69 | 70 | def random_training_set(dataset, random_state, device): 71 | pair_i = random_state.choice(len(dataset.trn_pairs)) 72 | pair = dataset.trn_pairs[pair_i] 73 | return _pair_to_tensors(dataset.input_side, dataset.output_side, pair, device) 74 | 75 | def random_test_set(dataset, random_state, device): 76 | pair_i = random_state.choice(len(dataset.test_pairs)) 77 | pair = dataset.test_pairs[pair_i] 78 | return _pair_to_tensors(dataset.input_side, dataset.output_side, pair, device) 79 | 80 | def index_to_word(lang, top_i): 81 | return lang.index_to_word(top_i) + " " 82 | 83 | 84 | def long_word_tensor_to_string(lang, t): 85 | s = '' 86 | for i in range(t.size(0)): 87 | top_i = t.data[i] 88 | s += index_to_word(lang, top_i) 89 | return s 90 | 91 | 92 | def float_word_tensor_to_string(lang, t): 93 | s = '' 94 | for i in range(t.size(0)): 95 | ti = t[i] 96 | top_k = ti.data.topk(1) 97 | top_i = top_k[1][0] 98 | s += index_to_word(lang, top_i) 99 | if top_i == EOS_token: 100 | break 101 | return s 102 | 103 | 104 | class Encoder(nn.Module): 105 | def sample(self, mu, logvar, device): 106 | eps = Variable(torch.randn(mu.size())).to(device) 107 | std = torch.exp(logvar / 2.0) 108 | return mu + eps * std 109 | 110 | # Encoder 111 | # ------------------------------------------------------------------------------ 112 | 113 | # Encode into Z with mu and log_var 114 | 115 | class EncoderRNN(Encoder): 116 | def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True): 117 | super(EncoderRNN, self).__init__() 118 | self.input_size = input_size 119 | self.hidden_size = hidden_size 120 | self.output_size = output_size 121 | self.n_layers = n_layers 122 | self.bidirectional = bidirectional 123 | 124 | self.embed = nn.Embedding(input_size, hidden_size) 125 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=0.1, bidirectional=bidirectional) 126 | self.o2p = nn.Linear(hidden_size, output_size * 2) 127 | 128 | def forward(self, input, device): 129 | embedded = self.embed(input).unsqueeze(1) 130 | 131 | output, hidden = self.gru(embedded, None) 132 | # mean loses positional info? 133 | #output = torch.mean(output, 0).squeeze(0) #output[-1] # Take only the last value 134 | output = output[-1]#.squeeze(0) 135 | if self.bidirectional: 136 | output = output[:, :self.hidden_size] + output[: ,self.hidden_size:] # Sum bidirectional outputs 137 | else: 138 | output = output[:, :self.hidden_size] 139 | 140 | ps = self.o2p(output) 141 | mu, logvar = torch.chunk(ps, 2, dim=1) 142 | z = self.sample(mu, logvar, device) 143 | return mu, logvar, z 144 | 145 | # Decoder 146 | # ------------------------------------------------------------------------------ 147 | 148 | # Decode from Z into sequence 149 | 150 | class DecoderRNN(nn.Module): 151 | def __init__(self, z_size, n_conditions, condition_size, hidden_size, output_size, n_layers=1, word_dropout=1.): 152 | super(DecoderRNN, self).__init__() 153 | self.output_size = output_size 154 | self.n_layers = n_layers 155 | self.word_dropout = word_dropout 156 | 157 | input_size = z_size + condition_size 158 | 159 | self.embed = nn.Embedding(output_size, hidden_size) 160 | self.gru = nn.GRU(hidden_size + input_size, hidden_size, n_layers) 161 | self.i2h = nn.Linear(input_size, hidden_size) 162 | if n_conditions > 0 and condition_size > 0 and n_conditions != condition_size: 163 | self.c2h = nn.Linear(n_conditions, condition_size) 164 | #self.dropout = nn.Dropout() 165 | self.h2o = nn.Linear(hidden_size * 2, hidden_size) 166 | self.out = nn.Linear(hidden_size + input_size, output_size) 167 | 168 | print(f'MAX_SAMPLE: {MAX_SAMPLE}; TRUNCATED_SAMPLE: {TRUNCATED_SAMPLE}') 169 | 170 | def sample(self, output, temperature, device, max_sample=MAX_SAMPLE, trunc_sample=TRUNCATED_SAMPLE): 171 | if max_sample: 172 | # Sample top value only 173 | top_i = output.data.topk(1)[1].item() 174 | else: 175 | # Sample from the network as a multinomial distribution 176 | if trunc_sample: 177 | # Sample from top k values only 178 | k = 10 179 | new_output = torch.empty_like(output).fill_(float('-inf')) 180 | top_v, top_i = output.data.topk(k) 181 | new_output.data.scatter_(1, top_i, top_v) 182 | output = new_output 183 | 184 | output_dist = output.data.view(-1).div(temperature).exp() 185 | if len(torch.nonzero(output_dist)) > 0: 186 | top_i = torch.multinomial(output_dist, 1)[0] 187 | else: 188 | # TODO: how does this happen? 189 | print(f'[WARNING] output_dist is all zeroes') 190 | top_i = UNK_token 191 | 192 | input = Variable(torch.LongTensor([top_i])).to(device) 193 | return input, top_i 194 | 195 | def forward(self, z, condition, inputs, temperature, device): 196 | n_steps = inputs.size(0) 197 | outputs = Variable(torch.zeros(n_steps, 1, self.output_size)).to(device) 198 | 199 | input = Variable(torch.LongTensor([SOS_token])).to(device) 200 | if condition is None: 201 | decode_embed = z 202 | else: 203 | if hasattr(self, 'c2h'): 204 | #squashed_condition = self.c2h(self.dropout(condition)) 205 | squashed_condition = self.c2h(condition) 206 | decode_embed = torch.cat([z, squashed_condition], 1) 207 | else: 208 | decode_embed = torch.cat([z, condition], 1) 209 | 210 | 211 | hidden = self.i2h(decode_embed).unsqueeze(0).repeat(self.n_layers, 1, 1) 212 | 213 | for i in range(n_steps): 214 | output, hidden = self.step(i, decode_embed, input, hidden) 215 | outputs[i] = output 216 | 217 | use_word_dropout = model_random_state.rand() < self.word_dropout 218 | if use_word_dropout and i < (n_steps - 1): 219 | unk_input = Variable(torch.LongTensor([UNK_token])).to(device) 220 | input = unk_input 221 | continue 222 | 223 | use_teacher_forcing = model_random_state.rand() < temperature 224 | if use_teacher_forcing: 225 | input = inputs[i] 226 | else: 227 | input, top_i = self.sample(output, temperature, device, max_sample=True) 228 | 229 | if input.dim() == 0: 230 | input = input.unsqueeze(0) 231 | 232 | return outputs.squeeze(1) 233 | 234 | def generate_with_embed(self, embed, n_steps, temperature, device, max_sample=MAX_SAMPLE, trunc_sample=TRUNCATED_SAMPLE): 235 | outputs = Variable(torch.zeros(n_steps, 1, self.output_size)).to(device) 236 | input = Variable(torch.LongTensor([SOS_token])).to(device) 237 | 238 | hidden = self.i2h(embed).unsqueeze(0).repeat(self.n_layers, 1, 1) 239 | 240 | for i in range(n_steps): 241 | output, hidden = self.step(i, embed, input, hidden) 242 | outputs[i] = output 243 | input, top_i = self.sample(output, temperature, device, max_sample=max_sample, trunc_sample=trunc_sample) 244 | #if top_i == EOS: break 245 | return outputs.squeeze(1) 246 | 247 | def generate(self, z, condition, n_steps, temperature, device, max_sample=MAX_SAMPLE, trunc_sample=TRUNCATED_SAMPLE): 248 | if condition is None: 249 | decode_embed = z 250 | else: 251 | if condition.dim() == 1: 252 | condition = condition.unsqueeze(0) 253 | 254 | if hasattr(self, 'c2h'): 255 | #squashed_condition = self.c2h(self.dropout(condition)) 256 | squashed_condition = self.c2h(condition) 257 | decode_embed = torch.cat([z, squashed_condition], 1) 258 | else: 259 | decode_embed = torch.cat([z, condition], 1) 260 | 261 | return self.generate_with_embed(decode_embed, n_steps, temperature, device, max_sample, trunc_sample) 262 | 263 | def step(self, s, decode_embed, input, hidden): 264 | # print('[DecoderRNN.step] s =', s, 'decode_embed =', decode_embed.size(), 'i =', input.size(), 'h =', hidden.size()) 265 | input = F.relu(self.embed(input)) 266 | input = torch.cat((input, decode_embed), 1) 267 | input = input.unsqueeze(0) 268 | output, hidden = self.gru(input, hidden) 269 | output = output.squeeze(0) 270 | output = torch.cat((output, decode_embed), 1) 271 | output = self.out(output) 272 | return output, hidden 273 | 274 | # Container 275 | # ------------------------------------------------------------------------------ 276 | 277 | class VAE(nn.Module): 278 | def __init__(self, encoder, decoder, n_steps=None): 279 | super(VAE, self).__init__() 280 | self.encoder = encoder 281 | self.decoder = decoder 282 | 283 | self.register_buffer('steps_seen', torch.tensor(0, dtype=torch.long)) 284 | self.register_buffer('kld_max', torch.tensor(1.0, dtype=torch.float)) 285 | self.register_buffer('kld_weight', torch.tensor(0.0, dtype=torch.float)) 286 | if n_steps is not None: 287 | self.register_buffer('kld_inc', torch.tensor((self.kld_max - self.kld_weight) / (n_steps // 2), dtype=torch.float)) 288 | else: 289 | self.register_buffer('kld_inc', torch.tensor(0, dtype=torch.float)) 290 | 291 | def encode(self, inputs): 292 | m, l, z = self.encoder(inputs) 293 | return m, l, z 294 | 295 | def forward(self, inputs, targets, condition, device, temperature=1.0): 296 | m, l, z = self.encoder(inputs, device) 297 | decoded = self.decoder(z, condition, targets, temperature, device) 298 | return m, l, z, decoded 299 | 300 | # Test 301 | 302 | if __name__ == '__main__': 303 | device = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu') 304 | hidden_size = 20 305 | z_size = 10 306 | e = EncoderRNN(n_characters, hidden_size, z_size).to(device) 307 | d = DecoderRNN(z_size, hidden_size, n_characters, 2).to(device) 308 | vae = VAE(e, d) 309 | m, l, z, decoded = vae(char_tensor('@spro')) 310 | print('m =', m.size()) 311 | print('l =', l.size()) 312 | print('z =', z.size()) 313 | print('decoded', tensor_to_string(decoded)) 314 | -------------------------------------------------------------------------------- /pytorchtextvae/train.py: -------------------------------------------------------------------------------- 1 | import dill as pickle 2 | import numpy as np 3 | import os 4 | import shutil 5 | 6 | from datasets import Condition, get_vocabulary, prepare_pair_data 7 | from model import * 8 | 9 | def train_vae(data_path, tmp_path=f'..{os.sep}tmp', 10 | encoder_hidden_size=512, n_encoder_layers=2, decoder_hidden_size=512, z_size=128, 11 | condition_size=16, max_vocab=-1, lr=0.0001, n_steps=1500000, grad_clip=10.0, 12 | save_every=None, log_every_n_seconds=5*60, log_every_n_steps=1000, 13 | kld_start_inc=10000, habits_lambda=0.2, 14 | word_dropout=0.25, temperature=1.0, temperature_min=0.75, condition_on=0, 15 | use_cuda=True, generate_samples=True, generate_interpolations=True, min_gen_len=10, max_gen_len=200): 16 | 17 | args_passed = locals() 18 | print(args_passed) 19 | 20 | DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() and use_cuda else 'cpu') 21 | if 'cuda' in DEVICE.type: 22 | print('Using CUDA!') 23 | 24 | # should get to the temperature around 50% through training, then hold 25 | temperature_dec = (temperature - temperature_min) / (0.5 * n_steps) 26 | if save_every is None: 27 | save_every = log_every_n_steps 28 | 29 | filename = data_path.split(os.sep)[-1].split(".")[0] 30 | if data_path.endswith(".pkl"): 31 | cache_path = os.path.join(*data_path.split(os.sep)[:-1]) 32 | cache_file = os.path.join(cache_path, filename + ".pkl") 33 | else: 34 | if os.path.isdir(data_path): # subdir within tmp_path dir 35 | filename = data_path.split(os.sep)[-2] 36 | tmp_path = os.path.join(tmp_path, filename) 37 | print(f'Updating tmp_path: {tmp_path}') 38 | cache_path = tmp_path 39 | cache_file = os.path.join(cache_path, filename + "_stored_info.pkl") 40 | 41 | if not os.path.exists(cache_file): 42 | print("Cached info at {} not found".format(cache_file)) 43 | print("Creating cache... this may take some time") 44 | 45 | if not os.path.exists(cache_path): 46 | os.mkdir(cache_path) 47 | 48 | if condition_on < 0 or condition_on > (len(Condition) - 1): 49 | print(f'Invalid condition_on value of {condition_on}. Falling back to {Condition.NONE}') 50 | condition_on = Condition.NONE 51 | else: 52 | condition_on = Condition(condition_on) 53 | 54 | dataset = prepare_pair_data(data_path, max_vocab, tmp_path, min_gen_len, max_gen_len, condition_on, reverse=True) 55 | 56 | if condition_on == Condition.NONE: 57 | condition_size = 0 58 | elif condition_on == Condition.AF: 59 | condition_size = dataset.trn_split.n_conditions 60 | 61 | with open(cache_file, "wb") as f: 62 | pickle.dump((dataset, z_size, condition_size, condition_on, decoder_hidden_size, encoder_hidden_size, n_encoder_layers), f) 63 | else: 64 | start_load = time.time() 65 | print("Fetching cached info at {}".format(cache_file)) 66 | with open(cache_file, "rb") as f: 67 | dataset, z_size, condition_size, condition_on, decoder_hidden_size, encoder_hidden_size, n_encoder_layers = pickle.load(f) 68 | end_load = time.time() 69 | print(f"Cache {cache_file} loaded (load time: {end_load - start_load:.2f}s)") 70 | 71 | print("Shuffling training data") 72 | random_state = np.random.RandomState(1999) 73 | random_state.shuffle(dataset.trn_pairs) 74 | 75 | print("Initializing model") 76 | print(f'condition_on {condition_on.name} (condition_size: {condition_size})') 77 | n_words = dataset.input_side.n_words 78 | e = EncoderRNN(n_words, encoder_hidden_size, z_size, n_encoder_layers, bidirectional=True).to(DEVICE) 79 | 80 | # custom weights initialization # TODO: should we do this if using saved_vae? 81 | def rnn_weights_init(m): 82 | for c in m.children(): 83 | classname = c.__class__.__name__ 84 | if classname.find("GRU") != -1: 85 | for k, v in c.named_parameters(): 86 | if "weight" in k: 87 | v.data.normal_(0.0, 0.02) 88 | 89 | d = DecoderRNN(z_size, dataset.trn_split.n_conditions, condition_size, decoder_hidden_size, n_words, 1, word_dropout=word_dropout).to(DEVICE) 90 | rnn_weights_init(d) 91 | 92 | vae = VAE(e, d, n_steps).to(DEVICE) 93 | saved_vae = filename + "_state.pt" 94 | if os.path.exists(saved_vae): 95 | start_load_model = time.time() 96 | print("Found saved model {}, continuing...".format(saved_vae)) 97 | shutil.copyfile(saved_vae, saved_vae + ".bak") 98 | vae.load_state_dict(torch.load(saved_vae)) 99 | print(f"Found model was already trained for {vae.steps_seen} steps (load time: {time.time() - start_load_model:.2f}s)") 100 | print(f'kld_max: {vae.kld_max}; kld_weight: {vae.kld_weight}') 101 | temperature = temperature_min 102 | temperature_min = temperature_min 103 | temperature_dec = 0. 104 | 105 | print("Setting new random seed") 106 | # change random seed and reshuffle the data, so that we don't repeat the same 107 | # use hash of the weights and biases? try with float16 to avoid numerical issues in the tails... 108 | new_seed = hash(tuple([hash(tuple(vae.state_dict()[k].cpu().numpy().ravel().astype("float16"))) for k, v in vae.state_dict().items()])) 109 | # must be between 0 and 4294967295 110 | new_seed = abs(new_seed) % 4294967295 111 | print(new_seed) 112 | random_state = np.random.RandomState(new_seed) 113 | print("Reshuffling training data") 114 | random_state.shuffle(dataset.trn_pairs) 115 | 116 | optimizer = torch.optim.Adam(vae.parameters(), lr=lr) 117 | criterion = nn.CrossEntropyLoss() 118 | criterion.to(DEVICE) 119 | vae.train() 120 | 121 | def save(): 122 | save_state_filename = filename + '_state.pt' 123 | torch.save(vae.state_dict(), save_state_filename) 124 | print('Saved as %s' % (save_state_filename)) 125 | 126 | try: 127 | # set it so that the first one logs 128 | start_time = time.time() 129 | last_log_time = time.time() - log_every_n_seconds 130 | last_log_step = -log_every_n_steps - 1 131 | start_steps = vae.steps_seen 132 | for step in range(start_steps, n_steps): 133 | input, target, condition = random_training_set(dataset, random_state, DEVICE) 134 | optimizer.zero_grad() 135 | 136 | m, l, z, decoded = vae(input, target, condition, DEVICE, temperature) 137 | if temperature > temperature_min: 138 | temperature -= temperature_dec 139 | 140 | ll_loss = criterion(decoded, target) 141 | 142 | KLD = -0.5 * (2 * l - torch.pow(m, 2) - torch.pow(torch.exp(l), 2) + 1) 143 | # ha bits , like free bits but over whole layer 144 | clamp_KLD = torch.clamp(KLD.mean(), min=habits_lambda).squeeze() 145 | loss = ll_loss + clamp_KLD * vae.kld_weight 146 | 147 | loss.backward() 148 | 149 | if step > kld_start_inc and vae.kld_weight < vae.kld_max: 150 | vae.kld_weight += vae.kld_inc 151 | 152 | ec = torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip) 153 | optimizer.step() 154 | 155 | def log_and_generate(tag, value): 156 | if dataset.test_pairs: 157 | with torch.no_grad(): 158 | test_input, test_target, test_condition = random_test_set(dataset, random_state, DEVICE) 159 | _, _, _, test_decoded = vae(test_input, test_target, test_condition, DEVICE, temperature) 160 | test_ll_loss = criterion(test_decoded, test_target) 161 | else: 162 | test_ll_loss = torch.FloatTensor([float('inf')]).to(DEVICE) 163 | 164 | if tag == "step": 165 | #print('|%s|[%d] %.4f (k=%.4f, t=%.4f, kl=%.4f, ckl=%.4f, nll=%.4f, ec=%.4f)' % ( 166 | # tag, value, loss.item(), vae.kld_weight, temperature, KLD.data.mean(), clamp_KLD.item(), ll_loss.item(), ec 167 | #)) 168 | print(f'|{tag}|[{value}] {loss.item():.4f} (k={vae.kld_weight:.4f}, t={temperature:.4f}, kl={KLD.data.mean():.4f}, ckl={clamp_KLD.item():.4f}, nll={ll_loss.item():.4f}, test_nll={test_ll_loss.item():.4f}, ec={ec:.4f})') 169 | with open('plots.txt', 'a') as f: 170 | f.write(f'{value}\t{loss.item()}\t{ll_loss.item()}\t{test_ll_loss.item()}\t{KLD.data.mean()}\n') 171 | elif tag == "time": 172 | #print('|%s|[%.4f] %.4f (k=%.4f, t=%.4f, kl=%.4f, ckl=%.4f, nll=%.4f, ec=%.4f)' % ( 173 | # tag, value, loss.item(), vae.kld_weight, temperature, KLD.data.mean(), clamp_KLD.item(), ll_loss.item(), ec 174 | #)) 175 | print(f'|{tag}|[{value:.4f}] {loss.item():.4f} (k={vae.kld_weight:.4f}, t={temperature:.4f}, kl={KLD.data.mean():.4f}, ckl={clamp_KLD.item():.4f}, nll={ll_loss.item():.4f}, test_nll={test_ll_loss.item():.4f}, ec={ec:.4f})') 176 | 177 | EOS_str = f' {dataset.output_side.index_to_word(torch.LongTensor([EOS_token]))} ' 178 | 179 | if generate_samples: 180 | rand_z = torch.randn(z_size).unsqueeze(0).to(DEVICE) 181 | if condition_on == Condition.GENRE: 182 | fixed_condition = torch.FloatTensor(dataset.trn_split.encode_conditions(['vapor soul'])).to(DEVICE) 183 | elif condition_on == Condition.AF: 184 | fixed_condition = torch.FloatTensor(dataset.get_mean_condition(dataset.test_pairs)).to(DEVICE) 185 | else: 186 | fixed_condition = None 187 | 188 | generated = vae.decoder.generate(rand_z, fixed_condition, max_gen_len, temperature, DEVICE, max_sample=True) 189 | generated_str = float_word_tensor_to_string(dataset.output_side, generated) 190 | 191 | if generated_str.endswith(EOS_str): 192 | generated_str = generated_str[:-5] 193 | 194 | # flip it back 195 | print('----') 196 | print(' (sample {}) "{}"'.format(tag, generated_str[::-1])) 197 | 198 | if generate_interpolations: 199 | inp_str = long_word_tensor_to_string(dataset.input_side, input) 200 | print('----') 201 | print(' (input/target {}) "{}"'.format(tag, inp_str)) 202 | 203 | generated = vae.decoder.generate(z, condition, max_gen_len, temperature, DEVICE, max_sample=True) 204 | generated_str = float_word_tensor_to_string(dataset.output_side, generated) 205 | if generated_str.endswith(EOS_str): 206 | generated_str = generated_str[:-5] 207 | 208 | # flip it back 209 | print(' (interpolation {}) "{}"'.format(tag, generated_str[::-1])) 210 | print('----') 211 | 212 | if last_log_time <= time.time() - log_every_n_seconds: 213 | log_and_generate("time", time.time() - start_time) 214 | last_log_time = time.time() 215 | 216 | if last_log_step <= step - log_every_n_steps: 217 | log_and_generate("step", step) 218 | last_log_step = step 219 | 220 | if step > 0 and step % save_every == 0 or step == (n_steps - 1): 221 | vae.steps_seen = torch.tensor(step, dtype=torch.long).to(DEVICE) 222 | save() 223 | 224 | save() 225 | 226 | except KeyboardInterrupt as err: 227 | print("ERROR", err) 228 | print("Saving before quit...") 229 | vae.steps_seen = torch.tensor(step, dtype=torch.long).to(DEVICE) 230 | save() 231 | 232 | if __name__ == "__main__": 233 | import fire; fire.Fire(train_vae) 234 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name = "pytorchtextvae", 8 | version = "0.0.1", 9 | description = "A partial reimplementation of \"Generating Sentences From a Continuous Space\" by Bowman, Vilnis, Vinyals, Dai, Jozefowicz, Bengio (https://arxiv.org/abs/1511.06349).", 10 | long_description = long_description, 11 | long_description_content_type = "text/markdown", 12 | license = "MIT", 13 | url = "https://github.com/iconix/pytorch-text-vae", 14 | packages = [ 'pytorchtextvae' ], 15 | install_requires = [ 'dill', 'fire', 'unidecode' ], 16 | keywords = [ 'deeplearning', 'pytorch', 'vae', 'nlp' ], 17 | classifiers = ['Development Status :: 3 - Alpha', 18 | 'Programming Language :: Python', 19 | 'Programming Language :: Python :: 3.6', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'] 21 | ) 22 | --------------------------------------------------------------------------------