├── .gitignore ├── model ├── __init__.py ├── utils.py ├── bottles.py ├── hash_counter.py ├── hasher.py ├── generator.py └── discriminator.py ├── dataset ├── __init__.py ├── replay_buffer.py ├── samplers.py ├── real.py └── gen.py ├── scripts ├── make_links.sh ├── prep_wiki_dataset.sh ├── dx2_rnn.patch ├── make_w2v.py ├── format_embs_projector.py ├── prep_wiki_dataset.py └── prep_qa_dataset.py ├── environ ├── __init__.py ├── real.py ├── synth.py └── environment.py ├── LICENSE.txt ├── README.md ├── common.py ├── main.py ├── notebooks └── utils.py └── .pylintrc /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data 3 | run 4 | __pycache__ 5 | *.ipynb* 6 | .cache 7 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hasher, generator, discriminator, hash_counter, utils 2 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Datasets and Samplers.""" 2 | 3 | from .gen import GenDataset 4 | from .real import NLDataset 5 | from .replay_buffer import ReplayBuffer 6 | from . import samplers 7 | -------------------------------------------------------------------------------- /scripts/make_links.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for f in $(git ls-tree --full-tree -r --name-only HEAD | grep "\.py$") 4 | do 5 | if [ "$1" == "--cp" ]; then 6 | rm "$f" 7 | cp "../../$f" "$f" 8 | else 9 | ln -Lrsf "../../$f" "$f" 10 | fi 11 | done 12 | -------------------------------------------------------------------------------- /scripts/prep_wiki_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CACHE_DIR="../data/_cache" 4 | WIKI="simplewiki" 5 | WIKI_DUMP_DATE="20171103" 6 | WIKI_NAME="$WIKI-$WIKI_DUMP_DATE" 7 | WIKI_DUMP_URL="https://dumps.wikimedia.org/$WIKI/$WIKI_DUMP_DATE/$WIKI_NAME-pages-articles.xml.bz2" 8 | WIKI_DUMP_PATH="$CACHE_DIR/$WIKI_NAME.xml.bz2" 9 | 10 | WIKI_EXTRACTOR="$HOME/tools/wikiextractor/WikiExtractor.py" 11 | 12 | curl -L "$WIKI_DUMP_URL" > "$WIKI_DUMP_PATH" 13 | 14 | $WIKI_EXTRACTOR --filter_disambig_pages --no-templates --min_text_length 100 "$WIKI_DUMP_PATH" -o - --json -q > "$CACHE_DIR/$WIKI_NAME.json" 15 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Apply(nn.Module): 6 | """A Module that wraps a function.""" 7 | def __init__(self, fn, detach=False): 8 | super(Apply, self).__init__() 9 | self.fn = fn 10 | self.detach = detach 11 | 12 | def forward(self, input): 13 | output = self.fn(input) 14 | if self.detach: 15 | output = output.detach() 16 | return output 17 | 18 | 19 | def load_w2v_file(w2v_file): 20 | """Loads a textual word2vec file in which the tokens are numeric.""" 21 | return torch.np.loadtxt(w2v_file)[:, 1:] 22 | -------------------------------------------------------------------------------- /model/bottles.py: -------------------------------------------------------------------------------- 1 | """Nd wrappers for modules that operate on the columns of a matrix.""" 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Bottle(nn.Module): 8 | """Allows a 2D module to process an Nd input.""" 9 | def forward(self, *args, **kwargs): 10 | return bottle(super(Bottle, self).forward, *args, **kwargs) 11 | 12 | 13 | def bottle(fn, inp, *args, **kwargs): 14 | """Applies a fn defined on matrices to tensors.""" 15 | sz = inp.size() 16 | if len(sz) <= 2: 17 | return super(Bottle, self).forward(inp) 18 | out = fn(inp.view(-1, sz[-1]), *args, **kwargs) 19 | out_sz = out.size() 20 | return out.view(*sz[:-1], *out_sz[-(len(out_sz) - 1):]) 21 | 22 | 23 | class BottledLinear(Bottle, nn.Linear): 24 | pass 25 | 26 | 27 | class BottledEmbedding(Bottle, nn.Embedding): 28 | pass 29 | -------------------------------------------------------------------------------- /environ/__init__.py: -------------------------------------------------------------------------------- 1 | from .real import NLEnvironment 2 | from .synth import SynthEnvironment 3 | 4 | 5 | ENVS = ('real', 'synth') 6 | REAL, SYNTH = ENVS 7 | 8 | 9 | def _get_env(env): 10 | if env == SYNTH: 11 | return SynthEnvironment 12 | return NLEnvironment 13 | 14 | def create(env, opts): 15 | """Creates the Environment appropriate for the given opts.""" 16 | return _get_env(env)(opts) 17 | 18 | def parse_env_opts(init_opts, remaining_opts, no_defaults=False): 19 | """Returns environment-specific options.""" 20 | env_type = _get_env(init_opts.env) 21 | parser = env_type.get_opt_parser() 22 | opts = parser.parse_args(remaining_opts) 23 | if no_defaults: 24 | parser.set_defaults(**{opt: None for opt in vars(opts)}) 25 | opts = parser.parse_args(remaining_opts) 26 | for k, v in vars(init_opts).items(): 27 | if k == 'env' and no_defaults: 28 | continue 29 | setattr(opts, k, v) 30 | return opts 31 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 Nick Hynes 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Behavioral Cloning 2 | 3 | Improves on the [SeqGAN](https://aaai.org/ocs/index.php/AAAI/AAAI17/paper/view/14344) 4 | idea by adding more reinfocement learning and GAN techniques like: 5 | * a replay buffer 6 | * [Consensus Optimization](https://arxiv.org/abs/1705.10461) 7 | * [count-based exploration bonus](https://arxiv.org/abs/1611.04717) 8 | * Proximal Policy Optimization (was not found to help, but can be found in 9 | [this commit](https://github.com/nhynes/abc/commit/af7d921ab96e20ba75a60558e1e293b8667b4480)) 10 | * advantage normalization 11 | 12 | 13 | ## How to run 14 | 15 | If you wish to enable Consensus Optimization (via the `--grad-reg` option), you'll need 16 | to [patch](scripts/dx2_rnn.patch) PyTorch to allow forcing the use a 17 | twice-differentiable RNN. 18 | 19 | `python3 main.py` will run the project with the default options. 20 | Output will be written to the `run/` directory. 21 | 22 | ## Shameless plug 23 | 24 | The [em](https://github.com/nhynes/em) tool makes it really easy to twiddle 25 | hyperparameters by tracking changes to code (no need to make everything an option!). 26 | 27 | Just run `em run -g 0 exp_name` with your desired options and you'll find a reproducable 28 | snapshot in `experiments/`! 29 | 30 | If you want to resume from a snapshot (perhaps with different options), 31 | use `em resume -g 0 exp_name ...` 32 | 33 | You can also fork an experiment and its changes using `em fork`, but the quick and dirty 34 | solution is to run `bash scripts/make_links.sh` :) 35 | -------------------------------------------------------------------------------- /scripts/dx2_rnn.patch: -------------------------------------------------------------------------------- 1 | diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py 2 | index 477c551..c034e3b 100644 3 | --- a/torch/nn/_functions/rnn.py 4 | +++ b/torch/nn/_functions/rnn.py 5 | @@ -11,9 +11,6 @@ except ImportError: 6 | pass 7 | 8 | 9 | -force_unfused = False 10 | - 11 | - 12 | def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 13 | hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh)) 14 | return hy 15 | @@ -25,7 +22,7 @@ def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 16 | 17 | 18 | def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 19 | - if input.is_cuda and not force_unfused: 20 | + if input.is_cuda: 21 | igates = F.linear(input, w_ih) 22 | hgates = F.linear(hidden[0], w_hh) 23 | state = fusedBackend.LSTMFused.apply 24 | @@ -49,7 +46,7 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 25 | 26 | def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 27 | 28 | - if input.is_cuda and not force_unfused: 29 | + if input.is_cuda: 30 | gi = F.linear(input, w_ih) 31 | gh = F.linear(hidden, w_hh) 32 | state = fusedBackend.GRUFused.apply 33 | @@ -373,7 +370,7 @@ def hack_onnx_rnn(fargs, output, args, kwargs): 34 | 35 | def RNN(*args, **kwargs): 36 | def forward(input, *fargs, **fkwargs): 37 | - if not force_unfused and cudnn.is_acceptable(input.data): 38 | + if cudnn.is_acceptable(input.data): 39 | func = CudnnRNN(*args, **kwargs) 40 | else: 41 | func = AutogradRNN(*args, **kwargs) 42 | -------------------------------------------------------------------------------- /dataset/replay_buffer.py: -------------------------------------------------------------------------------- 1 | """A Dataset that acts as a replay buffer.""" 2 | from collections import deque 3 | 4 | import torch 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | 8 | 9 | class ReplayBuffer(torch.utils.data.ConcatDataset): 10 | """Loads data from a replay buffer.""" 11 | 12 | def __init__(self, max_history, label, **unused_kwargs): 13 | # pylint: disable=super-init-not-called 14 | self.datasets = deque(maxlen=max_history) 15 | self.cumulative_sizes = [0] 16 | self.label = label 17 | 18 | def add_samples(self, samples): 19 | """Adds a batch of samples to the replay buffer.""" 20 | if samples.is_cuda: 21 | samples = samples.cpu() 22 | if isinstance(samples, Variable): 23 | samples = samples.data 24 | 25 | dataset = torch.utils.data.TensorDataset( 26 | samples, torch.LongTensor(len(samples)).fill_(self.label)) 27 | self.datasets.append(dataset) 28 | self.cumulative_sizes = self.cumsum(self.datasets) 29 | 30 | 31 | def test_replay_buffer(): 32 | """Tests the replay buffer.""" 33 | 34 | rbuf = ReplayBuffer(label=-2, max_history=2) 35 | 36 | rbuf.add_samples(torch.zeros(2, 3)) 37 | rbuf.add_samples(torch.ones(3, 3)) 38 | 39 | assert len(rbuf) == 5 40 | 41 | assert (rbuf[0][0] == 0).all() 42 | assert rbuf[0][1] == -2 43 | assert (rbuf[len(rbuf) - 1][0] == 1).all() 44 | 45 | rbuf.add_samples(Variable(torch.ones(4, 3)*2)) 46 | assert len(rbuf) == 7 47 | assert (rbuf[0][0] == 1).all() 48 | assert (rbuf[len(rbuf) - 1][0] == 2).all() 49 | -------------------------------------------------------------------------------- /scripts/make_w2v.py: -------------------------------------------------------------------------------- 1 | """Prunes token vectors to the subset actually present in a vocabulary.""" 2 | 3 | import argparse 4 | import os 5 | import pickle 6 | import sys 7 | 8 | import numpy as np 9 | 10 | PROJ_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 11 | DATA_DIR = os.path.join(PROJ_ROOT, 'data') 12 | 13 | sys.path.insert(0, PROJ_ROOT) 14 | from common import EXTRA_VOCAB, UNK, BOS, EOS 15 | import common 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--vocab', 21 | default='../data/qa/vocab.pkl') 22 | parser.add_argument('--word-vecs', default='../data/glove.840B.300d') 23 | parser.add_argument('--vocab-size', default=20000, type=int) 24 | args = parser.parse_args() 25 | 26 | tok_vocab = EXTRA_VOCAB 27 | tok_vocab.extend(tok for tok, _ in common.unpickle(args.vocab).tok_counts) 28 | tok_vocab = tok_vocab[:args.vocab_size] 29 | 30 | vecs_vocab = common.unpickle(f'{args.word_vecs}_vocab.pkl') 31 | vecs = np.load(f'{args.word_vecs}.npy') 32 | vecs_w2i = {w: i for i, w in enumerate(vecs_vocab)} 33 | for i, w in enumerate(vecs_vocab): 34 | if not w.lower() in vecs_w2i: 35 | vecs_w2i[w.lower()] = i 36 | 37 | filt_vecs = np.random.randn(len(tok_vocab), vecs.shape[1]) 38 | n = 0 39 | for i, w in enumerate(tok_vocab): 40 | if w in vecs_w2i: 41 | filt_vecs[i] = vecs[vecs_w2i[w]] 42 | n += 1 43 | print(f'found {n} words') 44 | 45 | np.save(os.path.join(DATA_DIR, 'tok_vecs_pruned.npy'), filt_vecs) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /dataset/samplers.py: -------------------------------------------------------------------------------- 1 | "Dataset samplers." 2 | 3 | import sys 4 | 5 | import torch.utils.data 6 | 7 | 8 | class InfiniteRandomSampler(torch.utils.data.sampler.RandomSampler): 9 | """A RandomSampler that cycles forever.""" 10 | def __iter__(self): 11 | index_iter = iter(()) 12 | while True: 13 | try: 14 | yield next(index_iter) 15 | except StopIteration: 16 | index_iter = super(InfiniteRandomSampler, self).__iter__() 17 | 18 | def __len__(self): 19 | return sys.maxsize 20 | 21 | 22 | class ReplayBufferSampler(torch.utils.data.sampler.Sampler): 23 | """A Sampler that uniforly samples batches of indices forever.""" 24 | def __init__(self, replay_buffer, batch_size): 25 | super(ReplayBufferSampler, self).__init__(replay_buffer) 26 | self.replay_buffer = replay_buffer 27 | self.batch_size = batch_size 28 | 29 | def __iter__(self): 30 | while True: 31 | rbuf_len = len(self.replay_buffer) 32 | num_samps = min(rbuf_len, self.batch_size) 33 | yield torch.LongTensor(num_samps).random_(rbuf_len) 34 | 35 | def __len__(self): 36 | return sys.maxsize 37 | 38 | 39 | def test_inf_rand_sampler(): 40 | """Tests the InfiniteRandomSampler.""" 41 | import itertools 42 | 43 | sampler = InfiniteRandomSampler(torch.randn(4)) 44 | inds = list(itertools.islice(iter(sampler), 8)) 45 | 46 | assert len(sampler) > 1e10 47 | assert len(inds) == 8 48 | 49 | 50 | def test_replay_buffer_sampler(): 51 | """Tests the ReplayBufferSampler.""" 52 | t = torch.randn(2) 53 | 54 | sampler_it = iter(ReplayBufferSampler(t, 4)) 55 | 56 | assert len(next(sampler_it)) == 2 57 | 58 | t.resize_(8) 59 | assert len(next(sampler_it)) == 4 60 | -------------------------------------------------------------------------------- /dataset/real.py: -------------------------------------------------------------------------------- 1 | """A Dataset that loads a natural language corpus.""" 2 | 3 | import os 4 | 5 | import torch 6 | import torch.utils.data 7 | 8 | from common import EXTRA_VOCAB, UNK, BOS, EOS 9 | import common 10 | 11 | 12 | class NLDataset(torch.utils.data.Dataset): 13 | """Loads the data.""" 14 | 15 | def __init__(self, data_dir, vocab_size, seqlen, part, **unused_kwargs): 16 | super(NLDataset, self).__init__() 17 | 18 | self.seqlen = seqlen 19 | self.part = part 20 | 21 | self.vocab = (common.unpickle(os.path.join(data_dir, 'vocab.pkl')) 22 | .add_extra_vocab(EXTRA_VOCAB) 23 | .truncate(vocab_size).set_unk_tok(UNK)) 24 | 25 | qs = common.unpickle(os.path.join(data_dir, part + '.pkl')) 26 | self.qtoks = [] 27 | for q in qs: 28 | qtoks = q.split(' ') 29 | if len(qtoks) >= self.seqlen: 30 | continue 31 | pct_unk = sum(qtok not in self.vocab.w2i 32 | for qtok in qtoks) / len(qtoks) 33 | if pct_unk > 0.1: 34 | continue 35 | self.qtoks.append(qtoks) 36 | 37 | def __getitem__(self, index): 38 | toks = self.qtoks[index] 39 | qtoks = torch.LongTensor(self.seqlen + 1).zero_() 40 | qtoks[0] = self.vocab[BOS] 41 | for i, tok in enumerate(toks, 1): 42 | qtoks[i] = self.vocab[tok] 43 | qtoks[len(toks)] = self.vocab[EOS] # replaces final punct with 44 | 45 | return qtoks, 1 46 | 47 | def __len__(self): 48 | return len(self.qtoks) 49 | 50 | def decode(self, toks_vec): 51 | """Turns a vector of token indices into a string.""" 52 | toks = [] 53 | for idx in toks_vec: 54 | toks.append(self.vocab[idx]) 55 | if idx == 0 or idx == self.vocab[EOS]: 56 | break 57 | return ' '.join(toks) 58 | 59 | 60 | def create(*args, **kwargs): 61 | """Returns a NLDataset.""" 62 | return NLDataset(*args, **kwargs) 63 | 64 | 65 | def test_dataset(): 66 | """Tests the NLDataset.""" 67 | 68 | # pylint: disable=unused-variable 69 | data_dir = os.path.join(os.path.dirname(__file__), '..', 'data', 'qa') 70 | part = 'test' 71 | vocab_size = 25000 72 | seqlen = 21 73 | debug = True 74 | 75 | ds = NLDataset(**locals()) 76 | rp = torch.randperm(len(ds)) 77 | toks, labels = ds[rp[0]] 78 | print(toks) 79 | print(ds.decode(toks)) 80 | 81 | for i in rp: 82 | toks, labels = ds[i] 83 | assert (toks >= 0).all() and (toks < vocab_size).all() 84 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training neural networks.""" 2 | 3 | import pickle 4 | from contextlib import contextmanager 5 | 6 | 7 | RUN_DIR = 'run' 8 | LOG_FILE = 'log.txt' 9 | OPTS_FILE = 'opts.pkl' 10 | STATE_FILE = 'state.pth' 11 | 12 | PHASES = ('hasher', 'g_ml', 'd_ml', 'adv') 13 | HASHER, G_ML, D_ML, ADV = PHASES 14 | LABEL_GEN, LABEL_REAL = 0, 1 15 | 16 | EXTRA_VOCAB = ['PAD', 'UNK', '', ''] 17 | PAD, UNK, BOS, EOS = EXTRA_VOCAB 18 | 19 | 20 | @contextmanager 21 | def rand_state(th, rand_state): 22 | """Pushes and pops a random state. 23 | th: torch or torch.cuda 24 | rand_state: an integer or tensor returned by `get_rng_state` 25 | """ 26 | orig_rand_state = th.get_rng_state() 27 | if isinstance(rand_state, int): 28 | th.manual_seed(rand_state) # this is a slow operation! 29 | rand_state = th.get_rng_state() 30 | th.set_rng_state(rand_state) 31 | yield rand_state 32 | th.set_rng_state(orig_rand_state) 33 | 34 | 35 | def unpickle(path_pkl): 36 | """Loads the contents of a pickle file.""" 37 | with open(path_pkl, 'rb') as f_pkl: 38 | return pickle.load(f_pkl) 39 | 40 | 41 | def load_txt(path_txt): 42 | """Loads a text file.""" 43 | with open(path_txt) as f_txt: 44 | return [line.rstrip() for line in f_txt] 45 | 46 | 47 | class Vocab(object): 48 | """Represents a token2index and index2token map.""" 49 | 50 | def __init__(self, tok_counts, unk_tok=None): 51 | """Constructs a Vocab ADT.""" 52 | self.tok_counts = tok_counts 53 | self.w2i = {w: i for i, (w, _) in enumerate(self.tok_counts)} 54 | 55 | self.unk_tok = unk_tok 56 | if unk_tok is not None: 57 | assert unk_tok in self.w2i 58 | self.unk_idx = self.w2i[unk_tok] 59 | 60 | def __getitem__(self, index): 61 | if isinstance(index, int): 62 | if index < len(self): 63 | return self.tok_counts[index][0] 64 | elif self.unk_tok: 65 | return self.unk_idx 66 | else: 67 | raise IndexError(f'No token in position {index}!') 68 | elif isinstance(index, str): 69 | if index in self.w2i: 70 | return self.w2i[index] 71 | elif self.unk_tok: 72 | return self.unk_idx 73 | else: 74 | raise KeyError(f'{index} not in vocab!') 75 | else: 76 | raise ValueError('Index to Vocab must be string or int.') 77 | 78 | def add_extra_vocab(self, extra_vocab): 79 | """Returns a new Vocab with extra tokens prepended.""" 80 | extra_tok_counts = [(w, float('inf')) for w in extra_vocab] 81 | return Vocab(extra_tok_counts + self.tok_counts, 82 | unk_tok=self.unk_tok) 83 | 84 | def set_unk_tok(self, unk_tok): 85 | """Sets the token/index to return when looking up an OOV token.""" 86 | return Vocab(self.tok_counts, unk_tok=unk_tok) 87 | 88 | def truncate(self, size): 89 | """Returns a new Vocab containing the top `size` tokens.""" 90 | return Vocab(self.tok_counts[:size], unk_tok=self.unk_tok) 91 | 92 | def __len__(self): 93 | return len(self.tok_counts) 94 | -------------------------------------------------------------------------------- /environ/real.py: -------------------------------------------------------------------------------- 1 | """An Environment for use with a natural language dataset.""" 2 | 3 | import os 4 | import logging 5 | 6 | import torch 7 | 8 | import common 9 | import dataset 10 | from .environment import Environment 11 | 12 | class NLEnvironment(Environment): 13 | """Functions for training a model on the NL dataset.""" 14 | 15 | _EVAL_METRIC = 'val' 16 | 17 | @classmethod 18 | def get_opt_parser(cls): 19 | """Returns an `ArgumentParser` that parses env-specific opts.""" 20 | parser = super(NLEnvironment, cls).get_opt_parser() 21 | parser.add_argument( 22 | '--data-dir', default='data/qa', type=os.path.abspath) 23 | parser.set_defaults( 24 | seqlen=22, 25 | vocab_size=20000, 26 | g_tok_emb_dim=32, 27 | d_tok_emb_dim=32, 28 | rnn_dim=64, 29 | num_gen_layers=2, 30 | pretrain_g_epochs=10, 31 | pretrain_d_epochs=10, 32 | train_hasher_epochs=15, 33 | adv_train_iters=750, 34 | code_len=11, 35 | dropout=0.25, 36 | batch_size=256, 37 | lr_g=0.001, 38 | lr_d=0.001, 39 | lr_hasher=0.002, 40 | hasher_ent_reg=0.3, 41 | log_freq=20, 42 | ) 43 | return parser 44 | 45 | def __init__(self, opts): 46 | """Creates a NLEnvironment.""" 47 | if opts.load_w2v: 48 | w2v = torch.from_numpy(torch.np.load(opts.load_w2v)) 49 | w2v[0] = 0 50 | opts.g_tok_emb_dim = opts.d_tok_emb_dim = w2v.shape[1] 51 | 52 | super(NLEnvironment, self).__init__(opts) 53 | 54 | if opts.load_w2v: 55 | def _grad_mask(grad): 56 | masked_grad = grad.clone() 57 | masked_grad[len(common.EXTRA_VOCAB):] = 0 58 | return masked_grad 59 | 60 | def _set_w2v(emb): 61 | tok_embs = emb.weight 62 | tok_embs.data.copy_(w2v) 63 | tok_embs.register_hook(_grad_mask) 64 | 65 | for net in (self.g, self.d): 66 | _set_w2v(net.tok_emb) 67 | if opts.exploration_bonus: 68 | _set_w2v(self.hasher.tok_emb) 69 | 70 | self.train_dataset = dataset.NLDataset(part='train', **vars(opts)) 71 | self.test_dataset = dataset.NLDataset(part='val', **vars(opts)) 72 | 73 | self.ro_init_toks.data.fill_(self.train_dataset.vocab[common.BOS]) 74 | self.opts.padding_idx = self.train_dataset.vocab[common.PAD] 75 | self.opts.eos_idx = self.train_dataset.vocab[common.EOS] 76 | 77 | def _compute_eval_metric(self): 78 | test_loader = self._create_dataloader(self.test_dataset) 79 | val_loss = sum(self._forward_g_ml(batch, volatile=True)[0].data[0] 80 | for batch in test_loader) / len(test_loader) 81 | 82 | init_toks_volatile = self.init_toks.volatile 83 | self.init_toks.volatile = True 84 | gen_toks, _ = self.g.rollout(self.init_toks[:5], self.opts.seqlen) 85 | for tok_vec in torch.cat(gen_toks, -1).data: 86 | logging.debug(self.test_dataset.decode(tok_vec)) 87 | logging.debug('\n---') 88 | self.init_toks.volatile = init_toks_volatile 89 | 90 | return val_loss 91 | -------------------------------------------------------------------------------- /scripts/format_embs_projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import random 5 | import sys 6 | 7 | import torch 8 | import numpy as np 9 | 10 | 11 | PROJ_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 12 | DATA_DIR = os.path.join(PROJ_ROOT, 'data') 13 | RUN_DIR = os.path.join(PROJ_ROOT, 'run') 14 | 15 | sys.path.insert(0, PROJ_ROOT) 16 | import environ 17 | import common 18 | 19 | 20 | def _get_metadata(embs_path): 21 | embs_name = os.path.splitext(os.path.basename(embs_path))[0] 22 | embs_name_parts = embs_name.split('_') 23 | questions_name = '_'.join(embs_name_parts[1:-1]) 24 | part = embs_name_parts[-1] 25 | questions = common.unpickle( 26 | os.path.join(DATA_DIR, questions_name, f'{part}.pkl')) 27 | return { 28 | 'questions': questions, 29 | 'dataset': [questions_name] * len(questions), 30 | } 31 | 32 | 33 | def main(): 34 | # -------------------------------------------------------------------------- 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--phase', default='g_ml') 37 | parser.add_argument('--subsample', default=1, type=int) 38 | parser.add_argument('--out-dir', default='wembs_projector') 39 | args = parser.parse_args() 40 | # -------------------------------------------------------------------------- 41 | 42 | opts = argparse.Namespace(**common.unpickle(os.path.join(RUN_DIR, 'opts.pkl'))) 43 | env = environ.create(opts.env, opts) 44 | env.state = torch.load(os.path.join(RUN_DIR, 'g_ml', 'state.pth')) 45 | 46 | wembs = env.g.tok_emb.weight.data.cpu().numpy() 47 | vocab = [tok for tok, _ in env.train_dataset.vocab.tok_counts] 48 | # wembs = np.load('../data/tok_vecs_pruned.npy') 49 | # vocab = common.EXTRA_VOCAB + [t for t, _ in common.unpickle('../data/qa/vocab.pkl').tok_counts] 50 | # vocab = vocab[:len(wembs)] 51 | # print(vocab) 52 | 53 | if args.subsample: 54 | wembs = wembs[::args.subsample] 55 | vocab = vocab[::args.subsample] 56 | 57 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 58 | 59 | import tensorflow as tf 60 | from tensorflow.contrib.tensorboard.plugins import projector 61 | 62 | out_dir = os.path.join(RUN_DIR, args.out_dir) 63 | if not os.path.isdir(out_dir): 64 | os.makedirs(out_dir) 65 | 66 | metadata_path = os.path.join(out_dir, 'metadata.tsv') 67 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.0) 68 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess, tf.device("/cpu:0"): 69 | with open(metadata_path, 'w') as f_md: 70 | for tok in vocab: 71 | print(tok, file=f_md) 72 | 73 | embs_var = tf.Variable(wembs, trainable=False, name='tok_embs') 74 | sess.run(tf.global_variables_initializer()) 75 | print(sess.run(embs_var[0]) - wembs[0]) 76 | 77 | saver = tf.train.Saver() 78 | saver.save(sess, os.path.join(out_dir, 'model.ckpt'), global_step=42) 79 | 80 | projector_config = projector.ProjectorConfig() 81 | projector_config.embeddings.add(tensor_name=embs_var.name, 82 | metadata_path=metadata_path) 83 | 84 | summary_writer = tf.summary.FileWriter(out_dir, sess.graph) 85 | projector.visualize_embeddings(summary_writer, projector_config) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /scripts/prep_wiki_dataset.py: -------------------------------------------------------------------------------- 1 | """Formats a Wikipedia dump dataset.""" 2 | 3 | from collections import Counter 4 | import argparse 5 | import gzip 6 | import json 7 | import os 8 | import pickle 9 | import random 10 | import sys 11 | 12 | from tqdm import tqdm 13 | import spacy 14 | 15 | PROJ_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 16 | sys.path.append(PROJ_ROOT) 17 | 18 | import common 19 | 20 | DATA_DIR = os.path.join(PROJ_ROOT, 'data') 21 | CACHE_DIR = os.path.join(DATA_DIR, '_cache', 'wiki') 22 | OUT_DIR = os.path.join(DATA_DIR, 'wiki') 23 | WIKI_PATH = os.path.join(CACHE_DIR, 'simplewiki-20171103.json.gz') 24 | 25 | 26 | def _unpickle(pkl_path): 27 | with open(pkl_path, 'rb') as f: 28 | return pickle.load(f) 29 | 30 | 31 | def _pickle(data, pkl_path): 32 | with open(pkl_path, 'wb') as f: 33 | pickle.dump(data, f) 34 | 35 | 36 | def _load(): 37 | paragraphs = [] 38 | with gzip.open(WIKI_PATH, 'rt') as f_wiki: 39 | for line in f_wiki: 40 | text = json.loads(line)['text'] 41 | paragraphs.extend(filter(bool, text.split('\n'))) 42 | return paragraphs 43 | 44 | 45 | def _tokenize(paragraphs): 46 | nlp = spacy.load('en') 47 | tok_sents = [] 48 | for i, para in enumerate(nlp.pipe(tqdm(paragraphs, desc='tokenize', leave=False))): 49 | for sent in para.sents: 50 | if not sent[-1].is_punct: 51 | continue 52 | tok_sents.append(tuple(tok.text for tok in sent if tok)) 53 | return tok_sents 54 | 55 | 56 | def _concatenate(tok_sents): 57 | return tuple(map(' '.join, tok_sents)) 58 | 59 | 60 | def _run_pipeline(pipeline, cache_dir): 61 | for i, stage in enumerate(pipeline): 62 | cache_path = os.path.join(cache_dir, stage.__name__[1:]) + '.pkl' 63 | if os.path.isfile(cache_path): 64 | output = _unpickle(cache_path) 65 | else: 66 | output = stage(output) if i > 0 else stage() 67 | _pickle(output, cache_path) 68 | return output 69 | 70 | 71 | def main(): 72 | """Runs a pipeline that formats the Wikipedia dump dataset.""" 73 | # -------------------------------------------------------------------------- 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--train-frac', type=float, default=0.8) 76 | parser.add_argument('--cased', action='store_true') 77 | parser.add_argument('--seed', type=int, default=42) 78 | args = parser.parse_args() 79 | # -------------------------------------------------------------------------- 80 | 81 | for d in (CACHE_DIR, OUT_DIR): 82 | if not os.path.isdir(d): 83 | os.makedirs(d) 84 | 85 | sents = _run_pipeline([_load, _tokenize, _concatenate], CACHE_DIR) 86 | 87 | random.seed(args.seed) 88 | if not args.cased: 89 | sents = list(map(str.lower, sents)) 90 | random.shuffle(sents) 91 | 92 | n_val = n_test = int(len(sents) * (1 - args.train_frac) / 2) 93 | n_train = len(sents) - n_val - n_test 94 | part_bounds = [None, n_train, n_train + n_val, None] 95 | 96 | for i, part in enumerate(('train', 'val', 'test')): 97 | part_sents = sents[slice(*part_bounds[i:i+2])] 98 | 99 | if part == 'train': 100 | vocab = Counter(tok for q in part_sents for tok in q.split(' ')) 101 | _pickle(common.Vocab(vocab.most_common()), 102 | os.path.join(OUT_DIR, 'vocab.pkl')) 103 | 104 | _pickle(part_sents, os.path.join(OUT_DIR, part + '.pkl')) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /model/hash_counter.py: -------------------------------------------------------------------------------- 1 | """A Module that uses locality sensitve hashing to count visited states.""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class HashCounter(nn.Module): 9 | """Accumulates counts of items that hash to a particular value.""" 10 | 11 | def __init__(self, hash_fn, num_hash_buckets, **unused_kwargs): 12 | super(HashCounter, self).__init__() 13 | 14 | self.code_len = int(torch.np.ceil(torch.np.log2(num_hash_buckets))) 15 | self.hash_fn = hash_fn 16 | 17 | self.register_buffer('counts', torch.zeros(num_hash_buckets)) 18 | 19 | self.register_buffer('_ones', torch.ones(1)) 20 | self.register_buffer('_powers_of_two', 21 | 2**torch.arange(self.code_len-1, -1, -1)) 22 | 23 | def forward(self, items, accumulator='counts', **unused_kwargs): 24 | """ Accumulates hashed item counts. 25 | 26 | items: N*(size of single item accepted by hash_fn) 27 | accumulator: the name of an accumulator 28 | 29 | Returns: A LongTensor of hash_bucket indices: N 30 | """ 31 | ones = self._ones.expand(items.size(0)) 32 | acc = self._buffers.get(accumulator, torch.zeros_like(self.counts)) 33 | if accumulator not in self._buffers: 34 | self.register_buffer(accumulator, acc) 35 | 36 | hash_codes = self.hash_fn(items).data # N*code_len 37 | hash_buckets = (hash_codes @ self._powers_of_two).long() # N 38 | acc.put_(hash_buckets, ones, accumulate=True) 39 | 40 | return hash_buckets 41 | 42 | 43 | def create(hash_fn, **opts): 44 | """Creates a token generator.""" 45 | return HashCounter(hash_fn, **opts) 46 | 47 | 48 | def test_simhash_table(): 49 | """Tests the HashCounter.""" 50 | # pylint: disable=too-many-locals,unused-variable 51 | 52 | num_hash_buckets = 4 53 | debug = True 54 | 55 | class HashFn(object): 56 | """A mock hash function. Big-endian.""" 57 | codes = None 58 | buckets = None 59 | 60 | @staticmethod 61 | def _i2b(i): 62 | bitwidth = int(torch.np.log2(num_hash_buckets)) 63 | bin_rep = list(map(int, bin(i)[2:])) 64 | return [0]*(bitwidth - len(bin_rep)) + bin_rep 65 | 66 | def set_codes(self, bin_counts): 67 | """Sets the big-endian binary codes that the HashFn will return.""" 68 | codes = [] 69 | buckets = [] 70 | for i, count in enumerate(bin_counts): 71 | codes.extend([self._i2b(i)]*count) 72 | buckets.extend([i]*count) 73 | rp = torch.randperm(len(codes)) 74 | self.codes = torch.FloatTensor(codes)[rp] 75 | self.buckets = torch.LongTensor(buckets)[rp] 76 | 77 | def __call__(self, _): 78 | return Variable(self.codes) 79 | 80 | hash_fn = HashFn() 81 | simhash_table = HashCounter(**locals()) 82 | 83 | expected_counts_train = [1, 2, 0, 4] 84 | hash_fn.set_codes(expected_counts_train) 85 | toks = Variable(torch.LongTensor(sum(expected_counts_train), 4)) 86 | 87 | assert (simhash_table(toks, 'counts2') == hash_fn.buckets).all() 88 | assert (simhash_table.counts2.numpy() == expected_counts_train).all() 89 | assert (simhash_table.counts == 0).all() 90 | 91 | expected_counts_test = [4, 3, 2, 1] 92 | hash_fn.set_codes(expected_counts_test) 93 | toks = Variable(torch.LongTensor(sum(expected_counts_test), 4)) 94 | 95 | assert (simhash_table(toks) == hash_fn.buckets).all() 96 | assert (simhash_table.counts2.numpy() == expected_counts_train).all() 97 | assert (simhash_table.counts.numpy() == expected_counts_test).all() 98 | -------------------------------------------------------------------------------- /scripts/prep_qa_dataset.py: -------------------------------------------------------------------------------- 1 | """Formats the Yahoo Answers dataset.""" 2 | 3 | from collections import Counter 4 | import argparse 5 | import os 6 | import random 7 | import re 8 | import sys 9 | import pickle 10 | 11 | from tqdm import tqdm 12 | from lxml import etree 13 | import spacy 14 | 15 | PROJ_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 16 | sys.path.append(PROJ_ROOT) 17 | 18 | import common 19 | 20 | DATA_DIR = os.path.join(PROJ_ROOT, 'data') 21 | CACHE_DIR = os.path.join(DATA_DIR, '_cache', 'qa') 22 | OUT_DIR = os.path.join(DATA_DIR, 'qa') 23 | QS_PATH = os.path.join(DATA_DIR, 'FullOct2007.xml.part{}.gz') 24 | 25 | 26 | QUESTION_WORDS = { 27 | 'any', 'are', 'can', 'could', 'did', 'do', 'does', 'has', 'have', 'how', 28 | 'is', 'must', 'should', 'was', 'what', 'when', 'where', 'which', 'who', 29 | 'whom', 'whose', 'why', 'will', 'would'} 30 | 31 | 32 | def _unpickle(pkl_path): 33 | with open(pkl_path, 'rb') as f: 34 | return pickle.load(f) 35 | 36 | 37 | def _pickle(data, pkl_path): 38 | with open(pkl_path, 'wb') as f: 39 | pickle.dump(data, f) 40 | 41 | 42 | def _load(): 43 | qs = [] 44 | for part in range(1, 3): 45 | try: 46 | data_path = QS_PATH.format(part) 47 | for _, subj in etree.iterparse(data_path, tag='subject'): 48 | qs.append(subj.text) 49 | except etree.XMLSyntaxError: 50 | pass 51 | return qs 52 | 53 | 54 | def _tokenize(qs): 55 | nlp = spacy.load('en', disable=['ner']) 56 | qtoks = [] 57 | for q in nlp.pipe(tqdm(qs, desc='tokenize', leave=False)): 58 | for sent in q.sents: 59 | sent_toks = [tok.text for tok in sent if not tok.is_space] 60 | if sent_toks: 61 | qtoks.append(sent_toks) 62 | return qtoks 63 | 64 | 65 | def _filter(qs): 66 | good_qs = [] 67 | for q in qs: 68 | if not q[0] in QUESTION_WORDS or q[-1] != '?': 69 | continue 70 | qcat = ' '.join(q) 71 | qcat = re.sub('( [?!.]+)+ \?$', ' ?', qcat) 72 | good_qs.append(qcat) 73 | return good_qs 74 | 75 | 76 | def _run_pipeline(pipeline, cache_dir): 77 | for i, stage in enumerate(pipeline): 78 | cache_path = os.path.join(cache_dir, stage.__name__[1:]) + '.pkl' 79 | if os.path.isfile(cache_path): 80 | output = _unpickle(cache_path) 81 | else: 82 | output = stage(output) if i > 0 else stage() 83 | _pickle(output, cache_path) 84 | return output 85 | 86 | 87 | def main(): 88 | """Runs a pipeline that formats the Yahoo Answers dataset.""" 89 | # -------------------------------------------------------------------------- 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--train-frac', type=float, default=0.8) 92 | parser.add_argument('--cased', action='store_true') 93 | parser.add_argument('--seed', type=int, default=42) 94 | args = parser.parse_args() 95 | # -------------------------------------------------------------------------- 96 | 97 | for d in (CACHE_DIR, OUT_DIR): 98 | if not os.path.isdir(d): 99 | os.makedirs(d) 100 | 101 | qs = _run_pipeline([_load, _tokenize, _filter], CACHE_DIR) 102 | 103 | random.seed(args.seed) 104 | if not args.cased: 105 | qs = [q.lower() for q in qs] 106 | random.shuffle(qs) 107 | 108 | n_val = n_test = int(len(qs) * (1 - args.train_frac) / 2) 109 | n_train = len(qs) - n_val - n_test 110 | part_bounds = [None, n_train, n_train + n_val, None] 111 | 112 | for i, part in enumerate(('train', 'val', 'test')): 113 | part_qs = qs[slice(*part_bounds[i:i+2])] 114 | 115 | if part == 'train': 116 | vocab = Counter(tok for q in part_qs for tok in q.split(' ')) 117 | _pickle(common.Vocab(vocab.most_common()), 118 | os.path.join(OUT_DIR, 'vocab.pkl')) 119 | 120 | _pickle(part_qs, os.path.join(OUT_DIR, part + '.pkl')) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Main training script.""" 2 | 3 | from contextlib import contextmanager 4 | import argparse 5 | import os 6 | import pickle 7 | import logging 8 | 9 | import torch 10 | 11 | import common 12 | from common import PHASES, HASHER, G_ML, D_ML, ADV 13 | from common import RUN_DIR, STATE_FILE, OPTS_FILE, LOG_FILE 14 | import environ 15 | 16 | 17 | def main(): 18 | """Trains the model.""" 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--env', choices=environ.ENVS, default=environ.SYNTH) 22 | parser.add_argument('--resume', action='store_true') 23 | parser.add_argument('--seed', default=42, type=int) 24 | parser.add_argument('--prefix') 25 | parser.add_argument('--rerun', nargs='+', default=[], choices=PHASES) 26 | init_opts, remaining_opts = parser.parse_known_args() 27 | 28 | opts_file = os.path.join(RUN_DIR, OPTS_FILE) 29 | if init_opts.resume: 30 | new_opts = environ.parse_env_opts( 31 | init_opts, remaining_opts, no_defaults=True) 32 | opts = argparse.Namespace(**common.unpickle(opts_file)) 33 | for k, v in vars(new_opts).items(): 34 | if k not in opts or v is not None: 35 | setattr(opts, k, v) 36 | else: 37 | opts = environ.parse_env_opts(init_opts, remaining_opts) 38 | os.mkdir(RUN_DIR) 39 | with open(opts_file, 'wb') as f_opts: 40 | pickle.dump(vars(opts), f_opts) 41 | 42 | logging.basicConfig(format='%(message)s', level=logging.DEBUG, filemode='w') 43 | 44 | torch.manual_seed(opts.seed) 45 | torch.cuda.manual_seed_all(opts.seed) 46 | 47 | env = environ.create(opts.env, opts) 48 | 49 | for phase in PHASES: 50 | if phase == HASHER and not opts.exploration_bonus: 51 | continue 52 | torch.manual_seed(opts.seed) 53 | torch.cuda.manual_seed_all(opts.seed) 54 | with _phase(env, phase, opts) as phase_runner: 55 | if phase_runner: 56 | logging.debug(f'# running phase: {phase}') 57 | phase_runner() # pylint: disable=not-callable 58 | 59 | 60 | @contextmanager 61 | def _phase(env, phase, opts): 62 | phase_dir = os.path.join(RUN_DIR, phase) 63 | if not os.path.isdir(phase_dir): 64 | os.mkdir(phase_dir) 65 | 66 | prefixes = [opts.prefix]*bool(opts.prefix) 67 | def _prefix(suffixes): 68 | suffixes = suffixes if isinstance(suffixes, list) else [suffixes] 69 | return '_'.join(prefixes + suffixes) 70 | 71 | snap_file = os.path.join(phase_dir, STATE_FILE) 72 | prefix_snap_file = os.path.join(phase_dir, _prefix(STATE_FILE)) 73 | if os.path.isfile(prefix_snap_file): 74 | snap_file = prefix_snap_file 75 | 76 | if os.path.isfile(snap_file) and phase not in opts.rerun: 77 | env.state = torch.load(snap_file) 78 | yield None 79 | return 80 | 81 | if phase == HASHER: 82 | # import functools 83 | # def _saver(env, epoch): 84 | # torch.save(env.hasher.state_dict(), 85 | # os.path.join(phase_dir, f'{epoch}.pth')) 86 | # runner = functools.partial(env.train_hasher, hook=_saver) 87 | runner = env.train_hasher 88 | elif phase == G_ML: 89 | runner = env.pretrain_g 90 | elif phase == D_ML: 91 | runner = env.pretrain_d 92 | elif phase == ADV: 93 | runner = env.train_adv 94 | 95 | logger = logging.getLogger() 96 | def _add_file_handler(lvl, log_prefix=None): 97 | suffixes = [log_prefix]*bool(log_prefix) + [LOG_FILE] 98 | log_path = os.path.join(phase_dir, _prefix(suffixes)) 99 | handler = logging.FileHandler(log_path, mode='w') 100 | handler.setLevel(lvl) 101 | logger.addHandler(handler) 102 | return handler 103 | 104 | file_handlers = [ 105 | _add_file_handler(logging.INFO), 106 | _add_file_handler(logging.DEBUG, 'debug'), 107 | ] 108 | 109 | yield runner 110 | 111 | torch.save(env.state, prefix_snap_file) 112 | for handler in file_handlers: 113 | logger.removeHandler(handler) 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /environ/synth.py: -------------------------------------------------------------------------------- 1 | """A class for training a SeqGAN model on synthetic data.""" 2 | 3 | import logging 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as nnf 8 | 9 | import common 10 | from common import LABEL_REAL 11 | import model 12 | 13 | from .environment import Environment 14 | 15 | class SynthEnvironment(Environment): 16 | """Functions for training a model on a synthetic dataset.""" 17 | 18 | _EVAL_METRIC = 'nll' 19 | 20 | @classmethod 21 | def get_opt_parser(cls): 22 | """Returns an `ArgumentParser` that parses env-specific opts.""" 23 | parser = super(SynthEnvironment, cls).get_opt_parser() 24 | parser.add_argument( 25 | '--oracle-type', default=model.generator.RNN, 26 | choices=model.generator.TYPES) 27 | parser.add_argument('--oracle-dim', default=128, type=int) 28 | parser.add_argument('--num-gen-samps', default=100000, type=int) 29 | parser.set_defaults( 30 | seqlen=20, 31 | vocab_size=5000, 32 | g_tok_emb_dim=32, 33 | d_tok_emb_dim=32, 34 | pretrain_g_epochs=50, # try 20 when using pretrained w2v 35 | pretrain_d_epochs=10, 36 | train_hasher_epochs=25, 37 | adv_train_iters=750, 38 | rnn_dim=32, 39 | code_len=6, 40 | dropout=0.25, 41 | num_gen_layers=1, 42 | batch_size=64, 43 | lr_g=0.01, 44 | lr_d=0.001, 45 | lr_hasher=0.002, 46 | ) 47 | return parser 48 | 49 | def __init__(self, opts): 50 | """Creates a SynthEnvironment.""" 51 | super(SynthEnvironment, self).__init__(opts) 52 | 53 | self.ro_init_toks.data.zero_() 54 | self.opts.padding_idx = self.opts.eos_idx = -1 55 | 56 | self.oracle = self._create_oracle().cuda() 57 | oracle_checksum = sum(p.data.sum() for p in self.oracle.parameters()) 58 | logging.debug(f'#oracle: {oracle_checksum:.3f}') 59 | 60 | self.train_dataset = self._create_gen_dataset( 61 | self.oracle, LABEL_REAL, num_samples=opts.num_gen_samps) 62 | self.test_dataset = self._create_gen_dataset( 63 | self.oracle, LABEL_REAL, 64 | num_samples=len(self.ro_init_toks)*5, seed=-1) 65 | 66 | if self.opts.load_w2v: 67 | oracle_w2v = model.utils.Apply(self.oracle.tok_emb, detach=True) 68 | for net in (self.g, self.d): 69 | net.tok_emb = oracle_w2v 70 | if opts.exploration_bonus: 71 | self.hasher.encoder.tok_emb = oracle_w2v 72 | 73 | def _create_oracle(self): 74 | """Returns a randomly initialized generator.""" 75 | with common.rand_state(torch, self.opts.seed): 76 | opt_vars = vars(self.opts) 77 | opt_vars.pop('rnn_dim') 78 | oracle = model.generator.create( 79 | gen_type=self.opts.oracle_type, 80 | rnn_dim=self.opts.oracle_dim, 81 | **opt_vars) 82 | for param in oracle.parameters(): 83 | nn.init.normal(param, std=1) 84 | return oracle 85 | 86 | def _compute_eval_metric(self, num_samples=256): 87 | test_nll = 0 88 | num_test_batches = max(num_samples // len(self.init_toks), 1) 89 | with common.rand_state(torch.cuda, -1): 90 | for _ in range(num_test_batches): 91 | gen_seqs, _ = self.g.rollout(self.init_toks, self.opts.seqlen) 92 | test_nll += self.compute_oracle_nll(gen_seqs) 93 | test_nll /= num_test_batches 94 | return test_nll 95 | 96 | def compute_oracle_nll(self, toks, return_probs=False): 97 | """ 98 | toks: [N]*T 99 | """ 100 | toks = torch.cat([self.init_toks] + toks).view(len(toks)+1, -1) 101 | log_probs = self.oracle(toks.t())[0][:-1] # T*N*V 102 | flat_log_probs = log_probs.view(-1, log_probs.size(-1)) # (T*N)*V 103 | nll = nnf.nll_loss(flat_log_probs, toks[1:].view(-1)).data[0] 104 | if return_probs: 105 | return nll, log_probs 106 | return nll 107 | -------------------------------------------------------------------------------- /notebooks/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for analyzing results in notebooks.""" 2 | 3 | from collections import defaultdict 4 | import re 5 | 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | FG = r'(\d+\.\d+)' # Float Group 13 | LOG_RES = { 14 | 'loss': { 15 | 'expr': rf'loss: train={FG} test={FG}', 16 | 'groups': ['train', 'test'], 17 | }, 18 | 'iter': r'\[(\d+)]', 19 | 'nll': rf'nll: {FG}', 20 | 'acc': { 21 | 'expr': rf'acc: o={FG} g={FG}', 22 | 'groups': ['o', 'g'], 23 | }, 24 | 'gnorm': { 25 | 'expr': rf'gnorm: g={FG} d={FG}', 26 | 'groups': ['g', 'd'], 27 | }, 28 | } 29 | 30 | 31 | def rolling_window_lastaxis(a, window): 32 | """Directly taken from Erik Rigtorp's post to numpy-discussion. 33 | """ 34 | if window < 1: 35 | raise ValueError("`window` must be at least 1.") 36 | if window > a.shape[-1]: 37 | raise ValueError("`window` is too long.") 38 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) 39 | strides = a.strides + (a.strides[-1],) 40 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 41 | 42 | 43 | def ngram_count(seqs, n): 44 | """Returns a dict of {ngram: counts} for a numpy array of tokens.""" 45 | rw = rolling_window_lastaxis(seqs, n).reshape(-1, n) 46 | ngrams, counts = np.unique(rw, axis=0, return_counts=True) 47 | counts = counts / counts.sum() 48 | return {tuple(ngram): count for ngram, count in zip(ngrams, counts)} 49 | 50 | 51 | def sample_gen(env, num_samps=10000, gen='g', temperature=1, 52 | return_probs=False): 53 | """Samples from a generative model.""" 54 | samps = [] 55 | probs = [] 56 | gen = getattr(env, gen) 57 | num_batches = (num_samps + len(env.ro_init_toks)) // len(env.ro_init_toks) 58 | for _ in range(num_batches): 59 | ro, ro_probs = gen.rollout(env.ro_init_toks, 20, 60 | temperature=temperature) 61 | samps.append(torch.cat(ro, -1).data.cpu()) 62 | if return_probs: 63 | probs.append(torch.stack(ro_probs).data.cpu()) 64 | 65 | samps = torch.cat(samps, 0).numpy() 66 | if return_probs: 67 | return samps, torch.cat(probs, 0).numpy() 68 | return samps 69 | 70 | 71 | def load_log(log_path): 72 | """Loads a log file.""" 73 | 74 | if not LOG_RES.get('_compiled'): 75 | for stat, spec in LOG_RES.items(): 76 | if isinstance(spec, dict): 77 | spec['expr'] = re.compile(spec['expr']) 78 | else: 79 | LOG_RES[stat] = re.compile(spec) 80 | LOG_RES['_compiled'] = True 81 | 82 | cols = defaultdict(list) 83 | with open(log_path) as f_log: 84 | for line in f_log: 85 | if line.startswith('#'): 86 | continue 87 | for stat, spec in sorted(LOG_RES.items()): 88 | if stat.startswith('_'): 89 | continue 90 | 91 | if isinstance(spec, dict): 92 | matcher = spec['expr'] 93 | colnames = [f'{stat}_{substat}' 94 | for substat in spec['groups']] 95 | else: 96 | matcher = spec 97 | colnames = [stat] 98 | 99 | match = matcher.search(line) 100 | if not match: 101 | continue 102 | 103 | for colname, val in zip(colnames, match.groups()): 104 | cols[colname].append(val) 105 | log_df = pd.DataFrame.from_dict(cols) 106 | log_df = log_df.astype(float) 107 | log_df.index = log_df.iter.astype(int) 108 | return log_df 109 | 110 | 111 | def do_plot(get_data, logs, filt=None, baseline=None): 112 | """Plots data from several logs. 113 | 114 | Args: 115 | get_data: a function (log_name, log_data) -> plot_data 116 | """ 117 | for exp_name, log in logs.items(): 118 | if filt and (exp_name != baseline and not filt in exp_name): 119 | continue 120 | get_data(exp_name, log).plot(label=exp_name) 121 | plt.legend() 122 | 123 | 124 | def plot_ts(col, logs, *args, **kwargs): 125 | """Plots a column from logs as a time series.""" 126 | do_plot(lambda name, log: log[col], logs, *args, **kwargs) 127 | -------------------------------------------------------------------------------- /dataset/gen.py: -------------------------------------------------------------------------------- 1 | """A Dataset that loads the output of a generative model.""" 2 | import torch 3 | import torch.utils.data 4 | from torch.autograd import Variable 5 | 6 | import common 7 | 8 | 9 | class GenDataset(torch.utils.data.Dataset): 10 | """Loads data from a generative model.""" 11 | 12 | def __init__(self, generator, label, seqlen, num_samples, 13 | gen_init_toks, seed, eos_idx=None, **unused_kwargs): 14 | super(GenDataset, self).__init__() 15 | 16 | self.label = label 17 | 18 | th = torch.cuda if gen_init_toks.is_cuda else torch 19 | with common.rand_state(th, seed): 20 | init_toks = gen_init_toks.data.cpu() 21 | batch_size = gen_init_toks.size(0) 22 | num_batches = (num_samples + batch_size - 1) // batch_size 23 | samples = [] 24 | for _ in range(num_batches): 25 | gen_seqs, _ = generator.rollout(gen_init_toks, seqlen) 26 | samps = torch.cat([gen_init_toks] + gen_seqs, -1).data 27 | if eos_idx: 28 | self.mask_gen_seqs_(samps, eos_idx) 29 | samples.append(samps.cpu()) 30 | self.samples = torch.cat(samples) 31 | 32 | @staticmethod 33 | def mask_gen_seqs_(seqs, eos_idx): 34 | """ 35 | Zeroes out all entries after the first occurrence of EOS. 36 | Operates in-place on seqs. 37 | 38 | seqs: N*(T+1); must include init toks 39 | eos_idx: the number assigned to the end-of-sentence token 40 | """ 41 | # 1. create a mask of ones up until the first eos token 42 | mask = (seqs != eos_idx).cumprod(-1) 43 | # 2. create a mask for the first the eos token 44 | # this method requires that the init tok exists 45 | eos_pos = (seqs == eos_idx)[:, 1:] * mask[:, :-1] 46 | # 3. zero out everything including+after the first eos tok 47 | seqs.masked_fill_(1 - mask, 0) 48 | # 4. put back the eos tok 49 | seqs[:, 1:].masked_fill_(eos_pos, eos_idx) 50 | 51 | 52 | def __getitem__(self, index): 53 | label = self.label 54 | if not isinstance(index, int): 55 | label = torch.LongTensor(len(index)).fill_(self.label) 56 | return self.samples[index], label 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | 61 | 62 | def test_dataset(): 63 | """Tests the Dataset.""" 64 | import model 65 | 66 | # pylint: disable=unused-variable 67 | vocab_size = 50 68 | batch_size = 32 69 | label = 0 70 | num_samples = 1000 71 | seqlen = 21 72 | seed = 42 73 | 74 | generator = model.generator.RNNGenerator( 75 | vocab_size=50, tok_emb_dim=32, rnn_dim=16, num_layers=1) 76 | gen_init_toks = Variable(torch.LongTensor(batch_size, 1).fill_(1)) 77 | 78 | ds = GenDataset(**locals()) 79 | toks, labels = ds[0] 80 | print(toks) 81 | print(labels) 82 | 83 | assert len(ds) == torch.np.ceil(num_samples / batch_size) * batch_size 84 | 85 | for i in torch.randperm(len(ds)): 86 | toks, labels = ds[i] 87 | assert (toks >= 0).all() and (toks < vocab_size).all() 88 | 89 | batch_toks, batch_labels = ds[torch.randperm(batch_size)] 90 | assert len(batch_toks) == batch_size 91 | assert len(batch_labels) == batch_size 92 | 93 | def test_dataset_mask(): 94 | """Tests the Dataset.""" 95 | import model 96 | 97 | # pylint: disable=unused-variable 98 | vocab_size = 50 99 | batch_size = 5 100 | label = 0 101 | num_samples = batch_size 102 | seqlen = 21 103 | seed = 42 104 | eos_idx = 2 105 | 106 | gen_samps = torch.LongTensor([ 107 | [1, 1, 2, 1, 2, 1], 108 | [1, 1, 1, 2, 1, 1], 109 | [2, 1, 1, 2, 1, 1], 110 | [1, 1, 1, 1, 1, 1], 111 | [1, 1, 1, 1, 1, 2], 112 | ]) 113 | 114 | expected_samps = torch.LongTensor([ 115 | [-1, 1, 1, 2, 0, 0, 0], 116 | [-1, 1, 1, 1, 2, 0, 0], 117 | [-1, 2, 0, 0, 0, 0, 0], 118 | [-1, 1, 1, 1, 1, 1, 1], 119 | [-1, 1, 1, 1, 1, 1, 2], 120 | ]) 121 | 122 | class MockGenerator(object): 123 | @staticmethod 124 | def rollout(*args, **kwargs): 125 | return list(Variable(gen_samps).split(1, dim=1)), None 126 | 127 | generator = MockGenerator() 128 | gen_init_toks = Variable(torch.LongTensor(batch_size, 1).fill_(-1)) 129 | 130 | ds = GenDataset(**locals()) 131 | assert (ds.samples == expected_samps).all() 132 | -------------------------------------------------------------------------------- /model/hasher.py: -------------------------------------------------------------------------------- 1 | """A Hasher for locality sensitive hashing.""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as nnf 6 | from torch.autograd import Variable 7 | 8 | import environ 9 | from .bottles import BottledLinear, bottle 10 | from .utils import Apply 11 | 12 | 13 | TYPES = ('ae',) 14 | AE, = TYPES 15 | 16 | 17 | class _RNNEncoder(nn.Module): 18 | """Encodes tokens to binary codes.""" 19 | 20 | def __init__(self, code_len, tok_emb_dim, rnn_dim, num_layers, 21 | **unused_kwargs): 22 | super(_RNNEncoder, self).__init__() 23 | 24 | self.code_len = code_len 25 | 26 | self.enc = nn.LSTM(tok_emb_dim, rnn_dim, num_layers=num_layers, 27 | bidirectional=True) 28 | self.precoder = nn.Linear(rnn_dim*2, code_len*2) 29 | 30 | def forward(self, tok_embs): 31 | """ 32 | tok_embs: T*N*tok_emb_dim 33 | """ 34 | seq_embs, _ = self.enc(tok_embs) 35 | code_embs = self.precoder(seq_embs[-1]).view(-1, self.code_len, 2) 36 | logits = nnf.log_softmax(code_embs, dim=2) 37 | 38 | codes = bottle(nnf.gumbel_softmax, logits, hard=True)[:, :, 0] 39 | # codes = precodes.max(-1)[1].float() # doesn't train very well! 40 | return codes.contiguous(), logits 41 | 42 | 43 | class _RNNDecoder(nn.Module): 44 | """Decodes codes to token log-probs.""" 45 | 46 | def __init__(self, code_len, tok_emb_dim, seqlen, rnn_dim, vocab_size, 47 | **unused_kwargs): 48 | super(_RNNDecoder, self).__init__() 49 | 50 | self.seqlen = seqlen 51 | self.emb_dim = code_len #+ tok_emb_dim 52 | 53 | self.dec = nn.LSTM(self.emb_dim, rnn_dim, bidirectional=True) 54 | self.tok_dec = BottledLinear(rnn_dim*2, vocab_size) 55 | 56 | def forward(self, tok_embs, codes, *unused_kwargs): 57 | """ 58 | tok_embs: (T+1)*N*tok_emb_dim; +1 for init toks 59 | codes: N*code_len 60 | """ 61 | batch_size, code_dim = codes.size() 62 | 63 | rep_codes = codes[None].expand(self.seqlen, *codes.size()) 64 | # tok_codes = torch.cat((tok_embs[:-1], rep_codes), dim=-1) 65 | tok_embs, _ = self.dec(rep_codes) # forwarding tok_codes ignores codes 66 | tok_logits = self.tok_dec(tok_embs) 67 | return nnf.log_softmax(tok_logits, dim=2) 68 | 69 | 70 | class AEHasher(nn.Module): 71 | """An autoencoder-based hasher.""" 72 | 73 | def __init__(self, code_len, num_hash_buckets=None, 74 | padding_idx=None, **kwargs): 75 | super(AEHasher, self).__init__() 76 | 77 | tok_emb_dim = kwargs['tok_emb_dim'] 78 | vocab_size = kwargs['vocab_size'] 79 | padding_idx = None if kwargs.get('env') == environ.SYNTH else 0 80 | self.tok_emb = nn.Embedding(vocab_size, tok_emb_dim, 81 | padding_idx=padding_idx) 82 | 83 | self.encoder = _RNNEncoder(code_len, **kwargs) 84 | self.decoder = _RNNDecoder(code_len, **kwargs) 85 | 86 | num_hash_buckets = num_hash_buckets or 2**code_len 87 | hash_code_len = int(torch.np.ceil(torch.np.log2(num_hash_buckets))) 88 | self.proj = Apply(torch.round) 89 | if hash_code_len != code_len: 90 | self.proj = nn.Sequential( 91 | Apply(lambda x: x * 2 - 1), # {0, 1} -> {-1, 1} 92 | nn.Linear(code_len, hash_code_len, bias=False), 93 | Apply(torch.sign), 94 | nn.ReLU(True)) 95 | 96 | def forward(self, toks, **unused_kwargs): 97 | """ 98 | In training mode, return log-probs of reconstructed tokens. 99 | toks: N*(T+1); the first timestep is init toks 100 | 101 | In evaluate mode, return binary codes 102 | toks: N*T; no init toks 103 | """ 104 | tok_embs = self.tok_emb(toks).transpose(0, 1) 105 | codes, code_logits = self.encoder(tok_embs[self.training:]) 106 | if self.training: 107 | return self.decoder(tok_embs, codes), code_logits 108 | return self.proj(codes).detach() 109 | 110 | 111 | def create(g_tok_emb_dim, num_gen_layers, **opts): 112 | """Creates a token generator.""" 113 | return AEHasher(tok_emb_dim=g_tok_emb_dim, 114 | num_layers=num_gen_layers, 115 | **opts) 116 | 117 | 118 | def test_ae_hasher(): 119 | """Tests the AEHashser.""" 120 | # pylint: disable=too-many-locals,unused-variable 121 | import common 122 | 123 | batch_size = 4 124 | code_len = 3 125 | num_hash_buckets = 8 126 | seqlen = 4 127 | vocab_size = 32 128 | tok_emb_dim = 8 129 | rnn_dim = 12 130 | num_layers = 1 131 | debug = True 132 | 133 | hasher = AEHasher(**locals()) 134 | 135 | toks = Variable(torch.LongTensor(batch_size, seqlen+1).random_(vocab_size)) 136 | 137 | hasher.train() 138 | tok_log_probs, code_logits = hasher(toks) 139 | assert tok_log_probs.size()[1:] == (seqlen, vocab_size) 140 | assert code_logits.size()[1:] == (code_len, 2) 141 | assert torch.np.allclose(code_logits.data.exp().sum(-1).numpy(), 1) 142 | 143 | hasher.eval() 144 | hash_code = hasher(toks) 145 | assert hash_code.size(1) == torch.np.log2(num_hash_buckets) 146 | assert (nnf.relu(hash_code) == hash_code).all() 147 | -------------------------------------------------------------------------------- /model/generator.py: -------------------------------------------------------------------------------- 1 | """The Generators.""" 2 | import itertools 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as nnf 7 | from torch.autograd import Variable 8 | 9 | import environ 10 | from .bottles import BottledLinear 11 | 12 | 13 | TYPES = ('rnn',) 14 | RNN, = TYPES 15 | 16 | 17 | class RNNGenerator(nn.Module): 18 | """An RNN token generator.""" 19 | 20 | def __init__(self, vocab_size, tok_emb_dim, rnn_dim, num_layers, 21 | padding_idx=None, **unused_kwargs): 22 | super(RNNGenerator, self).__init__() 23 | 24 | self.tok_emb = nn.Embedding(vocab_size, tok_emb_dim, 25 | padding_idx=padding_idx) 26 | 27 | self.gen = nn.LSTM(tok_emb_dim, rnn_dim, num_layers=num_layers) 28 | self.tok_dec = BottledLinear(rnn_dim, vocab_size) 29 | 30 | def forward(self, toks, prev_state=None, temperature=1, **unused_kwargs): 31 | """ 32 | toks: N*T 33 | """ 34 | wembs = self.tok_emb(toks).transpose(0, 1) # T*N*d_wemb 35 | tok_embs, next_state = self.gen(wembs, prev_state) 36 | logits = self.tok_dec(tok_embs) # T*N*vocab_size 37 | if temperature != 1: 38 | logits /= temperature 39 | tok_probs = nnf.log_softmax(logits, 2) 40 | return tok_probs, next_state 41 | 42 | def rollout(self, init_state, ro_steps, return_first_state=False, 43 | temperature=1): 44 | """ 45 | init_state: 46 | toks: N*T or (toks, prev_hidden state) 47 | ro_steps: roll out this many steps 48 | temperature: non-negative scalar that controls entropy 49 | (0 = max likelihood) 50 | 51 | This method does not modify the global random state. 52 | 53 | Returns: 54 | a list of T samples of size N*1, 55 | a list of T word log-probs of size N*V 56 | """ 57 | 58 | if isinstance(init_state, Variable): 59 | init_state = (init_state, None) 60 | gen_toks, gen_state = init_state 61 | 62 | gen_seqs = [] 63 | gen_log_probs = [] 64 | for i in range(ro_steps): 65 | tok_log_probs, gen_state = self( 66 | gen_toks, gen_state, temperature=temperature or 1) 67 | tok_log_probs = tok_log_probs[-1] # N*V 68 | if temperature == 0: 69 | gen_toks = tok_log_probs.max(-1)[1][:, None] 70 | else: 71 | gen_toks = torch.multinomial(tok_log_probs.exp(), 1) # N*1 72 | gen_toks = gen_toks.detach() 73 | gen_seqs.append(gen_toks) 74 | gen_log_probs.append(tok_log_probs) 75 | if i == 0 and return_first_state: 76 | th = torch.cuda if gen_toks.is_cuda else torch 77 | first_state = (gen_state, th.get_rng_state()) 78 | 79 | if return_first_state: 80 | return gen_seqs, gen_log_probs, first_state 81 | return gen_seqs, gen_log_probs 82 | 83 | def parameters(self, dx2=False): 84 | """ 85 | Returns an iterator over module parameters. 86 | If dx2=True, only yield parameters that are twice differentiable. 87 | """ 88 | if not dx2: 89 | return super(RNNGenerator, self).parameters() 90 | return itertools.chain(*[ 91 | m.parameters() for m in self.children() 92 | if m != self.tok_emb]) 93 | 94 | 95 | def create(g_tok_emb_dim, num_gen_layers, gen_type=RNN, **opts): 96 | """Creates a token generator.""" 97 | return RNNGenerator(tok_emb_dim=g_tok_emb_dim, 98 | num_layers=num_gen_layers, 99 | **opts) 100 | 101 | 102 | def test_rnn_generator(): 103 | """Tests the RNNGenerator.""" 104 | # pylint: disable=too-many-locals,unused-variable 105 | import common 106 | 107 | batch_size = 4 108 | seqlen = 4 109 | vocab_size = 32 110 | tok_emb_dim = 8 111 | rnn_dim = 12 112 | num_layers = 1 113 | debug = True 114 | 115 | gen = RNNGenerator(**locals()) 116 | 117 | toks = Variable(torch.LongTensor(batch_size, 1).fill_(1)) 118 | 119 | gen_probs, gen_state = gen(toks) 120 | gen_toks = torch.multinomial(gen_probs.exp(), 1).detach() 121 | gen_probs, gen_state = gen(toks=gen_toks, prev_state=gen_state) 122 | 123 | # test basic rollout 124 | init_toks = Variable(torch.LongTensor(batch_size, seqlen).fill_(1)) 125 | ro_seqs, ro_log_probs = gen.rollout(init_toks, seqlen, 0) 126 | assert len(ro_seqs) == seqlen and len(ro_log_probs) == seqlen 127 | assert torch.np.allclose( 128 | torch.stack(ro_log_probs).data.exp().sum(-1).numpy(), 1) 129 | 130 | # test reproducability 131 | init_rand_state = torch.get_rng_state() 132 | with common.rand_state(torch, 42) as rand_state: 133 | ro1, _ = gen.rollout(init_toks, 8) 134 | with common.rand_state(torch, rand_state): 135 | ro2, _ = gen.rollout(init_toks, 8) 136 | assert all((t1.data == t2.data).all() for t1, t2 in zip(ro1, ro2)) 137 | assert (torch.get_rng_state() == init_rand_state).all() 138 | 139 | # test continuation 140 | rand_toks = Variable(torch.LongTensor(batch_size, 2).random_(vocab_size)) 141 | ro_seqs, _, (ro_hid, ro_rng) = gen.rollout(rand_toks, 2, 142 | return_first_state=True) 143 | with common.rand_state(torch, ro_rng): 144 | next_ro, _ = gen.rollout((ro_seqs[0], ro_hid), 1) 145 | assert (ro_seqs[1].data == next_ro[0].data).all() 146 | 147 | # test double-backward 148 | sum(gen_probs).sum().backward(create_graph=True) 149 | sum(p.grad.norm() for p in gen.parameters(dx2=True)).backward() 150 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | """The Discriminators.""" 2 | import functools 3 | import itertools 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as nnf 8 | from torch.autograd import Variable 9 | 10 | import environ 11 | 12 | 13 | TYPES = ('cnn', 'rnn') 14 | CNN, RNN = TYPES 15 | 16 | 17 | def _l2_reg(mod, l=1e-4): 18 | def _reg(var, grad): 19 | return grad + l*var 20 | mod.weight.register_hook(functools.partial(_reg, mod.weight)) 21 | mod.bias.register_hook(functools.partial(_reg, mod.bias)) 22 | return mod 23 | 24 | 25 | class Highway(nn.Module): 26 | """A Highway layer Module.""" 27 | 28 | def __init__(self, in_features, activation=nn.ReLU(True)): 29 | super(Highway, self).__init__() 30 | 31 | self.gate = nn.Sequential( 32 | nn.Linear(in_features, in_features), 33 | nn.Sigmoid(), 34 | ) 35 | 36 | self.tx = nn.Sequential( 37 | nn.Linear(in_features, in_features), 38 | activation, 39 | ) 40 | 41 | def forward(self, x): 42 | g = self.gate(x) 43 | return g * self.tx(x) + (1 - g) * x 44 | 45 | 46 | class Discriminator(nn.Module): 47 | """A base class for Discriminators.""" 48 | 49 | def __init__(self, vocab_size, tok_emb_dim, **kwargs): 50 | super(Discriminator, self).__init__() 51 | 52 | padding_idx = None if kwargs.get('env') == environ.SYNTH else 0 53 | self.tok_emb = nn.Embedding(vocab_size, tok_emb_dim, 54 | padding_idx=padding_idx) 55 | 56 | self.stdev = nn.Parameter(torch.randn(1) * 0.1) 57 | 58 | def parameters(self, dx2=False): 59 | """ 60 | Returns an iterator over module parameters. 61 | If dx2=True, only yield parameters that are twice differentiable. 62 | """ 63 | if not dx2: 64 | return super(Discriminator, self).parameters() 65 | return itertools.chain(*[ 66 | m.parameters() for m in self.children() if m != self.tok_emb]) 67 | 68 | def forward(self, toks, return_embs=False): 69 | """ 70 | toks: N*T or [N*1]*T 71 | """ 72 | if isinstance(toks, (list, tuple)): 73 | toks = torch.cat(toks, -1) 74 | logits, embs = self._forward(toks) 75 | fuzz = Variable(logits.data.new(logits.size()).normal_()) * self.stdev 76 | log_probs = nnf.log_softmax(logits + fuzz, dim=1) 77 | if return_embs: 78 | return log_probs, Variable(embs.data, requires_grad=True) 79 | return log_probs 80 | 81 | 82 | 83 | class _CNNDiscriminator(Discriminator): 84 | """A CNN token discriminator.""" 85 | 86 | def __init__(self, tok_emb_dim, filter_widths, num_filters, dropout, 87 | **kwargs): 88 | super(_CNNDiscriminator, self).__init__( 89 | tok_emb_dim=tok_emb_dim, **kwargs) 90 | 91 | assert len(filter_widths) == len(num_filters) 92 | 93 | cnn_layers = [] 94 | for kw, c in zip(filter_widths, num_filters): 95 | cnn_layers.append(nn.Sequential( 96 | nn.Conv1d(tok_emb_dim, c, kw), 97 | nn.ReLU(True), 98 | )) 99 | self.cnn_layers = nn.ModuleList(cnn_layers) 100 | 101 | emb_dim = sum(num_filters) 102 | self.logits = nn.Sequential( 103 | Highway(emb_dim), 104 | nn.Dropout(dropout), 105 | _l2_reg(nn.Linear(emb_dim, 2))) 106 | 107 | def _forward(self, toks): 108 | """ 109 | toks_embs: N*T*d_wemb 110 | """ 111 | tok_embs = self.tok_emb(toks).transpose(1, 2).detach() # N*d_wemb*T 112 | 113 | layer_acts = torch.cat([ # N*sum(num_filters) 114 | layer(tok_embs).sum(-1) for layer in self.cnn_layers], -1) 115 | 116 | return self.logits(layer_acts), tok_embs 117 | 118 | 119 | class _RNNDiscriminator(Discriminator): 120 | """An RNN token discriminator.""" 121 | 122 | def __init__(self, tok_emb_dim, rnn_dim, **kwargs): 123 | super(_RNNDiscriminator, self).__init__(tok_emb_dim=tok_emb_dim, 124 | **kwargs) 125 | 126 | self.rnn = nn.LSTM(tok_emb_dim, rnn_dim, num_layers=3) 127 | self.logits = nn.Linear(rnn_dim, 2) 128 | 129 | def _forward(self, toks): 130 | """ 131 | toks: N*T 132 | """ 133 | tok_embs = self.tok_emb(toks).transpose(0, 1) # T*N*d_wemb 134 | seq_embs, _ = self.rnn(tok_embs) 135 | masked_embs = seq_embs.masked_fill( # T*N*rnn_dim 136 | (toks.t() == 0)[..., None].expand_as(seq_embs), 0) 137 | num_non_pad = (toks != 0).float().sum(1, keepdim=True) # N*1 138 | pre_logits = nnf.relu(masked_embs).sum(0) / num_non_pad 139 | logits = self.logits(pre_logits) 140 | return logits, masked_embs 141 | 142 | 143 | def create(d_type, d_tok_emb_dim, **opts): 144 | """Creates a token discriminator.""" 145 | d_cls = _RNNDiscriminator if d_type == RNN else _CNNDiscriminator 146 | return d_cls(tok_emb_dim=d_tok_emb_dim, **opts) 147 | 148 | 149 | def test_cnn_discriminator(): 150 | """Tests the CNNDiscriminator.""" 151 | # pylint: disable=unused-variable 152 | batch_size = 3 153 | vocab_size = 32 154 | tok_emb_dim = 10 155 | filter_widths = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] 156 | num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] 157 | dropout = 0.25 158 | debug = True 159 | 160 | d = _CNNDiscriminator(**locals()) 161 | 162 | preds = d(Variable(torch.LongTensor(batch_size, 20).fill_(1))) 163 | assert preds.size(0) == batch_size 164 | assert preds.size(1) == 2 165 | assert torch.np.allclose(preds.data.exp().sum(1).numpy(), 1) 166 | 167 | preds.sum().backward(create_graph=True) 168 | sum(p.grad.norm() for p in d.parameters(dx2=True)).backward() 169 | 170 | 171 | def test_rnn_discriminator(): 172 | """Tests the RNNDiscriminator.""" 173 | # pylint: disable=unused-variable 174 | batch_size = 3 175 | vocab_size = 32 176 | tok_emb_dim = 10 177 | rnn_dim = 4 178 | debug = True 179 | 180 | d = _RNNDiscriminator(**locals()) 181 | 182 | preds = d(Variable(torch.LongTensor(batch_size, 20).fill_(1))) 183 | assert preds.size(0) == batch_size 184 | assert preds.size(1) == 2 185 | assert torch.np.allclose(preds.data.exp().sum(1).numpy(), 1) 186 | 187 | preds.sum().backward(create_graph=True) 188 | sum(p.grad.norm() for p in d.parameters(dx2=True)).backward() 189 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code 6 | extension-pkg-whitelist= 7 | 8 | # Add files or directories to the blacklist. They should be base names, not 9 | # paths. 10 | ignore= 11 | 12 | # Add files or directories matching the regex patterns to the blacklist. The 13 | # regex matches against base names, not paths. 14 | ignore-patterns= 15 | 16 | # Python code to execute, usually for sys.path manipulation such as 17 | # pygtk.require(). 18 | #init-hook= 19 | 20 | # Use multiple processes to speed up Pylint. 21 | jobs=8 22 | 23 | # List of plugins (as comma separated values of python modules names) to load, 24 | # usually to register additional checkers. 25 | load-plugins= 26 | 27 | # Pickle collected data for later comparisons. 28 | persistent=yes 29 | 30 | # Specify a configuration file. 31 | #rcfile= 32 | 33 | # Allow loading of arbitrary C extensions. Extensions are imported into the 34 | # active Python interpreter and may run arbitrary code. 35 | unsafe-load-any-extension=no 36 | 37 | 38 | [MESSAGES CONTROL] 39 | 40 | # Only show warnings with the listed confidence levels. Leave empty to show 41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 42 | confidence= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,long-suffix,old-ne-operator,old-octal-literal,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,locally-enabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,eq-without-hash,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call,redefined-outer-name,invalid-name,too-few-public-methods,arguments-differ,no-self-use,wrong-import-position,fixme 54 | 55 | # Enable the message, report, category or checker with the given id(s). You can 56 | # either give multiple identifier separated by comma (,) or put this option 57 | # multiple time (only on the command line, not in the configuration file where 58 | # it should appear only once). See also the "--disable" option for examples. 59 | enable= 60 | 61 | 62 | [REPORTS] 63 | 64 | # Python expression which should return a note less than 10 (10 is the highest 65 | # note). You have access to the variables errors warning, statement which 66 | # respectively contain the number of errors / warnings messages and the total 67 | # number of statements analyzed. This is used by the global evaluation report 68 | # (RP0004). 69 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 70 | 71 | # Template used to display messages. This is a python new-style format string 72 | # used to format the message information. See doc for all details 73 | #msg-template= 74 | 75 | # Set the output format. Available formats are text, parseable, colorized, json 76 | # and msvs (visual studio).You can also give a reporter class, eg 77 | # mypackage.mymodule.MyReporterClass. 78 | output-format=text 79 | 80 | # Tells whether to display a full report or only the messages 81 | reports=no 82 | 83 | # Activate the evaluation score. 84 | score=yes 85 | 86 | 87 | [REFACTORING] 88 | 89 | # Maximum number of nested blocks for function / method body 90 | max-nested-blocks=5 91 | 92 | 93 | [FORMAT] 94 | 95 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 96 | expected-line-ending-format=LF 97 | 98 | # Regexp for a line that is allowed to be longer than the limit. 99 | ignore-long-lines=^\s*(# )??$ 100 | 101 | # Number of spaces of indent required inside a hanging or continued line. 102 | indent-after-paren=4 103 | 104 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 105 | # tab). 106 | indent-string=' ' 107 | 108 | # Maximum number of characters on a single line. 109 | max-line-length=80 110 | 111 | # Maximum number of lines in a module 112 | max-module-lines=1000 113 | 114 | # List of optional constructs for which whitespace checking is disabled. `dict- 115 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 116 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 117 | # `empty-line` allows space-only lines. 118 | no-space-check=trailing-comma,dict-separator 119 | 120 | # Allow the body of a class to be on the same line as the declaration if body 121 | # contains single statement. 122 | single-line-class-stmt=no 123 | 124 | # Allow the body of an if to be on the same line as the test if there is no 125 | # else. 126 | single-line-if-stmt=no 127 | 128 | 129 | [SPELLING] 130 | 131 | # Spelling dictionary name. Available dictionaries: none. To make it working 132 | # install python-enchant package. 133 | spelling-dict= 134 | 135 | # List of comma separated words that should not be checked. 136 | spelling-ignore-words= 137 | 138 | # A path to a file that contains private dictionary; one word per line. 139 | spelling-private-dict-file= 140 | 141 | # Tells whether to store unknown words to indicated private dictionary in 142 | # --spelling-private-dict-file option instead of raising a message. 143 | spelling-store-unknown-words=no 144 | 145 | 146 | [LOGGING] 147 | 148 | # Logging modules to check that the string format arguments are in logging 149 | # function parameter format 150 | logging-modules=logging 151 | 152 | 153 | [TYPECHECK] 154 | 155 | # List of decorators that produce context managers, such as 156 | # contextlib.contextmanager. Add to this list to register other decorators that 157 | # produce valid context managers. 158 | contextmanager-decorators=contextlib.contextmanager 159 | 160 | # List of members which are set dynamically and missed by pylint inference 161 | # system, and so shouldn't trigger E1101 when accessed. Python regular 162 | # expressions are accepted. 163 | generated-members= 164 | 165 | # Tells whether missing members accessed in mixin class should be ignored. A 166 | # mixin class is detected if its name ends with "mixin" (case insensitive). 167 | ignore-mixin-members=yes 168 | 169 | # This flag controls whether pylint should warn about no-member and similar 170 | # checks whenever an opaque object is returned when inferring. The inference 171 | # can return multiple potential results while evaluating a Python object, but 172 | # some branches might not be evaluated, which results in partial inference. In 173 | # that case, it might be useful to still emit no-member and other checks for 174 | # the rest of the inferred objects. 175 | ignore-on-opaque-inference=yes 176 | 177 | # List of class names for which member attributes should not be checked (useful 178 | # for classes with dynamically set attributes). This supports the use of 179 | # qualified names. 180 | ignored-classes=optparse.Values,thread._local,_thread._local 181 | 182 | # List of module names for which member attributes should not be checked 183 | # (useful for modules/projects where namespaces are manipulated during runtime 184 | # and thus existing member attributes cannot be deduced by static analysis. It 185 | # supports qualified module names, as well as Unix pattern matching. 186 | ignored-modules=numpy*,torch,pandas,lxml* 187 | 188 | # Show a hint with possible names when a member name was not found. The aspect 189 | # of finding the hint is based on edit distance. 190 | missing-member-hint=yes 191 | 192 | # The minimum edit distance a name should have in order to be considered a 193 | # similar match for a missing member name. 194 | missing-member-hint-distance=1 195 | 196 | # The total number of similar names that should be taken in consideration when 197 | # showing a hint for a missing member. 198 | missing-member-max-choices=1 199 | 200 | 201 | [MISCELLANEOUS] 202 | 203 | # List of note tags to take in consideration, separated by a comma. 204 | notes=FIXME,XXX,TODO 205 | 206 | 207 | [VARIABLES] 208 | 209 | # List of additional names supposed to be defined in builtins. Remember that 210 | # you should avoid to define new builtins when possible. 211 | additional-builtins= 212 | 213 | # Tells whether unused global variables should be treated as a violation. 214 | allow-global-unused-variables=yes 215 | 216 | # List of strings which can identify a callback function by name. A callback 217 | # name must start or end with one of those strings. 218 | callbacks=cb_,_cb 219 | 220 | # A regular expression matching the name of dummy variables (i.e. expectedly 221 | # not used). 222 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 223 | 224 | # Argument names that match this expression will be ignored. Default to name 225 | # with leading underscore 226 | ignored-argument-names=_.*|^ignored_|^unused_|^args|^kwargs 227 | 228 | # Tells whether we should check for unused import in __init__ files. 229 | init-import=no 230 | 231 | # List of qualified module names which can have objects that can redefine 232 | # builtins. 233 | redefining-builtins-modules=six.moves,future.builtins 234 | 235 | 236 | [BASIC] 237 | 238 | # Naming hint for argument names 239 | argument-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 240 | 241 | # Regular expression matching correct argument names 242 | argument-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 243 | 244 | # Naming hint for attribute names 245 | attr-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 246 | 247 | # Regular expression matching correct attribute names 248 | attr-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 249 | 250 | # Bad variable names which should always be refused, separated by a comma 251 | bad-names=foo,bar,baz,toto,tutu,tata 252 | 253 | # Naming hint for class attribute names 254 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 255 | 256 | # Regular expression matching correct class attribute names 257 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 258 | 259 | # Naming hint for class names 260 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 261 | 262 | # Regular expression matching correct class names 263 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 264 | 265 | # Naming hint for constant names 266 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 267 | 268 | # Regular expression matching correct constant names 269 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 270 | 271 | # Minimum line length for functions/classes that require docstrings, shorter 272 | # ones are exempt. 273 | docstring-min-length=-1 274 | 275 | # Naming hint for function names 276 | function-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 277 | 278 | # Regular expression matching correct function names 279 | function-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 280 | 281 | # Good variable names which should always be accepted, separated by a comma 282 | good-names=i,j,k,ex,Run,_ 283 | 284 | # Include a hint for the correct naming format with invalid-name 285 | include-naming-hint=no 286 | 287 | # Naming hint for inline iteration names 288 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 289 | 290 | # Regular expression matching correct inline iteration names 291 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 292 | 293 | # Naming hint for method names 294 | method-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 295 | 296 | # Regular expression matching correct method names 297 | method-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 298 | 299 | # Naming hint for module names 300 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 301 | 302 | # Regular expression matching correct module names 303 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 304 | 305 | # Colon-delimited sets of names that determine each other's naming style when 306 | # the name regexes allow several styles. 307 | name-group= 308 | 309 | # Regular expression which should only match function or class names that do 310 | # not require a docstring. 311 | no-docstring-rgx=^_ 312 | 313 | # List of decorators that produce properties, such as abc.abstractproperty. Add 314 | # to this list to register other decorators that produce valid properties. 315 | property-classes=abc.abstractproperty 316 | 317 | # Naming hint for variable names 318 | variable-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 319 | 320 | # Regular expression matching correct variable names 321 | variable-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 322 | 323 | 324 | [SIMILARITIES] 325 | 326 | # Ignore comments when computing similarities. 327 | ignore-comments=yes 328 | 329 | # Ignore docstrings when computing similarities. 330 | ignore-docstrings=yes 331 | 332 | # Ignore imports when computing similarities. 333 | ignore-imports=no 334 | 335 | # Minimum lines number of a similarity. 336 | min-similarity-lines=4 337 | 338 | 339 | [CLASSES] 340 | 341 | # List of method names used to declare (i.e. assign) instance attributes. 342 | defining-attr-methods=__init__,__new__,setUp 343 | 344 | # List of member names, which should be excluded from the protected access 345 | # warning. 346 | exclude-protected=_asdict,_fields,_replace,_source,_make 347 | 348 | # List of valid names for the first argument in a class method. 349 | valid-classmethod-first-arg=cls 350 | 351 | # List of valid names for the first argument in a metaclass class method. 352 | valid-metaclass-classmethod-first-arg=mcs 353 | 354 | 355 | [DESIGN] 356 | 357 | # Maximum number of arguments for function / method 358 | max-args=10 359 | 360 | # Maximum number of attributes for a class (see R0902). 361 | max-attributes=15 362 | 363 | # Maximum number of boolean expressions in a if statement 364 | max-bool-expr=5 365 | 366 | # Maximum number of branch for function / method body 367 | max-branches=12 368 | 369 | # Maximum number of locals for function / method body 370 | max-locals=20 371 | 372 | # Maximum number of parents for a class (see R0901). 373 | max-parents=7 374 | 375 | # Maximum number of public methods for a class (see R0904). 376 | max-public-methods=20 377 | 378 | # Maximum number of return / yield for function / method body 379 | max-returns=6 380 | 381 | # Maximum number of statements in function / method body 382 | max-statements=50 383 | 384 | # Minimum number of public methods for a class (see R0903). 385 | min-public-methods=2 386 | 387 | 388 | [IMPORTS] 389 | 390 | # Allow wildcard imports from modules that define __all__. 391 | allow-wildcard-with-all=no 392 | 393 | # Analyse import fallback blocks. This can be used to support both Python 2 and 394 | # 3 compatible code, which means that the block might have code that exists 395 | # only in one or another interpreter, leading to false positives when analysed. 396 | analyse-fallback-blocks=no 397 | 398 | # Deprecated modules which should not be used, separated by a comma 399 | deprecated-modules=optparse,tkinter.tix 400 | 401 | # Create a graph of external dependencies in the given file (report RP0402 must 402 | # not be disabled) 403 | ext-import-graph= 404 | 405 | # Create a graph of every (i.e. internal and external) dependencies in the 406 | # given file (report RP0402 must not be disabled) 407 | import-graph= 408 | 409 | # Create a graph of internal dependencies in the given file (report RP0402 must 410 | # not be disabled) 411 | int-import-graph= 412 | 413 | # Force import order to recognize a module as part of the standard 414 | # compatibility libraries. 415 | known-standard-library= 416 | 417 | # Force import order to recognize a module as part of a third party library. 418 | known-third-party=enchant 419 | 420 | 421 | [EXCEPTIONS] 422 | 423 | # Exceptions that will emit a warning when being caught. Defaults to 424 | # "Exception" 425 | overgeneral-exceptions=Exception 426 | -------------------------------------------------------------------------------- /environ/environment.py: -------------------------------------------------------------------------------- 1 | """A class for training a SeqGAN model.""" 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import time 7 | 8 | import torch 9 | from torch.nn import functional as nnf 10 | from torch.autograd import Variable 11 | 12 | import common 13 | from common import LABEL_GEN, LABEL_REAL 14 | import dataset 15 | from dataset import samplers 16 | import model 17 | 18 | 19 | class Environment(object): 20 | """A base class for training a SeqGAN model.""" 21 | 22 | _EVAL_METRIC = None # string describing generator eval metric 23 | 24 | _STATEFUL = ('hasher', 'g', 'd', 'optim_g', 'optim_d') 25 | 26 | def __init__(self, opts): 27 | """Creates an Environment.""" 28 | 29 | opts.num_hash_buckets = opts.num_hash_buckets or 2**opts.code_len 30 | self.opts = opts 31 | 32 | torch.nn._functions.rnn.force_unfused = opts.grad_reg # pylint: disable=protected-access 33 | 34 | self.train_dataset = self.test_dataset = None # `Dataset`s of real data 35 | 36 | self.g = model.generator.create(**vars(opts)).cuda() 37 | self.d = model.discriminator.create(**vars(opts)).cuda() 38 | 39 | self.optim_g = torch.optim.Adam(self.g.parameters(), lr=opts.lr_g) 40 | self.optim_d = torch.optim.Adam(self.d.parameters(), lr=opts.lr_d) 41 | 42 | if opts.exploration_bonus: 43 | self.hasher = model.hasher.create(**vars(opts)).cuda() 44 | self.optim_hasher = torch.optim.Adam(self.hasher.parameters(), 45 | lr=opts.lr_hasher) 46 | self.state_counter = model.hash_counter.HashCounter( 47 | self.hasher, opts.num_hash_buckets).cuda() 48 | 49 | num_inits = max(opts.num_rollouts, 1) * opts.batch_size 50 | self.ro_init_toks = Variable(torch.cuda.LongTensor(num_inits, 1)) 51 | self.init_toks = self.ro_init_toks[:opts.batch_size].detach() 52 | 53 | self._labels = torch.cuda.LongTensor(self.opts.batch_size) 54 | # self._inv_idx = torch.arange(opts.seqlen-1, -1, -1).long().cuda() 55 | 56 | @classmethod 57 | def get_opt_parser(cls): 58 | """Returns an `ArgumentParser` that parses env-specific opts.""" 59 | parser = argparse.ArgumentParser(add_help=False) 60 | 61 | # general 62 | parser.add_argument('--seed', default=42, type=int) 63 | parser.add_argument('--debug', action='store_true') 64 | parser.add_argument('--log-freq', default=1, type=int) 65 | 66 | # data 67 | parser.add_argument('--nworkers', default=4, type=int) 68 | 69 | # model 70 | parser.add_argument('--d-type', choices=model.discriminator.TYPES, 71 | default=model.discriminator.RNN) 72 | parser.add_argument('--vocab-size', type=int) 73 | parser.add_argument('--load-w2v', type=os.path.abspath) 74 | parser.add_argument('--g-tok-emb-dim', type=int) 75 | parser.add_argument('--d-tok-emb-dim', type=int) 76 | parser.add_argument('--rnn-dim', type=int) 77 | parser.add_argument('--num-gen-layers', default=1, type=int) 78 | parser.add_argument('--dropout', type=float) 79 | parser.add_argument('--seqlen', type=int) 80 | parser.add_argument('--batch-size', type=int) 81 | parser.add_argument('--num-filters', 82 | default=[100] + [200]*4 + [100]*5 + [160]*2, 83 | nargs='+', type=int) 84 | parser.add_argument('--filter-widths', 85 | default=list(range(1, 11)) + [15, 20], 86 | nargs='+', type=int) 87 | parser.add_argument('--exploration-bonus', default=0, type=float) 88 | parser.add_argument('--code-len', type=int) 89 | parser.add_argument('--num-hash-buckets', 90 | type=lambda x: 2**int( 91 | torch.np.round(torch.np.log2(float(x))))) 92 | 93 | # training 94 | parser.add_argument('--lr-g', type=float) 95 | parser.add_argument('--lr-d', type=float) 96 | parser.add_argument('--lr-hasher', type=float) 97 | parser.add_argument('--pretrain-g-epochs', type=int) 98 | parser.add_argument('--pretrain-d-epochs', type=int) 99 | parser.add_argument('--train-hasher-epochs', type=int) 100 | parser.add_argument('--adv-train-iters', type=int) 101 | parser.add_argument('--adv-g-iters', default=2, type=int) 102 | parser.add_argument('--num-rollouts', default=8, type=int) 103 | parser.add_argument('--discount', default=0.95, type=float) 104 | parser.add_argument('--g-ent-reg', default=1e-3, type=float) 105 | parser.add_argument('--d-ent-reg', default=1e-2, type=float) 106 | parser.add_argument('--hasher-ent-reg', default=1e-1, type=float) 107 | parser.add_argument('--grad-reg', default=0, type=float) 108 | parser.add_argument('--temperature', default=1, type=float) 109 | parser.add_argument('--rbuf-size', default=100, type=int) 110 | 111 | return parser 112 | 113 | @property 114 | def state(self): 115 | """Returns a dict containing the state of this Environment.""" 116 | return {item: getattr(self, item).state_dict() 117 | for item in self._STATEFUL if hasattr(self, item)} 118 | 119 | @state.setter 120 | def state(self, state): 121 | for item_name, item_state in state.items(): 122 | item = getattr(self, item_name, None) 123 | if item is None: 124 | logging.warning(f'WARNING: missing {item_name}') 125 | continue # don't load missing modules/optimizers 126 | if (isinstance(item, torch.optim.Optimizer) and 127 | not item_state['state']): 128 | continue # ignore unstepped optimizers 129 | try: 130 | item.load_state_dict(item_state) 131 | except (RuntimeError, KeyError, ValueError): 132 | logging.warning(f'WARNING: could not load {item_name} state') 133 | self.optim_g.param_groups[0]['lr'] = self.opts.lr_g 134 | self.optim_d.param_groups[0]['lr'] = self.opts.lr_d 135 | 136 | def _compute_eval_metric(self): 137 | """Returns a number by which the generative model should be evaluate.""" 138 | raise NotImplementedError() 139 | 140 | def _forward_seq2seq(self, fwd_fn, batch, volatile=False): 141 | """ 142 | batch: (toks, labels); assumes toks is N*(seqlen + 1) with init toks 143 | 144 | Returns: loss, 145 | """ 146 | toks = Variable(batch[0], volatile=volatile).cuda() 147 | flat_tgts = toks[:, 1:].t().contiguous().view(-1) 148 | 149 | output = fwd_fn(toks) 150 | if not isinstance(output, (list, tuple)): 151 | output = (output,) 152 | gen_log_probs = output[0] 153 | flat_gen_log_probs = gen_log_probs.view(-1, gen_log_probs.size(-1)) 154 | loss = nnf.nll_loss(flat_gen_log_probs, flat_tgts, 155 | ignore_index=self.opts.padding_idx) 156 | return (loss, *output) 157 | 158 | def train_hasher(self, hook=None): 159 | """Train an auto-encoder on the dataset for use in hashing.""" 160 | self.hasher.train() 161 | train_loader = self._create_dataloader(self.train_dataset) 162 | test_loader = self._create_dataloader(self.test_dataset) 163 | 164 | for epoch in range(1, self.opts.train_hasher_epochs + 1): 165 | tick = time.time() 166 | train_loss = train_entropy = 0 167 | for batch in train_loader: 168 | loss, _, code_logits = self._forward_seq2seq(self.hasher, batch) 169 | train_loss += loss.data[0] 170 | 171 | entropy = self._get_entropy(code_logits)[0] 172 | loss -= entropy * self.opts.hasher_ent_reg 173 | train_entropy += entropy.data[0] 174 | 175 | self.optim_hasher.zero_grad() 176 | loss.backward() 177 | self.optim_hasher.step() 178 | train_loss /= len(train_loader) 179 | train_entropy /= len(train_loader) 180 | 181 | test_loss = sum(self._forward_seq2seq(self.hasher, batch, 182 | volatile=True)[0].data[0] 183 | for batch in test_loader) / len(test_loader) 184 | logging.info( 185 | f'[{epoch:02d}] ' 186 | f'loss: train={train_loss:.3f} test={test_loss:.3f} ' 187 | f'H: {train_entropy:.3f} ' 188 | f'({time.time() - tick:.1f})') 189 | 190 | if callable(hook): 191 | hook(self, epoch) 192 | 193 | def pretrain_g(self, hook=None): 194 | """Pretrains G using maximum-likelihood.""" 195 | logging.info(f'[00] {self._EVAL_METRIC}: ' 196 | f'{self._compute_eval_metric():.3f}') 197 | train_loader = self._create_dataloader(self.train_dataset) 198 | for epoch in range(1, self.opts.pretrain_g_epochs + 1): 199 | tick = time.time() 200 | train_loss = entropy = gnorm = 0 201 | for batch in train_loader: 202 | loss, gen_log_probs = self._forward_g_ml(batch) 203 | entropy += self._get_entropy(gen_log_probs)[1].data[0] 204 | train_loss += loss.data[0] 205 | 206 | self.optim_g.zero_grad() 207 | loss.backward() 208 | gnorm += self._get_grad_norm(self.g).data[0] 209 | self.optim_g.step() 210 | train_loss /= len(train_loader) 211 | entropy /= len(train_loader) 212 | gnorm /= len(train_loader) 213 | 214 | logging.info( 215 | f'[{epoch:02d}] loss: {train_loss:.3f} ' 216 | f'{self._EVAL_METRIC}: {self._compute_eval_metric():.3f} ' 217 | f'H: {entropy:.2f} ' 218 | f'gnorm: {self._get_grad_norm(self.g).data[0]:.2f} ' 219 | f'({time.time() - tick:.1f})') 220 | 221 | if callable(hook): 222 | hook(self, epoch) 223 | 224 | def _forward_g_ml(self, batch, volatile=False): 225 | """ 226 | batch: (toks: N*T, labels: N) 227 | Returns: (tok_probs: T*N*V, next_state) 228 | """ 229 | return self._forward_seq2seq(lambda toks: self.g(toks[:, :-1])[0], 230 | batch, volatile=volatile) 231 | 232 | def pretrain_d(self, hook=None): 233 | """Pretrains D using pretrained G.""" 234 | for epoch in range(1, self.opts.pretrain_d_epochs+1): 235 | tick = time.time() 236 | gen_dataset = self._create_gen_dataset(self.g, LABEL_GEN, 237 | seed=self.opts.seed+epoch) 238 | dataloader = self._create_dataloader(torch.utils.data.ConcatDataset( 239 | (self.train_dataset, gen_dataset))) 240 | 241 | train_loss = 0 242 | for batch in dataloader: 243 | loss, _ = self._forward_d(batch) 244 | train_loss += loss.data[0] 245 | 246 | self.optim_d.zero_grad() 247 | loss.backward() 248 | self.optim_d.step() 249 | train_loss /= len(dataloader) 250 | 251 | acc_real, acc_gen = self._compute_d_test_acc() 252 | logging.info(f'[{epoch:02d}] loss: {train_loss:.3f} ' 253 | f'acc: real={acc_real:.2f} gen={acc_gen:.2f} ' 254 | f'({time.time() - tick:.1f})') 255 | 256 | if callable(hook): 257 | hook(self, epoch) 258 | 259 | def _forward_d(self, batch, volatile=False, has_init=True): 260 | toks, labels = batch 261 | toks = Variable(toks, volatile=volatile).cuda() 262 | labels = Variable(labels, volatile=volatile).cuda() 263 | d_log_probs = self.d(toks[:, has_init:]) 264 | return nnf.nll_loss(d_log_probs, labels), d_log_probs 265 | 266 | def _compute_d_test_acc(self, num_samples=256): 267 | num_test_batches = max(num_samples // len(self.init_toks), 1) 268 | 269 | test_loader = self._create_dataloader(self.test_dataset) 270 | acc_real = 0 271 | for i, (batch_toks, _) in enumerate(test_loader): 272 | if i == num_test_batches: 273 | break 274 | toks = Variable(batch_toks[:, 1:].cuda()) # no init toks 275 | acc_real += self.compute_acc(self.d(toks), LABEL_REAL) 276 | acc_real /= num_test_batches 277 | 278 | acc_gen = 0 279 | with common.rand_state(torch.cuda, -1): 280 | for _ in range(num_test_batches): 281 | gen_seqs, _ = self.g.rollout(self.init_toks, self.opts.seqlen) 282 | init_gen_seqs = torch.cat([self.init_toks] + gen_seqs, -1) 283 | dataset.GenDataset.mask_gen_seqs_(init_gen_seqs.data, 284 | self.opts.eos_idx) 285 | acc_gen += self.compute_acc(self.d(init_gen_seqs[:, 1:]), 286 | LABEL_GEN) 287 | acc_gen /= num_test_batches 288 | 289 | return acc_real, acc_gen 290 | 291 | def compute_acc(self, probs, label): 292 | """Computes the accuracy given prob Variable and and label.""" 293 | self._labels.fill_(label) 294 | probs = probs.data if isinstance(probs, Variable) else probs 295 | return (probs.max(1)[1] == self._labels).float().mean() 296 | 297 | def train_adv(self): 298 | """Adversarially train G against D.""" 299 | 300 | self.optim_g = torch.optim.Adam(self.g.parameters(), lr=self.opts.lr_g) 301 | self.optim_d = torch.optim.Adam(self.d.parameters(), lr=self.opts.lr_d) 302 | if self.opts.exploration_bonus: 303 | self.hasher.eval() 304 | self._init_state_counter() 305 | 306 | real_dataloader = iter( 307 | self._create_dataloader(self.train_dataset, cycle=True)) 308 | 309 | replay_buffer = rbuf_loader = replay_buffer_iter = None 310 | if self.opts.rbuf_size: 311 | replay_buffer, rbuf_loader = self._create_replay_buffer( 312 | self.opts.rbuf_size, LABEL_GEN) 313 | 314 | for i in range(1, self.opts.adv_train_iters+1): 315 | tick = time.time() 316 | 317 | loss_g, gen_seqs, entropy_g = self._train_adv_g(replay_buffer) 318 | self.optim_g.zero_grad() 319 | loss_g.backward(create_graph=bool(self.opts.grad_reg)) 320 | if not self.opts.grad_reg: 321 | self.optim_g.step() 322 | 323 | if replay_buffer and replay_buffer_iter is None: 324 | replay_buffer_iter = iter(rbuf_loader) 325 | 326 | loss_d = self._train_adv_d(gen_seqs, real_dataloader, 327 | replay_buffer_iter) 328 | self.optim_d.zero_grad() 329 | loss_d.backward(create_graph=bool(self.opts.grad_reg)) 330 | 331 | gnormg, gnormd = map(self._get_grad_norm, (self.g, self.d)) 332 | gnorm = (gnormg * (self.opts.grad_reg * 50.) + 333 | gnormd * (self.opts.grad_reg * 0.1)) 334 | if self.opts.grad_reg: 335 | gnorm.backward() 336 | self.optim_g.step() 337 | self.optim_d.step() 338 | 339 | if (i-1) % self.opts.log_freq == 0: 340 | acc_oracle, acc_gen = self._compute_d_test_acc() 341 | logging.info( 342 | f'[{i:03d}] ' 343 | f'{self._EVAL_METRIC}: {self._compute_eval_metric():.3f} ' 344 | f'acc: o={acc_oracle:.2f} g={acc_gen:.2f} ' 345 | f'gnorm: g={gnormg.data[0]:.2f} d={gnormd.data[0]:.2f} ' 346 | f'H: {entropy_g.data[0]:.2f} ' 347 | f'({time.time() - tick:.1f})') 348 | 349 | def _init_state_counter(self): 350 | for toks, _ in self._create_dataloader(self.train_dataset): 351 | self.state_counter(Variable(toks.cuda(), volatile=True), 352 | 'counts_train') 353 | 354 | def _train_adv_g(self, replay_buffer=None): 355 | losses = [] 356 | entropies = [] 357 | for i in range(self.opts.adv_g_iters): 358 | # train G 359 | gen_seqs, gen_log_probs = self.g.rollout( 360 | self.init_toks, self.opts.seqlen, 361 | temperature=self.opts.temperature) 362 | gen_seqs = torch.cat(gen_seqs, -1) # N*T 363 | if i == 0 and replay_buffer is not None: 364 | replay_buffer.add_samples(gen_seqs) 365 | 366 | gen_log_probs = torch.stack(gen_log_probs) # T*N*V 367 | seq_log_probs = gen_log_probs.transpose(0, 1).gather( # N*T 368 | -1, gen_seqs.unsqueeze(-1)).squeeze(-1) 369 | 370 | advantages = self._get_advantages(gen_seqs) # N*T 371 | score = (seq_log_probs * advantages).sum(1).mean() 372 | 373 | disc_entropy, entropy = self._get_entropy( 374 | gen_log_probs, discount_rate=self.opts.discount) 375 | 376 | # _, roomtemp_lprobs = self.g.rollout( 377 | # self.init_toks, self.opts.seqlen, temperature=1) 378 | # roomtemp_lprobs = torch.stack(roomtemp_lprobs) 379 | # _, entropy = self._get_entropy(roomtemp_lprobs) 380 | 381 | entropies.append(entropy) 382 | losses.append(-score - disc_entropy * self.opts.g_ent_reg) 383 | 384 | loss = sum(losses) 385 | avg_entropy = sum(entropies) / len(entropies) 386 | return loss, gen_seqs, avg_entropy 387 | 388 | def _get_advantages(self, gen_seqs): 389 | rep_gen_seqs = gen_seqs.repeat(max(1, self.opts.num_rollouts), 1) 390 | qs_g = self._get_qs(self.g, rep_gen_seqs) # N*T 391 | 392 | advs = qs_g # something clever like PPO would be inserted here 393 | 394 | # advs = advs[:, self._inv_idx].cumsum(1)[:, self._inv_idx] # adv to go 395 | advs -= advs.mean() 396 | advs /= advs.std() 397 | return advs.detach() 398 | 399 | def _get_qs(self, g_ro, rep_gen_seqs): 400 | rep_gen_seqs.volatile = True 401 | 402 | qs = torch.cuda.FloatTensor( 403 | self.opts.seqlen, self.opts.batch_size).zero_() 404 | bonus = torch.cuda.FloatTensor(1, 1).zero_().expand_as(qs) 405 | 406 | gen_seqs = rep_gen_seqs[:self.opts.batch_size] 407 | qs[-1] = self.d(gen_seqs)[:, LABEL_REAL].data 408 | 409 | if self.opts.exploration_bonus: 410 | # bonus = self._get_exploration_bonus(gen_seqs).repeat( # T*N 411 | # self.opts.seqlen, 1) 412 | bonus = bonus.contiguous() 413 | bonus[-1] = self._get_exploration_bonus(gen_seqs) 414 | 415 | if self.opts.num_rollouts == 0: 416 | return Variable(qs.t().exp_().add_(bonus.t())) 417 | 418 | ro_rng = torch.cuda.get_rng_state() 419 | _, ro_hid = g_ro(self.ro_init_toks) 420 | for n in range(1, self.opts.seqlen): 421 | # ro_suff, _ = g_ro.rollout(rep_gen_seqs[:,:n], self.opts.seqlen-n) 422 | 423 | torch.cuda.set_rng_state(ro_rng) 424 | ro_state = (rep_gen_seqs[:, n-1].unsqueeze(-1), ro_hid) 425 | ro_suffix, _, (ro_hid, ro_rng) = g_ro.rollout( 426 | ro_state, self.opts.seqlen - n, return_first_state=True) 427 | ro = torch.cat([rep_gen_seqs[:, :n]] + ro_suffix, -1) 428 | assert ro.size(1) == self.opts.seqlen 429 | 430 | qs[n-1] = self._ro_mean(self.d(ro), (-1, 2))[:, LABEL_REAL].data 431 | # LABEL_G gives cost, LABEL_REAL gives reward 432 | # if self.opts.exploration_bonus: 433 | # bonus[n-1] = self._ro_mean(self._get_exploration_bonus(ro)) 434 | 435 | return Variable(qs.t().exp_().add_(bonus.t())) 436 | 437 | def _ro_mean(self, t, sizes=(-1,)): 438 | """Averages a tensor over rollouts. 439 | t: (num_rollouts*N)*sizes 440 | 441 | Returns: N*sizes 442 | """ 443 | return t.view(self.opts.num_rollouts, *sizes).mean(0) 444 | 445 | def _get_exploration_bonus(self, gen_seqs): 446 | seq_buckets = self.state_counter( 447 | Variable(gen_seqs.data, volatile=True)) 448 | reachable = self.state_counter.counts_train 449 | visit_counts = self.state_counter.counts[seq_buckets] 450 | bonus_weights = (reachable + 0.1).log() * self.opts.exploration_bonus 451 | return bonus_weights[seq_buckets] / visit_counts**0.5 452 | 453 | def _train_adv_d(self, gen_seqs, real_dataloader, replay_buffer_iter=None): 454 | REAL_W = 0.5 455 | GEN_W = (1 - REAL_W) 456 | RBUF_W = 0.5 457 | 458 | loss_d, d_log_probs = self._forward_d(next(real_dataloader)) 459 | loss_d *= REAL_W 460 | entropy_d = REAL_W * self._get_entropy(d_log_probs)[0] 461 | 462 | n_rbuf_batches = 0 463 | if replay_buffer_iter: 464 | n_rbuf_batches = min( 465 | len(replay_buffer_iter) // self.opts.batch_size, 4) 466 | cur_w = GEN_W 467 | if n_rbuf_batches: 468 | cur_w *= RBUF_W 469 | rbuf_batch_w = GEN_W * RBUF_W / n_rbuf_batches 470 | 471 | self._labels.fill_(LABEL_GEN) 472 | loss_d_g, d_log_probs = self._forward_d( 473 | (gen_seqs.data, self._labels), has_init=False) 474 | 475 | loss_d += cur_w * loss_d_g 476 | entropy_d += cur_w * self._get_entropy(d_log_probs)[0] 477 | 478 | for _ in range(n_rbuf_batches): 479 | loss_d_g, d_log_probs = self._forward_d( 480 | next(replay_buffer_iter), has_init=False) 481 | loss_d += rbuf_batch_w * loss_d_g 482 | entropy_d += rbuf_batch_w * self._get_entropy(d_log_probs)[0] 483 | 484 | return loss_d - entropy_d * self.opts.d_ent_reg 485 | 486 | def _create_gen_dataset(self, gen, label, num_samples=None, seed=None): 487 | num_samples = num_samples or len(self.train_dataset) 488 | seed = seed or self.opts.seed 489 | return dataset.GenDataset(generator=gen, 490 | label=label, 491 | seqlen=self.opts.seqlen, 492 | seed=seed, 493 | gen_init_toks=self.ro_init_toks, 494 | num_samples=num_samples, 495 | eos_idx=self.opts.eos_idx) 496 | 497 | def _create_dataloader(self, src_dataset, cycle=False): 498 | dl_opts = {'batch_size': self.opts.batch_size, 499 | 'num_workers': self.opts.nworkers, 500 | 'pin_memory': True, 501 | 'shuffle': not cycle} 502 | if cycle: 503 | dl_opts['sampler'] = samplers.InfiniteRandomSampler(src_dataset) 504 | return torch.utils.data.DataLoader(src_dataset, **dl_opts) 505 | 506 | def _create_replay_buffer(self, max_history, label): 507 | replay_buffer = dataset.ReplayBuffer(max_history, label) 508 | sampler = samplers.ReplayBufferSampler(replay_buffer, 509 | self.opts.batch_size) 510 | loader = torch.utils.data.DataLoader(replay_buffer, 511 | batch_sampler=sampler, 512 | num_workers=0, # TODO 513 | pin_memory=True) 514 | return replay_buffer, loader 515 | 516 | @staticmethod 517 | def _get_grad_norm(mod): 518 | return sum((p.grad**2).sum() for p in mod.parameters(dx2=True)) 519 | 520 | @staticmethod 521 | def _get_entropy(log_probs, discount_rate=None): 522 | # assumes distributions are along the last dimension 523 | infos = log_probs.exp() * log_probs 524 | entropy = entropy_undiscounted = -infos.sum(-1).mean() 525 | if discount_rate and discount_rate != 1: 526 | sz = [log_probs.size(0)] + [1]*(log_probs.ndimension() - 1) 527 | discount = log_probs.data.new(*sz).fill_(1) 528 | discount[1:] *= discount_rate 529 | discount.cumprod(0, out=discount) 530 | infos = infos * Variable(discount) 531 | entropy = -infos.sum(-1).mean() 532 | return entropy, entropy_undiscounted 533 | --------------------------------------------------------------------------------