├── .gitignore ├── LICENSE ├── README.md ├── data ├── counterfitted_neighbors.json └── splits │ ├── imdb_dev_files.txt │ ├── imdb_test_files.txt │ └── imdb_train_files.txt ├── download_deps.sh ├── scripts ├── analyze_wordvec.py ├── compute_p_values.py ├── error_analysis.py ├── find_neighbors_connected_components.py ├── latex_errors.py └── split_imdb.py └── src ├── attacks.py ├── data_util.py ├── entailment.py ├── ibp.py ├── precompute_lm_scores.py ├── test.py ├── text_classification.py ├── train.py └── vocabulary.py /.gitignore: -------------------------------------------------------------------------------- 1 | #Big data files 2 | data/aclImdb 3 | data/counter-fitted-vectors.txt 4 | data/glove 5 | data/lm_scores 6 | data/snli 7 | 8 | # LM 9 | windweller-l2w 10 | 11 | #Compiled files 12 | *.pyc 13 | 14 | # Swap files 15 | *.sw* 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Robin Jia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Certified Robustness to Adversarial Word Substitutions 2 | 3 | This is the official GitHub repository for the following paper: 4 | 5 | > [**Certified Robustness to Adversarial Word Substitutions.**](https://arxiv.org/abs/1909.00986) 6 | > Robin Jia, Aditi Raghunathan, Kerem Göksel, and Percy Liang. 7 | > _Empirical Methods in Natural Language Processing (EMNLP)_, 2019. 8 | 9 | For full details on reproducing the results, see this [Codalab worksheet](https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689), 10 | which contains all code, data, and experiments from the paper. 11 | This GitHub repository serves as an easy way to get started with the code, and has some additional instructions and documentation. 12 | 13 | # Setup 14 | This code has been tested with python3.6, pytorch 1.3.1, numpy 1.15.4, and NLTK 3.4. 15 | 16 | Download data dependencies by running the provided script: 17 | ``` 18 | ./download_deps.sh 19 | ``` 20 | 21 | If you already have GloVe vectors on your system, 22 | it may be more convenient to comment out the part of `download_deps.sh` that downloads GloVe, 23 | and instead add a symlink to the directory containing the GloVe vectors at `data/glove`. 24 | 25 | # Interval Bound Propagation library 26 | 27 | We have implemented many primitives for Interval Bound Propagation (IBP), 28 | which can be found in `src/ibp.py`. This code should be reusable and intuitive for anyone familiar with pytorch. 29 | When designing this library, our goal was to make it possible to write code that looks like standard pytorch code, but can be trained with IBP. 30 | Below, we give an overview of the code. 31 | 32 | ## BoundedTensor 33 | `BoundedTensor` is our version of `torch.Tensor`. It represents a tensor that additionally has some bounded set of possible values. The two most important subclasses of `BoundedTensor` are `IntervalBoundedTensor` and `DiscreteChoiceTensor`. 34 | 35 | ### IntervalBoundedTensor 36 | An `IntervalBoundedTensor` keeps track of three instance variables: an actual value, a coordinate-wise upper bound on the value, and a coordinate-wise lower bound on the value. All three of these are `torch.Tensor` objects. 37 | It also implements many standard methods of `torch.Tensor`. 38 | 39 | ### DiscreteChoiceTensor 40 | A `DiscreteChoiceTensor` represents a tensor that can take a discrete set of values. 41 | We use `DiscreteChoiceTensor` to represent the set of possible word vectors 42 | that can appear at each slice of the input. 43 | Importantly, `DiscreteChoiceTensor.to_interval_bounded()` converts a `DiscreteChoiceTensor` to an `IntervalBoundedTensor` by taking a coordinate-wise min/max. 44 | 45 | ### NormBallTensor 46 | We also provide `NormBallTensor`, which represents a p-norm ball of a given radius around a value. 47 | 48 | ## Functions and layers 49 | To go with `BoundedTensor`, we include functions and layers that know how to take `BoundedTensor` objects as inputs 50 | and return `BoundedTensor` objects as outputs. 51 | Most of these should be straightforward to use for folks familiar with their standard `torch`, `torch.nn`, and `torch.nn.functional` equivalents (with a caveat that not all flags in the standard library are necessarily supported). 52 | 53 | ### Functions 54 | Available implementations of basic `torch` functions include: 55 | 56 | * `add` 57 | * `mul` 58 | * `div` 59 | * `bmm` 60 | * `cat` 61 | * `stack` 62 | * `sum` 63 | 64 | In many cases, we directly call the `torch` counterpart if the inputs are `torch.Tensor` objects. 65 | A few additional cases are described below. 66 | 67 | #### Activation functions 68 | Since monotonic functions all use the same IBP formula, 69 | we export a single function `ibp.activation` 70 | which can apply elementwise ReLU, sigmoid, tanh, or exp to an `IntervalBoundedTensor`. 71 | 72 | #### Logsoftmax 73 | We include a `log_softmax()` function that is equivalent to `torch.nn.functional.log_softmax()`. 74 | We strongly advise users to use this implementation rather than implementing their own softmax operation, as numerical instability can easily arise with a naive implementation. 75 | 76 | #### Nonnegative matrix multiplication 77 | We include `matmul_nneg()` function that handles matrix multiplication between two non-negative matrices, as this is simpler than the general case. 78 | 79 | ### Layers (`nn.Module` objects) 80 | Many basic layers are implemented by extending their `torch.nn` counterparts, including 81 | * `Linear` 82 | * `Embedding` 83 | * `Conv1d` 84 | * `MaxPool1d` 85 | * `LSTM` 86 | * `Dropout` 87 | 88 | #### RNNs 89 | Our library also includes `LSTM` and `GRU` classes, which extend `nn.Module` directly. 90 | These are unfortunately slower than their `torch.nn` counterparts, 91 | because the `torch.nn` RNN's use cuDNN. 92 | 93 | ## Examples 94 | If you want to see this library in action, a good place to start is `BOWModel` in `src/text_classification.py`. This implements a simple bag-of-words model for text classification. 95 | Note that in `forward()`, we accept a flag called `compute_bounds` which lets the user decide whether to run IBP or not. 96 | 97 | # Paper experiments 98 | In this repository, we include a minimal set of commands and instructions to reproduce a few key results from our EMNLP 2019 paper. 99 | We will focus on the CNN model results on the IMDB dataset. 100 | To see other available command line flags, you can run `python src/train.py -h`. 101 | 102 | If you are interested in reproducing our experiments, we recommend looking at the 103 | aforementioned [Codalab worksheet](https://worksheets.codalab.org/worksheets/0x79feda5f1998497db75422eca8fcd689), which shows how to reproduce all results in our paper. 104 | Note that the commands on Codalab include some extra flags (`--neighbor-file`, `--glove-dir`, `--imdb-dir`, and `--snli-dir`) that are used to specify non-default paths to files. 105 | These flags are unnecessary when following the instructions in this repository. 106 | 107 | ## Training 108 | Here are commands to train the CNN model on IMDB with standard training, certifiably robust training, and data augmentation. 109 | 110 | **Standard training** 111 | 112 | To train the baseline model without IBP, run the following: 113 | ``` 114 | python src/train.py classification cnn outdir_cnn_normal -d 100 --pool mean -T 10 --dropout-prob 0.2 -b 32 --save-best-only 115 | ``` 116 | 117 | This should get about 88% accuracy on dev (but 0% certified accuracy). 118 | `outdir_cnn_normal` is an output directory where model parameters and stats will be saved. 119 | 120 | **Certifiably robust training** 121 | 122 | To use certifiably robust training with IBP, run the following: 123 | ``` 124 | python src/train.py classification cnn outdir_cnn_cert -d 100 --pool mean -T 60 --full-train-epochs 20 -c 0.8 --dropout-prob 0.2 -b 32 --save-best-only 125 | ``` 126 | 127 | This should get about 81% accuracy and 66% certified accuracy on dev. 128 | Note that these results do not include language model constraints on the attack surface, 129 | and therefore the certified accuracy is a bit too low. 130 | These constraints will be enforced in the testing commands below. 131 | 132 | **Training with data augmentation** 133 | 134 | To train with data augmentation, run the following: 135 | ``` 136 | python src/train.py classification cnn outdir_cnn_aug -d 100 --pool mean -T 60 --augment-by 4 --dropout-prob 0.2 -b 32 --save-best-only 137 | ``` 138 | 139 | This should get about 85% accuracy and 84% augmented accuracy on dev (but 0% certified accuracy). 140 | 141 | ## Testing 142 | Next, we will show how to test the trained models using the genetic attack. 143 | The genetic attack heuristically searches for a perturbation that causes an error. 144 | In this phase, we also incorporate pre-computed language model scores that determine which perturbations are valid. 145 | 146 | For example, let's say we want to use the trained model inside the `outdir_cnn_cert` directory. 147 | First, we choose a checkpoint based on the best certified accuracy on the dev set, say checkpoint 57. 148 | (Note: the training code with `--save-best-only` will save only the best model and the final model; 149 | stats on all checkpoints are logged in `/all_epoch_stats.json`.) 150 | 151 | This command will run the genetic attack: 152 | ``` 153 | python src/train.py classification cnn eval_cnn_cert -L outdir_cnn_cert --load-ckpt 57 -d 100 --pool mean -T 0 -b 1 -a genetic --adv-num-epochs 40 --adv-pop-size 60 --use-lm --downsample-to 1000 154 | ``` 155 | It should get about 80% standard accuracy, 72.5% certified accuracy, 156 | and 73% adversarial accuracy (i.e., accuracy against the genetic attack). 157 | For all models, you should find that adversarial accuracy is between standard accuracy and certified accuracy. 158 | For IMDB, we downsample to 1000 examples, as the genetic attack is pretty slow; 159 | the provided precomputed LM scores (in `lm_scores`) are only for the first 1000 examples in the train, development, and test sets. 160 | For SNLI, we use the entire development and test sets for evaluation. 161 | 162 | **Note:** This code is sensitive to the version of NLTK you use. 163 | The LM prediction files provided here should work if you are using the current version of NLTK and have updated your `nltk_data` directory recently. 164 | The experiments on Codalab use an older NLTK version; 165 | you can download the LM files from Codalab if you need compatibility with older NLTK versions. 166 | NLTK version issues will result in a `KeyError` with an `Unrecognized sentence` message. 167 | 168 | ## Running the language model yourself 169 | If you want to precompute language model scores on other data, use the following instructions. 170 | 171 | 1. Clone the following git repository: 172 | 173 | ``` 174 | git clone https://github.com/robinjia/l2w windweller-l2w 175 | ``` 176 | 177 | 2. Obtain pre-trained parameters and put them in a 178 | directory named `l2w-params` within that repository. 179 | Please contact us if you need a copy of the parameters. 180 | 181 | 3. Adapt `src/precompute_lm_scores.py` for your dataset. 182 | -------------------------------------------------------------------------------- /download_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | 5 | # LM scores 6 | wget https://nlp.stanford.edu/data/robinjia/jia2019_cert_lm_scores.zip 7 | unzip jia2019_cert_lm_scores.zip 8 | rm jia2019_cert_lm_scores.zip 9 | 10 | # GloVe 11 | mkdir glove 12 | cd glove 13 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip 14 | unzip glove.840B.300d.zip 15 | rm glove.840B.300d.zip 16 | cd - 17 | 18 | # Counterfitted vectors 19 | wget https://github.com/nmrksic/counter-fitting/raw/master/word_vectors/counter-fitted-vectors.txt.zip 20 | unzip counter-fitted-vectors.txt.zip 21 | rm counter-fitted-vectors.txt.zip 22 | 23 | # IMDB 24 | wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz 25 | tar -xvzf aclImdb_v1.tar.gz 26 | rm -f aclImdb_v1.tar.gz 27 | cp splits/imdb_train_files.txt aclImdb/train 28 | cp splits/imdb_dev_files.txt aclImdb/train 29 | cp splits/imdb_test_files.txt aclImdb/test 30 | 31 | # SNLI 32 | mkdir -p snli 33 | cd snli 34 | wget https://nlp.stanford.edu/projects/snli/snli_1.0.zip 35 | unzip snli_1.0.zip 36 | mv snli_1.0/*.jsonl . 37 | rm -r snli_1.0.zip snli_1.0 __MACOSX 38 | -------------------------------------------------------------------------------- /scripts/analyze_wordvec.py: -------------------------------------------------------------------------------- 1 | """Analyze word vectors.""" 2 | import argparse 3 | import os 4 | from scipy.stats import special_ortho_group 5 | import sys 6 | import torch 7 | from tqdm import tqdm 8 | 9 | sys.path.append(os.path.join( 10 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src')) 11 | import data_util 12 | import text_classification 13 | import vocabulary 14 | 15 | OPTS = None 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('model_dir') 20 | parser.add_argument('checkpoint', type=int) 21 | parser.add_argument('--data_cache_dir') 22 | parser.add_argument('--epsilon', '-e', type=float, default=1.0) 23 | parser.add_argument('--dimension', '-d', type=int, default=100) 24 | parser.add_argument('--num-matrices', '-T', type=int, default=20) 25 | if len(sys.argv) == 1: 26 | parser.print_help() 27 | sys.exit(1) 28 | return parser.parse_args() 29 | 30 | def log_box_volume(vecs): 31 | lb, _ = torch.min(vecs, dim=0) 32 | ub, _ = torch.max(vecs, dim=0) 33 | return torch.sum(torch.log(ub - lb)) 34 | 35 | def find_within_box(mat, indices, vocab, epsilon=1.0): 36 | cur_vecs = torch.stack([mat[i,:] for i in indices]) # n, d 37 | lb = torch.min(cur_vecs, dim=0)[0] 38 | ub = torch.max(cur_vecs, dim=0)[0] 39 | within_box = (torch.min(lb <= mat, dim=1)[0] & torch.min(mat <= ub, dim=1)[0]).nonzero() 40 | return sorted([vocab.get_word(i.item()) for i in within_box]) 41 | 42 | def measure_size(mat, indices, stdevs): 43 | cur_vecs = torch.stack([mat[i,:] for i in indices]) # n, d 44 | lb = torch.min(cur_vecs, dim=0)[0] 45 | ub = torch.max(cur_vecs, dim=0)[0] 46 | num_stdevs = (ub - lb) / (stdevs + 1e-7) 47 | return torch.mean(num_stdevs) 48 | 49 | def main(): 50 | OPTS.use_toy_data = False 51 | OPTS.use_lm = False 52 | OPTS.neighbor_file = data_util.NEIGHBOR_FILE 53 | OPTS.imdb_dir = text_classification.IMDB_DIR 54 | OPTS.test = False 55 | OPTS.glove_dir = vocabulary.GLOVE_DIR 56 | OPTS.glove = '840B.300d' 57 | OPTS.downsample_to = None 58 | OPTS.downsample_shard = 0 59 | OPTS.truncate_to = None 60 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 61 | print('Loading model', file=sys.stderr) 62 | model_state = dict(torch.load(os.path.join(OPTS.model_dir, 'model-checkpoint-%d.pth' % OPTS.checkpoint))) 63 | input_layer = model_state['linear_input.weight'] 64 | print('Loading data', file=sys.stderr) 65 | train_data, dev_data, glove_mat, attack_surface = text_classification.load_datasets(device, OPTS) 66 | learned_mat = torch.matmul(glove_mat, torch.t(input_layer)) 67 | vocab = train_data.vocab 68 | neighbors = attack_surface.neighbors 69 | words = vocab.word_list 70 | glove_scale = torch.std(glove_mat, dim=0) 71 | learned_scale = torch.std(learned_mat, dim=0) 72 | learned_smaller = 0 73 | total = 0 74 | for w in tqdm(words): 75 | if w not in neighbors: continue 76 | if not neighbors[w]: continue 77 | print('Word: %s' % w) 78 | cur_words = [w] + neighbors[w] 79 | indices = [vocab.get_index(x) for x in cur_words] 80 | glove_contained = find_within_box(glove_mat, indices, vocab, OPTS.epsilon) 81 | learned_contained = find_within_box(learned_mat, indices, vocab, OPTS.epsilon) 82 | print(' Neighborhood is %d words: [%s]' % (len(cur_words), ', '.join(sorted(cur_words)))) 83 | print(' GloVe contains %d words: [%s]' % (len(glove_contained), ', '.join(glove_contained))) 84 | print(' Learned contains %d words: [%s]' % (len(learned_contained), ', '.join(learned_contained))) 85 | glove_size = measure_size(glove_mat, indices, glove_scale) 86 | learned_size = measure_size(learned_mat, indices, learned_scale) 87 | print(' GloVe size: %.2f' % glove_size) 88 | print(' Learned size: %.2f' % learned_size) 89 | print() 90 | if learned_size < glove_size: 91 | learned_smaller += 1 92 | total += 1 93 | print(' So far: Learned smaller on %d/%d = %.2f%% words' % ( 94 | learned_smaller, total, 100 * learned_smaller / total)) 95 | 96 | 97 | if __name__ == '__main__': 98 | OPTS = parse_args() 99 | main() 100 | 101 | -------------------------------------------------------------------------------- /scripts/compute_p_values.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from scipy import stats 3 | import sys 4 | 5 | OPTS = None 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('scores_1') 10 | parser.add_argument('scores_2') 11 | if len(sys.argv) == 1: 12 | parser.print_help() 13 | sys.exit(1) 14 | return parser.parse_args() 15 | 16 | def read_scores(fn): 17 | with open(fn) as f: 18 | return [float(x) for x in f] 19 | 20 | def main(): 21 | x1 = read_scores(OPTS.scores_1) 22 | x2 = read_scores(OPTS.scores_2) 23 | print('Avg1 = %d/%d = %.2f%%' % (sum(x1), len(x1), 100 * sum(x1) / len(x1))) 24 | print('Avg2 = %d/%d = %.2f%%' % (sum(x2), len(x2), 100 * sum(x2) / len(x2))) 25 | s_ttest_unpaired, p_ttest_unpaired = stats.ttest_ind(x1, x2) 26 | print('Unpaired t-test: p=%.2e' % p_ttest_unpaired) 27 | s_ttest_paired, p_ttest_paired = stats.ttest_rel(x1, x2) 28 | print('Paired t-test: p=%.2e' % p_ttest_paired) 29 | s_mwu, p_mwu = stats.mannwhitneyu(x1, x2) 30 | print('Mann-Whitney U test (unpaired): p=%.2e' % p_mwu) 31 | s_wilcoxon, p_wilcoxon = stats.wilcoxon(x1, x2) 32 | print('Wilcoxon signed-rank test (paired): p=%.2e' % p_wilcoxon) 33 | 34 | if __name__ == '__main__': 35 | OPTS = parse_args() 36 | main() 37 | 38 | -------------------------------------------------------------------------------- /scripts/error_analysis.py: -------------------------------------------------------------------------------- 1 | """Error analysis on adversary successes.""" 2 | import argparse 3 | import collections 4 | import math 5 | import re 6 | import sys 7 | import termcolor 8 | 9 | OPTS = None 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('in_file') 14 | parser.add_argument('--out-score-file', help='Print 0-1 loss on adversary (for p-values)') 15 | parser.add_argument('--quiet', '-q', action='store_true') 16 | parser.add_argument('--snli', action='store_true') 17 | if len(sys.argv) == 1: 18 | parser.print_help() 19 | sys.exit(1) 20 | return parser.parse_args() 21 | 22 | def count_and_frac(x, f): 23 | count = sum(1 for e in x if f(e)) 24 | return count, len(x), 100 * count / len(x) 25 | 26 | def main(): 27 | bound_pattern = re.compile(r'.*Logit bounds: (.*) <= (.*) <= (.*), cert_correct=(.*)') 28 | success_pattern = re.compile(r'ADVERSARY SUCCESS on \("(.*)", ([^)]*)\): Found "(.*)" with margin (.*)') 29 | orig_wrong_pattern = re.compile(r'ORIGINAL PREDICTION WAS WRONG') 30 | failure_pattern = re.compile(r'ADVERSARY FAILURE .*') 31 | bounds = None 32 | num_ex = 1 33 | diffs = [] 34 | lens = [] 35 | diff_percs = [] 36 | probs = [] 37 | adv_scores = [] 38 | with open(OPTS.in_file) as f: 39 | for line in f: 40 | m_bound = re.match(bound_pattern, line.strip()) 41 | m_success = re.match(success_pattern, line.strip()) 42 | m_orig = re.match(orig_wrong_pattern, line.strip()) 43 | m_failure = re.match(failure_pattern, line.strip()) 44 | if line.startswith('ADVERSARY SUCCESS') and not m_success: 45 | print(line) 46 | if m_bound: 47 | cur_bounds = (float(m_bound.group(1)), float(m_bound.group(2)), 48 | float(m_bound.group(3))) 49 | elif m_success: 50 | orig_toks = m_success.group(1).split(' ') 51 | if OPTS.snli: 52 | y = m_success.group(2) 53 | else: 54 | y = int(m_success.group(2)) 55 | perturbed_toks = m_success.group(3).split(' ') 56 | margin = float(m_success.group(4)) 57 | orig_colored = [termcolor.colored(w1, 'cyan') if w1 != w2 else w1 58 | for w1, w2 in zip(orig_toks, perturbed_toks)] 59 | perturbed_colored = [termcolor.colored(w2, 'red') if w1 != w2 else w2 60 | for w1, w2 in zip(orig_toks, perturbed_toks)] 61 | num_diff = sum(1 for w1, w2 in zip(orig_toks, perturbed_toks) if w1 != w2) 62 | diff_perc = 100.0 * num_diff / len(orig_toks) 63 | if not OPTS.snli: 64 | orig_prob = 1 / (1 + math.exp(-(2 * y - 1) * cur_bounds[1])) 65 | if not OPTS.quiet: 66 | print('Case %d' % num_ex) 67 | print(' x_orig: %s' % ' '.join(orig_colored)) 68 | print(' x_pert: %s' % ' '.join(perturbed_colored)) 69 | print(' y : %d' % y) 70 | print(' diff : %d' % num_diff) 71 | print(' len : %d' % len(orig_toks)) 72 | print(' diff %%: %.2f%%' % diff_perc) 73 | if not OPTS.snli: 74 | print(' orig prob : %.6f' % orig_prob) 75 | print(' orig logits: %.6f <= %.6f <= %.6f' % cur_bounds) 76 | print(' new margin : %.2f' % margin) 77 | print() 78 | num_ex += 1 79 | diffs.append(num_diff) 80 | lens.append(len(orig_toks)) 81 | diff_percs.append(diff_perc) 82 | if not OPTS.snli: 83 | probs.append(orig_prob) 84 | adv_scores.append(0) 85 | elif m_orig: 86 | adv_scores.append(0) 87 | elif m_failure: 88 | adv_scores.append(1) 89 | print('Adversarial accuracy: %d/%d = %.2f%%' % ( 90 | sum(adv_scores), len(adv_scores), 100 * sum(adv_scores) / len(adv_scores))) 91 | print('Overall averages') 92 | print(' diff : %.2f' % (sum(diffs) / len(diffs))) 93 | print(' len : %.2f' % (sum(lens) / len(lens))) 94 | print(' diff %%: %.2f%%' % (sum(diff_percs) / len(diff_percs))) 95 | if not OPTS.snli: 96 | print(' probs : %.6f' % (sum(probs) / len(probs))) 97 | print(' p > 0.9: %d/%d = %.2f%%' % count_and_frac(probs, lambda p: p > 0.9)) 98 | print(' p > 0.8: %d/%d = %.2f%%' % count_and_frac(probs, lambda p: p > 0.8)) 99 | print(' p > 0.7: %d/%d = %.2f%%' % count_and_frac(probs, lambda p: p > 0.7)) 100 | print(' p > 0.6: %d/%d = %.2f%%' % count_and_frac(probs, lambda p: p > 0.6)) 101 | print(' diff <= 3 : %d/%d = %.2f%%' % count_and_frac(diffs, lambda d: d <= 3)) 102 | print(' diff >= 10: %d/%d = %.2f%%' % count_and_frac(diffs, lambda d: d >= 10)) 103 | print(' diff histogram:') 104 | diff_histo = collections.Counter(diffs) 105 | for k in sorted(diff_histo): 106 | print(' %02d: %d' % (k, diff_histo[k])) 107 | print(' histogram in list form: [%s]' % [diff_histo[i] for i in range(1, 31)]) 108 | if OPTS.out_score_file: 109 | with open(OPTS.out_score_file, 'w') as f: 110 | for s in adv_scores: 111 | print(s, file=f) 112 | 113 | if __name__ == '__main__': 114 | OPTS = parse_args() 115 | main() 116 | 117 | -------------------------------------------------------------------------------- /scripts/find_neighbors_connected_components.py: -------------------------------------------------------------------------------- 1 | """Find connected components of word vector neighbors.""" 2 | import collections 3 | import json 4 | import sys 5 | 6 | def main(): 7 | with open('data/counterfitted_neighbors.json') as f: 8 | neighbors = json.load(f) 9 | print('Read %d words' % len(neighbors)) 10 | for w in list(neighbors): 11 | if not neighbors[w]: 12 | del neighbors[w] 13 | words = list(neighbors) 14 | # Directed edges -> Undirected edges 15 | edges = collections.defaultdict(set) 16 | for w in words: 17 | for v in neighbors[w]: 18 | edges[w].add(v) 19 | edges[v].add(w) 20 | # DFS 21 | visited = set() 22 | comps = [] 23 | for start_word in words: 24 | if start_word in visited: continue 25 | cur_comp = set() 26 | stack = [start_word] 27 | while stack: 28 | w = stack.pop() 29 | if w in visited: continue 30 | visited.add(w) 31 | cur_comp.add(w) 32 | for v in edges[w]: 33 | if v not in visited: 34 | stack.append(v) 35 | comps.append(cur_comp) 36 | print('Found %d components' % len(comps)) 37 | print('Sum of component sizes is %d' % sum(len(c) for c in comps)) 38 | comps.sort(key=lambda x: len(x), reverse=True) 39 | print('Largest components: %s' % ', '.join(str(len(c)) for c in comps[:10])) 40 | for w in ('good', 'bad', 'awesome', 'excellent', 'terrible', 'una'): 41 | print('%s: visited==%s, in biggest component==%s' % (w, w in visited, w in comps[0])) 42 | 43 | if __name__ == '__main__': 44 | main() 45 | 46 | -------------------------------------------------------------------------------- /scripts/latex_errors.py: -------------------------------------------------------------------------------- 1 | """Print errors in LaTeX format.""" 2 | import argparse 3 | import random 4 | import re 5 | import sys 6 | 7 | OPTS = None 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('error_analysis_file') 12 | parser.add_argument('name') 13 | if len(sys.argv) == 1: 14 | parser.print_help() 15 | sys.exit(1) 16 | return parser.parse_args() 17 | 18 | def main(): 19 | random.seed(0) 20 | examples = [] 21 | cur_lines = [] 22 | with open(OPTS.error_analysis_file) as f: 23 | for line in f: 24 | line = line.strip() 25 | if line.startswith('x_orig') or line.startswith('x_pert'): 26 | line = re.sub(r'\x1B\[3[16]m', r'\\textbf{', line) 27 | line = re.sub(r'\x1B\[0m', r'}', line) 28 | line = re.sub('&', r'\\&', line) 29 | line = re.sub('x_orig', 'Original', line) 30 | line = re.sub('x_pert', 'Perturbed', line) 31 | cur_lines.append(r'\fbox{ \begin{minipage}{\textwidth}') 32 | cur_lines.append(line) 33 | cur_lines.append(r'\end{minipage} }') 34 | elif line.startswith('y : '): 35 | num = int(line[len('y : '):]) 36 | label = 'positive' if num else 'negative' 37 | cur_lines.append(r'Correct label: %s. \\' % label) 38 | elif line.startswith('orig prob'): 39 | prob = float(line[len('orig prob : '):]) 40 | cur_lines.append('Model confidence on original example: %.1f.' % (prob * 100)) 41 | cur_lines.append('\\end{figure*}') 42 | examples.append(cur_lines) 43 | cur_lines = [] 44 | random.shuffle(examples) 45 | for i, ex in enumerate(examples): 46 | print(r'\begin{figure*}') 47 | print(r'%s, example %d \\' % (OPTS.name, i + 1)) 48 | for line in ex: 49 | print(line) 50 | print('') 51 | 52 | if __name__ == '__main__': 53 | OPTS = parse_args() 54 | main() 55 | 56 | -------------------------------------------------------------------------------- /scripts/split_imdb.py: -------------------------------------------------------------------------------- 1 | """Split IMDB based on movies.""" 2 | import argparse 3 | import collections 4 | import glob 5 | import os 6 | import random 7 | import sys 8 | 9 | OPTS = None 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('imdb_dir') 14 | parser.add_argument('out_prefix') 15 | parser.add_argument('-s', '--rng_seed', type=int, default=314159) 16 | parser.add_argument('-f', '--train-frac', type=float, default=0.8) 17 | if len(sys.argv) == 1: 18 | parser.print_help() 19 | sys.exit(1) 20 | return parser.parse_args() 21 | 22 | def read_urls(label): 23 | with open(os.path.join(OPTS.imdb_dir, 'train', 'urls_%s.txt' % label)) as f: 24 | return [line.strip() for line in f] 25 | 26 | def read_files(label): 27 | files = glob.glob(os.path.join(OPTS.imdb_dir, 'train', label, '*.txt')) 28 | files.sort(key=lambda x: int(os.path.basename(x).split('_')[0])) 29 | return files 30 | 31 | def add_data(urls, files, url_to_file): 32 | for u, f in zip(urls, files): 33 | url_to_file[u].append(f) 34 | 35 | def write_data(split, files): 36 | out_fn = os.path.join(OPTS.imdb_dir, 'train', '%s_%s_files.txt' % (OPTS.out_prefix, split)) 37 | with open(out_fn, 'w') as f: 38 | for fn in files: 39 | prefix_len = len(os.path.join(OPTS.imdb_dir, 'train')) + 1 40 | print(fn[prefix_len:], file=f) 41 | 42 | def main(): 43 | random.seed(OPTS.rng_seed) 44 | pos_urls = read_urls('pos') 45 | neg_urls = read_urls('neg') 46 | pos_files = read_files('pos') 47 | neg_files = read_files('neg') 48 | url_to_file = collections.defaultdict(list) 49 | add_data(pos_urls, pos_files, url_to_file) 50 | add_data(neg_urls, neg_files, url_to_file) 51 | urls = sorted(list(url_to_file)) 52 | random.shuffle(urls) 53 | total_files = sum(len(x) for x in url_to_file.values()) 54 | print('Found %d urls, %d files' % (len(urls), total_files)) 55 | num_train = int(OPTS.train_frac * total_files) 56 | train_files = [] 57 | dev_files = [] 58 | for url in urls: 59 | files = url_to_file[url] 60 | if len(train_files) + len(files) <= num_train: 61 | train_files.extend(url_to_file[url]) 62 | else: 63 | dev_files.extend(url_to_file[url]) 64 | write_data('train', train_files) 65 | write_data('dev', dev_files) 66 | 67 | if __name__ == '__main__': 68 | OPTS = parse_args() 69 | main() 70 | 71 | -------------------------------------------------------------------------------- /src/attacks.py: -------------------------------------------------------------------------------- 1 | """Defines an attack surface.""" 2 | import collections 3 | import json 4 | import sys 5 | 6 | OPTS = None 7 | 8 | DEFAULT_MAX_LOG_P_DIFF = -5.0 # Maximum difference in log p for swaps. 9 | 10 | class AttackSurface(object): 11 | def get_swaps(self, words): 12 | """Return valid substitutions for each position in input |words|.""" 13 | raise NotImplementedError 14 | 15 | class WordSubstitutionAttackSurface(AttackSurface): 16 | def __init__(self, neighbors): 17 | self.neighbors = neighbors 18 | 19 | @classmethod 20 | def from_file(cls, neighbors_file): 21 | with open(neighbors_file) as f: 22 | return cls(json.load(f)) 23 | 24 | def get_swaps(self, words): 25 | swaps = [] 26 | for i in range(len(words)): 27 | if words[i] in self.neighbors: 28 | swaps.append(self.neighbors[words[i]]) 29 | else: 30 | swaps.append([]) 31 | return swaps 32 | 33 | class LMConstrainedAttackSurface(AttackSurface): 34 | """WordSubstitutionAttackSurface with language model constraint.""" 35 | def __init__(self, neighbors, lm_scores, min_log_p_diff=DEFAULT_MAX_LOG_P_DIFF): 36 | self.neighbors = neighbors 37 | self.lm_scores = lm_scores 38 | self.min_log_p_diff = min_log_p_diff 39 | 40 | @classmethod 41 | def from_files(cls, neighbors_file, lm_file): 42 | with open(neighbors_file) as f: 43 | neighbors = json.load(f) 44 | with open(lm_file) as f: 45 | lm_scores = {} 46 | cur_sent = None 47 | for line in f: 48 | toks = line.strip().split('\t') 49 | if len(toks) == 2: 50 | cur_sent = toks[1].lower() 51 | lm_scores[cur_sent] = collections.defaultdict(dict) 52 | else: 53 | word_idx, word, score = int(toks[1]), toks[2], float(toks[3]) 54 | lm_scores[cur_sent][word_idx][word] = score 55 | return cls(neighbors, lm_scores) 56 | 57 | def get_swaps(self, words): 58 | swaps = [] 59 | words = [word.lower() for word in words] 60 | s = ' '.join(words) 61 | if s not in self.lm_scores: 62 | raise KeyError('Unrecognized sentence "%s"' % s) 63 | for i in range(len(words)): 64 | if i in self.lm_scores[s]: 65 | cur_swaps = [] 66 | orig_score = self.lm_scores[s][i][words[i]] 67 | for swap, score in self.lm_scores[s][i].items(): 68 | if swap == words[i]: continue 69 | if swap not in self.neighbors[words[i]]: continue 70 | if score - orig_score >= self.min_log_p_diff: 71 | cur_swaps.append(swap) 72 | swaps.append(cur_swaps) 73 | else: 74 | swaps.append([]) 75 | return swaps 76 | -------------------------------------------------------------------------------- /src/data_util.py: -------------------------------------------------------------------------------- 1 | """Data handler classes and methods""" 2 | 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader, BatchSampler 5 | import torch.nn.functional as F 6 | import random 7 | 8 | 9 | NEIGHBOR_FILE = 'data/counterfitted_neighbors.json' 10 | 11 | 12 | def dict_batch_to_device(batch, device): 13 | """ 14 | Moves a batch of data to device 15 | Args: 16 | - batch: Can be a Torch tensor or a dict where the values are torch tensors 17 | - device: A Torch device to move all the tensors to 18 | Returns: 19 | - a batch of the same type as input batch but on the device 20 | If a dict, also a dict with same keys 21 | """ 22 | try: 23 | return batch.to(device) 24 | except AttributeError: 25 | # don't have a to function, must be a dict, recursively move to device 26 | return {k: dict_batch_to_device(v, device) for k, v in batch.items()} 27 | 28 | 29 | class RawDataset(Dataset): 30 | """ 31 | Dataset that only holds unprocessed text values 32 | Subsequent tasks should implement the get_word_set method for vocab picking 33 | """ 34 | def __init__(self, train_data, dev_data): 35 | self.train_data = train_data 36 | self.dev_data = dev_data 37 | self.data = train_data + dev_data 38 | 39 | def __getitem__(self, index): 40 | return self.data[index] 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def get_word_set(self, neighbors): 46 | """ 47 | Returns all the words found in this dataset 48 | and all their neighbors in a set 49 | """ 50 | raise NotImplementedError 51 | 52 | 53 | class ProcessedDataset(Dataset): 54 | """ 55 | Dataset that holds processed examples 56 | Subsequent tasks should implement the methods defined below 57 | to interface with the rest of the training module 58 | """ 59 | def __init__(self, raw_data, vocab, examples): 60 | self.raw_data = raw_data 61 | self.vocab = vocab 62 | self.examples = examples 63 | 64 | @classmethod 65 | def from_raw_data(self, raw_data, vocab, neighbors=None, truncate_to=None): 66 | """ 67 | Given a RawDataset of examples, a vocab set and potentially dict of neigbors, 68 | initializes the dataset such that self.examples[i] corresponds to the ith example 69 | that can be passed to the relevant model 70 | """ 71 | raise NotImplementedError 72 | 73 | @classmethod 74 | def get_raw_data(cls, *args, **kwargs): 75 | """ 76 | Method that returns a RawClassificationDataset object 77 | that holds the entire dataset. This would later be passed to __init__ 78 | """ 79 | raise NotImplementedError 80 | 81 | @staticmethod 82 | def collate_examples(examples): 83 | """ 84 | Method that takes a list of examples from self.examples and collates them into a single example 85 | with batched tensors 86 | """ 87 | raise NotImplementedError 88 | 89 | @staticmethod 90 | def example_len(example): 91 | """ 92 | Given an example returns the length of its principal sequence(s). Used for pooling samples of similar 93 | lengths into same batches to reduce the amount of padding necessary 94 | """ 95 | raise NotImplementedError 96 | 97 | def __len__(self): 98 | return len(self.examples) 99 | 100 | def __getitem__(self, index): 101 | return self.examples[index] 102 | 103 | def get_loader(self, batch_size): 104 | batch_sampler = PooledBatchSampler(self, batch_size, sort_key=self.example_len) 105 | return DataLoader(self, pin_memory=True, collate_fn=self.collate_examples, batch_sampler=batch_sampler) 106 | 107 | 108 | def multi_dim_padded_cat(tensors, dim, padding_value=0): 109 | """ 110 | Concatenates tensors along dim, padding elements to the largest length at 111 | each dimension. Assumes all tensors have the same dimensionality but makes no 112 | other assumptions about their size 113 | """ 114 | if dim == 0: 115 | original_ordering = dim_first_ordering = list(range(len(tensors[0].shape))) 116 | else: 117 | # If dim is not 0, we make it so for ease later and re-permute at the end 118 | dims = list(range(len(tensors[0].shape))) 119 | dims.pop(dim) 120 | dim_first_ordering = [dim] + dims 121 | original_ordering = [] 122 | for dim_idx in range(len(dim_first_ordering)): 123 | if dim_idx < dim: 124 | original_ordering.append(dim_idx + 1) 125 | elif dim_idx == dim: 126 | original_ordering.append(0) 127 | else: 128 | original_ordering.append(dim_idx) 129 | out_shape = [] 130 | for in_dim in dim_first_ordering: 131 | out_shape.append(max(tensor.shape[in_dim] for tensor in tensors)) 132 | out_shape[0] = sum(tensor.shape[dim] for tensor in tensors) 133 | out_tensor = tensors[0].new_empty(out_shape) 134 | cur_idx = 0 135 | for tensor in tensors: 136 | out_shape[0] = tensor.shape[dim] 137 | pad = [] 138 | # see torch.nn.functional.pad documentation for why we need this format 139 | for tensor_dim, out_dim in list(zip(tensor.shape, out_shape))[:0:-1]: 140 | pad = pad + [0, out_dim - tensor_dim] 141 | out_tensor[cur_idx:cur_idx+out_shape[0], ...] = F.pad(tensor.permute(*dim_first_ordering), pad) 142 | cur_idx += out_shape[0] 143 | if dim != 0: 144 | out_tensor = out_tensor.permute(*original_ordering) 145 | return out_tensor 146 | 147 | 148 | class PooledBatchSampler(BatchSampler): 149 | def __init__(self, dataset, batch_size, sort_within_batch=True, sort_key=len): 150 | self.dataset_lens = [sort_key(el) for el in dataset] 151 | self.batch_size = batch_size 152 | self.sort_within_batch = sort_within_batch 153 | 154 | def __iter__(self): 155 | """ 156 | 1- Partitions data indices into chunks of batch_size * 100 157 | 2- Sorts each chunk by the sort_key 158 | 3- Batches sorted chunks sequentially 159 | 4- Shuffles the batches 160 | 5- Yields each batch 161 | """ 162 | idx_chunks = torch.split(torch.randperm(len(self.dataset_lens)), self.batch_size * 100) 163 | for idx_chunk in idx_chunks: 164 | sorted_chunk = torch.tensor(sorted(idx_chunk.tolist(), key=lambda idx: self.dataset_lens[idx])) 165 | chunk_batches = [chunk.tolist() for chunk in torch.split(sorted_chunk, self.batch_size)] 166 | random.shuffle(chunk_batches) 167 | for batch in chunk_batches: 168 | if self.sort_within_batch: 169 | batch.reverse() 170 | yield batch 171 | 172 | def __len__(self): 173 | return (len(self.dataset_lens) + self.batch_size - 1) // self.batch_size 174 | 175 | 176 | class DataAugmenter(object): 177 | def __init__(self, augment_by): 178 | self.augment_by = augment_by 179 | 180 | def augment(self, dataset): 181 | raise NotImplementedError 182 | -------------------------------------------------------------------------------- /src/entailment.py: -------------------------------------------------------------------------------- 1 | """IBP textual entailment model.""" 2 | 3 | from enum import Enum 4 | import glob 5 | import itertools 6 | import json 7 | import os 8 | import pickle 9 | import random 10 | 11 | from nltk import word_tokenize 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | from torch.nn.utils.rnn import pad_sequence 17 | from tqdm import tqdm 18 | 19 | import attacks 20 | import data_util 21 | import ibp 22 | import vocabulary 23 | 24 | 25 | SNLI_DIR = 'data/snli' 26 | LM_FILE = 'data/lm_scores/snli_all.txt' 27 | LOSS_FUNC = nn.BCEWithLogitsLoss() 28 | 29 | 30 | class AdversarialModel(nn.Module): 31 | def __init__(self): 32 | super(AdversarialModel, self).__init__() 33 | 34 | def query(self, dataset, device, batch_size=1, return_bounds=False): 35 | """Query the model on a Dataset. 36 | 37 | Args: 38 | dataset: a Dataset. 39 | device: torch device. 40 | neighbors: if provided, pass this to Dataset(). 41 | batch_size: batch size (default=1). 42 | 43 | Returns: Tensor of logits & gold labels 44 | """ 45 | data = dataset.get_loader(batch_size) 46 | output = [] 47 | gold = [] 48 | with torch.no_grad(): 49 | for batch in data: 50 | batch = data_util.dict_batch_to_device(batch, device) 51 | output.append(self.forward(batch, compute_bounds=return_bounds)) 52 | gold.append(batch['y']) 53 | return ibp.cat(output, dim=0), ibp.cat(gold, dim=0) 54 | 55 | 56 | class EntailmentLabels(Enum): 57 | contradiction = 0 58 | neutral = 1 59 | entailment = 2 60 | 61 | 62 | class BOWModel(AdversarialModel): 63 | """Bag of words + MLP""" 64 | 65 | def __init__(self, word_vec_size, hidden_size, word_mat, 66 | dropout_prob=0.1, num_layers=3, no_wordvec_layer=False): 67 | super(BOWModel, self).__init__() 68 | self.embs = ibp.Embedding.from_pretrained(word_mat) 69 | self.rotation = ibp.Linear(word_vec_size, hidden_size) 70 | self.sum_drop = ibp.Dropout(dropout_prob) if dropout_prob else None 71 | layers = [] 72 | for i in range(num_layers): 73 | layers.append(ibp.Linear(2*hidden_size, 2*hidden_size)) 74 | layers.append(ibp.Activation(F.relu)) 75 | if dropout_prob: 76 | layers.append(ibp.Dropout(dropout_prob)) 77 | layers.append(ibp.Linear(2*hidden_size, len(EntailmentLabels))) 78 | layers.append(ibp.LogSoftmax(dim=1)) 79 | self.layers = nn.Sequential(*layers) 80 | 81 | def forward(self, batch, compute_bounds=True, cert_eps=1.0): 82 | """ 83 | Forward pass of BOWModel. 84 | Args: 85 | batch: A batch dict from an EntailmentDataset with the following keys: 86 | - prem: tensor of word vector indices for premise (B, p, 1) 87 | - hypo: tensor of word vector indices for hypothesis (B, h, 1) 88 | - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) 89 | - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) 90 | - prem_lengths: lengths of premises, size (B,) 91 | - hypo_lengths: lengths of hypotheses, size (B,) 92 | compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values 93 | cert_eps: float, scaling factor for the interval bounds. 94 | """ 95 | def encode(sequence, mask): 96 | vecs = self.embs(sequence) 97 | vecs = self.rotation(vecs) 98 | if isinstance(vecs, ibp.DiscreteChoiceTensor): 99 | vecs = vecs.to_interval_bounded(eps=cert_eps) 100 | z1 = ibp.activation(F.relu, vecs) 101 | z1_masked = z1 * mask.unsqueeze(-1) 102 | z1_pooled = ibp.sum(z1_masked, -2) 103 | return z1_pooled 104 | if not compute_bounds: 105 | batch['prem']['x'] = batch['prem']['x'].val 106 | batch['hypo']['x'] = batch['hypo']['x'].val 107 | prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) 108 | hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) 109 | input_encoded = ibp.cat([prem_encoded, hypo_encoded], -1) 110 | logits = self.layers(input_encoded) 111 | return logits 112 | 113 | 114 | class DecompAttentionModel(AdversarialModel): 115 | """Decomposable Attention model from Parikh et al""" 116 | 117 | def __init__(self, word_vec_size, hidden_size, word_mat, 118 | dropout_prob=0.1, num_layers=2, no_wordvec_layer=False): 119 | super(DecompAttentionModel, self).__init__() 120 | self.embs = ibp.Embedding.from_pretrained(word_mat) 121 | self.null = nn.Parameter(torch.normal(mean=torch.zeros(word_vec_size))) 122 | self.rotation = None 123 | hidden_size = word_vec_size 124 | self.rotation = ibp.Linear(word_vec_size, hidden_size) 125 | 126 | def get_feedforward_layers(num_layers, input_size, hidden_size, output_size): 127 | layers = [] 128 | for i in range(num_layers): 129 | layer_in_size = input_size if i == 0 else hidden_size 130 | layer_out_size = output_size if i == num_layers - 1 else hidden_size 131 | if dropout_prob: 132 | layers.append(ibp.Dropout(dropout_prob)) 133 | layers.append(ibp.Linear(layer_in_size, layer_out_size)) 134 | if i < num_layers - 1: 135 | layers.append(ibp.Activation(F.relu)) 136 | return layers 137 | 138 | ff_layers = get_feedforward_layers(num_layers, hidden_size, hidden_size, 1) 139 | self.feedforward = nn.Sequential(*ff_layers) 140 | 141 | compare_layers = get_feedforward_layers(num_layers, 2 * hidden_size, hidden_size, hidden_size) 142 | self.compare_ff = nn.Sequential(*compare_layers) 143 | 144 | output_layers = get_feedforward_layers(num_layers, 2 * hidden_size, hidden_size, hidden_size) 145 | output_layers.append(ibp.Linear(hidden_size, len(EntailmentLabels))) 146 | output_layers.append(ibp.LogSoftmax(dim=1)) 147 | self.output_layer = nn.Sequential(*output_layers) 148 | 149 | def forward(self, batch, compute_bounds=True, cert_eps=1.0): 150 | """ 151 | Forward pass of DecompAttentionModel. 152 | Args: 153 | batch: A batch dict from an EntailmentDataset with the following keys: 154 | - prem: tensor of word vector indices for premise (B, p, 1) 155 | - hypo: tensor of word vector indices for hypothesis (B, h, 1) 156 | - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) 157 | - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) 158 | - prem_lengths: lengths of premises, size (B,) 159 | - hypo_lengths: lengths of hypotheses, size (B,) 160 | compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values 161 | cert_eps: float, scaling factor for the interval bounds. 162 | """ 163 | def encode(sequence, mask): 164 | vecs = self.embs(sequence) 165 | if isinstance(vecs, ibp.DiscreteChoiceTensor): 166 | null = torch.zeros_like(vecs.val[0]) 167 | null_choice = torch.zeros_like(vecs.choice_mat[0]) 168 | null[0] = self.null 169 | null_choice[0, 0] = self.null 170 | vecs.val = vecs.val + null 171 | vecs.choice_mat = vecs.choice_mat + null_choice 172 | else: 173 | null = torch.zeros_like(vecs[0]) 174 | null[0] = self.null 175 | vecs = vecs + null 176 | vecs = self.rotation(vecs) 177 | if isinstance(vecs, ibp.DiscreteChoiceTensor): 178 | vecs = vecs.to_interval_bounded(eps=cert_eps) 179 | return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1) 180 | 181 | if not compute_bounds: 182 | batch['prem']['x'] = batch['prem']['x'].val 183 | batch['hypo']['x'] = batch['hypo']['x'].val 184 | prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) # (bXpXe) 185 | hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) # (bXhXe) 186 | prem_weights = self.feedforward(prem_encoded) * batch['prem']['mask'].unsqueeze(-1) # (bXpX1) 187 | hypo_weights = self.feedforward(hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1) # (bXhX1) 188 | attention = ibp.bmm(prem_weights, hypo_weights.permute(0,2,1)) # (bXpX1) X (bX1Xh) => (bXpXh) 189 | attention_mask = batch['prem']['mask'].unsqueeze(-1) * batch['hypo']['mask'].unsqueeze(1) 190 | attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20) 191 | attended_prem = self.attend_on(hypo_encoded, prem_encoded, attention_masked) # (bXpX2e) 192 | attended_hypo = self.attend_on(prem_encoded, hypo_encoded, attention_masked.permute(0,2,1)) # (bXhX2e) 193 | compared_prem = self.compare_ff(attended_prem) * batch['prem']['mask'].unsqueeze(-1) # (bXpXhid) 194 | compared_hypo = self.compare_ff(attended_hypo) * batch['hypo']['mask'].unsqueeze(-1) # (bXhXhid) 195 | prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1) # (bXhid) 196 | hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1) # (bXhid) 197 | aggregate = ibp.cat([prem_aggregate, hypo_aggregate], dim=-1) # (bX2hid) 198 | return self.output_layer(aggregate) # (b) 199 | 200 | def attend_on(self, source, target, attention): 201 | """ 202 | Args: 203 | - source: (bXsXe) 204 | - target: (bXtXe) 205 | - attention: (bXtXs) 206 | """ 207 | attention_logsoftmax = ibp.log_softmax(attention, 1) 208 | attention_normalized = ibp.activation(torch.exp, attention_logsoftmax) 209 | attended_target = ibp.matmul_nneg(attention_normalized, source) # (bXtXe) 210 | return ibp.cat([target, attended_target], dim=-1) 211 | 212 | def load_model(word_mat, device, opts): 213 | """ 214 | Try to load a model on the device given the word_mat and opts. 215 | Tries to load a model from the given or latest checkpoint if specified in the opts. 216 | Otherwise instantiates a new model on the device. 217 | """ 218 | vec_size = vocabulary.GLOVE_CONFIGS[opts.glove]['size'] 219 | if opts.model == 'bow': 220 | model = BOWModel( 221 | vec_size, vec_size, word_mat, num_layers=opts.num_layers, dropout_prob=opts.dropout_prob, no_wordvec_layer=opts.no_wordvec_layer).to(device) 222 | if opts.model == 'decomp-attn': 223 | model = DecompAttentionModel( 224 | vec_size, opts.hidden_size, word_mat, dropout_prob=opts.dropout_prob, num_layers=opts.num_layers, no_wordvec_layer=opts.no_wordvec_layer).to(device) 225 | if opts.load_dir: 226 | try: 227 | if opts.load_ckpt is None: 228 | load_fn = sorted(glob.glob(os.path.join(opts.load_dir, 'model-checkpoint-[0-9]+.pth')))[-1] 229 | else: 230 | load_fn = os.path.join(opts.load_dir, 'model-checkpoint-%d.pth' % opts.load_ckpt) 231 | print('Loading model from %s.' % load_fn) 232 | # Cache the word vectors before loading to avoid size mismatches 233 | state_dict = dict(torch.load(load_fn)) 234 | if opts.prepend_null: 235 | null_vec = state_dict['embs.weight'][vocabulary.NULL_INDEX] 236 | model.embs.weight[vocabulary.NULL_INDEX] = null_vec 237 | state_dict['embs.weight'] = model.embs.weight 238 | model.load_state_dict(state_dict, strict=False) 239 | print('Finished loading model.') 240 | except Exception as ex: 241 | print("Couldn't load model, starting anew: {}".format(ex)) 242 | return model 243 | 244 | 245 | def load_datasets(device, opts): 246 | """ 247 | Loads entailment datasets given opts on the device and returns the dataset. 248 | If a data cache is specified in opts and the cached data there is of the same class 249 | as the one specified in opts, uses the cache. Otherwise reads from the raw dataset 250 | files specified in OPTS. 251 | Returns: 252 | - train_data: EntailmentDataset - Processed training dataset 253 | - dev_data: Optional[EntailmentDataset] - Processed dev dataset if raw dev data was found or 254 | dev_frac was specified in opts 255 | - word_mat: torch.Tensor 256 | - attack_surface: AttackSurface - defines the adversarial attack surface 257 | """ 258 | data_class = ToyEntailmentDataset if opts.use_toy_data else SNLIDataset 259 | try: 260 | if opts.adv_only: 261 | train_data = None 262 | train_attack_surface = None 263 | else: 264 | with open(os.path.join(opts.data_cache_dir, 'train_data.pkl'), 'rb') as infile: 265 | train_data = pickle.load(infile) 266 | if not isinstance(train_data, data_class): 267 | raise Exception("Cached dataset of wrong class: {}".format(type(train_data))) 268 | with open(os.path.join(opts.data_cache_dir, 'train_attack_surface.pkl'), 'rb') as infile: 269 | train_attack_surface = pickle.load(infile) 270 | with open(os.path.join(opts.data_cache_dir, 'dev_data.pkl'), 'rb') as infile: 271 | dev_data = pickle.load(infile) 272 | if not isinstance(dev_data, data_class): 273 | raise Exception("Cached dataset of wrong class: {}".format(type(train_data))) 274 | with open(os.path.join(opts.data_cache_dir, 'word_mat.pkl'), 'rb') as infile: 275 | word_mat = pickle.load(infile) 276 | with open(os.path.join(opts.data_cache_dir, 'dev_attack_surface.pkl'), 'rb') as infile: 277 | dev_attack_surface = pickle.load(infile) 278 | print("Loaded data from {}.".format(opts.data_cache_dir)) 279 | except Exception as ex: 280 | print('Couldn\'t load data from cache: {}, reading from raw files'.format(ex)) 281 | if opts.adv_only: 282 | train_data = None 283 | train_attack_surface = None 284 | else: 285 | train_attack_surface = attacks.WordSubstitutionAttackSurface.from_file(opts.neighbor_file) 286 | if opts.use_lm: 287 | dev_attack_surface = attacks.LMConstrainedAttackSurface.from_files( 288 | opts.neighbor_file, opts.snli_lm_file) 289 | else: 290 | dev_attack_surface = attacks.WordSubstitutionAttackSurface.from_file(opts.neighbor_file) 291 | raw_data = data_class.get_raw_data(opts) 292 | word_set = raw_data.get_word_set(train_attack_surface, dev_attack_surface=dev_attack_surface) 293 | vocab, word_mat = vocabulary.Vocabulary.read_word_vecs(word_set, opts.glove_dir, opts.glove, device, prepend_null=opts.prepend_null, normalize=opts.normalize_word_vecs) 294 | if not opts.adv_only: 295 | train_data = data_class.from_raw_data(raw_data.train_data, vocab, attack_surface=train_attack_surface, 296 | downsample_to=opts.downsample_to, prepend_null=opts.prepend_null, use_tqdm=True, downsample_shard=opts.downsample_shard) 297 | dev_data = data_class.from_raw_data(raw_data.dev_data, vocab, attack_surface=dev_attack_surface, 298 | downsample_to=opts.downsample_to, prepend_null=opts.prepend_null, use_tqdm=True, downsample_shard=opts.downsample_shard) 299 | if opts.data_cache_dir: 300 | if not opts.adv_only: 301 | with open(os.path.join(opts.data_cache_dir, 'train_data.pkl'), 'wb') as outfile: 302 | pickle.dump(train_data, outfile) 303 | with open(os.path.join(opts.data_cache_dir, 'train_attack_surface.pkl'), 'wb') as outfile: 304 | pickle.dump(train_attack_surface, outfile) 305 | with open(os.path.join(opts.data_cache_dir, 'dev_data.pkl'), 'wb') as outfile: 306 | pickle.dump(dev_data, outfile) 307 | with open(os.path.join(opts.data_cache_dir, 'word_mat.pkl'), 'wb') as outfile: 308 | pickle.dump(word_mat, outfile) 309 | with open(os.path.join(opts.data_cache_dir, 'dev_attack_surface.pkl'), 'wb') as outfile: 310 | pickle.dump(dev_attack_surface, outfile) 311 | return train_data, dev_data, word_mat, dev_attack_surface 312 | 313 | 314 | def get_margins(model_output, gold_labels): 315 | if isinstance(model_output, ibp.IntervalBoundedTensor): 316 | logits = model_output.val 317 | w_true_class_pred = (model_output.lb * gold_labels).sum(dim=1) 318 | w_highest_false_pred = (model_output.ub + (gold_labels * -1e20)).max(dim=1)[0] 319 | w_value_margin = w_true_class_pred - w_highest_false_pred 320 | else: 321 | logits = model_output 322 | w_value_margin = None 323 | true_class_pred = (logits * gold_labels).sum(dim=1) 324 | highest_false_pred = (logits + (gold_labels * -1e20)).max(dim=1)[0] 325 | value_margin = true_class_pred - highest_false_pred 326 | return value_margin, w_value_margin 327 | 328 | 329 | def compute_is_correct(model_output, gold_labels): 330 | if isinstance(model_output, ibp.IntervalBoundedTensor): 331 | logits = model_output.val 332 | # Worst case pred. is the LB of correct class 333 | # combined with the UBs of the other classes 334 | worst_case_pred = (gold_labels * model_output.lb + (1 - gold_labels) * model_output.ub).argmax(dim=1) 335 | gold_labels = gold_labels.argmax(dim=1) 336 | cert_correct = ((worst_case_pred - gold_labels) == 0) 337 | else: 338 | gold_labels = gold_labels.argmax(dim=1) 339 | logits = model_output 340 | cert_correct = None 341 | predictions = logits.argmax(dim=1) 342 | correct = ((predictions - gold_labels) == 0) 343 | return correct, cert_correct 344 | 345 | 346 | def num_correct(model_output, gold_labels): 347 | """ 348 | Given the output of model and gold labels returns number of correct and certified correct 349 | predictions 350 | Args: 351 | - model_output: output of the model, could be ibp.IntervalBoundedTensor or torch.Tensor 352 | - gold_labels: torch.Tensor 353 | Returns: 354 | - num_correct: int - number of correct predictions from the actual model output 355 | - num_cert_correct - number of bounds-certified correct predictions if the model_output was an 356 | IntervalBoundedTensor, 0 otherwise. 357 | """ 358 | is_correct, is_cert_correct = compute_is_correct(model_output, gold_labels) 359 | num_correct = is_correct.sum().item() 360 | if is_cert_correct is not None: 361 | num_cert_correct = is_cert_correct.sum().item() 362 | else: 363 | num_cert_correct = 0 364 | return num_correct, num_cert_correct 365 | 366 | 367 | class RawEntailmentDataset(data_util.RawDataset): 368 | """ 369 | Dataset that only holds (prem, hypo) ,y as ((str, str), str) tuples 370 | """ 371 | def get_word_set(self, train_attack_surface, dev_attack_surface=None): 372 | if dev_attack_surface is None: 373 | dev_attack_surface = train_attack_surface 374 | word_set = set() 375 | for x, y in self.train_data: 376 | prem, hypo = x 377 | prem_words = prem.split() 378 | hypo_words = hypo.split() 379 | words = prem_words + hypo_words 380 | for w in words: 381 | word_set.add(w) 382 | try: 383 | swaps = train_attack_surface.get_swaps(hypo_words) 384 | for cur_swaps in swaps: 385 | for w in cur_swaps: 386 | word_set.add(w) 387 | except KeyError: 388 | pass 389 | for x, y in self.dev_data: 390 | prem, hypo = x 391 | prem_words = prem.split() 392 | hypo_words = hypo.split() 393 | words = prem_words + hypo_words 394 | for w in words: 395 | word_set.add(w) 396 | try: 397 | swaps = dev_attack_surface.get_swaps(hypo_words) 398 | for cur_swaps in swaps: 399 | for w in cur_swaps: 400 | word_set.add(w) 401 | except KeyError: 402 | pass 403 | return word_set 404 | 405 | 406 | class EntailmentDataset(data_util.ProcessedDataset): 407 | """ 408 | Dataset that holds processed example dicts 409 | """ 410 | @classmethod 411 | def from_raw_data(cls, raw_data, vocab, attack_surface=None, truncate_to=None, downsample_to=None, prepend_null=False, use_tqdm=False, downsample_shard=None): 412 | if truncate_to: 413 | raise NotImplementedError # Probably never needed since SNLI sentences are so short 414 | if downsample_to: 415 | if downsample_shard is None: 416 | downsample_shard = 0 417 | raw_data = raw_data[downsample_shard * downsample_to:(downsample_shard + 1) * downsample_to] 418 | examples = [] 419 | iteration = tqdm(raw_data) if use_tqdm else raw_data 420 | for inpt, y in iteration: 421 | try: 422 | examples.append(cls.process_example(inpt, y, vocab, attack_surface, prepend_null=prepend_null)) 423 | except ValueError as err: 424 | print('Error processing example: {} Skipping.'.format(err)) 425 | return cls(raw_data, vocab, examples) 426 | 427 | @classmethod 428 | def process_example(cls, inpt, y, vocab, attack_surface, skip_prem=True, prepend_null=False): 429 | example = {} 430 | for idx,sequence in enumerate(['prem', 'hypo']): 431 | x = inpt[idx] 432 | all_words = x.split() 433 | if prepend_null: 434 | all_words = [''] + all_words 435 | words = [w for w in all_words if w in vocab] # Delete UNK words 436 | word_idxs = [vocab.get_index(w) for w in words] 437 | if len(word_idxs) < 1: 438 | raise ValueError("Sequence:\n\t{}\n is all UNK words in sample:\n \t{}\n".format(x, inpt)) 439 | x_torch = torch.tensor(word_idxs).view(1, -1, 1) # (1, T, d) 440 | if attack_surface and not (skip_prem and sequence == 'prem'): 441 | swap_words = all_words[1:] if prepend_null else all_words # Don't try to swap NULL 442 | all_swaps = attack_surface.get_swaps(swap_words) 443 | if prepend_null: 444 | all_swaps = [[]] + all_swaps # Add an empty swaps list at index 0 for NULL 445 | swaps = [s for w, s in zip(all_words, all_swaps) if w in vocab] 446 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 447 | choices_word_idxs = [ 448 | torch.tensor([vocab.get_index(c) for c in c_list], dtype=torch.long) for c_list in choices 449 | ] 450 | if any(0 in choices.view(-1).tolist() for choices in choices_word_idxs): 451 | raise ValueError("UNK tokens found") 452 | choices_torch = pad_sequence(choices_word_idxs, batch_first=True).unsqueeze(2).unsqueeze(0) # (1, T, C, 1) 453 | choices_mask = (choices_torch.squeeze(-1) != 0).long() # (1, T, C) 454 | else: 455 | choices_torch = x_torch.view(1, -1, 1, 1) # (1, T, 1, 1) 456 | choices_mask = torch.ones_like(x_torch.view(1, -1, 1)) 457 | mask_torch = torch.ones((1, len(word_idxs))) 458 | x_bounded = ibp.DiscreteChoiceTensor(x_torch, choices_torch, choices_mask, mask_torch) 459 | lengths_torch = torch.tensor(len(word_idxs)).view(1) 460 | example[sequence] = dict(x=x_bounded, mask=mask_torch, lengths=lengths_torch) 461 | example['y'] = torch.zeros((1, len(EntailmentLabels)), dtype=torch.float) 462 | example['y'][0, y.value] = 1 463 | return example 464 | 465 | @staticmethod 466 | def example_len(example): 467 | """ 468 | Should sort by first hypothesis length then premise length 469 | """ 470 | return (example['hypo']['x'].shape[1], example['prem']['x'].shape[1]) 471 | 472 | @staticmethod 473 | def collate_examples(examples): 474 | """ 475 | Turns a list of examples into a workable batch: 476 | """ 477 | if len(examples) == 1: 478 | return examples[0] 479 | B = len(examples) 480 | 481 | max_prem_len = max(ex['prem']['x'].shape[1] for ex in examples) 482 | prem_vals = [] 483 | prem_choice_mats = [] 484 | prem_choice_masks = [] 485 | prem_lengths = torch.zeros((B, ), dtype=torch.long) 486 | prem_masks = torch.zeros((B, max_prem_len)) 487 | 488 | max_hypo_len = max(ex['hypo']['x'].shape[1] for ex in examples) 489 | hypo_vals = [] 490 | hypo_choice_mats = [] 491 | hypo_choice_masks = [] 492 | hypo_lengths = torch.zeros((B, ), dtype=torch.long) 493 | hypo_masks = torch.zeros((B, max_hypo_len)) 494 | 495 | gold_ys = [] 496 | for i, ex in enumerate(examples): 497 | prem_vals.append(ex['prem']['x'].val) 498 | prem_choice_mats.append(ex['prem']['x'].choice_mat) 499 | prem_choice_masks.append(ex['prem']['x'].choice_mask) 500 | cur_prem_len = ex['prem']['x'].shape[1] 501 | prem_masks[i, :cur_prem_len] = 1 502 | prem_lengths[i] = ex['prem']['lengths'][0] 503 | 504 | hypo_vals.append(ex['hypo']['x'].val) 505 | hypo_choice_mats.append(ex['hypo']['x'].choice_mat) 506 | hypo_choice_masks.append(ex['hypo']['x'].choice_mask) 507 | cur_hypo_len = ex['hypo']['x'].shape[1] 508 | hypo_masks[i, :cur_hypo_len] = 1 509 | hypo_lengths[i] = ex['hypo']['lengths'][0] 510 | 511 | gold_ys.append(ex['y']) 512 | prem_vals = data_util.multi_dim_padded_cat(prem_vals, 0).long() 513 | prem_choice_mats = data_util.multi_dim_padded_cat(prem_choice_mats, 0).long() 514 | prem_choice_masks = data_util.multi_dim_padded_cat(prem_choice_masks, 0).long() 515 | 516 | hypo_vals = data_util.multi_dim_padded_cat(hypo_vals, 0).long() 517 | hypo_choice_mats = data_util.multi_dim_padded_cat(hypo_choice_mats, 0).long() 518 | hypo_choice_masks = data_util.multi_dim_padded_cat(hypo_choice_masks, 0).long() 519 | 520 | y = torch.cat(gold_ys, 0) 521 | return { 522 | 'prem': { 523 | 'x': ibp.DiscreteChoiceTensor(prem_vals, prem_choice_mats, prem_choice_masks, prem_masks), 524 | 'mask': prem_masks, 'lengths': prem_lengths}, 525 | 'hypo': { 526 | 'x': ibp.DiscreteChoiceTensor(hypo_vals, hypo_choice_mats, hypo_choice_masks, hypo_masks), 527 | 'mask': hypo_masks, 'lengths': hypo_lengths}, 528 | 'y': y} 529 | 530 | 531 | class ToyEntailmentDataset(EntailmentDataset): 532 | """ 533 | Dataset that holds toy entailment data 534 | """ 535 | @classmethod 536 | def get_raw_data(cls, *args, **kwargs): 537 | data = [ 538 | (('man running', 'human moving'), EntailmentLabels.entailment), 539 | (('man running', 'man running'), EntailmentLabels.entailment), 540 | (('man running', 'human moving'), EntailmentLabels.entailment), 541 | ] 542 | return RawEntailmentDataset(data, data) 543 | 544 | 545 | class SNLIDataset(EntailmentDataset): 546 | """ 547 | Dataset that holds the SNLI sentiment classification data 548 | """ 549 | @classmethod 550 | def get_raw_data(cls, opts): 551 | splits = {} 552 | # Number of examples in each split for better tqdm 553 | totals = {'train': 550152, 'dev': 10000, 'test': 10000} 554 | if opts.test: 555 | split_names = ['train', 'test'] 556 | else: 557 | split_names = ['train', 'dev'] 558 | for split in split_names: 559 | if opts.adv_only and split == 'train': 560 | splits['train'] = [] 561 | continue 562 | data = [] 563 | fn = os.path.join(opts.snli_dir, 'snli_1.0_{}.jsonl'.format(split)) 564 | with open(fn) as f: 565 | for line in tqdm(f, total=totals[split]): 566 | example = json.loads(line) 567 | prem, hypo, gold_label = example['sentence1'], example['sentence2'], example['gold_label'] 568 | prem_tokenized = ' '.join(word_tokenize(prem)) 569 | hypo_tokenized = ' '.join(word_tokenize(hypo)) 570 | try: 571 | gold_label = EntailmentLabels[gold_label] 572 | except KeyError: 573 | # Encountered gold label '-', can't use so skip it 574 | continue 575 | data.append(((prem_tokenized, hypo_tokenized), gold_label)) 576 | random.shuffle(data) 577 | splits[split] = data 578 | return RawEntailmentDataset(*[splits[n] for n in split_names]) 579 | 580 | 581 | class DataAugmenter(data_util.DataAugmenter): 582 | def augment(self, dataset): 583 | new_examples = [] 584 | for ex in tqdm(dataset.examples): 585 | new_examples.append(ex) 586 | x_orig = ex['hypo']['x'] # (1, T, 1) 587 | choices = [] 588 | for i in range(x_orig.shape[1]): 589 | cur_choices = torch.masked_select( 590 | x_orig.choice_mat[0,i,:,0], x_orig.choice_mask[0,i,:].type(torch.uint8)) 591 | choices.append(cur_choices) 592 | for t in range(self.augment_by): 593 | x_new = torch.stack([choices[i][random.choice(range(len(choices[i])))] 594 | for i in range(len(choices))]).view(1, -1, 1) 595 | x_bounded = ibp.DiscreteChoiceTensor( 596 | x_new, x_orig.choice_mat, x_orig.choice_mask, x_orig.sequence_mask) 597 | ex_new = dict(ex) 598 | ex_new['hypo'] = dict(ex['hypo']) 599 | ex_new['hypo']['x'] = x_bounded 600 | new_examples.append(ex_new) 601 | return EntailmentDataset(None, dataset.vocab, new_examples) 602 | 603 | 604 | class Adversary(object): 605 | """An Adversary tries to fool a model on a given example.""" 606 | def __init__(self, attack_surface): 607 | self.attack_surface = attack_surface 608 | 609 | def run(self, model, dataset, device, opts=None): 610 | """Run adversary on a dataset. 611 | 612 | Args: 613 | model: a TextClassificationModel. 614 | dataset: a TextClassificationDataset. 615 | device: torch device. 616 | Returns: pair of 617 | - list of 0-1 adversarial loss of same length as |dataset| 618 | - list of list of adversarial examples (each is just a text string) 619 | """ 620 | raise NotImplementedError 621 | 622 | 623 | class ExhaustiveAdversary(Adversary): 624 | """An Adversary that exhaustively tries all allowed perturbations. 625 | 626 | Only practical for short sentences. 627 | """ 628 | def run(self, model, dataset, device, opts=None): 629 | model.eval() 630 | is_correct = [] 631 | adv_exs = [] 632 | for x, y in dataset.raw_data: 633 | prem, hypo = x 634 | words = hypo.split() 635 | swaps = self.attack_surface.get_swaps(words) 636 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 637 | prod = 1 638 | for c in choices: 639 | prod *= len(c) 640 | print('ExhaustiveAdversary: "%s" -> %d options' % (hypo, prod)) 641 | all_raw = [(prem, (' '.join(hypo_new)), y) for hypo_new in itertools.product(*choices)] 642 | cur_dataset = EntailmentDataset.from_raw_data(all_raw, dataset.vocab) 643 | preds, gold = model.query(cur_dataset, device) 644 | model_correct, model_cert_correct = compute_is_correct(preds, gold) 645 | cur_adv_exs = [all_raw[i][0] for i, p in enumerate(model_correct) 646 | if p.item()] 647 | print(cur_adv_exs) 648 | adv_exs.append(cur_adv_exs) 649 | is_correct.append(int(len(cur_adv_exs) > 0)) 650 | return is_correct, adv_exs 651 | 652 | 653 | class GreedyAdversary(Adversary): 654 | """An adversary that picks a random word and greedily tries perturbations.""" 655 | def __init__(self, attack_surface, num_epochs=10, num_tries=2, margin_goal=0.0): 656 | super(GreedyAdversary, self).__init__(attack_surface) 657 | self.num_epochs = num_epochs 658 | self.num_tries = num_tries 659 | self.margin_goal = margin_goal 660 | 661 | def run(self, model, dataset, device, opts=None): 662 | model.eval() 663 | is_correct = [] 664 | adv_exs = [] 665 | for x, y in tqdm(dataset.raw_data): 666 | prem, hypo = x 667 | # First query the example itself 668 | orig_pred, orig_gold = model.query(EntailmentDataset.from_raw_data( 669 | [(x, y)], dataset.vocab, attack_surface=self.attack_surface), device, return_bounds=True) 670 | model_correct, model_cert_correct = compute_is_correct(orig_pred, orig_gold) 671 | if model_correct.sum().item() == 0: 672 | print('ORIGINAL PREDICTION WAS WRONG') 673 | is_correct.append(0) 674 | adv_exs.append(x) 675 | continue 676 | 677 | # Now run adversarial search 678 | words = hypo.split() 679 | swaps = self.attack_surface.get_swaps(words) 680 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 681 | found = False 682 | for try_idx in range(self.num_tries): 683 | cur_words = list(words) 684 | for epoch in range(self.num_epochs): 685 | word_idxs = list(range(len(choices))) 686 | random.shuffle(word_idxs) 687 | for i in word_idxs: 688 | cur_raw = [] 689 | for w_new in choices[i]: 690 | cur_raw.append(((prem, ' '.join(cur_words[:i] + [w_new] + cur_words[i+1:])), y)) 691 | cur_dataset = EntailmentDataset.from_raw_data(cur_raw, dataset.vocab) 692 | preds, gold = model.query(cur_dataset, device) 693 | _, margins, _ = get_margins(preds, gold) 694 | best_idx = margins.argmin() 695 | best_idx = min(enumerate(margins), key=lambda x: x[1])[0] 696 | cur_words[i] = choices[i][best_idx] 697 | if margins[best_idx] < self.margin_goal: 698 | found = True 699 | is_correct.append(0) 700 | adv_exs.append([' '.join(cur_words)]) 701 | print('ADVERSARY SUCCESS on ("%s", %s): Found "%s" with margin %.2f' % (x, y, adv_exs[-1], margins[best_idx])) 702 | if model_cert_correct.sum().item() > 0: 703 | print('^^ CERT CORRECT THOUGH') 704 | break 705 | if found: break 706 | if found: break 707 | else: 708 | is_correct.append(1) 709 | adv_exs.append([]) 710 | print('ADVERSARY FAILURE on ("%s", %s)' % (x, y)) 711 | return is_correct, adv_exs 712 | 713 | 714 | class GeneticAdversary(Adversary): 715 | """An adversary that runs a genetic attack.""" 716 | def __init__(self, attack_surface, num_iters=20, pop_size=60, margin_goal=0.0): 717 | super(GeneticAdversary, self).__init__(attack_surface) 718 | self.num_iters = num_iters 719 | self.pop_size = pop_size 720 | self.margin_goal = margin_goal 721 | 722 | def perturb(self, prem, hypo, choices, model, y, vocab, device, prepend_null=False): 723 | if all(len(c) == 1 for c in choices): 724 | value_margin, _ = get_margins(*(model.query(EntailmentDataset.from_raw_data([((prem, ' '.join(hypo)), y)], vocab, prepend_null=prepend_null), device))) 725 | return hypo, value_margin.item() 726 | good_idxs = [i for i, c in enumerate(choices) if len(c) > 1] 727 | idx = random.sample(good_idxs, 1)[0] 728 | best_replacement = None 729 | worst_margin = float('inf') 730 | for w_new in choices[idx]: 731 | cur_raw = [((prem, ' '.join(hypo[:idx] + [w_new] + hypo[idx+1:])), y)] 732 | cur_dataset = EntailmentDataset.from_raw_data(cur_raw, vocab, prepend_null=prepend_null) 733 | model_output, gold_labels = model.query(cur_dataset, device) 734 | value_margins, worst_case_margins = get_margins(model_output, gold_labels) 735 | if best_replacement is None or value_margins[0].item() < worst_margin: 736 | best_replacement = w_new 737 | worst_margin = value_margins[0].item() 738 | cur_words = list(hypo) 739 | cur_words[idx] = best_replacement 740 | return cur_words, worst_margin 741 | 742 | def run(self, model, dataset, device, opts=None): 743 | prepend_null = opts.prepend_null if opts is not None else False 744 | model.eval() 745 | is_correct = [] 746 | adv_exs = [] 747 | for x, y in tqdm(dataset.raw_data): 748 | # First query the example itself 749 | prem, hypo = x 750 | orig_pred, orig_gold = model.query(EntailmentDataset.from_raw_data( 751 | [(x, y)], dataset.vocab, attack_surface=self.attack_surface, prepend_null=prepend_null), device, return_bounds=True) 752 | model_correct, model_cert_correct = compute_is_correct(orig_pred, orig_gold) 753 | cert_correct = model_cert_correct.sum().item() 754 | value_margins, worst_case_margins = get_margins(orig_pred, orig_gold) 755 | print('Margin: %.6f, lower bound: %.6f, cert_correct=%s' % ( 756 | value_margins[0].item(), worst_case_margins[0].item(), cert_correct)) 757 | if model_correct.sum().item() <= 0: 758 | print('ORIGINAL PREDICTION WAS WRONG') 759 | is_correct.append(0) 760 | adv_exs.append(x) 761 | continue 762 | # Now run adversarial search 763 | hypo_words = hypo.split() 764 | swaps = self.attack_surface.get_swaps(hypo_words) 765 | choices = [[w] + cur_swaps for w, cur_swaps in zip(hypo_words, swaps)] 766 | found = False 767 | population = [self.perturb(prem, hypo_words, choices, model, y, dataset.vocab, device, prepend_null=prepend_null) 768 | for i in range(self.pop_size)] 769 | for g in range(self.num_iters): 770 | best_idx = min(enumerate(population), key=lambda x: x[1][1])[0] 771 | print('Iteration %d: %.6f' % (g, population[best_idx][1])) 772 | if population[best_idx][1] < self.margin_goal: 773 | found = True 774 | is_correct.append(0) 775 | adv_exs.append(' '.join(population[best_idx][0])) 776 | print('ADVERSARY SUCCESS on ("%s", %s): Found "%s" with margin %.2f' % (x, y, adv_exs[-1], population[best_idx][1])) 777 | if cert_correct: 778 | print('^^ CERT CORRECT THOUGH') 779 | break 780 | new_population = [population[best_idx]] 781 | margins = np.array([m for c, m in population]) 782 | adv_probs = 1 / (1 + np.exp(margins)) + 1e-6 783 | # Sigmoid of negative margin, for probabilty of wrong class 784 | # Add 1e-6 for numerical stability 785 | sample_probs = adv_probs / np.sum(adv_probs) 786 | for i in range(1, self.pop_size): 787 | parent1 = population[np.random.choice(range(len(population)), p=sample_probs)][0] 788 | parent2 = population[np.random.choice(range(len(population)), p=sample_probs)][0] 789 | child = [random.sample([w1, w2], 1)[0] for (w1, w2) in zip(parent1, parent2)] 790 | child_mut, new_margin = self.perturb(prem, child, choices, model, y, 791 | dataset.vocab, device, prepend_null=prepend_null) 792 | new_population.append((child_mut, new_margin)) 793 | population = new_population 794 | else: 795 | is_correct.append(1) 796 | adv_exs.append([]) 797 | print('ADVERSARY FAILURE on ("%s", %s)' % (x, y)) 798 | return is_correct, adv_exs 799 | -------------------------------------------------------------------------------- /src/ibp.py: -------------------------------------------------------------------------------- 1 | """Interval bound propagation layers in pytorch.""" 2 | import sys 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn.parameter import Parameter 7 | import torch.nn.functional as F 8 | 9 | #DEBUG = False 10 | DEBUG = True 11 | TOLERANCE = 1e-5 12 | 13 | ##### BoundedTensor and subclasses ##### 14 | 15 | class BoundedTensor(object): 16 | """Contains a torch.Tensor plus bounds on it.""" 17 | @property 18 | def shape(self): 19 | return self.val.shape 20 | 21 | def __add__(self, other): 22 | return add(self, other) 23 | 24 | def __mul__(self, other): 25 | return mul(self, other) 26 | 27 | def __truediv__(self, other): 28 | return div(self, other) 29 | 30 | 31 | class IntervalBoundedTensor(BoundedTensor): 32 | """A tensor with elementwise upper and lower bounds. 33 | This is the main building BoundedTensor subclass. 34 | All layers in this library accept IntervalBoundedTensor as input, 35 | and when handed one will generate another IntervalBoundedTensor as output. 36 | """ 37 | def __init__(self, val, lb, ub): 38 | self.val = val 39 | self.lb = lb 40 | self.ub = ub 41 | if DEBUG: 42 | # Sanity check lower and upper bounds when creating this 43 | # Note that there may be small violations on the order of 1e-5, 44 | # due to floating point rounding/non-associativity issues. 45 | # e.g. https://github.com/pytorch/pytorch/issues/9146 46 | max_lb_violation = torch.max(lb - val) 47 | if max_lb_violation > TOLERANCE: 48 | print('WARNING: Lower bound wrong (max error = %g)' % max_lb_violation.item(), file=sys.stderr) 49 | max_ub_violation = torch.max(val - ub) 50 | if max_ub_violation > TOLERANCE: 51 | print('WARNING: Upper bound wrong (max error = %g)' % max_ub_violation.item(), file=sys.stderr) 52 | 53 | ### Reimplementations of torch.Tensor methods 54 | def __neg__(self): 55 | return IntervalBoundedTensor(-self.val, -self.ub, -self.lb) 56 | 57 | def permute(self, *dims): 58 | return IntervalBoundedTensor(self.val.permute(*dims), 59 | self.lb.permute(*dims), 60 | self.ub.permute(*dims)) 61 | 62 | def squeeze(self, dim=None): 63 | return IntervalBoundedTensor(self.val.squeeze(dim=dim), 64 | self.lb.squeeze(dim=dim), 65 | self.ub.squeeze(dim=dim)) 66 | 67 | def unsqueeze(self, dim): 68 | return IntervalBoundedTensor(self.val.unsqueeze(dim), 69 | self.lb.unsqueeze(dim), 70 | self.ub.unsqueeze(dim)) 71 | 72 | def to(self, device): 73 | self.val = self.val.to(device) 74 | self.lb = self.lb.to(device) 75 | self.ub = self.ub.to(device) 76 | return self 77 | 78 | # For slicing 79 | def __getitem__(self, key): 80 | return IntervalBoundedTensor(self.val.__getitem__(key), 81 | self.lb.__getitem__(key), 82 | self.ub.__getitem__(key)) 83 | 84 | def __setitem__(self, key, value): 85 | if not isinstance(value, IntervalBoundedTensor): 86 | raise TypeError(value) 87 | self.val.__setitem__(key, value.val) 88 | self.lb.__setitem__(key, value.lb) 89 | self.ub.__setitem__(key, value.ub) 90 | 91 | def __delitem__(self, key): 92 | self.val.__delitem__(key) 93 | self.lb.__delitem__(key) 94 | self.ub.__delitem__(key) 95 | 96 | 97 | class DiscreteChoiceTensor(BoundedTensor): 98 | """A tensor for which each row can take a discrete set of values. 99 | 100 | More specifically, each slice along the first d-1 dimensions of the tensor 101 | is allowed to take on values within some discrete set. 102 | The overall tensor's possible values are the direct product of all these 103 | individual choices. 104 | 105 | Recommended usage is only as an input tensor, passed to Linear() layer. 106 | Only some layers accept this tensor. 107 | """ 108 | def __init__(self, val, choice_mat, choice_mask, sequence_mask): 109 | """Create a DiscreteChoiceTensor. 110 | 111 | Args: 112 | val: value, dimension (*, d). Let m = product of first d-1 dimensions. 113 | choice_mat: all choices-padded with 0 where fewer than max choices are available, size (*, C, d) 114 | choice_mask: mask tensor s.t. choice_maks[i,j,k]==1 iff choice_mat[i,j,k,:] is a valid choice, size (*, C) 115 | sequence_mask: mask tensor s.t. sequence_mask[i,j,k]==1 iff choice_mat[i,j] is a valid word in a sequence and not padding, size (*) 116 | """ 117 | self.val = val 118 | self.choice_mat = choice_mat 119 | self.choice_mask = choice_mask 120 | self.sequence_mask = sequence_mask 121 | 122 | def to_interval_bounded(self, eps=1.0): 123 | """ 124 | Convert to an IntervalBoundedTensor. 125 | Args: 126 | - eps: float, scaling factor for the interval bounds 127 | """ 128 | choice_mask_mat = (((1 - self.choice_mask).float() * 1e16)).unsqueeze(-1) # *, C, 1 129 | seq_mask_mat = self.sequence_mask.unsqueeze(-1).unsqueeze(-1).float() 130 | lb = torch.min((self.choice_mat + choice_mask_mat) * seq_mask_mat, -2)[0] # *, d 131 | ub = torch.max((self.choice_mat - choice_mask_mat) * seq_mask_mat, -2)[0] # *, d 132 | val = self.val * self.sequence_mask.unsqueeze(-1) 133 | if eps != 1.0: 134 | lb = val - (val - lb) * eps 135 | ub = val + (ub - val) * eps 136 | return IntervalBoundedTensor(val, lb, ub) 137 | 138 | def to(self, device): 139 | """Moves the Tensor to the given device""" 140 | self.val = self.val.to(device) 141 | self.choice_mat = self.choice_mat.to(device) 142 | self.choice_mask = self.choice_mask.to(device) 143 | self.sequence_mask = self.sequence_mask.to(device) 144 | return self 145 | 146 | 147 | class NormBallTensor(BoundedTensor): 148 | """A tensor for which each is within some norm-ball of the original value.""" 149 | def __init__(self, val, radius, p_norm): 150 | self.val = val 151 | self.radius = radius 152 | self.p_norm = p_norm 153 | 154 | 155 | ##### nn.Module's for BoundedTensor ##### 156 | 157 | class Linear(nn.Linear): 158 | """Linear layer.""" 159 | def forward(self, x): 160 | if isinstance(x, torch.Tensor): 161 | return super(Linear, self).forward(x) 162 | if isinstance(x, IntervalBoundedTensor): 163 | z = F.linear(x.val, self.weight, self.bias) 164 | weight_abs = torch.abs(self.weight) 165 | mu_cur = (x.ub + x.lb) / 2 166 | r_cur = (x.ub - x.lb) / 2 167 | mu_new = F.linear(mu_cur, self.weight, self.bias) 168 | r_new = F.linear(r_cur, weight_abs) 169 | return IntervalBoundedTensor(z, mu_new - r_new, mu_new + r_new) 170 | elif isinstance(x, DiscreteChoiceTensor): 171 | new_val = F.linear(x.val, self.weight, self.bias) 172 | new_choices = F.linear(x.choice_mat, self.weight, self.bias) 173 | return DiscreteChoiceTensor(new_val, new_choices, x.choice_mask, x.sequence_mask) 174 | elif isinstance(x, NormBallTensor): 175 | q = 1.0 / (1.0 - 1.0 / x.p_norm) # q from Holder's inequality 176 | z = F.linear(x.val, self.weight, self.bias) 177 | q_norm = torch.norm(self.weight, p=q, dim=1) # Norm along in_dims axis 178 | delta = x.radius * q_norm 179 | return IntervalBoundedTensor(z, z - delta, z + delta) # Broadcast out_dims 180 | else: 181 | raise TypeError(x) 182 | 183 | 184 | class LinearOutput(Linear): 185 | """Linear output layer. 186 | 187 | A linear layer, but instead of computing interval bounds, computes 188 | 189 | max_{z feasible} c^T z + d 190 | 191 | where z is the output of this layer, for given vector(s) c and scalar(s) d. 192 | Following Gowal et al. (2018), we can get a slightly better bound here 193 | than by doing normal bound propagation. 194 | """ 195 | def forward(self, x_ibp, c_list=None, d_list=None): 196 | """Compute linear output layer and bound on adversarial objective. 197 | 198 | Args: 199 | x_ibp: an ibp.Tensor of shape (batch_size, in_dims) 200 | c_list: list of torch.Tensor, each of shape (batch_size, out_dims) 201 | d_list: list of torch.Tensor, each of shape (batch_size,) 202 | Returns: 203 | x: ibp.Tensor of shape (batch_size, out_dims) 204 | bounds: if c_list and d_list, torch.Tensor of shape (batch_size,) 205 | """ 206 | x, x_lb, x_ub = x_ibp 207 | z = F.linear(x, self.weight, self.bias) 208 | if c_list and d_list: 209 | bounds = [] 210 | mu_cur = ((x_lb + x_ub) / 2).unsqueeze(1) # B, 1, in_dims 211 | r_cur = ((x_ub - x_lb) / 2).unsqueeze(1) # B, 1, in_dims 212 | for c, d in zip(c_list, d_list): 213 | c_prime = c.matmul(self.weight).unsqueeze(2) # B, in_dims, 1 214 | d_prime = c.matmul(self.bias) + d # B, 215 | c_prime_abs = torch.abs(c_prime) # B, in_dims, 1 216 | mu_new = mu_cur.matmul(c_prime).view(-1) # B, 217 | r_cur = r_cur.matmul(c_prime_abs).view(-1) # B, 218 | bounds.append(mu_new + r_cur + d) 219 | return z, bounds 220 | else: 221 | return z 222 | 223 | 224 | class Embedding(nn.Embedding): 225 | """nn.Embedding for DiscreteChoiceTensor. 226 | 227 | Note that unlike nn.Embedding, this module requires that the last dimension 228 | of the input is size 1, and will squeeze it before calling F.embedding. 229 | This requirement is due to how DiscreteChoiceTensor requires a dedicated 230 | dimension to represent the dimension along which values can change. 231 | """ 232 | def forward(self, x): 233 | if isinstance(x, torch.Tensor): 234 | return super(Embedding, self).forward(x.squeeze(-1)) 235 | if isinstance(x, DiscreteChoiceTensor): 236 | if x.val.shape[-1] != 1: 237 | raise ValueError('Input tensor has shape %s, where last dimension != 1' % x.shape) 238 | new_val = F.embedding( 239 | x.val.squeeze(-1), self.weight, self.padding_idx, self.max_norm, 240 | self.norm_type, self.scale_grad_by_freq, self.sparse) 241 | new_choices = F.embedding( 242 | x.choice_mat.squeeze(-1), self.weight, self.padding_idx, self.max_norm, 243 | self.norm_type, self.scale_grad_by_freq, self.sparse) 244 | return DiscreteChoiceTensor(new_val, new_choices, x.choice_mask, x.sequence_mask) 245 | else: 246 | raise TypeError(x) 247 | 248 | 249 | class Conv1d(nn.Conv1d): 250 | """One-dimensional convolutional layer. 251 | 252 | Works the same as a linear layer. 253 | """ 254 | def forward(self, x): 255 | if isinstance(x, torch.Tensor): 256 | return super(Conv1d, self).forward(x) 257 | if isinstance(x, IntervalBoundedTensor): 258 | z = F.conv1d(x.val, self.weight, self.bias, self.stride, 259 | self.padding, self.dilation, self.groups) 260 | weight_abs = torch.abs(self.weight) 261 | mu_cur = (x.ub + x.lb) / 2 262 | r_cur = (x.ub - x.lb) / 2 263 | mu_new = F.conv1d(mu_cur, self.weight, self.bias, self.stride, 264 | self.padding, self.dilation, self.groups) 265 | r_new = F.conv1d(r_cur, weight_abs, None, self.stride, 266 | self.padding, self.dilation, self.groups) 267 | return IntervalBoundedTensor(z, mu_new - r_new, mu_new + r_new) 268 | else: 269 | raise TypeError(x) 270 | 271 | 272 | class MaxPool1d(nn.MaxPool1d): 273 | """One-dimensional max-pooling layer.""" 274 | def forward(self, x): 275 | if isinstance(x, torch.Tensor): 276 | return super(MaxPool1d, self).forward(x) 277 | elif isinstance(x, IntervalBoundedTensor): 278 | z = F.max_pool1d(x.val, self.kernel_size, self.stride, self.padding, 279 | self.dilation, self.ceil_mode, self.return_indices) 280 | lb = F.max_pool1d(x.lb, self.kernel_size, self.stride, self.padding, 281 | self.dilation, self.ceil_mode, self.return_indices) 282 | ub = F.max_pool1d(x.ub, self.kernel_size, self.stride, self.padding, 283 | self.dilation, self.ceil_mode, self.return_indices) 284 | return IntervalBoundedTensor(z, lb, ub) 285 | else: 286 | raise TypeError(x) 287 | 288 | 289 | class LSTM(nn.Module): 290 | """An LSTM.""" 291 | def __init__(self, input_size, hidden_size, bidirectional=False): 292 | super(LSTM, self).__init__() 293 | self.input_size = input_size 294 | self.hidden_size = hidden_size 295 | self.bidirectional = bidirectional 296 | self.i2h = Linear(input_size, 4 * hidden_size) 297 | self.h2h = Linear(hidden_size, 4 * hidden_size) 298 | if bidirectional: 299 | self.back_i2h = Linear(input_size, 4 * hidden_size) 300 | self.back_h2h = Linear(hidden_size, 4 * hidden_size) 301 | 302 | def _step(self, h, c, x_t, i2h, h2h, analysis_mode=False): 303 | preact = add(i2h(x_t), h2h(h)) 304 | g_t = activation(torch.tanh, preact[:, 3 * self.hidden_size:]) 305 | gates = activation(torch.sigmoid, preact[:, :3 * self.hidden_size]) 306 | i_t = gates[:, :self.hidden_size] 307 | f_t = gates[:, self.hidden_size:2 * self.hidden_size] 308 | o_t = gates[:, 2 * self.hidden_size:] 309 | c_t = add(mul(c, f_t), mul(i_t, g_t)) 310 | h_t = mul(o_t, activation(torch.tanh, c_t)) 311 | if analysis_mode: 312 | return h_t, c_t, i_t, f_t, o_t 313 | return h_t, c_t 314 | 315 | def _process(self, h, c, x, i2h, h2h, reverse=False, mask=None, analysis_mode=False): 316 | B, T, d = x.shape # batch_first=True 317 | idxs = range(T) 318 | if reverse: 319 | idxs = idxs[::-1] 320 | h_seq = [] 321 | c_seq = [] 322 | if analysis_mode: 323 | i_seq = [] 324 | f_seq = [] 325 | o_seq = [] 326 | for i in idxs: 327 | x_t = x[:,i,:] # B, d_in 328 | if analysis_mode: 329 | h_t, c_t, i_t, f_t, o_t = self._step(h, c, x_t, i2h, h2h, analysis_mode=True) 330 | i_seq.append(i_t) 331 | f_seq.append(f_t) 332 | o_seq.append(o_t) 333 | else: 334 | h_t, c_t = self._step(h, c, x_t, i2h, h2h) 335 | if mask is not None: 336 | # Don't update h or c when mask is 0 337 | mask_t = mask[:,i].unsqueeze(1) # B,1 338 | h = h_t * mask_t + h * (1.0 - mask_t) 339 | c = c_t * mask_t + c * (1.0 - mask_t) 340 | h_seq.append(h) 341 | c_seq.append(c) 342 | if reverse: 343 | h_seq = h_seq[::-1] 344 | c_seq = c_seq[::-1] 345 | if analysis_mode: 346 | i_seq = i_seq[::-1] 347 | f_seq = f_seq[::-1] 348 | o_seq = o_seq[::-1] 349 | if analysis_mode: 350 | return h_seq, c_seq, i_seq, f_seq, o_seq 351 | return h_seq, c_seq 352 | 353 | def forward(self, x, s0, mask=None, analysis_mode=False): 354 | """Forward pass of LSTM 355 | 356 | Args: 357 | x: word vectors, size (B, T, d) 358 | s0: tuple of (h0, x0) where each is (B, d), or (B, 2d) if bidirectional=True 359 | mask: If provided, 0-1 mask of size (B, T) 360 | """ 361 | h0, c0 = s0 # Each is (B, d), or (B, 2d) if bidirectional=True 362 | if self.bidirectional: 363 | h0_back = h0[:,self.hidden_size:] 364 | h0 = h0[:,:self.hidden_size] 365 | c0_back = c0[:,self.hidden_size:] 366 | c0 = c0[:,:self.hidden_size] 367 | if analysis_mode: 368 | h_seq, c_seq, i_seq, f_seq, o_seq = self._process( 369 | h0, c0, x, self.i2h, self.h2h, mask=mask, analysis_mode=True) 370 | else: 371 | h_seq, c_seq = self._process(h0, c0, x, self.i2h, self.h2h, mask=mask) 372 | if self.bidirectional: 373 | if analysis_mode: 374 | h_back_seq, c_back_seq, i_back_seq, f_back_seq, o_back_seq = self._process( 375 | h0_back, c0_back, x, self.back_i2h, self.back_h2h, reverse=True, mask=mask, 376 | analysis_mode=True) 377 | i_seq = [cat((f, b), dim=1) for f, b in zip(i_seq, i_back_seq)] 378 | f_seq = [cat((f, b), dim=1) for f, b in zip(f_seq, f_back_seq)] 379 | o_seq = [cat((f, b), dim=1) for f, b in zip(o_seq, o_back_seq)] 380 | else: 381 | h_back_seq, c_back_seq = self._process( 382 | h0_back, c0_back, x, self.back_i2h, self.back_h2h, reverse=True, mask=mask) 383 | h_seq = [cat((hf, hb), dim=1) for hf, hb in zip(h_seq, h_back_seq)] 384 | c_seq = [cat((cf, cb), dim=1) for cf, cb in zip(c_seq, c_back_seq)] 385 | h_mat = stack(h_seq, dim=1) # list of (B, d) -> (B, T, d) 386 | c_mat = stack(c_seq, dim=1) # list of (B, d) -> (B, T, d) 387 | if analysis_mode: 388 | i_mat = stack(i_seq, dim=1) 389 | f_mat = stack(f_seq, dim=1) 390 | o_mat = stack(o_seq, dim=1) 391 | return h_mat, c_mat, (i_mat, f_mat, o_mat) 392 | return h_mat, c_mat 393 | 394 | 395 | class GRU(nn.Module): 396 | """A GRU.""" 397 | def __init__(self, input_size, hidden_size, bidirectional=False): 398 | super(GRU, self).__init__() 399 | self.input_size = input_size 400 | self.hidden_size = hidden_size 401 | self.bidirectional = bidirectional 402 | self.i2h = Linear(input_size, 3 * hidden_size) 403 | self.h2h = Linear(hidden_size, 3 * hidden_size) 404 | if bidirectional: 405 | self.back_i2h = Linear(input_size, 3 * hidden_size) 406 | self.back_h2h = Linear(hidden_size, 3 * hidden_size) 407 | 408 | def _step(self, h, x_t, i2h, h2h): 409 | i_out = i2h(x_t) 410 | h_out = h2h(h) 411 | preact = add(i_out[:, :2*self.hidden_size], h_out[:, :2*self.hidden_size]) 412 | gates = activation(torch.sigmoid, preact) 413 | r_t = gates[:, :self.hidden_size] 414 | z_t = gates[:, self.hidden_size:] 415 | i_state = i_out[:, 2*self.hidden_size:] 416 | h_state = h_out[:, 2*self.hidden_size:] 417 | n_t = activation(torch.tanh, i_state + mul(r_t, h_state)) 418 | if isinstance(z_t, torch.Tensor): 419 | ones = torch.ones_like(z_t) 420 | else: 421 | ones = torch.ones_like(z_t.val) 422 | h_t = add(mul(add(ones, - z_t), n_t), mul(z_t, h)) 423 | return h_t 424 | 425 | def _process(self, h, x, i2h, h2h, reverse=False, mask=None): 426 | B, T, d = x.shape # batch_first=True 427 | idxs = range(T) 428 | if reverse: 429 | idxs = idxs[::-1] 430 | h_seq = [] 431 | for i in idxs: 432 | x_t = x[:,i,:] # B, d_in 433 | h_t = self._step(h, x_t, i2h, h2h) 434 | if mask is not None: 435 | # Don't update h when mask is 0 436 | mask_t = mask[:,i].unsqueeze(1) # B,1 437 | h = h_t * mask_t + h * (1.0 - mask_t) 438 | h_seq.append(h) 439 | if reverse: 440 | h_seq = h_seq[::-1] 441 | return h_seq 442 | 443 | def forward(self, x, h0, mask=None): 444 | """Forward pass of GRU 445 | 446 | Args: 447 | x: word vectors, size (B, T, d) 448 | h0: tuple of (h0, x0) where each is (B, d), or (B, 2d) if bidirectional=True 449 | mask: If provided, 0-1 mask of size (B, T) 450 | """ 451 | if self.bidirectional: 452 | h0_back = h0[:,self.hidden_size:] 453 | h0 = h0[:,:self.hidden_size] 454 | h_seq = self._process(h0, x, self.i2h, self.h2h, mask=mask) 455 | if self.bidirectional: 456 | h_back_seq = self._process( 457 | h0_back, x, self.back_i2h, self.back_h2h, reverse=True, mask=mask) 458 | h_seq = [cat((hf, hb), dim=1) for hf, hb in zip(h_seq, h_back_seq)] 459 | h_mat = stack(h_seq, dim=1) # list of (B, d) -> (B, T, d) 460 | return h_mat 461 | 462 | 463 | class Dropout(nn.Dropout): 464 | def forward(self, x): 465 | if isinstance(x, torch.Tensor): 466 | return super(Dropout, self).forward(x) 467 | elif isinstance(x, IntervalBoundedTensor): 468 | if self.training: 469 | probs = torch.full_like(x.val, 1.0 - self.p) 470 | mask = torch.distributions.Bernoulli(probs).sample() / (1.0 - self.p) 471 | return IntervalBoundedTensor(mask * x.val, mask * x.lb, mask * x.ub) 472 | else: 473 | return x 474 | else: 475 | raise TypeError(x) 476 | 477 | 478 | def add(x1, x2): 479 | """Sum two tensors.""" 480 | # I think we have to do it this way and not as operator overloading, 481 | # to catch the case of torch.Tensor.__add__(IntervalBoundedTensor) 482 | if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): 483 | return x1 + x2 484 | elif isinstance(x1, torch.Tensor) or isinstance(x2, torch.Tensor): 485 | if isinstance(x2, torch.Tensor): 486 | x1, x2 = x2, x1 # WLOG x1 is torch.Tensor 487 | if isinstance(x2, IntervalBoundedTensor): 488 | return IntervalBoundedTensor(x2.val + x1, x2.lb + x1, x2.ub + x1) 489 | else: 490 | raise TypeError(x1, x2) 491 | else: 492 | if isinstance(x1, IntervalBoundedTensor) and isinstance(x2, IntervalBoundedTensor): 493 | return IntervalBoundedTensor(x1.val + x2.val, x1.lb + x2.lb, x1.ub + x2.ub) 494 | else: 495 | raise TypeError(x1, x2) 496 | 497 | 498 | def mul(x1, x2): 499 | """Elementwise multiplication of two tensors.""" 500 | if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): 501 | return torch.mul(x1, x2) 502 | elif isinstance(x1, torch.Tensor) or isinstance(x2, torch.Tensor): 503 | if isinstance(x2, torch.Tensor): 504 | x1, x2 = x2, x1 # WLOG x1 is torch.Tensor 505 | if isinstance(x2, IntervalBoundedTensor): 506 | z = torch.mul(x2.val, x1) 507 | lb_mul = torch.mul(x2.lb, x1) 508 | ub_mul = torch.mul(x2.ub, x1) 509 | lb_new = torch.min(lb_mul, ub_mul) 510 | ub_new = torch.max(lb_mul, ub_mul) 511 | return IntervalBoundedTensor(z, lb_new, ub_new) 512 | else: 513 | raise TypeError(x1, x2) 514 | else: 515 | if isinstance(x1, IntervalBoundedTensor) and isinstance(x2, IntervalBoundedTensor): 516 | z = torch.mul(x1.val, x2.val) 517 | ll = torch.mul(x1.lb, x2.lb) 518 | lu = torch.mul(x1.lb, x2.ub) 519 | ul = torch.mul(x1.ub, x2.lb) 520 | uu = torch.mul(x1.ub, x2.ub) 521 | stack = torch.stack((ll, lu, ul, uu)) 522 | lb_new = torch.min(stack, dim=0)[0] 523 | ub_new = torch.max(stack, dim=0)[0] 524 | return IntervalBoundedTensor(z, lb_new, ub_new) 525 | else: 526 | raise TypeError(x1, x2) 527 | 528 | def div(x1, x2): 529 | """Elementwise division of two tensors.""" 530 | if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): 531 | return torch.div(x1, x2) 532 | if isinstance(x1, IntervalBoundedTensor) and isinstance(x2, torch.Tensor): 533 | z = torch.div(x1.val, x2) 534 | lb_div = torch.div(x1.lb, x2) 535 | ub_div = torch.div(x1.ub, x2) 536 | lb_new = torch.min(lb_div, ub_div) 537 | ub_new = torch.max(lb_div, ub_div) 538 | return IntervalBoundedTensor(z, lb_new, ub_new) 539 | else: 540 | raise TypeError(x1, x2) 541 | 542 | 543 | def bmm(x1, x2): 544 | """Batched matrix multiply. 545 | 546 | Args: 547 | x1: tensor of shape (B, m, p) 548 | x2: tensor of shape (B, p, n) 549 | Returns: 550 | tensor of shape (B, m, n) 551 | """ 552 | if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): 553 | return torch.matmul(x1, x2) 554 | elif isinstance(x1, torch.Tensor) or isinstance(x2, torch.Tensor): 555 | swap = False 556 | if isinstance(x2, torch.Tensor): 557 | swap = True 558 | x1, x2 = x2.permute(0, 2, 1), x1.permute(0, 2, 1) # WLOG x1 is torch.Tensor 559 | if isinstance(x2, IntervalBoundedTensor): 560 | z = torch.matmul(x1, x2.val) 561 | x1_abs = torch.abs(x1) 562 | mu_cur = (x2.ub + x2.lb) / 2 563 | r_cur = (x2.ub - x2.lb) / 2 564 | mu_new = torch.matmul(x1, mu_cur) 565 | r_new = torch.matmul(x1_abs, r_cur) 566 | if swap: 567 | z = z.permute(0, 2, 1) 568 | mu_new = mu_new.permute(0, 2, 1) 569 | r_new = r_new.permute(0, 2, 1) 570 | return IntervalBoundedTensor(z, mu_new - r_new, mu_new + r_new) 571 | else: 572 | raise TypeError(x1, x2) 573 | else: 574 | if isinstance(x1, IntervalBoundedTensor) and isinstance(x2, IntervalBoundedTensor): 575 | z = torch.matmul(x1.val, x2.val) 576 | ll = torch.einsum('ijk,ikl->ijkl', x1.lb, x2.lb) # B, m, p, n 577 | lu = torch.einsum('ijk,ikl->ijkl', x1.lb, x2.ub) # B, m, p, n 578 | ul = torch.einsum('ijk,ikl->ijkl', x1.ub, x2.lb) # B, m, p, n 579 | uu = torch.einsum('ijk,ikl->ijkl', x1.ub, x2.ub) # B, m, p, n 580 | stack = torch.stack([ll, lu, ul, uu]) 581 | mins = torch.min(stack, dim=0)[0] # B, m, p, n 582 | maxs = torch.max(stack, dim=0)[0] # B, m, p, n 583 | lb_new = torch.sum(mins, dim=2) # B, m, n 584 | ub_new = torch.sum(maxs, dim=2) # B, m, n 585 | return IntervalBoundedTensor(z, lb_new, ub_new) 586 | else: 587 | raise TypeError(x1, x2) 588 | 589 | 590 | def matmul_nneg(x1, x2): 591 | """Matrix multiply for non-negative matrices (easier than the general case).""" 592 | if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): 593 | if (x1 < 0).any(): raise ValueError('x1 has negative entries') 594 | if (x2 < 0).any(): raise ValueError('x2 has negative entries') 595 | return torch.matmul(x1, x2) 596 | elif isinstance(x1, torch.Tensor) or isinstance(x2, torch.Tensor): 597 | swap = False 598 | if isinstance(x2, torch.Tensor): 599 | swap = True 600 | x1, x2 = x2.permute(0, 2, 1), x1.permute(0, 2, 1) # WLOG x1 is torch.Tensor 601 | if isinstance(x2, IntervalBoundedTensor): 602 | if (x1 < 0).any(): raise ValueError('x1 has negative entries') 603 | if (x2.lb < 0).any(): raise ValueError('x2 has negative lower bounds') 604 | z = torch.matmul(x1, x2.val) 605 | lb_new = torch.matmul(x1, x2.lb) 606 | ub_new = torch.matmul(x1, x2.ub) 607 | if swap: 608 | lb_new = lb_new.permute(0, 2, 1) 609 | ub_new = ub_new.permute(0, 2, 1) 610 | return IntervalBoundedTensor(z, lb_new, ub_new) 611 | else: 612 | raise TypeError(x1, x2) 613 | else: 614 | if isinstance(x1, IntervalBoundedTensor) and isinstance(x2, IntervalBoundedTensor): 615 | if (x1.lb < 0).any(): raise ValueError('x1 has negative lower bounds') 616 | if (x2.lb < 0).any(): raise ValueError('x2 has negative lower bounds') 617 | z = torch.matmul(x1.val, x2.val) 618 | lb_new = torch.matmul(x1.lb, x2.lb) 619 | ub_new = torch.matmul(x1.ub, x2.ub) 620 | return IntervalBoundedTensor(z, lb_new, ub_new) 621 | else: 622 | raise TypeError(x1, x2) 623 | 624 | 625 | def cat(tensors, dim=0): 626 | if all(isinstance(x, torch.Tensor) for x in tensors): 627 | return torch.cat(tensors, dim=dim) 628 | tensors_ibp = [] 629 | for x in tensors: 630 | if isinstance(x, IntervalBoundedTensor): 631 | tensors_ibp.append(x) 632 | elif isinstance(x, torch.Tensor): 633 | tensors_ibp.append(IntervalBoundedTensor(x, x, x)) 634 | else: 635 | raise TypeError(x) 636 | return IntervalBoundedTensor(torch.cat([x.val for x in tensors_ibp], dim=dim), 637 | torch.cat([x.lb for x in tensors_ibp], dim=dim), 638 | torch.cat([x.ub for x in tensors_ibp], dim=dim)) 639 | 640 | def stack(tensors, dim=0): 641 | if all(isinstance(x, torch.Tensor) for x in tensors): 642 | return torch.stack(tensors, dim=dim) 643 | tensors_ibp = [] 644 | for x in tensors: 645 | if isinstance(x, IntervalBoundedTensor): 646 | tensors_ibp.append(x) 647 | elif isinstance(x, torch.Tensor): 648 | tensors_ibp.append(IntervalBoundedTensor(x, x, x)) 649 | else: 650 | raise TypeError(x) 651 | return IntervalBoundedTensor( 652 | torch.stack([x.val for x in tensors_ibp], dim=dim), 653 | torch.stack([x.lb for x in tensors_ibp], dim=dim), 654 | torch.stack([x.ub for x in tensors_ibp], dim=dim)) 655 | 656 | 657 | def pool(func, x, dim): 658 | """Pooling operations (e.g. mean, min, max). 659 | 660 | For all of these, the pooling passes straight through the bounds. 661 | """ 662 | if func not in (torch.mean, torch.min, torch.max, torch.sum): 663 | raise ValueError(func) 664 | if func in (torch.min, torch.max): 665 | func_copy = func 666 | func = lambda *args: func_copy(*args)[0] # Grab first return value for min/max 667 | if isinstance(x, torch.Tensor): 668 | return func(x, dim) 669 | elif isinstance(x, IntervalBoundedTensor): 670 | return IntervalBoundedTensor(func(x.val, dim), func(x.lb, dim), 671 | func(x.ub, dim)) 672 | else: 673 | raise TypeError(x) 674 | 675 | 676 | def sum(x, *args, **kwargs): 677 | if isinstance(x, torch.Tensor): 678 | return torch.sum(x, *args) 679 | elif isinstance(x, IntervalBoundedTensor): 680 | return IntervalBoundedTensor( 681 | torch.sum(x.val, *args, **kwargs), 682 | torch.sum(x.lb, *args, **kwargs), 683 | torch.sum(x.ub, *args, **kwargs)) 684 | else: 685 | raise TypeError(x) 686 | 687 | 688 | class Activation(nn.Module): 689 | def __init__(self, func): 690 | super(Activation, self).__init__() 691 | self.func = func 692 | 693 | def forward(self, x): 694 | return activation(self.func, x) 695 | 696 | 697 | def activation(func, x): 698 | """Monotonic elementwise activation functions (e.g. ReLU, sigmoid). 699 | 700 | Due to monotonicity, it suffices to evaluate the activation at the endpoints. 701 | """ 702 | if func not in (F.relu, torch.sigmoid, torch.tanh, torch.exp): 703 | raise ValueError(func) 704 | if isinstance(x, torch.Tensor): 705 | return func(x) 706 | elif isinstance(x, IntervalBoundedTensor): 707 | return IntervalBoundedTensor(func(x.val), func(x.lb), func(x.ub)) 708 | else: 709 | raise TypeError(x) 710 | 711 | 712 | class LogSoftmax(nn.Module): 713 | def __init__(self, dim): 714 | super(LogSoftmax, self).__init__() 715 | self.dim = dim 716 | 717 | def forward(self, x): 718 | return log_softmax(x, self.dim) 719 | 720 | 721 | def log_softmax(x, dim): 722 | """logsoftmax operation, requires |dim| to be provided. 723 | 724 | Have to do some weird gymnastics to get vectorization and stability. 725 | """ 726 | if isinstance(x, torch.Tensor): 727 | return F.log_softmax(x, dim=dim) 728 | elif isinstance(x, IntervalBoundedTensor): 729 | out = F.log_softmax(x.val, dim) 730 | # Upper-bound on z_i is u_i - log(sum_j(exp(l_j)) + (exp(u_i) - exp(l_i))) 731 | ub_lb_logsumexp = torch.logsumexp(x.lb, dim, keepdim=True) 732 | ub_relu = F.relu(x.ub - x.lb) # ReLU just to prevent cases where lb > ub due to rounding 733 | # Compute log(exp(u_i) - exp(l_i)) = u_i + log(1 - exp(l_i - u_i)) in 2 different ways 734 | # See https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf for further discussion 735 | # (1) When u_i - l_i <= log(2), use expm1 736 | ub_log_diff_expm1 = torch.log(-torch.expm1(-ub_relu)) 737 | # (2) When u_i - l_i > log(2), use log1p 738 | use_log1p = (ub_relu > 0.693) 739 | ub_relu_log1p = torch.masked_select(ub_relu, use_log1p) 740 | ub_log_diff_log1p = torch.log1p(-torch.exp(-ub_relu_log1p)) 741 | # NOTE: doing the log1p and then masked_select creates NaN's 742 | # I think this is likely to be a subtle pytorch bug that unnecessarily 743 | # propagates NaN gradients. 744 | ub_log_diff_expm1.masked_scatter_(use_log1p, ub_log_diff_log1p) 745 | ub_log_diff = x.ub + ub_log_diff_expm1 746 | 747 | ub_scale = torch.max(ub_lb_logsumexp, ub_log_diff) 748 | ub_log_partition = ub_scale + torch.log( 749 | torch.exp(ub_lb_logsumexp - ub_scale) 750 | + torch.exp(ub_log_diff - ub_scale)) 751 | ub_out = x.ub - ub_log_partition 752 | 753 | # Lower-bound on z_i is l_i - log(sum_{j != i}(exp(u_j)) + exp(l_i)) 754 | # Normalizing scores by max_j u_j works except when i = argmax_j u_j, u_i >> argmax_{j != i} u_j, and u_i >> l_i. 755 | # In this case we normalize by the second value 756 | lb_ub_max, lb_ub_argmax = torch.max(x.ub, dim, keepdim=True) 757 | 758 | # Make `dim` the last dim for easy argmaxing along it later 759 | dims = np.append(np.delete(np.arange(len(x.shape)), dim), dim).tolist() 760 | # Get indices to place `dim` back where it was originally 761 | rev_dims = np.insert(np.arange(len(x.shape) - 1), dim, len(x.shape) - 1).tolist() 762 | # Flatten x.ub except for `dim` 763 | ub_max_masked = x.ub.clone().permute(dims).contiguous().view(-1, x.shape[dim]) 764 | # Get argmax along `dim` and set max indices to -inf 765 | ub_max_masked[np.arange(np.prod(x.shape) / x.shape[dim]), ub_max_masked.argmax(1)] = -float('inf') 766 | # Reshape to make it look like x.ub again 767 | ub_max_masked = ub_max_masked.view(np.array(x.shape).take(dims).tolist()).permute(rev_dims) 768 | 769 | lb_logsumexp_without_argmax = ub_max_masked.logsumexp(dim, keepdim=True) 770 | 771 | lb_ub_exp = torch.exp(x.ub - lb_ub_max) 772 | lb_cumsum_fwd = torch.cumsum(lb_ub_exp, dim) 773 | lb_cumsum_bwd = torch.flip(torch.cumsum(torch.flip(lb_ub_exp, [dim]), dim), [dim]) 774 | # Shift the cumulative sums so that i-th element is sum of things before i (after i for bwd) 775 | pad_fwd = [0] * (2 * len(x.shape)) 776 | pad_fwd[-2*dim - 2] = 1 777 | pad_bwd = [0] * (2 * len(x.shape)) 778 | pad_bwd[-2*dim - 1] = 1 779 | lb_cumsum_fwd = torch.narrow(F.pad(lb_cumsum_fwd, pad_fwd), dim, 0, x.shape[dim]) 780 | lb_cumsum_bwd = torch.narrow(F.pad(lb_cumsum_bwd, pad_bwd), dim, 1, x.shape[dim]) 781 | lb_logsumexp_without_i = lb_ub_max + torch.log(lb_cumsum_fwd + lb_cumsum_bwd) # logsumexp over everything except i 782 | lb_logsumexp_without_i.scatter_(dim, lb_ub_argmax, lb_logsumexp_without_argmax) 783 | lb_scale = torch.max(lb_logsumexp_without_i, x.lb) 784 | lb_log_partition = lb_scale + torch.log( 785 | torch.exp(lb_logsumexp_without_i - lb_scale) 786 | + torch.exp(x.lb - lb_scale)) 787 | lb_out = x.lb - lb_log_partition 788 | return IntervalBoundedTensor(out, lb_out, ub_out) 789 | 790 | else: 791 | raise TypeError(x) 792 | -------------------------------------------------------------------------------- /src/precompute_lm_scores.py: -------------------------------------------------------------------------------- 1 | """Precompute language model scores on dev data.""" 2 | import argparse 3 | import json 4 | import os 5 | import sys 6 | import torch 7 | from tqdm import tqdm 8 | 9 | import data_util 10 | import entailment 11 | import text_classification 12 | from train import TASK_CLASSES 13 | import vocabulary 14 | 15 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 16 | 'windweller-lw2/adaptive_softmax')) 17 | import query as lmquery 18 | 19 | OPTS = None 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser('Precompute language model scores.') 23 | parser.add_argument('task', choices=TASK_CLASSES.keys()) 24 | parser.add_argument('split', choices=['train', 'dev', 'test']) 25 | parser.add_argument('out_file') 26 | parser.add_argument('--num-examples', '-n', type=int) 27 | parser.add_argument('--shard', '-s', type=int, help='Shard index', default=None) 28 | parser.add_argument('--shard-size', type=int, default=5000) 29 | parser.add_argument('--window-radius', '-w', type=int, default=6) 30 | parser.add_argument('--neighbor-file', type=str, default=data_util.NEIGHBOR_FILE) 31 | parser.add_argument('--imdb-dir', type=str, default=text_classification.IMDB_DIR) 32 | parser.add_argument('--snli-dir', type=str, default=entailment.SNLI_DIR) 33 | if len(sys.argv) == 1: 34 | parser.print_help() 35 | sys.exit(1) 36 | return parser.parse_args() 37 | 38 | def main(): 39 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 40 | query_handler = lmquery.load_model(device) 41 | with open(OPTS.neighbor_file) as f: 42 | neighbors = json.load(f) 43 | if OPTS.task == 'classification': 44 | raw_data = text_classification.IMDBDataset.get_raw_data( 45 | OPTS.imdb_dir, test=(OPTS.split == 'test')) 46 | elif OPTS.task == 'entailment': 47 | OPTS.test = (OPTS.split == 'test') 48 | OPTS.adv_only = False 49 | raw_data = entailment.SNLIDataset.get_raw_data(OPTS) 50 | else: 51 | raise NotImplementedError 52 | if OPTS.split == 'train': 53 | data = raw_data.train_data 54 | else: # dev or test 55 | data = raw_data.dev_data 56 | if OPTS.num_examples: 57 | data = data[:OPTS.num_examples] 58 | if OPTS.shard is not None: 59 | print('Restricting to shard %d' % OPTS.shard) 60 | data = data[OPTS.shard * OPTS.shard_size:(OPTS.shard + 1) * OPTS.shard_size] 61 | with open(OPTS.out_file, 'w') as f: 62 | for sent_idx, example in enumerate(tqdm(data)): 63 | if OPTS.task == 'classification': 64 | sentence = example[0] 65 | elif OPTS.task == 'entailment': 66 | sentence = example[0][1] # Only look at hypothesis 67 | print('%d\t%s' % (sent_idx, sentence), file=f) 68 | words = sentence.split(' ') 69 | for i, w in enumerate(words): 70 | if w in neighbors: 71 | options = [w] + neighbors[w] 72 | start = max(0, i - OPTS.window_radius) 73 | end = min(len(words), i + 1 + OPTS.window_radius) 74 | # Remove OOV words from prefix and suffix 75 | prefix = [x for x in words[start:i] if x in query_handler.word_to_idx] 76 | suffix = [x for x in words[i+1:end] if x in query_handler.word_to_idx] 77 | queries = [] 78 | in_vocab_options = [] 79 | for opt in options: 80 | if opt in query_handler.word_to_idx: 81 | queries.append(prefix + [opt] + suffix) 82 | in_vocab_options.append(opt) 83 | else: 84 | print('%d\t%d\t%s\t%s' % (sent_idx, i, opt, float('-inf')), file=f) 85 | if queries: 86 | log_probs = query_handler.query(queries, batch_size=16) 87 | for x, lp in zip(in_vocab_options, log_probs): 88 | print('%d\t%d\t%s\t%s' % (sent_idx, i, x, lp), file=f) 89 | f.flush() 90 | 91 | if __name__ == '__main__': 92 | OPTS = parse_args() 93 | main() 94 | 95 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | """Some tests.""" 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import ibp 8 | 9 | def test_logsoftmax(): 10 | x = torch.tensor([[0.0, 1.0, 2.0], [0.0, 2.0, 4.0]]) 11 | print('x: ', x) 12 | print('F.logsoftmax(x, dim=0): ', F.log_softmax(x, dim=0)) 13 | x_ibp = ibp.IntervalBoundedTensor(x, x, x) 14 | ls = ibp.log_softmax(x_ibp, dim=0) 15 | print('ibp.log_softmax(x, lb=x, ub=x): ', ls.val, ls.lb, ls.ub) 16 | 17 | lb = x - torch.tensor(0.1) 18 | ub = x + torch.tensor(0.1) 19 | print('lb: ', lb) 20 | print('ub: ', lb) 21 | x_ibp = ibp.IntervalBoundedTensor(x, lb, ub) 22 | ls = ibp.log_softmax(x_ibp, dim=0) 23 | print('ibp.log_softmax(x, lb, ub): ', ls.val, ls.lb, ls.ub) 24 | 25 | def test_bmm(): 26 | m1 = torch.tensor([[-1, 2], [3, 2], [-3, 1]], dtype=torch.float).view(1, 3, 2) 27 | m2 = torch.tensor([[4, 5], [-4, -5]], dtype=torch.float).view(1, 2, 2) 28 | z = ibp.bmm(ibp.IntervalBoundedTensor(m1, m1, m1), ibp.IntervalBoundedTensor(m2, m2, m2)) 29 | m1_bounded = ibp.IntervalBoundedTensor(m1, m1 - torch.tensor(0.1), m1 + torch.tensor(0.1)) 30 | m2_bounded = ibp.IntervalBoundedTensor(m2, m2 - torch.tensor(0.1), m2 + torch.tensor(0.1)) 31 | print('ibp.bmm, exact:', z.val, z.lb, z.ub) 32 | z2 = ibp.bmm(m1_bounded, m2_bounded) 33 | print('ibp.bmm, bound both:', z2.val, z2.lb, z2.ub) 34 | z3 = ibp.bmm(m1_bounded, m2) 35 | print('ibp.bmm, bound first:', z3.val, z3.lb, z3.ub) 36 | z4 = ibp.bmm(m1, m2_bounded) 37 | print('ibp.bmm, bound second:', z4.val, z4.lb, z4.ub) 38 | 39 | 40 | 41 | def main(): 42 | test_bmm() 43 | test_logsoftmax() 44 | 45 | if __name__ == '__main__': 46 | main() 47 | 48 | -------------------------------------------------------------------------------- /src/text_classification.py: -------------------------------------------------------------------------------- 1 | """IBP text classification model.""" 2 | import itertools 3 | import glob 4 | import json 5 | import numpy as np 6 | import os 7 | import pickle 8 | import random 9 | 10 | from nltk import word_tokenize 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_sequence 15 | from tqdm import tqdm 16 | 17 | import attacks 18 | import data_util 19 | import ibp 20 | import vocabulary 21 | 22 | 23 | LOSS_FUNC = nn.BCEWithLogitsLoss() 24 | IMDB_DIR = 'data/aclImdb' 25 | LM_FILE = 'data/lm_scores/imdb_all.txt' 26 | COUNTER_FITTED_FILE = 'data/counter-fitted-vectors.txt' 27 | 28 | 29 | class AdversarialModel(nn.Module): 30 | def __init__(self): 31 | super(AdversarialModel, self).__init__() 32 | 33 | def query(self, x, vocab, device, return_bounds=False, attack_surface=None): 34 | """Query the model on a Dataset. 35 | 36 | Args: 37 | x: a string 38 | vocab: vocabulary 39 | device: torch device. 40 | 41 | Returns: list of logits of same length as |examples|. 42 | """ 43 | dataset = TextClassificationDataset.from_raw_data( 44 | [(x, 0)], vocab, attack_surface=attack_surface) 45 | data = dataset.get_loader(1) 46 | with torch.no_grad(): 47 | batch = data_util.dict_batch_to_device(next(iter(data)), device) 48 | logits = self.forward(batch, compute_bounds=return_bounds) 49 | if return_bounds: 50 | return logits.val[0].item(), (logits.lb[0].item(), logits.ub[0].item()) 51 | else: 52 | return logits[0].item() 53 | 54 | 55 | def attention_pool(x, mask, layer): 56 | """Attention pooling 57 | 58 | Args: 59 | x: batch of inputs, shape (B, n, h) 60 | mask: binary mask, shape (B, n) 61 | layer: Linear layer mapping h -> 1 62 | Returns: 63 | pooled version of x, shape (B, h) 64 | """ 65 | attn_raw = layer(x).squeeze(2) # B, n, 1 -> B, n 66 | attn_raw = ibp.add(attn_raw, (1 - mask) * -1e20 ) 67 | attn_logsoftmax = ibp.log_softmax(attn_raw, 1) 68 | attn_probs = ibp.activation(torch.exp, attn_logsoftmax) # B, n 69 | return ibp.bmm(attn_probs.unsqueeze(1), x).squeeze(1) # B, 1, n x B, n, h -> B, h 70 | 71 | 72 | class BOWModel(AdversarialModel): 73 | """Bag of word vectors + MLP.""" 74 | def __init__(self, word_vec_size, hidden_size, word_mat, 75 | pool='max', dropout=0.2, no_wordvec_layer=False): 76 | super(BOWModel, self).__init__() 77 | self.pool = pool 78 | self.no_wordvec_layer = no_wordvec_layer 79 | self.embs = ibp.Embedding.from_pretrained(word_mat) 80 | if no_wordvec_layer: 81 | self.linear_hidden = ibp.Linear(word_vec_size, hidden_size) 82 | else: 83 | self.linear_input = ibp.Linear(word_vec_size, hidden_size) 84 | self.linear_hidden = ibp.Linear(hidden_size, hidden_size) 85 | self.linear_output = ibp.Linear(hidden_size, 1) 86 | self.dropout = ibp.Dropout(dropout) 87 | if self.pool == 'attn': 88 | self.attn_pool = ibp.Linear(hidden_size, 1) 89 | 90 | def forward(self, batch, compute_bounds=True, cert_eps=1.0): 91 | """Forward pass of BOWModel. 92 | 93 | Args: 94 | batch: A batch dict from a TextClassificationDataset with the following keys: 95 | - x: tensor of word vector indices, size (B, n, 1) 96 | - mask: binary mask over words (1 for real, 0 for pad), size (B, n) 97 | - lengths: lengths of sequences, size (B,) 98 | compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values 99 | cert_eps: Scaling factor for interval bounds of the input 100 | """ 101 | if compute_bounds: 102 | x = batch['x'] 103 | else: 104 | x= batch['x'].val 105 | mask = batch['mask'] 106 | lengths = batch['lengths'] 107 | 108 | x_vecs = self.embs(x) # B, n, d 109 | if not self.no_wordvec_layer: 110 | x_vecs = self.linear_input(x_vecs) # B, n, h 111 | if isinstance(x_vecs, ibp.DiscreteChoiceTensor): 112 | x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) 113 | if self.no_wordvec_layer: 114 | z1 = x_vecs 115 | else: 116 | z1 = ibp.activation(F.relu, x_vecs) 117 | z1_masked = z1 * mask.unsqueeze(-1) # B, n, h 118 | if self.pool == 'mean': 119 | z1_pooled = ibp.sum(z1_masked / lengths.to(dtype=torch.float).view(-1, 1, 1), 1) # B, h 120 | elif self.pool == 'attn': 121 | z1_pooled = attention_pool(z1_masked, mask, self.attn_pool) 122 | else: # max 123 | # zero-masking works b/c ReLU guarantees that everything is >= 0 124 | z1_pooled = ibp.pool(torch.max, z1_masked, 1) # B, h 125 | z1_pooled = self.dropout(z1_pooled) 126 | z2 = ibp.activation(F.relu, self.linear_hidden(z1_pooled)) # B, h 127 | z2 = self.dropout(z2) 128 | output = self.linear_output(z2) # B, 1 129 | return output 130 | 131 | 132 | class CNNModel(AdversarialModel): 133 | """Convolutional neural network. 134 | 135 | Here is the overall architecture: 136 | 1) Rotate word vectors 137 | 2) One convolutional layer 138 | 3) Max/mean pool across all time 139 | 4) Predict with MLP 140 | 141 | """ 142 | def __init__(self, word_vec_size, hidden_size, kernel_size, word_mat, 143 | pool='max', dropout=0.2, no_wordvec_layer=False, 144 | early_ibp=False, relu_wordvec=True, unfreeze_wordvec=False): 145 | super(CNNModel, self).__init__() 146 | cnn_padding = (kernel_size - 1) // 2 # preserves size 147 | self.pool = pool 148 | # Ablations 149 | self.no_wordvec_layer = no_wordvec_layer 150 | self.early_ibp = early_ibp 151 | self.relu_wordvec = relu_wordvec 152 | self.unfreeze_wordvec=False 153 | # End ablations 154 | self.embs = ibp.Embedding.from_pretrained(word_mat, freeze=not self.unfreeze_wordvec) 155 | if no_wordvec_layer: 156 | self.conv1 = ibp.Conv1d(word_vec_size, hidden_size, kernel_size, 157 | padding=cnn_padding) 158 | else: 159 | self.linear_input = ibp.Linear(word_vec_size, hidden_size) 160 | self.conv1 = ibp.Conv1d(hidden_size, hidden_size, kernel_size, 161 | padding=cnn_padding) 162 | if self.pool == 'attn': 163 | self.attn_pool = ibp.Linear(hidden_size, 1) 164 | self.dropout = ibp.Dropout(dropout) 165 | self.fc_hidden = ibp.Linear(hidden_size, hidden_size) 166 | self.fc_output = ibp.Linear(hidden_size, 1) 167 | 168 | def forward(self, batch, compute_bounds=True, cert_eps=1.0): 169 | """ 170 | Args: 171 | batch: A batch dict from a TextClassificationDataset with the following keys: 172 | - x: tensor of word vector indices, size (B, n, 1) 173 | - mask: binary mask over words (1 for real, 0 for pad), size (B, n) 174 | - lengths: lengths of sequences, size (B,) 175 | compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values 176 | cert_eps: Scaling factor for interval bounds of the input 177 | """ 178 | if compute_bounds: 179 | x = batch['x'] 180 | else: 181 | x = batch['x'].val 182 | mask = batch['mask'] 183 | lengths = batch['lengths'] 184 | 185 | x_vecs = self.embs(x) # B, n, d 186 | if self.early_ibp and isinstance(x_vecs, ibp.DiscreteChoiceTensor): 187 | x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) 188 | if not self.no_wordvec_layer: 189 | x_vecs = self.linear_input(x_vecs) # B, n, h 190 | if isinstance(x_vecs, ibp.DiscreteChoiceTensor): 191 | x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) 192 | if self.no_wordvec_layer or not self.relu_wordvec: 193 | z = x_vecs 194 | else: 195 | z = ibp.activation(F.relu, x_vecs) # B, n, h 196 | z_masked = z * mask.unsqueeze(-1) # B, n, h 197 | z_cnn_in = z_masked.permute(0, 2, 1) # B, h, n 198 | c1 = ibp.activation(F.relu, self.conv1(z_cnn_in)) # B, h, n 199 | c1_masked = c1 * mask.unsqueeze(1) # B, h, n 200 | if self.pool == 'mean': 201 | fc_in = ibp.sum(c1_masked / lengths.to(dtype=torch.float).view(-1, 1, 1), 2) # B, h 202 | elif self.pool == 'attn': 203 | fc_in = attention_pool(c1_masked.permute(0, 2, 1), mask, self.attn_pool) # B, h 204 | else: 205 | # zero-masking works b/c ReLU guarantees that everything is >= 0 206 | fc_in = ibp.pool(torch.max, c1_masked, 2) # B, h 207 | fc_in = self.dropout(fc_in) 208 | fc_hidden = ibp.activation(F.relu, self.fc_hidden(fc_in)) # B, h 209 | fc_hidden = self.dropout(fc_hidden) 210 | output = self.fc_output(fc_hidden) # B, 1 211 | return output 212 | 213 | 214 | class LSTMModel(AdversarialModel): 215 | """LSTM text classification model. 216 | 217 | Here is the overall architecture: 218 | 1) Rotate word vectors 219 | 2) Feed to bi-LSTM 220 | 3) Max/mean pool across all time 221 | 4) Predict with MLP 222 | 223 | """ 224 | def __init__(self, word_vec_size, hidden_size, word_mat, device, pool='max', dropout=0.2, 225 | no_wordvec_layer=False): 226 | super(LSTMModel, self).__init__() 227 | self.hidden_size = hidden_size 228 | self.pool = pool 229 | self.no_wordvec_layer = no_wordvec_layer 230 | self.device = device 231 | self.embs = ibp.Embedding.from_pretrained(word_mat) 232 | if no_wordvec_layer: 233 | self.lstm = ibp.LSTM(word_vec_size, hidden_size, bidirectional=True) 234 | else: 235 | self.linear_input = ibp.Linear(word_vec_size, hidden_size) 236 | self.lstm = ibp.LSTM(hidden_size, hidden_size, bidirectional=True) 237 | self.dropout = ibp.Dropout(dropout) 238 | self.fc_hidden = ibp.Linear(2 * hidden_size, hidden_size) 239 | self.fc_output = ibp.Linear(hidden_size, 1) 240 | 241 | def forward(self, batch, compute_bounds=True, cert_eps=1.0, analysis_mode=False): 242 | """ 243 | Args: 244 | batch: A batch dict from a TextClassificationDataset with the following keys: 245 | - x: tensor of word vector indices, size (B, n, 1) 246 | - mask: binary mask over words (1 for real, 0 for pad), size (B, n) 247 | - lengths: lengths of sequences, size (B,) 248 | compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values 249 | cert_eps: Scaling factor for interval bounds of the input 250 | """ 251 | if compute_bounds: 252 | x = batch['x'] 253 | else: 254 | x = batch['x'].val 255 | mask = batch['mask'] 256 | lengths = batch['lengths'] 257 | 258 | B = x.shape[0] 259 | x_vecs = self.embs(x) # B, n, d 260 | if not self.no_wordvec_layer: 261 | x_vecs = self.linear_input(x_vecs) # B, n, h 262 | if isinstance(x_vecs, ibp.DiscreteChoiceTensor): 263 | x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) 264 | if self.no_wordvec_layer: 265 | z = x_vecs 266 | else: 267 | z = ibp.activation(F.relu, x_vecs) # B, n, h 268 | h0 = torch.zeros((B, 2 * self.hidden_size), device=self.device) # B, 2*h 269 | c0 = torch.zeros((B, 2 * self.hidden_size), device=self.device) # B, 2*h 270 | if analysis_mode: 271 | h_mat, c_mat, lstm_analysis = self.lstm(z, (h0, c0), mask=mask, analysis_mode=True) # B, n, 2*h each 272 | else: 273 | h_mat, c_mat = self.lstm(z, (h0, c0), mask=mask) # B, n, 2*h each 274 | h_masked = h_mat * mask.unsqueeze(2) 275 | if self.pool == 'mean': 276 | fc_in = ibp.sum(h_masked / lengths.to(dtype=torch.float).view(-1, 1, 1), 1) # B, 2*h 277 | else: 278 | raise NotImplementedError() 279 | fc_in = self.dropout(fc_in) 280 | fc_hidden = ibp.activation(F.relu, self.fc_hidden(fc_in)) # B, h 281 | fc_hidden = self.dropout(fc_hidden) 282 | output = self.fc_output(fc_hidden) # B, 1 283 | if analysis_mode: 284 | return output, h_mat, c_mat, lstm_analysis 285 | return output 286 | 287 | 288 | class LSTMFinalStateModel(AdversarialModel): 289 | """LSTM text classification model that uses final hidden state.""" 290 | def __init__(self, word_vec_size, hidden_size, word_mat, device, dropout=0.2, 291 | no_wordvec_layer=False): 292 | super(LSTMFinalStateModel, self).__init__() 293 | self.hidden_size = hidden_size 294 | self.no_wordvec_layer = no_wordvec_layer 295 | self.device = device 296 | self.embs = ibp.Embedding.from_pretrained(word_mat) 297 | if no_wordvec_layer: 298 | self.lstm = ibp.LSTM(word_vec_size, hidden_size, bidirectional=True) 299 | else: 300 | self.linear_input = ibp.Linear(word_vec_size, hidden_size) 301 | self.lstm = ibp.LSTM(hidden_size, hidden_size) 302 | self.dropout = ibp.Dropout(dropout) 303 | self.fc_hidden = ibp.Linear(hidden_size, hidden_size) 304 | self.fc_output = ibp.Linear(hidden_size, 1) 305 | 306 | def forward(self, batch, compute_bounds=True, cert_eps=1.0, analysis_mode=False): 307 | """ 308 | Args: 309 | batch: A batch dict from a TextClassificationDataset with the following keys: 310 | - x: tensor of word vector indices, size (B, n, 1) 311 | - mask: binary mask over words (1 for real, 0 for pad), size (B, n) 312 | - lengths: lengths of sequences, size (B,) 313 | compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values 314 | cert_eps: Scaling factor for interval bounds of the input 315 | """ 316 | if compute_bounds: 317 | x = batch['x'] 318 | else: 319 | x = batch['x'].val 320 | mask = batch['mask'] 321 | lengths = batch['lengths'] 322 | 323 | B = x.shape[0] 324 | x_vecs = self.embs(x) # B, n, d 325 | if not self.no_wordvec_layer: 326 | x_vecs = self.linear_input(x_vecs) # B, n, h 327 | if isinstance(x_vecs, ibp.DiscreteChoiceTensor): 328 | x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) 329 | if self.no_wordvec_layer: 330 | z = x_vecs 331 | else: 332 | z = ibp.activation(F.relu, x_vecs) # B, n, h 333 | h0 = torch.zeros((B, self.hidden_size), device=self.device) # B, h 334 | c0 = torch.zeros((B, self.hidden_size), device=self.device) # B, h 335 | if analysis_mode: 336 | h_mat, c_mat, lstm_analysis = self.lstm(z, (h0, c0), mask=mask, analysis_mode=True) # B, n, h each 337 | else: 338 | h_mat, c_mat = self.lstm(z, (h0, c0), mask=mask) # B, n, h each 339 | h_final = h_mat[:,-1,:] # B, h 340 | fc_in = self.dropout(h_final) 341 | fc_hidden = ibp.activation(F.relu, self.fc_hidden(fc_in)) # B, h 342 | fc_hidden = self.dropout(fc_hidden) 343 | output = self.fc_output(fc_hidden) # B, 1 344 | if analysis_mode: 345 | return output, h_mat, c_mat, lstm_analysis 346 | return output 347 | 348 | 349 | class Adversary(object): 350 | """An Adversary tries to fool a model on a given example.""" 351 | def __init__(self, attack_surface): 352 | self.attack_surface = attack_surface 353 | 354 | def run(self, model, dataset, device, opts=None): 355 | """Run adversary on a dataset. 356 | 357 | Args: 358 | model: a TextClassificationModel. 359 | dataset: a TextClassificationDataset. 360 | device: torch device. 361 | Returns: pair of 362 | - list of 0-1 adversarial loss of same length as |dataset| 363 | - list of list of adversarial examples (each is just a text string) 364 | """ 365 | raise NotImplementedError 366 | 367 | 368 | class ExhaustiveAdversary(Adversary): 369 | """An Adversary that exhaustively tries all allowed perturbations. 370 | 371 | Only practical for short sentences. 372 | """ 373 | def run(self, model, dataset, device, opts=None): 374 | is_correct = [] 375 | adv_exs = [] 376 | for x, y in dataset.raw_data: 377 | words = x.split() 378 | swaps = self.attack_surface.get_swaps(words) 379 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 380 | prod = 1 381 | for c in choices: 382 | prod *= len(c) 383 | print('ExhaustiveAdversary: "%s" -> %d options' % (x, prod)) 384 | all_raw = [(' '.join(x_new), y) for x_new in itertools.product(*choices)] 385 | cur_dataset = TextClassificationDataset.from_raw_data(all_raw, dataset.vocab) 386 | preds = model.query(cur_dataset, device) 387 | cur_adv_exs = [all_raw[i][0] for i, p in enumerate(preds) 388 | if p * (2 * y - 1) <= 0] 389 | print(cur_adv_exs) 390 | adv_exs.append(cur_adv_exs) 391 | is_correct.append(int(len(cur_adv_exs) > 0)) 392 | return is_correct, adv_exs 393 | 394 | 395 | class GreedyAdversary(Adversary): 396 | """An adversary that picks a random word and greedily tries perturbations.""" 397 | def __init__(self, attack_surface, num_epochs=10, num_tries=2, margin_goal=0.0): 398 | super(GreedyAdversary, self).__init__(attack_surface) 399 | self.num_epochs = num_epochs 400 | self.num_tries = num_tries 401 | self.margin_goal = margin_goal 402 | 403 | def run(self, model, dataset, device, opts=None): 404 | is_correct = [] 405 | adv_exs = [] 406 | for x, y in tqdm(dataset.raw_data): 407 | # First query the example itself 408 | orig_pred, bounds = model.query(TextClassificationDataset.from_raw_data( 409 | [(x, y)], dataset.vocab, self.attack_surface), device, return_bounds=True) 410 | orig_pred, (orig_lb, orig_ub) = orig_pred[0], bounds[0] 411 | cert_correct = (orig_lb * (2 * y - 1) > 0) and (orig_ub * (2 * y - 1) > 0) 412 | print('Logit bounds: %.6f <= %.6f <= %.6f, cert_correct=%s' % ( 413 | orig_lb, orig_pred, orig_ub, cert_correct)) 414 | if orig_pred * (2 * y - 1) <= 0: 415 | print('ORIGINAL PREDICTION WAS WRONG') 416 | is_correct.append(0) 417 | adv_exs.append(x) 418 | continue 419 | # Now run adversarial search 420 | words = x.split() 421 | swaps = self.attack_surface.get_swaps(words) 422 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 423 | found = False 424 | for try_idx in range(self.num_tries): 425 | cur_words = list(words) 426 | for epoch in range(self.num_epochs): 427 | word_idxs = list(range(len(choices))) 428 | random.shuffle(word_idxs) 429 | for i in word_idxs: 430 | cur_raw = [] 431 | for w_new in choices[i]: 432 | cur_raw.append((' '.join(cur_words[:i] + [w_new] + cur_words[i+1:]), y)) 433 | cur_dataset = TextClassificationDataset.from_raw_data(cur_raw, dataset.vocab) 434 | preds = model.query(cur_dataset, device) 435 | margins = [p * (2 * y - 1) for p in preds] 436 | best_idx = min(enumerate(margins), key=lambda x: x[1])[0] 437 | cur_words[i] = choices[i][best_idx] 438 | if margins[best_idx] < self.margin_goal: 439 | found = True 440 | is_correct.append(0) 441 | adv_exs.append([' '.join(cur_words)]) 442 | print('ADVERSARY SUCCESS on ("%s", %d): Found "%s" with margin %.2f' % (x, y, adv_exs[-1], margins[best_idx])) 443 | if cert_correct: 444 | print('^^ CERT CORRECT THOUGH') 445 | break 446 | if found: break 447 | if found: break 448 | else: 449 | is_correct.append(1) 450 | adv_exs.append([]) 451 | print('ADVERSARY FAILURE on ("%s", %d)' % (x, y)) 452 | return is_correct, adv_exs 453 | 454 | 455 | class GeneticAdversary(Adversary): 456 | """An adversary that runs a genetic attack.""" 457 | def __init__(self, attack_surface, num_iters=20, pop_size=60, margin_goal=0.0): 458 | super(GeneticAdversary, self).__init__(attack_surface) 459 | self.num_iters = num_iters 460 | self.pop_size = pop_size 461 | self.margin_goal = margin_goal 462 | 463 | def perturb(self, words, choices, model, y, vocab, device): 464 | if all(len(c) == 1 for c in choices): return words 465 | good_idxs = [i for i, c in enumerate(choices) if len(c) > 1] 466 | idx = random.sample(good_idxs, 1)[0] 467 | x_list = [' '.join(words[:idx] + [w_new] + words[idx+1:]) 468 | for w_new in choices[idx]] 469 | preds = [model.query(x, vocab, device) for x in x_list] 470 | margins = [p * (2 * y - 1) for p in preds] 471 | best_idx = min(enumerate(margins), key=lambda x: x[1])[0] 472 | cur_words = list(words) 473 | cur_words[idx] = choices[idx][best_idx] 474 | return cur_words, margins[best_idx] 475 | 476 | def run(self, model, dataset, device, opts=None): 477 | is_correct = [] 478 | adv_exs = [] 479 | for x, y in tqdm(dataset.raw_data): 480 | # First query the example itself 481 | orig_pred, (orig_lb, orig_ub) = model.query( 482 | x, dataset.vocab, device, return_bounds=True, 483 | attack_surface=self.attack_surface) 484 | cert_correct = (orig_lb * (2 * y - 1) > 0) and (orig_ub * (2 * y - 1) > 0) 485 | print('Logit bounds: %.6f <= %.6f <= %.6f, cert_correct=%s' % ( 486 | orig_lb, orig_pred, orig_ub, cert_correct)) 487 | if orig_pred * (2 * y - 1) <= 0: 488 | print('ORIGINAL PREDICTION WAS WRONG') 489 | is_correct.append(0) 490 | adv_exs.append(x) 491 | continue 492 | # Now run adversarial search 493 | words = x.split() 494 | swaps = self.attack_surface.get_swaps(words) 495 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 496 | found = False 497 | population = [self.perturb(words, choices, model, y, dataset.vocab, device) 498 | for i in range(self.pop_size)] 499 | for g in range(self.num_iters): 500 | best_idx = min(enumerate(population), key=lambda x: x[1][1])[0] 501 | print('Iteration %d: %.6f' % (g, population[best_idx][1])) 502 | if population[best_idx][1] < self.margin_goal: 503 | found = True 504 | is_correct.append(0) 505 | adv_exs.append(' '.join(population[best_idx][0])) 506 | print('ADVERSARY SUCCESS on ("%s", %d): Found "%s" with margin %.2f' % (x, y, adv_exs[-1], population[best_idx][1])) 507 | if cert_correct: 508 | print('^^ CERT CORRECT THOUGH') 509 | break 510 | new_population = [population[best_idx]] 511 | margins = np.array([m for c, m in population]) 512 | adv_probs = 1 / (1 + np.exp(margins)) + 1e-6 513 | # Sigmoid of negative margin, for probabilty of wrong class 514 | # Add 1e-6 for numerical stability 515 | sample_probs = adv_probs / np.sum(adv_probs) 516 | for i in range(1, self.pop_size): 517 | parent1 = population[np.random.choice(range(len(population)), p=sample_probs)][0] 518 | parent2 = population[np.random.choice(range(len(population)), p=sample_probs)][0] 519 | child = [random.sample([w1, w2], 1)[0] for (w1, w2) in zip(parent1, parent2)] 520 | child_mut, new_margin = self.perturb(child, choices, model, y, 521 | dataset.vocab, device) 522 | new_population.append((child_mut, new_margin)) 523 | population = new_population 524 | else: 525 | is_correct.append(1) 526 | adv_exs.append([]) 527 | print('ADVERSARY FAILURE on ("%s", %d)' % (x, y)) 528 | return is_correct, adv_exs 529 | 530 | 531 | def load_datasets(device, opts): 532 | """ 533 | Loads text classification datasets given opts on the device and returns the dataset. 534 | If a data cache is specified in opts and the cached data there is of the same class 535 | as the one specified in opts, uses the cache. Otherwise reads from the raw dataset 536 | files specified in OPTS. 537 | Returns: 538 | - train_data: EntailmentDataset - Processed training dataset 539 | - dev_data: Optional[EntailmentDataset] - Processed dev dataset if raw dev data was found or 540 | dev_frac was specified in opts 541 | - word_mat: torch.Tensor 542 | - attack_surface: AttackSurface - defines the adversarial attack surface 543 | """ 544 | data_class = ToyClassificationDataset if opts.use_toy_data else IMDBDataset 545 | try: 546 | with open(os.path.join(opts.data_cache_dir, 'train_data.pkl'), 'rb') as infile: 547 | train_data = pickle.load(infile) 548 | if not isinstance(train_data, data_class): 549 | raise Exception("Cached dataset of wrong class: {}".format(type(train_data))) 550 | with open(os.path.join(opts.data_cache_dir, 'dev_data.pkl'), 'rb') as infile: 551 | dev_data = pickle.load(infile) 552 | if not isinstance(dev_data, data_class): 553 | raise Exception("Cached dataset of wrong class: {}".format(type(train_data))) 554 | with open(os.path.join(opts.data_cache_dir, 'word_mat.pkl'), 'rb') as infile: 555 | word_mat = pickle.load(infile) 556 | with open(os.path.join(opts.data_cache_dir, 'attack_surface.pkl'), 'rb') as infile: 557 | attack_surface = pickle.load(infile) 558 | print("Loaded data from {}.".format(opts.data_cache_dir)) 559 | except Exception: 560 | if opts.use_toy_data: 561 | attack_surface = ToyClassificationAttackSurface(ToyClassificationDataset.VOCAB_LIST) 562 | elif opts.use_lm: 563 | attack_surface = attacks.LMConstrainedAttackSurface.from_files( 564 | opts.neighbor_file, opts.imdb_lm_file) 565 | else: 566 | attack_surface = attacks.WordSubstitutionAttackSurface.from_file(opts.neighbor_file) 567 | print('Reading dataset.') 568 | raw_data = data_class.get_raw_data(opts.imdb_dir, test=opts.test) 569 | word_set = raw_data.get_word_set(attack_surface) 570 | vocab, word_mat = vocabulary.Vocabulary.read_word_vecs(word_set, opts.glove_dir, opts.glove, device) 571 | train_data = data_class.from_raw_data(raw_data.train_data, vocab, attack_surface, 572 | downsample_to=opts.downsample_to, 573 | downsample_shard=opts.downsample_shard, 574 | truncate_to=opts.truncate_to) 575 | dev_data = data_class.from_raw_data(raw_data.dev_data, vocab, attack_surface, 576 | downsample_to=opts.downsample_to, 577 | downsample_shard=opts.downsample_shard, 578 | truncate_to=opts.truncate_to) 579 | if opts.data_cache_dir: 580 | with open(os.path.join(opts.data_cache_dir, 'train_data.pkl'), 'wb') as outfile: 581 | pickle.dump(train_data, outfile) 582 | with open(os.path.join(opts.data_cache_dir, 'dev_data.pkl'), 'wb') as outfile: 583 | pickle.dump(dev_data, outfile) 584 | with open(os.path.join(opts.data_cache_dir, 'word_mat.pkl'), 'wb') as outfile: 585 | pickle.dump(word_mat, outfile) 586 | with open(os.path.join(opts.data_cache_dir, 'attack_surface.pkl'), 'wb') as outfile: 587 | pickle.dump(attack_surface, outfile) 588 | return train_data, dev_data, word_mat, attack_surface 589 | 590 | 591 | def num_correct(model_output, gold_labels): 592 | """ 593 | Given the output of model and gold labels returns number of correct and certified correct 594 | predictions 595 | Args: 596 | - model_output: output of the model, could be ibp.IntervalBoundedTensor or torch.Tensor 597 | - gold_labels: torch.Tensor, should be of size 1 per sample, 1 for positive 0 for negative 598 | Returns: 599 | - num_correct: int - number of correct predictions from the actual model output 600 | - num_cert_correct - number of bounds-certified correct predictions if the model_output was an 601 | IntervalBoundedTensor, 0 otherwise. 602 | """ 603 | if isinstance(model_output, ibp.IntervalBoundedTensor): 604 | logits = model_output.val 605 | num_cert_correct = sum( 606 | all((b * (2 * y - 1)).item() > 0 for b in (model_output.lb[i], model_output.ub[i])) 607 | for i, y in enumerate(gold_labels) 608 | ) 609 | else: 610 | logits = model_output 611 | num_cert_correct = 0 612 | num_correct = sum( 613 | (logits[i] * (2 * y - 1)).item() > 0 for i, y in enumerate(gold_labels) 614 | ) 615 | return num_correct, num_cert_correct 616 | 617 | 618 | def load_model(word_mat, device, opts): 619 | """ 620 | Try to load a model on the device given the word_mat and opts. 621 | Tries to load a model from the given or latest checkpoint if specified in the opts. 622 | Otherwise instantiates a new model on the device. 623 | """ 624 | if opts.model == 'bow': 625 | model = BOWModel( 626 | vocabulary.GLOVE_CONFIGS[opts.glove]['size'], opts.hidden_size, word_mat, 627 | pool=opts.pool, dropout=opts.dropout_prob, no_wordvec_layer=opts.no_wordvec_layer).to(device) 628 | elif opts.model == 'cnn': 629 | model = CNNModel( 630 | vocabulary.GLOVE_CONFIGS[opts.glove]['size'], opts.hidden_size, opts.kernel_size, 631 | word_mat, pool=opts.pool, dropout=opts.dropout_prob, no_wordvec_layer=opts.no_wordvec_layer, 632 | early_ibp=opts.early_ibp, relu_wordvec=not opts.no_relu_wordvec, unfreeze_wordvec=opts.unfreeze_wordvec).to(device) 633 | elif opts.model == 'lstm': 634 | model = LSTMModel( 635 | vocabulary.GLOVE_CONFIGS[opts.glove]['size'], opts.hidden_size, 636 | word_mat, device, pool=opts.pool, dropout=opts.dropout_prob, no_wordvec_layer=opts.no_wordvec_layer).to(device) 637 | elif opts.model == 'lstm-final-state': 638 | model = LSTMFinalStateModel( 639 | vocabulary.GLOVE_CONFIGS[opts.glove]['size'], opts.hidden_size, 640 | word_mat, device, dropout=opts.dropout_prob, no_wordvec_layer=opts.no_wordvec_layer).to(device) 641 | if opts.load_dir: 642 | try: 643 | if opts.load_ckpt is None: 644 | load_fn = sorted(glob.glob(os.path.join(opts.load_dir, 'model-checkpoint-[0-9]+.pth')))[-1] 645 | else: 646 | load_fn = os.path.join(opts.load_dir, 'model-checkpoint-%d.pth' % opts.load_ckpt) 647 | print('Loading model from %s.' % load_fn) 648 | state_dict = dict(torch.load(load_fn)) 649 | state_dict['embs.weight'] = model.embs.weight 650 | model.load_state_dict(state_dict) 651 | print('Finished loading model.') 652 | except Exception as ex: 653 | print("Couldn't load model, starting anew: {}".format(ex)) 654 | return model 655 | 656 | 657 | class RawClassificationDataset(data_util.RawDataset): 658 | """ 659 | Dataset that only holds x,y as (str, str) tuples 660 | """ 661 | def get_word_set(self, attack_surface): 662 | with open(COUNTER_FITTED_FILE) as f: 663 | counter_vocab = set([line.split(' ')[0] for line in f]) 664 | word_set = set() 665 | for x, y in self.data: 666 | words = [w.lower() for w in x.split(' ')] 667 | for w in words: 668 | word_set.add(w) 669 | try: 670 | swaps = attack_surface.get_swaps(words) 671 | for cur_swaps in swaps: 672 | for w in cur_swaps: 673 | word_set.add(w) 674 | except KeyError: 675 | # For now, ignore things not in attack surface 676 | # If we really need them, later code will throw an error 677 | pass 678 | return word_set & counter_vocab 679 | 680 | 681 | class TextClassificationDataset(data_util.ProcessedDataset): 682 | """ 683 | Dataset that holds processed example dicts 684 | """ 685 | @classmethod 686 | def from_raw_data(cls, raw_data, vocab, attack_surface=None, truncate_to=None, 687 | downsample_to=None, downsample_shard=0): 688 | if downsample_to: 689 | raw_data = raw_data[downsample_shard * downsample_to:(downsample_shard+1) * downsample_to] 690 | examples = [] 691 | for x, y in raw_data: 692 | all_words = [w.lower() for w in x.split()] 693 | if attack_surface: 694 | all_swaps = attack_surface.get_swaps(all_words) 695 | words = [w for w in all_words if w in vocab] 696 | swaps = [s for w, s in zip(all_words, all_swaps) if w in vocab] 697 | choices = [[w] + cur_swaps for w, cur_swaps in zip(words, swaps)] 698 | else: 699 | words = [w for w in all_words if w in vocab] # Delete UNK words 700 | if truncate_to: 701 | words = words[:truncate_to] 702 | word_idxs = [vocab.get_index(w) for w in words] 703 | x_torch = torch.tensor(word_idxs).view(1, -1, 1) # (1, T, d) 704 | if attack_surface: 705 | choices_word_idxs = [ 706 | torch.tensor([vocab.get_index(c) for c in c_list], dtype=torch.long) for c_list in choices 707 | ] 708 | if any(0 in c.view(-1).tolist() for c in choices_word_idxs): 709 | raise ValueError("UNK tokens found") 710 | choices_torch = pad_sequence(choices_word_idxs, batch_first=True).unsqueeze(2).unsqueeze(0) # (1, T, C, 1) 711 | choices_mask = (choices_torch.squeeze(-1) != 0).long() # (1, T, C) 712 | else: 713 | choices_torch = x_torch.view(1, -1, 1, 1) # (1, T, 1, 1) 714 | choices_mask = torch.ones_like(x_torch.view(1, -1, 1)) 715 | mask_torch = torch.ones((1, len(word_idxs))) 716 | x_bounded = ibp.DiscreteChoiceTensor(x_torch, choices_torch, choices_mask, mask_torch) 717 | y_torch = torch.tensor(y, dtype=torch.float).view(1, 1) 718 | lengths_torch = torch.tensor(len(word_idxs)).view(1) 719 | examples.append(dict(x=x_bounded, y=y_torch, mask=mask_torch, lengths=lengths_torch)) 720 | return cls(raw_data, vocab, examples) 721 | 722 | @staticmethod 723 | def example_len(example): 724 | return example['x'].shape[1] 725 | 726 | @staticmethod 727 | def collate_examples(examples): 728 | """ 729 | Turns a list of examples into a workable batch: 730 | """ 731 | if len(examples) == 1: 732 | return examples[0] 733 | B = len(examples) 734 | max_len = max(ex['x'].shape[1] for ex in examples) 735 | x_vals = [] 736 | choice_mats = [] 737 | choice_masks = [] 738 | y = torch.zeros((B, 1)) 739 | lengths = torch.zeros((B, ), dtype=torch.long) 740 | masks = torch.zeros((B, max_len)) 741 | for i, ex in enumerate(examples): 742 | x_vals.append(ex['x'].val) 743 | choice_mats.append(ex['x'].choice_mat) 744 | choice_masks.append(ex['x'].choice_mask) 745 | cur_len = ex['x'].shape[1] 746 | masks[i, :cur_len] = 1 747 | y[i, 0] = ex['y'] 748 | lengths[i] = ex['lengths'][0] 749 | x_vals = data_util.multi_dim_padded_cat(x_vals, 0).long() 750 | choice_mats = data_util.multi_dim_padded_cat(choice_mats, 0).long() 751 | choice_masks = data_util.multi_dim_padded_cat(choice_masks, 0).long() 752 | return {'x': ibp.DiscreteChoiceTensor(x_vals, choice_mats, choice_masks, masks), 753 | 'y': y, 'mask': masks, 'lengths': lengths} 754 | 755 | 756 | class ToyClassificationDataset(TextClassificationDataset): 757 | """ 758 | Dataset that holds a toy sentiment classification data 759 | """ 760 | VOCAB_LIST = [ 761 | 'cat', 'dog', 'fish', 'tiger', 'chicken', 762 | 'hamster', 'bear', 'lion', 'dragon', 'horse', 763 | 'monkey', 'goat', 'sheep', 'goose', 'duck'] 764 | @classmethod 765 | def get_raw_data(cls, ignore_dir, data_size=5000, max_len=10, *args, **kwargs): 766 | data = [] 767 | for t in range(data_size): 768 | seq_len = random.randint(3, max_len) 769 | words = [random.sample(cls.VOCAB_LIST, 1)[0] for i in range(seq_len - 1)] 770 | if random.random() > 0.5: 771 | words.append(words[0]) 772 | y = 1 773 | else: 774 | other_words = list(cls.VOCAB_LIST) 775 | other_words.remove(words[0]) 776 | words.append(random.sample(other_words, 1)[0]) 777 | y = 0 778 | data.append((' '.join(words), y)) 779 | num_train = int(round(data_size * 0.8)) 780 | train_data = data[:num_train] 781 | dev_data = data[num_train:] 782 | print(dev_data[:10]) 783 | return RawClassificationDataset(train_data, dev_data) 784 | 785 | 786 | class ToyClassificationAttackSurface(attacks.AttackSurface): 787 | """Attack surface for ToyClassificationDataset.""" 788 | def __init__(self, vocab_list): 789 | self.vocab_list = vocab_list 790 | 791 | def get_swaps(self, words): 792 | swaps = [] 793 | s = ' '.join(words) 794 | for i in range(len(words)): 795 | if i == 0 or i == len(words) - 1: 796 | swaps.append([]) 797 | else: 798 | swaps.append(self.vocab_list) 799 | return swaps 800 | 801 | 802 | class IMDBDataset(TextClassificationDataset): 803 | """ 804 | Dataset that holds the IMDB sentiment classification data 805 | """ 806 | @classmethod 807 | def read_text(cls, imdb_dir, split): 808 | if split == 'test': 809 | subdir = 'test' 810 | else: 811 | subdir = 'train' 812 | with open(os.path.join(imdb_dir, subdir, 'imdb_%s_files.txt' % split)) as f: 813 | filenames = [line.strip() for line in f] 814 | data = [] 815 | num_words = 0 816 | for fn in tqdm(filenames): 817 | label = 1 if fn.startswith('pos') else 0 818 | with open(os.path.join(imdb_dir, subdir, fn)) as f: 819 | x_raw = f.readlines()[0].strip().replace('
', ' ') 820 | x_toks = word_tokenize(x_raw) 821 | num_words += len(x_toks) 822 | data.append((' '.join(x_toks), label)) 823 | num_pos = sum(y for x, y in data) 824 | num_neg = sum(1 - y for x, y in data) 825 | avg_words = num_words / len(data) 826 | print('Read %d examples (+%d, -%d), average length %d words' % ( 827 | len(data), num_pos, num_neg, avg_words)) 828 | return data 829 | 830 | @classmethod 831 | def get_raw_data(cls, imdb_dir, test=False): 832 | train_data = cls.read_text(imdb_dir, 'train') 833 | if test: 834 | dev_data = cls.read_text(imdb_dir, 'test') 835 | else: 836 | dev_data = cls.read_text(imdb_dir, 'dev') 837 | return RawClassificationDataset(train_data, dev_data) 838 | 839 | 840 | class DataAugmenter(data_util.DataAugmenter): 841 | def augment(self, dataset): 842 | new_examples = [] 843 | for ex in tqdm(dataset.examples): 844 | new_examples.append(ex) 845 | x_orig = ex['x'] # (1, T, 1) 846 | choices = [] 847 | for i in range(x_orig.shape[1]): 848 | cur_choices = torch.masked_select( 849 | x_orig.choice_mat[0,i,:,0], x_orig.choice_mask[0,i,:].type(torch.uint8)) 850 | choices.append(cur_choices) 851 | for t in range(self.augment_by): 852 | x_new = torch.stack([choices[i][random.choice(range(len(choices[i])))] 853 | for i in range(len(choices))]).view(1, -1, 1) 854 | x_bounded = ibp.DiscreteChoiceTensor( 855 | x_new, x_orig.choice_mat, x_orig.choice_mask, x_orig.sequence_mask) 856 | ex_new = dict(ex) 857 | ex_new['x'] = x_bounded 858 | new_examples.append(ex_new) 859 | return TextClassificationDataset(None, dataset.vocab, new_examples) 860 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import random 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from tqdm import tqdm 12 | 13 | import data_util 14 | import entailment 15 | import text_classification 16 | import vocabulary 17 | 18 | 19 | # Maps string keys to modules that hold the relevant functions for training against 20 | # their tasks 21 | TASK_CLASSES = { 22 | 'classification': text_classification, 23 | 'entailment': entailment 24 | } 25 | 26 | 27 | def train(task_class, model, train_data, num_epochs, lr, device, dev_data=None, 28 | cert_frac=0.0, initial_cert_frac=0.0, cert_eps=1.0, initial_cert_eps=0.0, non_cert_train_epochs=0, full_train_epochs=0, 29 | batch_size=1, epochs_per_save=1, augmenter=None, clip_grad_norm=0, weight_decay=0, 30 | save_best_only=False): 31 | print('Training model') 32 | sys.stdout.flush() 33 | loss_func = task_class.LOSS_FUNC 34 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 35 | zero_stats = {'epoch': 0, 'clean_acc': 0.0, 'cert_acc': 0.0} 36 | if augmenter: 37 | zero_stats['aug_acc'] = 0.0 38 | all_epoch_stats = { 39 | "loss": {"total": [], 40 | "clean": [], 41 | "cert": []}, 42 | "cert": {"frac": [], 43 | "eps": []}, 44 | "acc": { 45 | "train": { 46 | "clean": [], 47 | "cert": []}, 48 | "dev": { 49 | "clean": [], 50 | "cert": []}, 51 | "best_dev": { 52 | "clean": [zero_stats], 53 | "cert": [zero_stats]}}, 54 | "total_epochs": num_epochs, 55 | } 56 | aug_dev_data = None 57 | if augmenter: 58 | all_epoch_stats['acc']['dev']['aug'] = [] 59 | all_epoch_stats['acc']['best_dev']['aug'] = [zero_stats] 60 | print('Augmenting training data') 61 | aug_train_data = augmenter.augment(train_data) 62 | data = aug_train_data.get_loader(batch_size) 63 | if dev_data: 64 | print('Augmenting dev data') 65 | aug_dev_data = augmenter.augment(dev_data) # Augment dev set now, for early stopping 66 | else: 67 | data = train_data.get_loader(batch_size) # Create all batches now and pin them in memory 68 | # Linearly increase the weight of adversarial loss over all the epochs to end up at the final desired fraction 69 | cert_schedule = torch.tensor(np.linspace(initial_cert_frac, cert_frac, num_epochs - full_train_epochs - non_cert_train_epochs), dtype=torch.float, device=device) 70 | eps_schedule = torch.tensor(np.linspace(initial_cert_eps, cert_eps, num_epochs - full_train_epochs - non_cert_train_epochs), dtype=torch.float, device=device) 71 | for t in range(num_epochs): 72 | model.train() 73 | if t < non_cert_train_epochs: 74 | cur_cert_frac = 0.0 75 | cur_cert_eps = 0.0 76 | else: 77 | cur_cert_frac = cert_schedule[t - non_cert_train_epochs] if t - non_cert_train_epochs < len(cert_schedule) else cert_schedule[-1] 78 | cur_cert_eps = eps_schedule[t - non_cert_train_epochs] if t - non_cert_train_epochs < len(eps_schedule) else eps_schedule[-1] 79 | epoch = { 80 | "total_loss": 0.0, 81 | "clean_loss": 0.0, 82 | "cert_loss": 0.0, 83 | "num_correct": 0, 84 | "num_cert_correct": 0, 85 | "num": 0, 86 | "clean_acc": 0, 87 | "cert_acc": 0, 88 | "dev": {}, 89 | "best_dev": {}, 90 | "cert_frac": cur_cert_frac if isinstance(cur_cert_frac, float) else cur_cert_frac.item(), 91 | "cert_eps": cur_cert_eps if isinstance(cur_cert_eps, float) else cur_cert_eps.item(), 92 | "epoch": t, 93 | } 94 | with tqdm(data) as batch_loop: 95 | for batch_idx, batch in enumerate(batch_loop): 96 | batch = data_util.dict_batch_to_device(batch, device) 97 | optimizer.zero_grad() 98 | if cur_cert_frac > 0.0: 99 | out = model.forward(batch, cert_eps=cur_cert_eps) 100 | logits = out.val 101 | loss = loss_func(logits, batch['y']) 102 | epoch["clean_loss"] += loss.item() 103 | cert_loss = torch.max(loss_func(out.lb, batch['y']), 104 | loss_func(out.ub, batch['y'])) 105 | loss = cur_cert_frac * cert_loss + (1.0 - cur_cert_frac) * loss 106 | epoch["cert_loss"] += cert_loss.item() 107 | else: 108 | # Bypass computing bounds during training 109 | logits = out = model.forward(batch, compute_bounds=False) 110 | loss = loss_func(logits, batch['y']) 111 | epoch["total_loss"] += loss.item() 112 | epoch["num"] += len(batch['y']) 113 | num_correct, num_cert_correct = task_class.num_correct(out, batch['y']) 114 | epoch["num_correct"] += num_correct 115 | epoch["num_cert_correct"] += num_cert_correct 116 | loss.backward() 117 | if any(p.grad is not None and torch.isnan(p.grad).any() for p in model.parameters()): 118 | nan_params = [p.name for p in model.parameters() if p.grad is not None and torch.isnan(p.grad).any()] 119 | print('NaN found in gradients: %s' % nan_params, file=sys.stderr) 120 | else: 121 | if clip_grad_norm: 122 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) 123 | optimizer.step() 124 | if cert_frac > 0.0: 125 | print("Epoch {epoch:>3}: train loss: {total_loss:.6f}, clean_loss: {clean_loss:.6f}, cert_loss: {cert_loss:.6f}".format(**epoch)) 126 | else: 127 | print("Epoch {epoch:>3}: train loss: {total_loss:.6f}".format(**epoch)) 128 | sys.stdout.flush() 129 | 130 | epoch["clean_acc"] = 100.0 * epoch["num_correct"] / epoch["num"] 131 | acc_str = " Train accuracy: {num_correct}/{num} = {clean_acc:.2f}".format(**epoch) 132 | if cert_frac > 0.0: 133 | epoch["cert_acc"] = 100.0 * epoch["num_cert_correct"] / epoch["num"] 134 | acc_str += ", certified {num_cert_correct}/{num} = {cert_acc:.2f}".format(**epoch) 135 | print(acc_str) 136 | is_best = False 137 | if dev_data: 138 | dev_results = test(task_class, model, "Dev", dev_data, device, batch_size=batch_size, 139 | aug_dataset=aug_dev_data) 140 | epoch['dev'] = dev_results 141 | all_epoch_stats['acc']['dev']['clean'].append(dev_results['clean_acc']) 142 | all_epoch_stats['acc']['dev']['cert'].append(dev_results['cert_acc']) 143 | if augmenter: 144 | all_epoch_stats['acc']['dev']['aug'].append(dev_results['aug_acc']) 145 | dev_stats = { 146 | 'epoch': t, 147 | 'loss': dev_results['loss'], 148 | 'clean_acc': dev_results['clean_acc'], 149 | 'cert_acc': dev_results['cert_acc'] 150 | } 151 | if augmenter: 152 | dev_stats['aug_acc'] = dev_results['aug_acc'] 153 | if dev_results['clean_acc'] > all_epoch_stats['acc']['best_dev']['clean'][-1]['clean_acc']: 154 | all_epoch_stats['acc']['best_dev']['clean'].append(dev_stats) 155 | if cert_frac == 0.0 and not augmenter: 156 | is_best = True 157 | if dev_results['cert_acc'] > all_epoch_stats['acc']['best_dev']['cert'][-1]['cert_acc']: 158 | all_epoch_stats['acc']['best_dev']['cert'].append(dev_stats) 159 | if cert_frac > 0.0: 160 | is_best = True 161 | if augmenter and dev_results['aug_acc'] > all_epoch_stats['acc']['best_dev']['aug'][-1]['aug_acc']: 162 | all_epoch_stats['acc']['best_dev']['aug'].append(dev_stats) 163 | if cert_frac == 0.0 and augmenter: 164 | is_best = True 165 | epoch['best_dev'] = { 166 | 'clean': all_epoch_stats['acc']['best_dev']['clean'][-1], 167 | 'cert': all_epoch_stats['acc']['best_dev']['cert'][-1]} 168 | if augmenter: 169 | epoch['best_dev']['aug'] = all_epoch_stats['acc']['best_dev']['aug'][-1] 170 | all_epoch_stats["loss"]['total'].append(epoch["total_loss"]) 171 | all_epoch_stats["loss"]['clean'].append(epoch["clean_loss"]) 172 | all_epoch_stats["loss"]['cert'].append(epoch["cert_loss"]) 173 | all_epoch_stats['cert']['frac'].append(epoch["cert_frac"]) 174 | all_epoch_stats['cert']['eps'].append(epoch["cert_eps"]) 175 | all_epoch_stats["acc"]['train']['clean'].append(epoch["clean_acc"]) 176 | all_epoch_stats["acc"]['train']['cert'].append(epoch["cert_acc"]) 177 | with open(os.path.join(OPTS.out_dir, "run_stats.json"), "w") as outfile: 178 | json.dump(epoch, outfile) 179 | with open(os.path.join(OPTS.out_dir, "all_epoch_stats.json"), "w") as outfile: 180 | json.dump(all_epoch_stats, outfile) 181 | if ((save_best_only and is_best) 182 | or (not save_best_only and epochs_per_save and (t+1) % epochs_per_save == 0) 183 | or t == num_epochs - 1): 184 | if save_best_only and is_best: 185 | for fn in glob.glob(os.path.join(OPTS.out_dir, 'model-checkpoint*.pth')): 186 | os.remove(fn) 187 | model_save_path = os.path.join(OPTS.out_dir, "model-checkpoint-{}.pth".format(t)) 188 | print('Saving model to %s' % model_save_path) 189 | torch.save(model.state_dict(), model_save_path) 190 | 191 | return model 192 | 193 | 194 | def test(task_class, model, name, dataset, device, show_certified=False, batch_size=1, 195 | adversary=None, aug_dataset=None): 196 | model.eval() 197 | loss_func = task_class.LOSS_FUNC 198 | results = { 199 | 'name': name, 200 | 'num_total': 0, 201 | 'num_correct': 0, 202 | 'num_cert_correct': 0, 203 | 'clean_acc': 0.0, 204 | 'cert_acc': 0.0, 205 | 'loss': 0.0 206 | } 207 | data = dataset.get_loader(batch_size) 208 | with torch.no_grad(): 209 | for batch in tqdm(data): 210 | batch = data_util.dict_batch_to_device(batch, device) 211 | out = model.forward(batch, cert_eps=1.0) 212 | results['loss'] += loss_func(out.val, batch['y']).item() 213 | num_correct, num_cert_correct = task_class.num_correct(out, batch['y']) 214 | results["num_correct"] += num_correct 215 | results["num_cert_correct"] += num_cert_correct 216 | results['num_total'] += len(batch['y']) 217 | if aug_dataset: 218 | results['aug_loss'] = results['loss'] 219 | results['aug_total'] = results['num_total'] 220 | results['aug_correct'] = results['num_correct'] 221 | aug_data = aug_dataset.get_loader(batch_size) 222 | for batch in tqdm(aug_data): 223 | batch = data_util.dict_batch_to_device(batch, device) 224 | out = model.forward(batch, cert_eps=1.0) 225 | results['aug_loss'] += loss_func(out.val, batch['y']).item() 226 | num_correct, num_cert_correct = task_class.num_correct(out, batch['y']) 227 | results["aug_correct"] += num_correct 228 | results['aug_total'] += len(batch['y']) 229 | results['clean_acc'] = 100 * results['num_correct'] / results['num_total'] 230 | results['cert_acc'] = 100 * results['num_cert_correct'] / results['num_total'] 231 | out_str = " {name} loss = {loss:.2f}; accuracy: {num_correct}/{num_total} = {clean_acc:.2f}, certified {num_cert_correct}/{num_total} = {cert_acc:.2f}".format(**results) 232 | if aug_dataset: 233 | results['aug_acc'] = 100 * results['aug_correct'] / results['aug_total'] 234 | out_str += ', augmented %d/%d = %.2f' % ( 235 | results['aug_correct'], results['aug_total'], results['aug_acc']) 236 | if adversary: 237 | adv_correct, adv_exs = adversary.run(model, dataset, device, opts=OPTS) 238 | results['num_adv_correct'] = sum(adv_correct) 239 | results['adv_acc'] = 100 * results['num_adv_correct'] / len(dataset) 240 | out_str += ', adversarial %d/%d = %.2f' % ( 241 | results['num_adv_correct'], len(dataset), results['adv_acc']) 242 | print(out_str) 243 | return results 244 | 245 | 246 | def parse_args(): 247 | parser = argparse.ArgumentParser() 248 | parser.add_argument('task', choices=TASK_CLASSES.keys()) 249 | parser.add_argument('model', choices=['bow', 'cnn', 'lstm', 'decomp-attn', 'lstm-final-state']) 250 | parser.add_argument('out_dir', help='Directory to store and load output') 251 | # Model 252 | parser.add_argument('--hidden-size', '-d', type=int, default=100) 253 | parser.add_argument('--kernel-size', '-k', type=int, default=3, 254 | help='Kernel size, for CNN convolutions and pooling') 255 | parser.add_argument('--pool', choices=['max', 'mean', 'attn'], default='max') 256 | parser.add_argument('--num-layers', type=int, default=3, help='Num layers for SNLI baseline BOW model') 257 | parser.add_argument('--no-wordvec-layer', action='store_true', help="Don't apply linear transform to word vectors") 258 | parser.add_argument('--early-ibp', action='store_true', help="Do to_interval_bounded directly on base word vectors") 259 | parser.add_argument('--no-relu-wordvec', action='store_true', help="Don't do ReLU after word vector transform") 260 | parser.add_argument('--unfreeze-wordvec', action='store_true', help="Don't freeze word vectors") 261 | parser.add_argument('--glove', '-g', choices=vocabulary.GLOVE_CONFIGS, default='840B.300d') 262 | # Adversary 263 | parser.add_argument('--adversary', '-a', choices=['exhaustive', 'greedy', 'genetic'], 264 | default=None, help='Which adversary to test on') 265 | parser.add_argument('--adv-num-epochs', type=int, default=10) 266 | parser.add_argument('--adv-num-tries', type=int, default=2) 267 | parser.add_argument('--adv-pop-size', type=int, default=60) 268 | parser.add_argument('--use-lm', action='store_true', help='Use LM scores to define attack surface') 269 | # Training 270 | parser.add_argument('--num-epochs', '-T', type=int, default=1) 271 | parser.add_argument('--learning-rate', '-r', type=float, default=1e-3) 272 | parser.add_argument('--dropout-prob', type=float, default=0.1) 273 | parser.add_argument('--batch-size', '-b', type=int, default=1) 274 | parser.add_argument('--clip-grad-norm', type=float, default=0.25) 275 | parser.add_argument('--weight-decay', type=float, default=1e-4) 276 | parser.add_argument('--cert-frac', '-c', type=float, default=0.0, 277 | help='Fraction of loss devoted to certified loss term.') 278 | parser.add_argument('--initial-cert-frac', type=float, default=0.0, 279 | help='If certified loss is being used, where the linear scale for it begins') 280 | parser.add_argument('--cert-eps', type=float, default=1.0, 281 | help='Max scaling factor for the interval bounds of the attack words to be used') 282 | parser.add_argument('--initial-cert-eps', type=float, default=0.0, 283 | help='If certified loss is being used, where the linear scale for its epsilon begins') 284 | parser.add_argument('--full-train-epochs', type=int, default=0, 285 | help='If specified use full cert_frac and cert_eps for this many epochs at the end') 286 | parser.add_argument('--non-cert-train-epochs', type=int, default=0, 287 | help='If specified train this many epochs regularly in beginning') 288 | parser.add_argument('--epochs-per-save', type=int, default=1, 289 | help='How often to save model; 0 to only save final model') 290 | parser.add_argument('--save-best-only', action='store_true', 291 | help='Only save best dev epochs (based on cert acc if cert_frac > 0, clean acc else)') 292 | parser.add_argument('--augment-by', type=int, default=0, 293 | help='How many augmented examples per real example') 294 | # Data and files 295 | parser.add_argument('--adv-only', action='store_true', help='Only run the adversary against the model on the given evaluation set') 296 | parser.add_argument('--test', action='store_true', help='Evaluate on test set') 297 | parser.add_argument('--data-cache-dir', '-D', help='Where to load cached dataset and glove') 298 | parser.add_argument('--neighbor-file', type=str, default=data_util.NEIGHBOR_FILE) 299 | parser.add_argument('--glove-dir', type=str, default=vocabulary.GLOVE_DIR) 300 | parser.add_argument('--imdb-dir', type=str, default=text_classification.IMDB_DIR) 301 | parser.add_argument('--snli-dir', type=str, default=entailment.SNLI_DIR) 302 | parser.add_argument('--imdb-lm-file', type=str, default=text_classification.LM_FILE) 303 | parser.add_argument('--snli-lm-file', type=str, default=entailment.LM_FILE) 304 | parser.add_argument('--prepend-null', action='store_true', help='If true add UNK token to sequences') 305 | parser.add_argument('--normalize-word-vecs', action='store_true', help='If true normalize word vectors') 306 | parser.add_argument('--downsample-to', type=int, default=None, 307 | help='Downsample train and dev data to this many examples') 308 | parser.add_argument('--downsample-shard', type=int, default=0, 309 | help='Downsample starting at this multiple of downsample_to') 310 | parser.add_argument('--use-toy-data', action='store_true') 311 | parser.add_argument('--truncate-to', type=int, default=None, 312 | help='Truncate examples to this max length') 313 | # Loading 314 | parser.add_argument('--load-dir', '-L', help='Where to load checkpoint') 315 | parser.add_argument('--load-ckpt', type=int, default=None, 316 | help='Which checkpoint to load') 317 | # Other 318 | parser.add_argument('--rng-seed', type=int, default=123456) 319 | parser.add_argument('--torch-seed', type=int, default=1234567) 320 | 321 | if len(sys.argv) == 1: 322 | parser.print_help() 323 | sys.exit(1) 324 | return parser.parse_args() 325 | 326 | 327 | def main(): 328 | random.seed(OPTS.rng_seed) 329 | np.random.seed(OPTS.rng_seed) 330 | torch.manual_seed(OPTS.torch_seed) 331 | torch.backends.cudnn.deterministic = True 332 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 333 | task_class = TASK_CLASSES[OPTS.task] 334 | print('Loading dataset.') 335 | if not os.path.exists(OPTS.out_dir): 336 | os.makedirs(OPTS.out_dir) 337 | with open(os.path.join(OPTS.out_dir, 'log.txt'), 'w') as f: 338 | print(sys.argv, file=f) 339 | print(OPTS, file=f) 340 | if OPTS.data_cache_dir: 341 | if not os.path.exists(OPTS.data_cache_dir): 342 | os.makedirs(OPTS.data_cache_dir) 343 | train_data, dev_data, word_mat, attack_surface = task_class.load_datasets(device, OPTS) 344 | print('Initializing model.') 345 | model = task_class.load_model(word_mat, device, OPTS) 346 | if OPTS.num_epochs > 0: 347 | augmenter = None 348 | if OPTS.augment_by: 349 | augmenter = task_class.DataAugmenter(OPTS.augment_by) 350 | train(task_class, model, train_data, OPTS.num_epochs, OPTS.learning_rate, device, 351 | dev_data=dev_data, cert_frac=OPTS.cert_frac, initial_cert_frac=OPTS.initial_cert_frac, 352 | cert_eps=OPTS.cert_eps, initial_cert_eps=OPTS.initial_cert_eps, batch_size=OPTS.batch_size, 353 | epochs_per_save=OPTS.epochs_per_save, augmenter=augmenter, clip_grad_norm=OPTS.clip_grad_norm, 354 | weight_decay=OPTS.weight_decay, full_train_epochs=OPTS.full_train_epochs, non_cert_train_epochs=OPTS.non_cert_train_epochs, save_best_only=OPTS.save_best_only) 355 | print('Training finished.') 356 | print('Testing model.') 357 | if not OPTS.adv_only: 358 | train_results = test(task_class, model, 'Train', train_data, device, 359 | batch_size=OPTS.batch_size) 360 | adversary = None 361 | if OPTS.adversary == 'exhaustive': 362 | adversary = task_class.ExhaustiveAdversary(attack_surface) 363 | elif OPTS.adversary == 'greedy': 364 | adversary = task_class.GreedyAdversary(attack_surface, num_epochs=OPTS.adv_num_epochs, 365 | num_tries=OPTS.adv_num_tries) 366 | elif OPTS.adversary == 'genetic': 367 | adversary = task_class.GeneticAdversary(attack_surface, num_iters=OPTS.adv_num_epochs, 368 | pop_size=OPTS.adv_pop_size) 369 | dev_results = test(task_class, model, 'Dev', dev_data, device, 370 | adversary=adversary, batch_size=OPTS.batch_size) 371 | results = { 372 | 'train': train_results, 373 | 'dev': dev_results 374 | } 375 | with open(os.path.join(OPTS.out_dir, 'test_results.json'), 'w') as f: 376 | json.dump(results, f) 377 | else: 378 | adversary = None 379 | if OPTS.adversary == 'exhaustive': 380 | adversary = task_class.ExhaustiveAdversary(attack_surface) 381 | elif OPTS.adversary == 'greedy': 382 | adversary = task_class.GreedyAdversary(attack_surface, num_epochs=OPTS.adv_num_epochs, 383 | num_tries=OPTS.adv_num_tries) 384 | elif OPTS.adversary == 'genetic': 385 | adversary = task_class.GeneticAdversary(attack_surface, num_iters=OPTS.adv_num_epochs, 386 | pop_size=OPTS.adv_pop_size) 387 | test(task_class, model, 'Dev', dev_data, device, adversary=adversary, batch_size=OPTS.batch_size) 388 | 389 | if __name__ == '__main__': 390 | OPTS = parse_args() 391 | main() 392 | -------------------------------------------------------------------------------- /src/vocabulary.py: -------------------------------------------------------------------------------- 1 | """A basic vocabulary class.""" 2 | import collections 3 | import os 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | UNK_TOKEN = '' 9 | UNK_INDEX = 0 10 | 11 | NULL_TOKEN = '' 12 | NULL_INDEX = 1 13 | 14 | GLOVE_DIR = 'data/glove' 15 | GLOVE_CONFIGS = { 16 | '6B.50d': {'size': 50, 'lines': 400000}, 17 | '840B.300d': {'size': 300, 'lines': 2196017} 18 | } 19 | 20 | 21 | class Vocabulary(object): 22 | @classmethod 23 | def read_word_vecs(cls, word_set, glove_dir, glove_name, device, normalize=False, prepend_null=False): 24 | vocab = cls(prepend_null=prepend_null) 25 | glove_config = GLOVE_CONFIGS[glove_name] 26 | vecs = [np.zeros(glove_config['size'])] # UNK embedding, won't be used 27 | if prepend_null: 28 | vecs.append(np.zeros((300))) # NULL embedding 29 | found = 0 30 | fn = os.path.join(glove_dir, 'glove.%s.txt' % glove_name) 31 | print('Reading GloVe vectors from %s...' % fn) 32 | with open(fn) as f: 33 | for i, line in tqdm(enumerate(f), total=glove_config['lines']): 34 | toks = line.strip().split(' ') 35 | word = toks[0] 36 | if word in word_set and word not in vocab: 37 | found += 1 38 | vocab.add_word_hard(word) 39 | vecs.append(np.array([float(x) for x in toks[1:]])) 40 | print('Found %d/%d words in %s' % (found, len(word_set), fn)) 41 | word_mat = torch.tensor(vecs, dtype=torch.float, device=device) 42 | if normalize: 43 | word_mat = word_mat / word_mat.norm(dim=-1, keepdim=True) 44 | return vocab, word_mat 45 | 46 | def __init__(self, unk_threshold=0, prepend_null=False): 47 | """Initialize the vocabulary. 48 | 49 | Args: 50 | unk_threshold: words with <= this many counts will be considered . 51 | prepend_null: if True index 1 will be 52 | """ 53 | self.unk_threshold = unk_threshold 54 | self.counts = collections.Counter() 55 | self.word2index = {UNK_TOKEN: UNK_INDEX} 56 | self.word_list = [UNK_TOKEN] 57 | if prepend_null: 58 | self.word2index[NULL_TOKEN] = NULL_INDEX 59 | self.word_list.append(NULL_TOKEN) 60 | 61 | def add_word(self, word, count=1): 62 | """Add a word (may still map to UNK if it doesn't pass unk_threshold).""" 63 | self.counts[word] += count 64 | if word not in self.word2index and self.counts[word] > self.unk_threshold: 65 | index = len(self.word_list) 66 | self.word2index[word] = index 67 | self.word_list.append(word) 68 | 69 | def add_words(self, words): 70 | for w in words: 71 | self.add_word(w) 72 | 73 | def add_sentence(self, sentence): 74 | self.add_words(sentence.split(' ')) 75 | 76 | def add_sentences(self, sentences): 77 | for s in sentences: 78 | self.add_sentence(s) 79 | 80 | def add_word_hard(self, word): 81 | """Add word, make sure it is not UNK.""" 82 | self.add_word(word, count=(self.unk_threshold+1)) 83 | 84 | def get_word(self, index): 85 | return self.word_list[index] 86 | 87 | def get_index(self, word): 88 | if word in self.word2index: 89 | return self.word2index[word] 90 | return UNK_INDEX 91 | 92 | def indexify_sentence(self, sentence): 93 | return [self.get_index(w) for w in sentence.split(' ')] 94 | 95 | def indexify_list(self, elems): 96 | return [self.get_index(w) for w in elems] 97 | 98 | def recover_sentence(self, indices): 99 | return ' '.join(self.get_word(i) for i in indices) 100 | 101 | def has_word(self, word): 102 | return word in self.word2index 103 | 104 | def __contains__(self, word): 105 | return self.has_word(word) 106 | 107 | def size(self): 108 | # Report number of words that have been assigned an index 109 | return len(self.word2index) 110 | 111 | def __len__(self): 112 | return self.size() 113 | 114 | def __iter__(self): 115 | return iter(self.word_list) 116 | --------------------------------------------------------------------------------