├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data ├── dictionary.pkl ├── download.sh ├── glove6b_init_300d.npy ├── preprocess-features.py └── preprocess-vocab.py ├── eval-acc.py ├── fig └── overview.png ├── log.txt ├── main.py ├── merge_trainval_adv.py ├── requirements.txt ├── run.sh ├── sea_flip_rate.py ├── seada ├── adversarial_vqa.py ├── attacks.py ├── butd │ ├── baseline_model.py │ ├── reuse_modules.py │ └── word_embedding.py ├── data.py ├── sea │ ├── onmt_model.py │ ├── paraphrase_scorer.py │ ├── replace_rules.py │ └── translation_models │ │ └── .gitignore └── utils.py ├── sort_para.py └── view-log.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pth 4 | *.json 5 | *.jpg 6 | *.pt 7 | *.h5 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | cache/ 35 | 36 | # Code Editors 37 | .vscode 38 | .idea 39 | 40 | # Code linters 41 | .mypy_caches 42 | *.model 43 | *.tsv 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ruixue Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Equivalent Adversarial Data Augmentation for Visual Question Answering-Pytorch 2 | 3 | #### *Working in Progress* 4 | 5 | This repository corresponds to the ECCV 2020 paper *Semantic Equivalent Adversarial Data Augmentation for Visual Question Answering*. 6 | 7 | ![](fig/overview.png) 8 | 9 | ## Dependencies 10 | 11 | You may need at least **1 GPU** with **11GB memory** for training, and **200GB free disk space** for storing VQAv2 dataset. We strongly recommend to use a SSD drive to guarantee high-speed I/O. 12 | 13 | - Python 3.6 14 | - pytorch = 1.0 15 | - torchvision 0.2 16 | - h5py 2.7 17 | - tqdm 4.19 18 | 19 | ## Installation 20 | 21 | 1. ``` 22 | git clone https://github.com/zaynmi/seada-vqa.git 23 | ``` 24 | 25 | 2. We recommend to install everything in an Anaconda environment. 26 | 27 | ``` 28 | conda create -n seada python=3.6 29 | source activate seada 30 | ``` 31 | 32 | 3. Install Pytorch 1.0 and torchvision 33 | 34 | ``` 35 | conda install pytorch=1.0 torchvision cudatoolkit=10.0 -c pytorch 36 | ``` 37 | 38 | 4. Install other dependencies as follows: 39 | 40 | ``` 41 | pip install -r requirements.txt 42 | python -m spacy download en 43 | ``` 44 | 45 | 5. Install [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) for generating paraphrases, it allows to install `onmt` package in your environment: 46 | 47 | ``` 48 | git clone https://github.com/zaynmi/OpenNMT-py.git 49 | cd OpenNMT-py 50 | python setup.py install 51 | cd .. 52 | ``` 53 | 54 | 6. Download and unpack [the translation models](https://drive.google.com/open?id=1b2upZvq5kM0lN0T7YaAY30xRdbamuk9y) into the `seada/sea/translation_models` folder. You'll get four `.pt` models. 55 | 56 | ## Prepare Dataset (Follow [Cyanogenoid/vqa-counting](https://github.com/Cyanogenoid/vqa-counting)) 57 | 58 | - In the `data` directory, execute `./download.sh` to download VQA v2 and the bottom-up-top-down features. 59 | - Prepare the data by running 60 | 61 | ``` 62 | python data/preprocess-features.py 63 | python data/preprocess-vocab.py 64 | ``` 65 | 66 | This creates an `h5py` database (95 GiB) containing the object proposal features and a vocabulary for questions and answers at the locations specified in `config.py`. It is strongly recommended to put database in SSD. 67 | 68 | ## Training 69 | 70 | ### Step 1: Generating the paraphrases of questions 71 | 72 | ``` 73 | python main.py --attack_only --attack_mode q --attack_al sea --attacked_checkpoint {your_trained_model}.pth --fliprate 0.3 --topk 2 --paraphrase_data train 74 | ``` 75 | 76 | This would generate paraphrases of train set with top-2 semantic similarity score and 30% flip rate considering `{your_trained_model}.pth` ([A BUTD model](https://drive.google.com/file/d/1mXm9R968zxzWz8GYkpRnn3k4yzgwcXdz/view?usp=sharing)), and store them in `config.paraphrase_save_path`. Similarly, by setting `--paraphrase_data val`, you can get paraphrases of val set. Don't forget to change `config.paraphrase_save_path` accordingly. 77 | 78 | In our paper, we didn't specify the flip rate , topk and attacked_checkpoint (`--fliprate 0, --topk 1`), which means we simply use paraphrases with top-1 semantic similarity score. 79 | 80 | There is another step left. We need to sort the generated paraphrases in the same order with annotations file. The script is in `sort_para.py` 81 | 82 | ### Step 2: Adversarial training 83 | 84 | - **Option-1**. Use both visual adversarial examples and paraphrases to augment data. 85 | 86 | ``` 87 | python main.py --advtrain --attack_al ifgsm,sea --attack_mode vq --attacked_checkpoint {checkpoint_you_attack_when_eval}.pth --resume {your_partial_trained_model}.pth 88 | ``` 89 | 90 | - **Option-2**. Use visual adversarial examples to augment data. 91 | 92 | ``` 93 | python main.py --advtrain --attack_al ifgsm --attack_mode v --attacked_checkpoint {checkpoint_you_attack_when_eval}.pth --resume {your_partial_trained_model}.pth 94 | ``` 95 | 96 | - **Option-3**. Use paraphrases to augment data. 97 | 98 | ``` 99 | python main.py --advtrain --attack_al sea --attack_mode q --attacked_checkpoint {checkpoint_you_attack_when_eval}.pth --resume {your_partial_trained_model}.pth 100 | ``` 101 | 102 | `--attacked_checkpoint` is optional, which allows you to evaluate the performance of adversarially trained model defense against adversarial examples generated by `{checkpoint_you_attack_when_eval}.pth` 103 | 104 | If you want to train with train and val set, add `--advtrain_data trainval` 105 | 106 | ## Evaluation 107 | 108 | - Generate `.json` file for you to upload to on-line evaluation server. The result file is specified in `config.result_json_path`. 109 | 110 | ``` 111 | python main.py --test_advtrain --checkpoint {your_trained_model}.pth 112 | ``` 113 | 114 | - Or you can evaluate on the val set. `--attacked_checkpoint` is optional and if it is declared, you would see the performance of defense. 115 | 116 | ``` 117 | python main.py --eval_advtrain --checkpoint {your_trained_model}.pth --attack_al ifgsm --attack_mode v --attacked_checkpoint {checkpoint_you_attack_when_eval}.pth 118 | ``` 119 | 120 | ## Performance of the model when being attacked 121 | 122 | How our model behaves when attacked by the attackers is of great concern to us too. You can use 123 | 124 | ``` 125 | python main.py --attack_only --attack_mode v --attack_al pgd --alpha 0.5 --iteration 6 --epsilon 5 --attacked_checkpoint {checkpoint_being_attacked}.pth 126 | ``` 127 | 128 | All the attackers act as a white-box attacker. 129 | 130 | ## License 131 | 132 | The code is released under the [MIT License](https://github.com/zaynmi/semantic-equivalent-da-for-vqa/blob/master/LICENSE) 133 | 134 | ## Citing 135 | 136 | If this repository is helpful for your research, we'd really appreciate it if you could cite the following paper: 137 | 138 | ``` 139 | @inproceedings{tang2020semantic, 140 | title={Semantic Equivalent Adversarial Data Augmentation for Visual Question Answering}, 141 | author={Tang, Ruixue and Ma,Chao and Zhang, Wei Emma and Wu, Qi and Yang, Xiaokang}, 142 | booktitle={European Conference on Computer Vision (ECCV)}, 143 | year={2020} 144 | } 145 | ``` 146 | 147 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # paths 2 | qa_path = '/home/tang/attack_on_VQA2.0-Recent-Approachs-2018/data' # directory containing the question and annotation jsons 3 | bottom_up_trainval_path = '/home/tang/trainval' # directory containing the .tsv file(s) with bottom up features 4 | bottom_up_test_path = '/home/tang/test2015' # directory containing the .tsv file(s) with bottom up features 5 | preprocessed_trainval_path = 'data/genome-trainval.h5' # path where preprocessed features from the trainval split are saved to and loaded from 6 | preprocessed_test_path = '/media/tang/新加卷/VQAv2/genome-test.h5' # path where preprocessed features from the test split are saved to and loaded from 7 | vocabulary_path = '/home/tang/attack_on_VQA2.0-Recent-Approachs-2018/data/vocab.json' # path where the used vocabularies for question and answers are saved to 8 | glove_index = 'data/dictionary.pkl' 9 | result_json_path = 'results.json' # the path to save the test json that can be uploaded to vqa2.0 online evaluation server 10 | paraphrase_save_path = 'data/v2_OpenEnded_mscoco_train2014_questions_adv.json' 11 | 12 | task = 'OpenEnded' 13 | dataset = 'mscoco' 14 | 15 | test_split = 'test2015' # always 'test2015' since from 2018, vqa online evaluation server requires to upload entire test2015 result even for test-dev split 16 | 17 | # preprocess config 18 | output_size = 100 # max number of object proposals per image 19 | output_features = 2048 # number of features in each object proposal 20 | 21 | ################################################################### 22 | # Default Setting for All Model 23 | ################################################################### 24 | # training config 25 | epochs = 23 26 | batch_size = 256 27 | initial_lr = 1e-3 28 | lr_decay_step = 2 29 | lr_decay_rate = 0.25 30 | lr_halflife = 50000 # for scheduler (counting) 31 | data_workers = 4 32 | max_answers = 3129 33 | max_q_length = 666 # question_length = min(max_q_length, max_length_in_dataset) 34 | clip_value = 0.25 35 | v_feat_norm = False # Only useful in learning to count 36 | print_gradient = False 37 | normalize_box = False 38 | seed = 5225 39 | weight_decay = 0.0 40 | 41 | model_type = 'baseline' # "Bottom-up top-down" 42 | 43 | optim_method = 'Adamax' # used in "Bottom-up top-down" 44 | #optim_method = 'Adam' # used in "Learning to count objects", set initial_lr to 1.5e-3 45 | 46 | schedule_method = 'warm_up' 47 | #schedule_method = 'batch_decay' 48 | 49 | loss_method = 'binary_cross_entropy_with_logits' 50 | #loss_method = 'soft_cross_entropy' 51 | #loss_method = 'KL_divergence' 52 | #loss_method = 'multi_label_soft_margin' 53 | 54 | gradual_warmup_steps = [1.0 * initial_lr, 1.0 * initial_lr, 2.0 * initial_lr, 2.0 * initial_lr] 55 | lr_decay_epochs = range(10, 100, lr_decay_step) 56 | 57 | ################################################################### 58 | # Detailed Setting for Each Model 59 | ################################################################### 60 | 61 | # "Bottom-up top-down" 62 | # baseline Setting 63 | if model_type == 'baseline': 64 | loss_method = 'binary_cross_entropy_with_logits' 65 | gradual_warmup_steps = [0.5 * initial_lr, 1.0 * initial_lr, 1.5 * initial_lr, 2.0 * initial_lr] 66 | 67 | 68 | def print_param(): 69 | print('--------------------------------------------------') 70 | print('Num obj: ', output_size) 71 | print('Num epochs: ', epochs) 72 | print('Batch size: ', batch_size) 73 | print('Model type: ', model_type) 74 | print('Optimization Method: ', optim_method) 75 | print('Schedule Method: ', schedule_method) 76 | print('Loss Method: ', loss_method) 77 | print('Clip Value: ', clip_value) 78 | print('Init LR: ', initial_lr) 79 | print('LR decay step: ', lr_decay_step) 80 | print('LR decay rate: ', lr_decay_rate) 81 | print('LR half life: ', lr_halflife) 82 | print('Normalize visual feature: ', v_feat_norm) 83 | print('Print Gradient: ', print_gradient) 84 | print('Normalize Box Size: ', normalize_box) 85 | print('Max answer choice: ', max_answers) 86 | print('Manually set max question lenght: ', max_q_length) 87 | print('Random Seed: ', seed) 88 | print('gradual_warmup_steps: ', gradual_warmup_steps) 89 | print('Weight Decay: ', weight_decay) 90 | print('--------------------------------------------------') 91 | -------------------------------------------------------------------------------- /data/dictionary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaynmi/seada-vqa/f121fb3e8fee8af5f1935a7526f19e0d884bd95b/data/dictionary.pkl -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # questions 4 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip 5 | 6 | # answers 7 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip 8 | 9 | # balanced pairs 10 | wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Complementary_Pairs_Train_mscoco.zip https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Complementary_Pairs_Val_mscoco.zip 11 | 12 | # bottom up features (https://github.com/peteanderson80/bottom-up-attention) 13 | wget https://imagecaption.blob.core.windows.net/imagecaption/trainval.zip https://imagecaption.blob.core.windows.net/imagecaption/test2015.zip 14 | ## alternative bottom-up features: 36 fixed proposals per image instead of 10--100 adaptive proposals per image. 15 | #wget https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip https://imagecaption.blob.core.windows.net/imagecaption/test2015_36.zip 16 | 17 | unzip "*.zip" 18 | -------------------------------------------------------------------------------- /data/glove6b_init_300d.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaynmi/seada-vqa/f121fb3e8fee8af5f1935a7526f19e0d884bd95b/data/glove6b_init_300d.npy -------------------------------------------------------------------------------- /data/preprocess-features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import base64 4 | import os 5 | import csv 6 | import itertools 7 | 8 | csv.field_size_limit(sys.maxsize) 9 | 10 | import h5py 11 | import torch.utils.data 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | import config 16 | import data 17 | import utils 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--test', action='store_true') 23 | args = parser.parse_args() 24 | 25 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] 26 | 27 | features_shape = ( 28 | 82783 + 40504 if not args.test else 81434, # number of images in trainval or in test 29 | config.output_features, 30 | config.output_size, 31 | ) 32 | boxes_shape = ( 33 | features_shape[0], 34 | 4, 35 | config.output_size, 36 | ) 37 | 38 | if not args.test: 39 | path = config.preprocessed_trainval_path 40 | else: 41 | path = config.preprocessed_test_path 42 | with h5py.File(path, libver='latest') as fd: 43 | features = fd.create_dataset('features', shape=features_shape, dtype='float32') 44 | boxes = fd.create_dataset('boxes', shape=boxes_shape, dtype='float32') 45 | coco_ids = fd.create_dataset('ids', shape=(features_shape[0],), dtype='int32') 46 | widths = fd.create_dataset('widths', shape=(features_shape[0],), dtype='int32') 47 | heights = fd.create_dataset('heights', shape=(features_shape[0],), dtype='int32') 48 | 49 | readers = [] 50 | if not args.test: 51 | path = config.bottom_up_trainval_path 52 | else: 53 | path = config.bottom_up_test_path 54 | for filename in os.listdir(path): 55 | if not '.tsv' in filename: 56 | continue 57 | full_filename = os.path.join(path, filename) 58 | fd = open(full_filename, 'r') 59 | reader = csv.DictReader(fd, delimiter='\t', fieldnames=FIELDNAMES) 60 | readers.append(reader) 61 | 62 | reader = itertools.chain.from_iterable(readers) 63 | for i, item in enumerate(tqdm(reader, total=features_shape[0])): 64 | coco_ids[i] = int(item['image_id']) 65 | widths[i] = int(item['image_w']) 66 | heights[i] = int(item['image_h']) 67 | 68 | buf = base64.decodestring(item['features'].encode('utf8')) 69 | array = np.frombuffer(buf, dtype='float32') 70 | array = array.reshape((-1, config.output_features)).transpose() 71 | features[i, :, :array.shape[1]] = array 72 | 73 | buf = base64.decodestring(item['boxes'].encode('utf8')) 74 | array = np.frombuffer(buf, dtype='float32') 75 | array = array.reshape((-1, 4)).transpose() 76 | boxes[i, :, :array.shape[1]] = array 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /data/preprocess-vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import Counter 4 | import itertools 5 | 6 | import config 7 | import data 8 | import utils 9 | 10 | 11 | def extract_vocab(iterable, top_k=None, start=0): 12 | """ Turns an iterable of list of tokens into a vocabulary. 13 | These tokens could be single answers or word tokens in questions. 14 | """ 15 | all_tokens = itertools.chain.from_iterable(iterable) 16 | counter = Counter(all_tokens) 17 | if top_k: 18 | most_common = counter.most_common(top_k) 19 | most_common = (t for t, c in most_common) 20 | else: 21 | most_common = counter.keys() 22 | # descending in count, then lexicographical order 23 | tokens = sorted(most_common, key=lambda x: (counter[x], x), reverse=True) 24 | vocab = {t: i for i, t in enumerate(tokens, start=start)} 25 | return vocab 26 | 27 | 28 | def main(): 29 | questions = utils.path_for(train=True, question=True) 30 | answers = utils.path_for(train=True, answer=True) 31 | 32 | with open(questions, 'r') as fd: 33 | questions = json.load(fd) 34 | with open(answers, 'r') as fd: 35 | answers = json.load(fd) 36 | 37 | questions = list(data.prepare_questions(questions)) 38 | answers = list(data.prepare_answers(answers)) 39 | 40 | question_vocab = extract_vocab(questions, start=1) 41 | answer_vocab = extract_vocab(answers, top_k=config.max_answers) 42 | 43 | vocabs = { 44 | 'question': question_vocab, 45 | 'answer': answer_vocab, 46 | } 47 | with open(config.vocabulary_path, 'w') as fd: 48 | json.dump(vocabs, fd) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /eval-acc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os.path 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import utils 10 | import config 11 | 12 | 13 | q_path = utils.path_for(val=True, question=True) 14 | with open(q_path, 'r') as fd: 15 | q_json = json.load(fd) 16 | a_path = utils.path_for(val=True, answer=True) 17 | with open(a_path, 'r') as fd: 18 | a_json = json.load(fd) 19 | with open(os.path.join(config.qa_path, 'v2_mscoco_val2014_complementary_pairs.json')) as fd: 20 | pairs = json.load(fd) 21 | 22 | question_list = q_json['questions'] 23 | question_ids = [q['question_id'] for q in question_list] 24 | questions = [q['question'] for q in question_list] 25 | answer_list = a_json['annotations'] 26 | categories = [a['answer_type'] for a in answer_list] # {'yes/no', 'other', 'number'} 27 | accept_condition = { 28 | 'number': (lambda x: id_to_cat[x] == 'number'), 29 | 'yes/no': (lambda x: id_to_cat[x] == 'yes/no'), 30 | 'other': (lambda x: id_to_cat[x] == 'other'), 31 | 'count': (lambda x: id_to_question[x].lower().startswith('how many')), 32 | 'all': (lambda x: True), 33 | } 34 | 35 | statistics = defaultdict(list) 36 | for path in sys.argv[1:]: 37 | log = torch.load(path) 38 | ans = log['eval'] 39 | d = [(acc, ans) for (acc, ans, _) in sorted(zip(ans['accuracies'], ans['answers'], ans['idx']), key=lambda x: x[-1])] 40 | accs = map(lambda x: x[0], d) 41 | id_to_cat = dict(zip(question_ids, categories)) 42 | id_to_acc = dict(zip(question_ids, accs)) 43 | id_to_question = dict(zip(question_ids, questions)) 44 | 45 | for name, f in accept_condition.items(): 46 | for on_pairs in [False, True]: 47 | acc = [] 48 | if on_pairs: 49 | for a, b in pairs: 50 | if not (f(a) and f(b)): 51 | continue 52 | if id_to_acc[a] == id_to_acc[b] == 1: 53 | acc.append(1) 54 | else: 55 | acc.append(0) 56 | else: 57 | for x in question_ids: 58 | if not f(x): 59 | continue 60 | acc.append(id_to_acc[x]) 61 | acc = np.mean(acc) 62 | statistics[name, 'pair' if on_pairs else 'single'].append(acc) 63 | 64 | for (name, pairness), accs in statistics.items(): 65 | mean = np.mean(accs) 66 | std = np.std(accs, ddof=1) 67 | print('{} ({})\t: {:.2f}% +- {}'.format(name, pairness, 100 * mean, 100 * std)) 68 | -------------------------------------------------------------------------------- /fig/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaynmi/seada-vqa/f121fb3e8fee8af5f1935a7526f19e0d884bd95b/fig/overview.png -------------------------------------------------------------------------------- /log.txt: -------------------------------------------------------------------------------- 1 | baseline_train_pgd,sea_vq_e0.3_it2_a0.5_w1_ad10_ld15_ade15_fr1 2 | 0.6507568 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import config 3 | import torch.backends.cudnn as cudnn 4 | import torch 5 | from seada.adversarial_vqa import AdversarialAttackVQA 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore') 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('name', nargs='*') 14 | parser.add_argument('--attack_only', action='store_true') 15 | parser.add_argument('--generate_adv_example', action='store_true') 16 | parser.add_argument('--attacked_checkpoint', type=str, help='must be announced when attack only') 17 | parser.add_argument('--attack_al', type=str, default='ifgsm', help='attack algorithm') 18 | parser.add_argument('--checkpoint', type=str) 19 | parser.add_argument('--resume', type=str) 20 | parser.add_argument('--attack_mode', default='v', choices=['v', 'q', 'vq', 'no']) 21 | parser.add_argument('--advtrain', action='store_true') 22 | parser.add_argument('--vqacp', action='store_true') 23 | parser.add_argument('--advtrain_data', default='train', choices=['train', 'trainval']) 24 | parser.add_argument('--eval_advtrain', action='store_true') 25 | parser.add_argument('--test_advtrain', action='store_true') 26 | parser.add_argument('--advloss_w', type=int, default=1) 27 | parser.add_argument('--samples_frac', type=float, default=1) 28 | parser.add_argument('--adv_delay', type=int, default=10) 29 | parser.add_argument('--adv_end', type=int, default=15) 30 | parser.add_argument('--epsilon', type=float, default=0.3) 31 | parser.add_argument('--alpha',type=float, default=0.5) 32 | parser.add_argument('--iteration', type=int, default=2) 33 | parser.add_argument('--lr_decay', type=int, default=15) 34 | parser.add_argument('--topk', type=int, default=1) 35 | parser.add_argument('--fliprate', type=float, default=0) 36 | parser.add_argument('--paraphrase_data', type=str, default='train', choices=['train', 'val', 'test']) 37 | parser.add_argument('--describe', type=str, default='describe your setting') 38 | args = parser.parse_args() 39 | if args.attack_only: 40 | args.generate_adv_example = True 41 | 42 | if args.eval_advtrain or args.advtrain: 43 | if args.attacked_checkpoint: 44 | args.generate_adv_example = True 45 | 46 | if args.test_advtrain: 47 | args.attacked_checkpoint = False 48 | args.generate_adv_example = False 49 | 50 | print('-' * 50) 51 | print(args) 52 | config.print_param() 53 | 54 | # set mannual seed 55 | torch.manual_seed(config.seed) 56 | torch.cuda.manual_seed(config.seed) 57 | # ----------Tasks------------------- 58 | attackvqa = AdversarialAttackVQA(args) 59 | if args.attack_only: 60 | attackvqa.attack(attackvqa.val_loader) 61 | if args.advtrain: 62 | attackvqa.advsarial_training() 63 | 64 | if args.eval_advtrain: 65 | #r = attackvqa.evaluate(attackvqa.val_loader) 66 | # you can save result by calling: 67 | attackvqa.save_result_json(attackvqa.val_loader) 68 | if args.test_advtrain: 69 | # r = attackvqa.evaluate(attackvqa.val_loader, has_answers=False) 70 | attackvqa.save_result_json(attackvqa.val_loader, has_answers=False) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() -------------------------------------------------------------------------------- /merge_trainval_adv.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | qtrainadv = json.load(open('data/v2_OpenEnded_mscoco_train2014_questions_adv.json', 'r')) 4 | qvaladv = json.load(open('data/v2_OpenEnded_mscoco_val2014_questions_adv.json', 'r')) 5 | qtrainval = json.load(open('data/v2_OpenEnded_mscoco_trainval2014_questions.json', 'r')) 6 | 7 | qmerge = {'questions': qtrainadv['questions'] + qvaladv['questions']} 8 | 9 | print(len(qtrainval['questions'])) 10 | print(len(qmerge['questions'])) 11 | 12 | with open('data/v2_OpenEnded_mscoco_trainval2014_questions_adv.json', 'w') as f: 13 | json.dump(qmerge, f) 14 | 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | editdistance==0.5.3 3 | h5py==2.8.0 4 | keras 5 | regex==2017.11.9 6 | requests==2.22.0 7 | retrying==1.3.3 8 | six==1.14.0 9 | spacy==2.2.2 10 | tensorboard==1.14.0 11 | tensorboardX==1.8 12 | tensorflow==1.14.0 13 | tensorflow-estimator==1.14.0 14 | torchtext==0.1.1 15 | tornado==6.0.3 16 | tqdm==4.34.0 17 | yacs==0.1.6 18 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #python new_main.py --attack_only --attack_mode v --attack_al pgd --alpha 0.5 --iteration 6 --epsilon 5 --attacked_checkpoint logs/bs256.pth 2 | #python main.py --advtrain --attack_mode vq --attack_al pgd,sea --resume /home/tang/attack_on_VQA2.0-Recent-Approachs-2018/logs/baseline_10.pth 3 | #python main.py --attack_only --attack_mode q --attack_al sea --attacked_checkpoint /home/tang/attack_on_VQA2.0-Recent-Approachs-2018/logs/bs256.pth --paraphrase_data train 4 | 5 | python main.py --eval_advtrain --checkpoint logs/baseline_train_pgd,sea_vq_e0.3_it2_a0.5_w1_ad10_ld15_ade15_fr1.pth --attack_al ifgsm --attack_mode v --attacked_checkpoint /home/tang/attack_on_VQA2.0-Recent-Approachs-2018/logs/bs256.pth -------------------------------------------------------------------------------- /sea_flip_rate.py: -------------------------------------------------------------------------------- 1 | from seada.butd import baseline_model as model 2 | import os.path 3 | import operator 4 | import copy 5 | import json 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.optim.lr_scheduler as lr_scheduler 13 | from torch.nn.utils import clip_grad_norm_ 14 | import config 15 | from seada import data 16 | from seada import utils 17 | 18 | val_loader = data.get_loader(val=True, sea=True) 19 | logs = torch.load('logs/bs256.pth') 20 | question_keys = logs['vocab']['question'].keys() 21 | model = model.Net(question_keys) 22 | model = nn.DataParallel(model).cuda() 23 | model.module.load_state_dict(logs['weights']) 24 | 25 | model.eval() 26 | 27 | 28 | def sort_sample(order, *args): 29 | var_params = { 30 | 'requires_grad': False, 31 | } 32 | args = [[arg[q_len_a[1]] for q_len_a in order] for arg in args] 33 | args = [Variable(torch.stack(arg, dim=0).cuda(), **var_params) for arg in args] 34 | return args 35 | 36 | flips = 0 37 | total = 0 38 | tracker = utils.Tracker() 39 | loader = tqdm(val_loader, desc='{}'.format('val'), ncols=0) 40 | tracker_class, tracker_params = tracker.MeanMonitor, {} 41 | perturbed_acc_tracker = tracker.track('{}_advacc'.format('val'), tracker_class(**tracker_params)) 42 | acc_tracker = tracker.track('{}_acc'.format('val'), tracker_class(**tracker_params)) 43 | for v, q, q_adv, q_str, a, b, idx, v_mask, q_mask, q_mask_adv, image_id, q_id, q_len_adv, q_len in loader: 44 | var_params = { 45 | 'requires_grad': False, 46 | } 47 | v = Variable(v.cuda(), **var_params) 48 | q = Variable(q.cuda(), **var_params) 49 | a = Variable(a.cuda(), **var_params) 50 | b = Variable(b.cuda(), **var_params) 51 | q_len = Variable(q_len.cuda(), **var_params) 52 | v_mask = Variable(v_mask.cuda(), **var_params) 53 | q_mask = Variable(q_mask.cuda(), **var_params) 54 | answer = utils.process_answer(a) 55 | 56 | clean_out = model(v, b, q, v_mask, q_mask, q_len) 57 | clean_loss = utils.calculate_loss(answer, clean_out, method=config.loss_method) 58 | clean_acc, _ = utils.batch_accuracy(clean_out, answer) 59 | acc_tracker.append(clean_acc.mean()) 60 | 61 | q_lens = [(q_len_adv[i], i) for i in range(q_len_adv.shape[0])] 62 | q_lens = sorted(q_lens, key=lambda x: x[0], reverse=True) 63 | q_len_adv = [q_len_a[0] for q_len_a in q_lens] 64 | q_len_adv = Variable(torch.stack(q_len_adv, dim=0).cuda(), **var_params) 65 | v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, answer, clean_out_sorted = sort_sample(q_lens, v, b, 66 | q_adv, v_mask, 67 | q_mask_adv, 68 | answer, clean_out) 69 | 70 | perturbed_out = model(v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 71 | clean_logits = torch.max(clean_out_sorted, 1)[1].cpu().numpy() 72 | perturbed_logits = torch.max(perturbed_out, 1)[1].cpu().numpy() 73 | flips += sum(clean_logits != perturbed_logits) 74 | total += 256 75 | flip_rate = flips / total 76 | 77 | perturbed_loss = utils.calculate_loss(answer, perturbed_out, method=config.loss_method) 78 | perturbed_acc, _ = utils.batch_accuracy(perturbed_out, answer) 79 | 80 | perturbed_acc_tracker.append(perturbed_acc.mean()) 81 | fmt = '{:.4f}'.format 82 | loader.set_postfix(flip=fmt(flip_rate), advacc=fmt(perturbed_acc_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 83 | -------------------------------------------------------------------------------- /seada/adversarial_vqa.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path 3 | import operator 4 | import copy 5 | import json 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.optim.lr_scheduler as lr_scheduler 13 | from torch.nn.utils import clip_grad_norm_ 14 | import config 15 | from . import data 16 | from .attacks import FGSMAttack, IFGSMAttack, RandomNoise, SEA 17 | if config.model_type == 'baseline': 18 | from .butd import baseline_model as model 19 | from . import utils 20 | 21 | 22 | class AdversarialAttackVQA: 23 | def __init__(self, args): 24 | self.args = args 25 | if args.name: 26 | self.name = ' '.join(args.name) 27 | else: 28 | self.name = '%s_%s_%s_%s_e%s_it%d_a%s_w%s_ad%s_ld%s_ade%s_fr%s' % \ 29 | (config.model_type, args.advtrain_data, args.attack_al, args.attack_mode, args.epsilon, args.iteration, args.alpha, 30 | args.advloss_w, args.adv_delay, args.lr_decay, args.adv_end, args.samples_frac) 31 | self.target_name = os.path.join('logs', '{}.pth'.format(self.name)) 32 | self.src = open(os.path.join('seada/butd', config.model_type + '_model.py')).read() 33 | self.config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} 34 | 35 | self.attack_al = args.attack_al.split(',') 36 | 37 | self.attack_dict = {'fgsm': FGSMAttack(args.epsilon), 38 | 'ifgsm': IFGSMAttack(args.epsilon, args.iteration, args.alpha, False), 39 | 'pgd': IFGSMAttack(args.epsilon, args.iteration, args.alpha, True), 40 | 'noise': RandomNoise(args.epsilon), 41 | 'sea': SEA(fliprate=args.fliprate, topk=args.topk) if (not args.advtrain) and 'sea' in self.attack_al and len(self.attack_al) == 1 else None} 42 | if len(self.attack_al) == 1: 43 | self.adversarial = self.attack_dict[self.attack_al[0]] 44 | else: 45 | if self.attack_al[1] == 'sea': 46 | self.adversarial = self.attack_dict[self.attack_al[0]] 47 | else: 48 | pass 49 | 50 | #### generate adversarial example setting #### 51 | if args.generate_adv_example: 52 | if not args.attacked_checkpoint: 53 | raise ValueError('checkpoint must be provided when generate adversarial examples') 54 | logs = torch.load(args.attacked_checkpoint) 55 | self.question_keys = logs['vocab']['question'].keys() 56 | self.base_model = model.Net(self.question_keys) 57 | self.base_model = nn.DataParallel(self.base_model).cuda() 58 | 59 | self.base_model.module.load_state_dict(logs['weights']) 60 | if args.attack_only: 61 | if self.attack_dict['sea'] is None and 'sea' in self.attack_al: 62 | self.val_loader = data.get_loader(val=True, sea=True) 63 | elif 'sea' in self.attack_al: 64 | if self.args.paraphrase_data == 'train': 65 | self.val_loader = data.get_loader(train=True) 66 | elif self.args.paraphrase_data == 'val': 67 | self.val_loader = data.get_loader(val=True) 68 | else: 69 | self.val_loader = data.get_loader(test=True) 70 | self.adversarial.dataset = self.val_loader.dataset 71 | self.questions_adv_saver = [] 72 | else: 73 | self.val_loader = data.get_loader(val=True) 74 | for param in self.base_model.parameters(): 75 | param.requires_grad = False 76 | # if not args.advtrain: 77 | # self.adversarial.model = self.base_model 78 | # if args.attack_al == 'fgsm': 79 | # self.adversary = FGSMAttack(config.epsilon, self.base_model) 80 | # elif args.attack_al == 'ifgsm': 81 | # self.adversary = IFGSMAttack(config.epsilon, config.ifgsm_iteration, config.alpha, False, self.base_model) 82 | # elif args.attack_al == 'pgd': 83 | # self.adversary = IFGSMAttack(config.epsilon, config.ifgsm_iteration, config.alpha, True, 84 | # self.base_model) 85 | # self.fgsm = FGSMAttack(config.epsilon, self.base_model) 86 | # self.ifgsm = IFGSMAttack(config.epsilon, config.ifgsm_iteration, config.alpha, False, self.base_model) 87 | # self.tfgsm = TargetedFGSM(config.epsilon, self.base_model) 88 | if args.advtrain: 89 | print('will save to {}'.format(self.target_name)) 90 | cudnn.benchmark = True 91 | if args.resume: 92 | logs = torch.load(args.resume) 93 | # hacky way to tell the VQA classes that they should use the vocab without passing more params around 94 | data.preloaded_vocab = logs['vocab'] 95 | if args.advtrain_data == 'trainval': 96 | if 'sea' in self.attack_al: 97 | self.train_loader = data.get_loader(trainval=True, sea=True, vqacp=self.args.vqacp) 98 | else: 99 | self.train_loader = data.get_loader(trainval=True, vqacp=self.args.vqacp) 100 | else: 101 | if 'sea' in self.attack_al: 102 | self.train_loader = data.get_loader(train=True, sea=True, frac=args.samples_frac, vqacp=self.args.vqacp) 103 | self.val_loader = data.get_loader(val=True, sea=True, vqacp=self.args.vqacp) 104 | else: 105 | self.train_loader = data.get_loader(train=True, frac=args.samples_frac, vqacp=self.args.vqacp) 106 | self.val_loader = data.get_loader(val=True, vqacp=self.args.vqacp) 107 | if self.attack_dict['sea'] is not None: 108 | self.adversarial.dataset = self.train_loader.dataset 109 | self.question_keys = self.train_loader.dataset.vocab['question'].keys() if args.advtrain_data == 'trainval' else \ 110 | self.val_loader.dataset.vocab[ 111 | 'question'].keys() 112 | self.model = model.Net(self.question_keys) 113 | self.model = nn.DataParallel(self.model).cuda() 114 | # if args.resume: 115 | # print('loading weights from %s' % args.resume) 116 | # self.model.module.load_state_dict(logs['weights']) 117 | self.start_epoch = 0 118 | self.select_optim = optim.Adamax if (config.optim_method == 'Adamax') else optim.Adam 119 | self.optimizer = self.select_optim([p for p in self.model.parameters() if p.requires_grad], lr=config.initial_lr, 120 | weight_decay=config.weight_decay) 121 | self.scheduler = lr_scheduler.ExponentialLR(self.optimizer, 0.5 ** (1 / config.lr_halflife)) 122 | if args.resume: 123 | print('loading weights from %s' % args.resume) 124 | if config.model_type == 'counting': 125 | self.model.load_state_dict(logs['weights']) 126 | else: 127 | self.model.module.load_state_dict(logs['weights']) 128 | self.optimizer.load_state_dict(logs['optimizer']) 129 | if config.model_type == 'counting': 130 | self.scheduler.load_state_dict(logs['scheduler']) 131 | self.start_epoch = logs['epoch'] 132 | 133 | if args.eval_advtrain or args.test_advtrain: 134 | if args.checkpoint: 135 | logs = torch.load(args.checkpoint) 136 | # hacky way to tell the VQA classes that they should use the vocab without passing more params around 137 | data.preloaded_vocab = logs['vocab'] 138 | self.val_loader = data.get_loader(val=True, sea=True if 'sea' in self.attack_al else False) if args.eval_advtrain else data.get_loader(test=True) 139 | self.question_keys = self.val_loader.dataset.vocab['question'].keys() 140 | self.model = model.Net(self.question_keys) 141 | self.model = nn.DataParallel(self.model).cuda() 142 | if args.checkpoint: 143 | print('loading weights from %s' % args.checkpoint) 144 | self.model.module.load_state_dict(logs['weights']) 145 | 146 | self.tracker = utils.Tracker() 147 | 148 | def attack(self, loader): 149 | tracker_class, tracker_params = self.tracker.MeanMonitor, {} 150 | loader = tqdm(loader, desc='{} '.format(self.args.attack_al), ncols=0) 151 | loss_tracker = self.tracker.track('{}_loss'.format('attack'), tracker_class(**tracker_params)) 152 | acc_tracker = self.tracker.track('{}_acc'.format('before attack'), tracker_class(**tracker_params)) 153 | perturbed_acc_tracker = self.tracker.track('{}_acc'.format('after attack'), tracker_class(**tracker_params)) 154 | dist_tracker = self.tracker.track('{}_dist'.format('dist'), tracker_class(**tracker_params)) 155 | if len(self.attack_al) == 2: 156 | vqc_q_tracker = self.tracker.track('{}_acc'.format('after attack'), tracker_class(**tracker_params)) 157 | vqadv_q_tracker = self.tracker.track('{}_acc'.format('after attack'), tracker_class(**tracker_params)) 158 | vqc_qadv_tracker = self.tracker.track('{}_acc'.format('after attack'), tracker_class(**tracker_params)) 159 | vqadv_qadv_tracker = self.tracker.track('{}_acc'.format('after attack'), tracker_class(**tracker_params)) 160 | self.adversarial.model = self.base_model 161 | for v, q, q_adv, q_str, a, b, idx, v_mask, q_mask, q_mask_adv, image_id, q_id, q_len_adv, q_len in loader: 162 | var_params = { 163 | 'requires_grad': False, 164 | } 165 | v = v.cuda() 166 | q = Variable(q.cuda()) 167 | a = Variable(a.cuda()) 168 | b = Variable(b.cuda()) 169 | q_len = Variable(q_len.cuda()) 170 | v_mask = Variable(v_mask.cuda()) 171 | q_mask = Variable(q_mask.cuda()) 172 | answer = utils.process_answer(a) 173 | 174 | if self.args.attack_mode == 'q': 175 | if self.attack_al[0] == 'sea': 176 | clean_out = self.base_model(v, b, q, v_mask, q_mask, q_len) 177 | clean_logits = torch.max(clean_out, 1)[1].cpu().numpy() 178 | v, b, v_mask, q_adv, q_len_adv, q_mask_adv, answer, q_str_adv, image_id_adv, q_id_adv = self.adversarial.perturb((v, b, q, q_str, v_mask, q_mask, image_id, q_id, q_len), y=answer, oripred=clean_logits) 179 | perturbed_out = self.base_model(v, b, q_adv, v_mask, q_mask_adv, q_len_adv) 180 | self.save_q_adv(q_str_adv, image_id_adv, q_id_adv) 181 | dist = 0 182 | dist_tracker.append(dist) 183 | else: 184 | q_adv = self.adversarial.perturb((v, b, q, v_mask, q_mask, q_len), answer, perturb_q=True) 185 | perturbed_out = self.base_model(v, b, q, v_mask, q_mask, q_len, q_adv) 186 | dist = self.distance(q_adv, self.base_model.module.text.embedded.detach()) 187 | dist_tracker.append(dist.data) 188 | elif self.args.attack_mode == 'v': 189 | v_adv, acc, loss = self.adversarial.perturb((v, b, q, v_mask, q_mask, q_len), answer) 190 | perturbed_out = self.base_model(v_adv, b, q, v_mask, q_mask, q_len) 191 | dist = self.distance(v, v_adv) 192 | dist_tracker.append(dist.data) 193 | 194 | elif self.args.attack_mode == 'vq': # todo: v cooperate q 195 | q_lens = [(q_len_adv[i], i) for i in range(q_len_adv.shape[0])] 196 | q_lens = sorted(q_lens, key=lambda x: x[0], reverse=True) 197 | q_len_adv = [q_len_a[0] for q_len_a in q_lens] 198 | q_len_adv = Variable(torch.stack(q_len_adv, dim=0).cuda(), **var_params) 199 | v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, answer_sorted = self.sort_sample(q_lens, v, b, 200 | q_adv, v_mask, 201 | q_mask_adv, 202 | answer) 203 | 204 | v_qadv, _, _ = self.adversarial.perturb((v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv), answer_sorted) 205 | v_qc, _, _ = self.adversarial.perturb((v, b, q, v_mask, q_mask, q_len), answer) 206 | v_qadv = Variable(v_qadv.data, requires_grad=False) 207 | v_qc = Variable(v_qc.data, requires_grad=False) 208 | 209 | q_sorted, q_mask_sorted, q_len_sorted, v_qc_sorted = self.sort_sample(q_lens, q, q_mask, q_len, v_qc) 210 | out_vqadv_qadv = self.base_model(v_qadv, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 211 | vqadv_qadv_acc, _ = utils.batch_accuracy(out_vqadv_qadv, answer_sorted) 212 | vqadv_qadv_tracker.append(vqadv_qadv_acc.data.cpu().mean()) 213 | 214 | out_vqc_q = self.base_model(v_qc, b, q, v_mask, q_mask, q_len) 215 | vqc_q_acc, _ = utils.batch_accuracy(out_vqc_q, answer) 216 | vqc_q_tracker.append(vqc_q_acc.data.cpu().mean()) 217 | 218 | v_qadv_re = self.restore_order(q_lens, v_qadv)[0] 219 | out_vqadv_q = self.base_model(v_qadv_re, b, q, v_mask, q_mask, q_len) 220 | vqadv_q_acc, _ = utils.batch_accuracy(out_vqadv_q, answer) 221 | vqadv_q_tracker.append(vqadv_q_acc.data.cpu().mean()) 222 | 223 | out_vqc_qadv = self.base_model(v_qc_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 224 | vqc_qadv_acc, _ = utils.batch_accuracy(out_vqc_qadv, answer_sorted) 225 | vqc_qadv_tracker.append(vqc_qadv_acc.data.cpu().mean()) 226 | 227 | fmt = '{:.4f}'.format 228 | loader.set_postfix(vqc_q_acc=fmt(vqc_q_tracker.mean.value), 229 | vqadv_q_acc=fmt(vqadv_q_tracker.mean.value), 230 | vqc_qadv_acc=fmt(vqc_qadv_tracker.mean.value), 231 | vqadv_qadv_acc=fmt(vqadv_qadv_tracker.mean.value)) 232 | continue 233 | else: 234 | perturbed_out = self.base_model(v, b, q, v_mask, q_mask, q_len) 235 | dist = 0 236 | dist_tracker.append(dist) 237 | 238 | perturbed_acc, _ = utils.batch_accuracy(perturbed_out, answer) 239 | loss = utils.calculate_loss(answer, perturbed_out, method=config.loss_method) 240 | 241 | loss_tracker.append(loss.item()) 242 | # acc_tracker.append(acc.mean()) 243 | 244 | perturbed_acc_tracker.append(perturbed_acc.data.cpu().mean()) 245 | fmt = '{:.4f}'.format 246 | loader.set_postfix(loss=fmt(loss_tracker.mean.value),# acc=fmt(acc_tracker.mean.value), 247 | acc_after_attack=fmt(perturbed_acc_tracker.mean.value), 248 | distance=fmt(dist_tracker.mean.value)) 249 | if self.args.attack_al == 'sea': 250 | with open(config.paraphrase_save_path, 'w') as f: 251 | json.dump({'questions': self.questions_adv_saver}, f) 252 | if len(self.attack_al) == 1: 253 | f = open('attack_log.txt', 'a') 254 | f.write(self.name + '\n') 255 | f.write(str(fmt(perturbed_acc_tracker.mean.value))) 256 | f.write('\n') 257 | f.write(str(fmt(dist_tracker.mean.value))) 258 | f.write('\n') 259 | else: 260 | f = open('attack_log.txt', 'a') 261 | f.write(self.name + '\n') 262 | f.write('vqc_q: ' + str(fmt(vqc_q_tracker.mean.value)) + ' vqadv_q: ' + str(fmt(vqadv_q_tracker.mean.value)) 263 | + ' vqc_qadv: ' + str(fmt(vqc_qadv_tracker.mean.value)) + ' vqadv_qadv: ' + str( 264 | fmt(vqadv_qadv_tracker.mean.value))) 265 | f.write('\n') 266 | 267 | def advsarial_training(self): 268 | best_valid = 0 269 | lr_decay_epochs = range(self.args.lr_decay, 100, 2) 270 | for epoch in range(self.start_epoch, config.epochs): 271 | self.model.train() 272 | tracker_class, tracker_params = self.tracker.MovingMeanMonitor, {'momentum': 0.99} 273 | if epoch < len(config.gradual_warmup_steps) and config.schedule_method == 'warm_up': 274 | utils.set_lr(self.optimizer, config.gradual_warmup_steps[epoch]) 275 | utils.print_lr(self.optimizer, 'train', epoch) 276 | elif (epoch in lr_decay_epochs) and config.schedule_method == 'warm_up': 277 | utils.decay_lr(self.optimizer, config.lr_decay_rate) 278 | utils.print_lr(self.optimizer, 'train', epoch) 279 | else: 280 | utils.print_lr(self.optimizer, 'train', epoch) 281 | loader = tqdm(self.train_loader, desc='{} E{:03d}'.format('train', epoch), ncols=0) 282 | loss_tracker = self.tracker.track('{}_loss'.format('train'), tracker_class(**tracker_params)) 283 | acc_tracker = self.tracker.track('{}_acc'.format('train'), tracker_class(**tracker_params)) 284 | 285 | for v, q, q_adv, q_str, a, b, idx, v_mask, q_mask, q_mask_adv, image_id, q_id, q_len_adv, q_len in loader: 286 | var_params = { 287 | 'requires_grad': False, 288 | } 289 | v = Variable(v.cuda(), **var_params) 290 | q = Variable(q.cuda(), **var_params) 291 | a = Variable(a.cuda(), **var_params) 292 | b = Variable(b.cuda(), **var_params) 293 | q_len = Variable(q_len.cuda(), **var_params) 294 | v_mask = Variable(v_mask.cuda(), **var_params) 295 | q_mask = Variable(q_mask.cuda(), **var_params) 296 | 297 | out = self.model(v, b, q, v_mask, q_mask, q_len) 298 | answer = utils.process_answer(a) 299 | loss = utils.calculate_loss(answer, out, method=config.loss_method) 300 | acc, y_pred = utils.batch_accuracy(out, answer) 301 | if self.args.adv_delay < epoch + 1 < self.args.adv_end: 302 | # use predicted label to prevent label leaking 303 | if 'sea' in self.attack_al: 304 | q_lens = [(q_len_adv[i], i) for i in range(q_len_adv.shape[0])] 305 | q_lens = sorted(q_lens, key=lambda x: x[0], reverse=True) 306 | q_len_adv = [q_len_a[0] for q_len_a in q_lens] 307 | q_len_adv = Variable(torch.stack(q_len_adv, dim=0).cuda(), **var_params) 308 | v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, answer_sorted = self.sort_sample(q_lens, v, b, q_adv, v_mask, q_mask_adv, answer) 309 | 310 | if self.args.attack_mode == 'q': 311 | if self.attack_al[0] == 'sea': 312 | out_adv = self.model(v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 313 | loss_adv = [utils.calculate_loss(answer_sorted, out_adv, method=config.loss_method)] 314 | else: 315 | v_adv = self.advtrain_step((v, b, q, v_mask, q_mask, q_len), y_pred, self.model, 316 | self.adversarial, True) 317 | v_adv = Variable(v_adv.data, requires_grad=False) 318 | out_adv = self.model(v, b, q, v_mask, q_mask, q_len, v_adv) 319 | loss_adv = [utils.calculate_loss(answer, out_adv, method=config.loss_method)] 320 | 321 | elif self.args.attack_mode == 'v': 322 | v_adv = self.advtrain_step((v, b, q, v_mask, q_mask, q_len), y_pred, self.model, 323 | self.adversarial, False) 324 | v_adv = Variable(v_adv.data, requires_grad=False) 325 | out_adv = self.model(v_adv, b, q, v_mask, q_mask, q_len) 326 | loss_adv = [utils.calculate_loss(answer, out_adv, method=config.loss_method)] 327 | else: 328 | if self.attack_al[1] == 'sea': # todo: v & q 329 | y_pred_sorted = self.sort_sample(q_lens, y_pred)[0] 330 | v_qadv = self.advtrain_step((v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv), y_pred_sorted, self.model, 331 | self.adversarial, False) 332 | v_qc = self.advtrain_step((v, b, q, v_mask, q_mask, q_len), y_pred, self.model, 333 | self.adversarial, False) 334 | 335 | v_qadv = Variable(v_qadv.data, requires_grad=False) 336 | v_qc = Variable(v_qc.data, requires_grad=False) 337 | q_sorted, q_mask_sorted, q_len_sorted, v_qc_sorted = self.sort_sample(q_lens, q, q_mask, q_len, v_qc) 338 | out_vqadv_qadv = self.model(v_qadv, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 339 | loss_vqadv_qadv = utils.calculate_loss(answer_sorted, out_vqadv_qadv, method=config.loss_method) 340 | out_vqc_q = self.model(v_qc, b, q, v_mask, q_mask, q_len) 341 | loss_vqc_q = utils.calculate_loss(answer, out_vqc_q, method=config.loss_method) 342 | v_qadv_re = self.restore_order(q_lens, v_qadv)[0] 343 | out_vqadv_q = self.model(v_qadv_re, b, q, v_mask, q_mask, q_len) 344 | loss_vqadv_q = utils.calculate_loss(answer, out_vqadv_q, method=config.loss_method) 345 | out_vqc_qadv = self.model(v_qc_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 346 | loss_vqc_qadv = utils.calculate_loss(answer_sorted, out_vqc_qadv, method=config.loss_method) 347 | loss_adv = [loss_vqadv_qadv, loss_vqc_q, loss_vqadv_q, loss_vqc_qadv] 348 | 349 | loss = (self.args.advloss_w * sum(loss_adv) + loss) / (len(loss_adv) + 1) 350 | 351 | self.optimizer.zero_grad() 352 | loss.backward() 353 | # print gradient 354 | if config.print_gradient: 355 | utils.print_grad([(n, p) for n, p in self.model.named_parameters() if p.grad is not None]) 356 | # clip gradient 357 | clip_grad_norm_(self.model.parameters(), config.clip_value) 358 | self.optimizer.step() 359 | if (config.schedule_method == 'batch_decay'): 360 | self.scheduler.step() 361 | loss_tracker.append(loss.item()) 362 | acc_tracker.append(acc.data.cpu().mean()) 363 | fmt = '{:.4f}'.format 364 | loader.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 365 | if self.args.advtrain_data != 'trainval': 366 | r = self.evaluate(self.val_loader) 367 | if epoch == self.args.adv_delay: 368 | best_valid = 0 369 | if sum(r[1]) / len(r[1]) > best_valid: 370 | best_valid = sum(r[1]) / len(r[1]) 371 | print('best valid') 372 | results = { 373 | 'name': self.name, 374 | 'tracker': self.tracker.to_dict(), 375 | 'config': self.config_as_dict, 376 | 'weights': self.model.module.state_dict(), 377 | 'eval': { 378 | 'clean_answers': r[0], 379 | 'clean_accuracies': r[1], 380 | 'adv_answers': r[3], 381 | 'adv_accuracies': r[4], 382 | 'idx': r[2], 383 | }, 384 | 'vocab': self.val_loader.dataset.vocab if self.args.advtrain_data == 'train' else self.train_loader.dataset.vocab, 385 | 'src': self.src, 386 | 'optimizer': self.optimizer.state_dict(), 387 | 'scheduler': self.scheduler.state_dict() if config.model_type == 'counting' else [], 388 | 'epoch': epoch + 1, 389 | } 390 | torch.save(results, self.target_name) 391 | else: 392 | r = [[-1], [-1], [-1], [-1], [-1]] 393 | results = { 394 | 'name': self.name, 395 | 'tracker': self.tracker.to_dict(), 396 | 'config': self.config_as_dict, 397 | 'weights': self.model.module.state_dict(), 398 | 'eval': { 399 | 'clean_answers': r[0], 400 | 'clean_accuracies': r[1], 401 | 'adv_answers': r[3], 402 | 'adv_accuracies': r[4], 403 | 'idx': r[2], 404 | }, 405 | 'vocab': self.val_loader.dataset.vocab if self.args.advtrain_data == 'train' else self.train_loader.dataset.vocab, 406 | 'src': self.src, 407 | 'optimizer': self.optimizer.state_dict(), 408 | 'scheduler': self.scheduler.state_dict() if config.model_type == 'counting' else [], 409 | 'epoch': epoch + 1, 410 | } 411 | torch.save(results, self.target_name) 412 | 413 | f = open('log.txt', 'a') 414 | f.write(self.name + '\n') 415 | f.write(str(best_valid.data.cpu().numpy())) 416 | f.write('\n') 417 | if self.args.attacked_checkpoint: 418 | f.write(str((sum(results['eval']['adv_accuracies'])/len(results['eval']['adv_accuracies'])).data.cpu())) 419 | f.write('\n') 420 | 421 | def advtrain_step(self, X, y, net, adversary, perturb_q): 422 | # If adversarial training, need a snapshot of 423 | # the model at each batch to compute grad, so 424 | # as not to mess up with the optimization step 425 | # model_cp = copy.deepcopy(net) 426 | model_cp = model.Net(self.question_keys) 427 | model_cp = nn.DataParallel(model_cp).cuda() 428 | model_cp.load_state_dict(net.state_dict()) 429 | for p in model_cp.parameters(): 430 | p.requires_grad = False 431 | # model_cp.eval() 432 | 433 | adversary.model = model_cp 434 | 435 | if perturb_q: 436 | X_adv = adversary.perturb(X, y, perturb_q=True) 437 | else: 438 | X_adv, _, _ = adversary.perturb(X, y) 439 | 440 | return X_adv 441 | 442 | def sort_sample(self, order, *args): 443 | var_params = { 444 | 'requires_grad': False, 445 | } 446 | args = [[arg[q_len_a[1]] for q_len_a in order] for arg in args] 447 | args = [Variable(torch.stack(arg, dim=0).cuda(), **var_params) for arg in args] 448 | return args 449 | 450 | def restore_order(self, order, *args): 451 | var_params = { 452 | 'requires_grad': False, 453 | } 454 | args = [[(arg[i], q_len_a[1]) for i, q_len_a in enumerate(order)] for arg in args] 455 | args = [sorted(arg, key=lambda x: x[1]) for arg in args] 456 | args = [[ar[0] for ar in arg] for arg in args] 457 | args = [Variable(torch.stack(arg, dim=0).cuda(), **var_params) for arg in args] 458 | return args 459 | 460 | def vq_loss(self, loss, loss_adv): 461 | return (sum(loss_adv) + loss) / (len(loss_adv) + 1) 462 | 463 | def evaluate(self, loader, has_answers=True): 464 | self.model.eval() 465 | tracker_class, tracker_params = self.tracker.MeanMonitor, {} 466 | answ = [] 467 | idxs = [] 468 | accs = [] 469 | perturbed_answ = [] 470 | perturbed_accs = [] 471 | if self.args.attacked_checkpoint and self.attack_dict['sea'] is None: 472 | self.adversarial.model = self.base_model 473 | loader = tqdm(loader, desc='{}'.format('val'), ncols=0) 474 | loss_tracker = self.tracker.track('{}_loss'.format('val'), tracker_class(**tracker_params)) 475 | acc_tracker = self.tracker.track('{}_acc'.format('val'), tracker_class(**tracker_params)) 476 | perturbed_loss_tracker = self.tracker.track('{}_advloss'.format('val'), tracker_class(**tracker_params)) 477 | perturbed_acc_tracker = self.tracker.track('{}_advacc'.format('val'), tracker_class(**tracker_params)) 478 | 479 | for v, q, q_adv, q_str, a, b, idx, v_mask, q_mask, q_mask_adv, image_id, q_id, q_len_adv, q_len in loader: 480 | var_params = { 481 | 'requires_grad': False, 482 | } 483 | v = Variable(v.cuda(), **var_params) 484 | q = Variable(q.cuda(), **var_params) 485 | a = Variable(a.cuda(), **var_params) 486 | b = Variable(b.cuda(), **var_params) 487 | q_len = Variable(q_len.cuda(), **var_params) 488 | v_mask = Variable(v_mask.cuda(), **var_params) 489 | q_mask = Variable(q_mask.cuda(), **var_params) 490 | 491 | clean_out = self.model(v, b, q, v_mask, q_mask, q_len) 492 | if has_answers: 493 | answer = utils.process_answer(a) # answer must be known when generate adversarial example 494 | clean_loss = utils.calculate_loss(answer, clean_out, method=config.loss_method) 495 | clean_acc, _ = utils.batch_accuracy(clean_out, answer) 496 | accs.append(clean_acc.data.cpu().view(-1)) 497 | loss_tracker.append(clean_loss.item()) 498 | acc_tracker.append(clean_acc.mean()) 499 | 500 | if self.args.attacked_checkpoint: 501 | # if self.args.attack_al == 'fgsm': 502 | # v_adv, acc, loss = self.fgsm.perturb((v, b, q, v_mask, q_mask, q_len), answer) 503 | # elif self.args.attack_al == 'ifgsm': 504 | # v_adv, acc, loss = self.ifgsm.perturb((v, b, q, v_mask, q_mask, q_len), answer) 505 | if 'sea' in self.attack_al: 506 | q_lens = [(q_len_adv[i], i) for i in range(q_len_adv.shape[0])] 507 | q_lens = sorted(q_lens, key=lambda x: x[0], reverse=True) 508 | q_len_adv = [q_len_a[0] for q_len_a in q_lens] 509 | q_len_adv = Variable(torch.stack(q_len_adv, dim=0).cuda(), **var_params) 510 | v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, answer = self.sort_sample(q_lens, v, b, 511 | q_adv, v_mask, 512 | q_mask_adv, 513 | answer) 514 | if self.args.attack_mode == 'q': 515 | if self.attack_al[0] == 'sea': 516 | perturbed_out = self.model(v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 517 | else: 518 | q_adv = self.adversarial.perturb((v, b, q, v_mask, q_mask, q_len), answer, perturb_q=True) 519 | q_adv = Variable(q_adv.data, requires_grad=False) 520 | perturbed_out = self.model(v, b, q, v_mask, q_mask, q_len, q_adv) 521 | elif self.args.attack_mode == 'v': 522 | v_adv, acc, loss = self.adversarial.perturb((v, b, q, v_mask, q_mask, q_len), answer) 523 | perturbed_out = self.model(v_adv, b, q, v_mask, q_mask, q_len) 524 | else: # todo: eval v & q 525 | if 'sea' in self.attack_al: 526 | v_adv, acc, loss = self.adversarial.perturb((v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv), answer) 527 | perturbed_out = self.model(v_sorted, b_sorted, q_adv, v_mask_sorted, q_mask_adv, q_len_adv) 528 | 529 | perturbed_loss = utils.calculate_loss(answer, perturbed_out, method=config.loss_method) 530 | perturbed_acc, _ = utils.batch_accuracy(perturbed_out, answer) 531 | _, perturbed_answer = perturbed_out.data.cpu().max(dim=1) 532 | perturbed_answ.append(perturbed_answer.view(-1)) 533 | perturbed_accs.append(perturbed_acc.data.cpu().view(-1)) 534 | perturbed_loss_tracker.append(perturbed_loss.item()) 535 | perturbed_acc_tracker.append(perturbed_acc.mean()) 536 | fmt = '{:.4f}'.format 537 | loader.set_postfix(advloss=fmt(perturbed_loss_tracker.mean.value), 538 | advacc=fmt(perturbed_acc_tracker.mean.value), 539 | loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 540 | else: 541 | fmt = '{:.4f}'.format 542 | loader.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 543 | 544 | _, clean_answer = clean_out.data.cpu().max(dim=1) 545 | answ.append(clean_answer.view(-1)) 546 | idxs.append(idx.view(-1).clone()) 547 | 548 | answ = list(torch.cat(answ, dim=0)) 549 | if has_answers: 550 | accs = list(torch.cat(accs, dim=0)) 551 | if self.args.attacked_checkpoint: 552 | perturbed_accs = list(torch.cat(perturbed_accs, dim=0)) 553 | perturbed_answ = list(torch.cat(perturbed_answ, dim=0)) 554 | idxs = list(torch.cat(idxs, dim=0)) 555 | 556 | return answ, accs, idxs, perturbed_answ, perturbed_accs 557 | 558 | def save_result_json(self, loader, has_answers=True): 559 | r = self.evaluate(loader, has_answers) 560 | answer_index_to_string = {a: s for s, a in loader.dataset.answer_to_index.items()} 561 | results = [] 562 | for answer, index in zip(r[0], r[2]): 563 | answer = answer_index_to_string[answer.item()] 564 | qid = loader.dataset.question_ids[index] 565 | entry = { 566 | 'question_id': qid, 567 | 'answer': answer, 568 | } 569 | results.append(entry) 570 | with open(config.result_json_path, 'w') as fd: 571 | json.dump(results, fd) 572 | 573 | def load_checkpoint(self, path): 574 | logs = torch.load(' '.join(path)) 575 | # hacky way to tell the VQA classes that they should use the vocab without passing more params around 576 | data.preloaded_vocab = logs['vocab'] 577 | self.model.module.load_state_dict(logs['weights']) 578 | 579 | def distance(self, x, x_adv): 580 | dist = torch.norm(x - x_adv, 2, 2) / x.shape[2] ** 0.5 581 | return torch.mean(dist) 582 | 583 | def save_q_adv(self, q_str, image_id, q_id): 584 | assert len(q_str) == len(image_id) == len(q_id) 585 | for i in range(len(q_str)): 586 | f = {'image_id': int(image_id[i]), 'question': q_str[i], 'question_id': int(q_id[i])} 587 | self.questions_adv_saver.append(f) 588 | -------------------------------------------------------------------------------- /seada/attacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import utils 3 | import config 4 | from torch.autograd import Variable 5 | from . import data 6 | import numpy as np 7 | import spacy 8 | 9 | from .sea.paraphrase_scorer import ParaphraseScorer 10 | from .sea import onmt_model, replace_rules 11 | 12 | # --- White-box attacks --- 13 | inter_feature = {} 14 | inter_gradient = {} 15 | 16 | 17 | def make_hook(name, flag): 18 | if flag == 'forward': 19 | def hook(m, input, output): 20 | inter_feature[name] = input 21 | return hook 22 | elif flag == 'backward': 23 | def hook(m, input, output): 24 | inter_gradient[name] = output 25 | return hook 26 | else: 27 | assert False 28 | 29 | 30 | class FGSMAttack(object): 31 | def __init__(self, epsilon=None, model=None): 32 | """ 33 | One step fast gradient sign method 34 | """ 35 | self.model = model 36 | self.epsilon = epsilon 37 | 38 | def perturb(self, X_nat, y, epsilon=None, k=None, alpha=None, perturb_q=False, targeted=False): 39 | """ 40 | Given examples (X_nat, y), returns their adversarial 41 | counterparts with an attack length of epsilon. 42 | """ 43 | v, b, q, v_mask, q_mask, q_len = X_nat 44 | v = Variable(v, requires_grad=True) 45 | if not perturb_q: 46 | # Providing epsilons in batch 47 | if epsilon is not None: 48 | self.epsilon = epsilon 49 | out = self.model(v, b, q, v_mask, q_mask, q_len) 50 | 51 | loss = utils.calculate_loss(y, out, method=config.loss_method) 52 | acc, _ = utils.batch_accuracy(out, y) 53 | self.model.zero_grad() 54 | loss.backward() 55 | 56 | if targeted: 57 | data_grad = -v.grad.data.sign() 58 | else: 59 | data_grad = v.grad.data.sign() 60 | perturbed_v = v + self.epsilon * data_grad 61 | 62 | return perturbed_v, acc.data.cpu(), loss 63 | else: 64 | out = self.model(v, b, q, v_mask, q_mask, q_len) 65 | loss = utils.calculate_loss(y, out, method=config.loss_method) 66 | # self.model.module.text.embed.register_backward_hook(make_hook('int_emb', 'backward')) 67 | self.model.zero_grad() 68 | loss.backward() 69 | if targeted: 70 | data_grad = -self.model.module.text.embedded.grad.data.sign() 71 | else: 72 | data_grad = self.model.module.text.embedded.grad.data.sign() 73 | perturbed_q = self.model.module.text.embedded.detach() 74 | perturbed_q = perturbed_q + config.epsilon * data_grad 75 | return perturbed_q 76 | 77 | 78 | class IFGSMAttack(object): 79 | def __init__(self, epsilon=0.3, k=40, alpha=0.01, 80 | random_start=True, model=None): 81 | """ 82 | Attack parameter initialization. The attack performs k steps of 83 | size a, while always staying within epsilon from the initial 84 | point. 85 | https://github.com/MadryLab/mnist_challenge/blob/master/pgd_attack.py 86 | """ 87 | self.model = model 88 | self.epsilon = epsilon 89 | self.k = k 90 | self.alpha = alpha 91 | self.rand = random_start # if true, it is PGD attack 92 | 93 | def perturb(self, X_nat, y, epsilon=None, k=None, alpha=None, perturb_q=False, targeted=False): 94 | """ 95 | Given examples (X_nat, y), returns adversarial 96 | examples within epsilon of X_nat in l_infinity norm. 97 | """ 98 | v_nat, b, q_nat, v_mask, q_mask, q_len = X_nat 99 | if epsilon is not None: 100 | self.epsilon = epsilon 101 | if k is not None: 102 | self.k = k 103 | if alpha is not None: 104 | self.alpha = alpha 105 | if not perturb_q: 106 | if self.rand: 107 | v = v_nat + torch.FloatTensor(v_nat.size()).uniform_(-self.epsilon, self.epsilon).cuda() 108 | else: 109 | v = v_nat.clone() 110 | v_adv = Variable(v, requires_grad=True) 111 | 112 | for i in range(self.k): 113 | out = self.model(v_adv, b, q_nat, v_mask, q_mask, q_len) 114 | loss = utils.calculate_loss(y, out, method=config.loss_method) 115 | acc, pred_out = utils.batch_accuracy(out, y) 116 | self.model.zero_grad() 117 | loss.backward() 118 | if targeted: 119 | data_grad = -v_adv.grad.data.sign() 120 | else: 121 | data_grad = v_adv.grad.data.sign() 122 | v_adv = v_adv + self.alpha * data_grad 123 | v_adv = utils.where(v_adv > v_nat + self.epsilon, v_nat + self.epsilon, v_adv) 124 | v_adv = utils.where(v_adv < v_nat - self.epsilon, v_nat - self.epsilon, v_adv) 125 | # v_adv = torch.clamp(v_adv, v - config.epsilon, v + config.epsilon) 126 | v_adv = Variable(v_adv.data, requires_grad=True) 127 | 128 | return v_adv, acc.data.cpu(), loss 129 | else: 130 | out = self.model(v_nat, b, q_nat, v_mask, q_mask, q_len) 131 | loss = utils.calculate_loss(y, out, method=config.loss_method) 132 | # self.model.module.text.embed.register_backward_hook(make_hook('int_emb', 'backward')) 133 | self.model.zero_grad() 134 | loss.backward() 135 | data_grad = self.model.module.text.embedded.grad.data.sign() 136 | origin_q = self.model.module.text.embedded.detach() 137 | perturbed_q = origin_q + self.alpha * data_grad 138 | for i in range(1, self.k): 139 | out = self.model(v_nat, b, q_nat, v_mask, q_mask, q_len, perturbed_q) 140 | loss = utils.calculate_loss(y, out, method=config.loss_method) 141 | # acc, pred_out = utils.batch_accuracy(out, y) 142 | self.model.zero_grad() 143 | loss.backward() 144 | if targeted: 145 | data_grad = -self.model.module.text.embedded.grad.data.sign() 146 | else: 147 | data_grad = self.model.module.text.embedded.grad.data.sign() 148 | perturbed_q = self.model.module.text.embedded.detach() + self.alpha * data_grad 149 | perturbed_q = utils.where(perturbed_q > origin_q + self.epsilon, origin_q + self.epsilon, perturbed_q) 150 | perturbed_q = utils.where(perturbed_q < origin_q - self.epsilon, origin_q - self.epsilon, perturbed_q) 151 | # v_adv = torch.clamp(v_adv, v - config.epsilon, v + config.epsilon) 152 | perturbed_q = Variable(perturbed_q.data) 153 | 154 | return perturbed_q 155 | 156 | class RandomNoise(object): 157 | def __init__(self, epsilon, model=None): 158 | self.epsilon = epsilon 159 | self.model = model 160 | 161 | def perturb(self, X_nat, y, epsilon=None, k=None, alpha=None, perturb_q=False, targeted=False): 162 | v, b, q, v_mask, q_mask, q_len = X_nat 163 | 164 | v_adv = v + self.epsilon * torch.randn_like(v).cuda() 165 | out = self.model(v_adv, b, q, v_mask, q_mask, q_len) 166 | loss = utils.calculate_loss(y, out, method=config.loss_method) 167 | acc, _ = utils.batch_accuracy(out, y) 168 | 169 | return v_adv, acc.data.cpu(), loss 170 | 171 | 172 | class SEA(object): 173 | def __init__(self, dataset=None, model=None, fliprate=0, topk=None): 174 | self.dataset = dataset 175 | self.model = model 176 | self.ps = ParaphraseScorer(gpu_id=0) 177 | self.nlp = spacy.load('en') 178 | self.fliprate = fliprate 179 | #self.ratetemp = fliprate 180 | self.topk = topk 181 | 182 | def perturb(self, X_nat, y=None, oripred=None, epsilon=None, k=None, alpha=None, perturb_q=False, targeted=False): 183 | v, b, q, q_str, v_mask, q_mask, image_id, q_id, q_len = X_nat 184 | q_advs = [] 185 | q_len_clue = [] 186 | q_mask_advs = [] 187 | q_str_advs = [] 188 | flips = int(v.shape[0] * self.fliprate) 189 | topk = self.topk 190 | fliprate = self.fliprate 191 | nflip = 0 192 | for i in range(v.shape[0]): 193 | q_adv, q_len_adv, q_mask_adv, q_str_adv, flipsign = self.find_flips(q_str[i], visual=(v[i], b[i], v_mask[i]), topk=topk, fliprate=fliprate, threshold=-10, oripred=oripred[i]) 194 | if q_adv is None: # support top1 right now todo: support topk 195 | q_adv = q[i].unsqueeze(0) 196 | q_len_adv = q_len[i].unsqueeze(0) 197 | q_mask_adv = q_mask[i].unsqueeze(0) 198 | q_str_adv = q_str[i] 199 | if flipsign: 200 | nflip += 1 201 | if nflip > flips: 202 | fliprate = 0 203 | topk = 1 204 | q_advs.append(q_adv) 205 | q_len_clue.append((q_len_adv, i)) 206 | q_mask_advs.append(q_mask_adv) 207 | q_str_advs.append(q_str_adv) 208 | q_len_clue = sorted(q_len_clue, key=lambda x: x[0], reverse=True) 209 | q_len_advs = [clue[0] for clue in q_len_clue] 210 | sort_id = [clue[1] for clue in q_len_clue] 211 | 212 | q_advs = [q_advs[idx] for idx in sort_id] 213 | q_mask_advs = [q_mask_advs[idx] for idx in sort_id] 214 | q_str_advs = [q_str_advs[idx] for idx in sort_id] 215 | v = [v[idx] for idx in sort_id] 216 | b = [b[idx] for idx in sort_id] 217 | v_mask = [v_mask[idx] for idx in sort_id] 218 | y = [y[idx] for idx in sort_id] 219 | image_id = [image_id[idx] for idx in sort_id] 220 | q_id = [q_id[idx] for idx in sort_id] 221 | 222 | q_adv = torch.cat(q_advs, dim=0) 223 | q_len_adv = torch.cat(q_len_advs, dim=0) 224 | q_mask_adv = torch.cat(q_mask_advs, dim=0) 225 | v = torch.stack(v, dim=0) 226 | b = torch.stack(b, dim=0) 227 | v_mask = torch.stack(v_mask, dim=0) 228 | y = torch.stack(y, dim=0) 229 | return v, b, v_mask, q_adv, q_len_adv, q_mask_adv, y, q_str_advs, image_id, q_id 230 | 231 | def find_flips(self, instance, visual=None, topk=1, fliprate=0, threshold=-10, oripred=None): 232 | instance_for_onmt = onmt_model.clean_text(' '.join([x.text for x in self.nlp.tokenizer(instance)]), only_upper=False) 233 | paraphrases = self.ps.generate_paraphrases(instance_for_onmt, topk=topk+1, edit_distance_cutoff=4, threshold=threshold) 234 | if len(paraphrases) == 0: 235 | return None, None, None, None, False 236 | for para in paraphrases: 237 | if para[0] == instance_for_onmt: 238 | paraphrases.remove(para) 239 | if len(paraphrases) == 0: 240 | return None, None, None, None, False 241 | paraphrases = paraphrases[:topk] 242 | prepared_paraphrases = data.prepare_questions_from_para(paraphrases) 243 | questions = [self.dataset.encode_question(paraphrase) for paraphrase in prepared_paraphrases] 244 | questions = [(q_tuple[0], q_tuple[1], i) for i, q_tuple in enumerate(questions)] 245 | sorted_questions = sorted(questions, key=lambda x: x[1], reverse=True) 246 | q_len = torch.cat([torch.tensor([q[1]]) for q in sorted_questions], 0).cuda() 247 | q = torch.stack([q[0] for q in sorted_questions], 0).cuda() 248 | q_m = [torch.from_numpy((np.arange(self.dataset.max_question_length) < q[1]).astype(int)) for q in 249 | sorted_questions] 250 | q_mask = torch.stack(q_m, 0).float().cuda() 251 | if fliprate == 0: 252 | return q, q_len, q_mask, paraphrases[0][0], False 253 | else: 254 | v, b, v_mask = visual 255 | v = v.unsqueeze(0).repeat(q.shape[0], 1, 1) 256 | b = b.unsqueeze(0).repeat(q.shape[0], 1, 1) 257 | v_mask = v_mask.unsqueeze(0).repeat(q.shape[0], 1, 1) 258 | # # discard the flip or not, compute adv for every example 259 | perturbed_out = self.model(v, b, q, v_mask, q_mask, q_len) 260 | perturbed_logits = torch.max(perturbed_out, 1)[1].cpu().numpy() 261 | p = np.where(perturbed_logits != oripred)[0].tolist() 262 | sorted_para = [paraphrases[qs[2]] for qs in sorted_questions] 263 | flipsign = False 264 | if len(p) == 0: 265 | return q[0].unsqueeze(0), q_len[0].unsqueeze(0), q_mask[0].unsqueeze(0), sorted_para[0][0], flipsign 266 | else: 267 | flipsign = True 268 | return q[p[0]].unsqueeze(0), q_len[p[0]].unsqueeze(0), q_mask[p[0]].unsqueeze(0), sorted_para[p[0]][0], flipsign 269 | 270 | # perturbed_acc, _ = utils.batch_accuracy(perturbed_out, orig_pred.unsqueeze(0).repeat(q.shape[0], 1)) 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /seada/butd/baseline_model.py: -------------------------------------------------------------------------------- 1 | ########################## 2 | # Implementation of Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering 3 | # Paper Link: https://arxiv.org/abs/1707.07998 4 | # Code Author: Kaihua Tang 5 | # Environment: Python 3.6, Pytorch 1.0 6 | ########################## 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.init as init 11 | from torch.autograd import Variable 12 | from torch.nn.utils import weight_norm 13 | from torch.nn.utils.rnn import pack_padded_sequence 14 | 15 | import config 16 | from . import word_embedding 17 | 18 | from .reuse_modules import Fusion, FCNet 19 | 20 | class Net(nn.Module): 21 | def __init__(self, words_list): 22 | super(Net, self).__init__() 23 | question_features = 1024 24 | vision_features = config.output_features 25 | glimpses = 2 26 | 27 | self.text = word_embedding.TextProcessor( 28 | classes=words_list, 29 | embedding_features=300, 30 | lstm_features=question_features, 31 | drop=0.0, 32 | ) 33 | 34 | self.attention = Attention( 35 | v_features=vision_features, 36 | q_features=question_features, 37 | mid_features=1024, 38 | glimpses=glimpses, 39 | drop=0.2,) 40 | 41 | self.classifier = Classifier( 42 | in_features=(glimpses * vision_features, question_features), 43 | mid_features=1024, 44 | out_features=config.max_answers, 45 | drop=0.5,) 46 | 47 | def forward(self, v, b, q, v_mask, q_mask, q_len, pred_from_emb=None): 48 | ''' 49 | v: visual feature [batch, num_obj, 2048] 50 | b: bounding box [batch, num_obj, 4] 51 | q: question [batch, max_q_len] 52 | v_mask: number of obj [batch, max_obj] 1 is obj, 0 is none 53 | q_mask: question length [batch, max_len] 1 is word, 0 is none 54 | answer: predict logits [batch, config.max_answers] 55 | ''' 56 | q = self.text(q, list(q_len.data), pred_from_emb) # [batch, 1024] 57 | if config.v_feat_norm: 58 | v = v / (v.norm(p=2, dim=2, keepdim=True) + 1e-12).expand_as(v) # [batch, num_obj, 2048] 59 | 60 | a = self.attention(v, q) # [batch, 36, num_glimpse] 61 | v = apply_attention(v.transpose(1,2), a) # [batch, 2048 * num_glimpse] 62 | answer = self.classifier(v, q) 63 | 64 | return answer 65 | 66 | 67 | class Classifier(nn.Module): 68 | def __init__(self, in_features, mid_features, out_features, drop=0.0): 69 | super(Classifier, self).__init__() 70 | self.lin11 = FCNet(in_features[0], mid_features, activate='relu') 71 | self.lin12 = FCNet(in_features[1], mid_features, activate='relu') 72 | self.lin2 = FCNet(mid_features, mid_features, activate='relu') 73 | self.lin3 = FCNet(mid_features, out_features, drop=drop) 74 | 75 | def forward(self, v, q): 76 | #x = self.fusion(self.lin11(v), self.lin12(q)) 77 | x = self.lin11(v) * self.lin12(q) 78 | x = self.lin2(x) 79 | x = self.lin3(x) 80 | return x 81 | 82 | class Attention(nn.Module): 83 | def __init__(self, v_features, q_features, mid_features, glimpses, drop=0.0): 84 | super(Attention, self).__init__() 85 | self.lin_v = FCNet(v_features, mid_features, activate='relu') # let self.lin take care of bias 86 | self.lin_q = FCNet(q_features, mid_features, activate='relu') 87 | self.lin = FCNet(mid_features, glimpses, drop=drop) 88 | 89 | def forward(self, v, q): 90 | """ 91 | v = batch, num_obj, dim 92 | q = batch, dim 93 | """ 94 | v = self.lin_v(v) 95 | q = self.lin_q(q) 96 | batch, num_obj, _ = v.shape 97 | _, q_dim = q.shape 98 | q = q.unsqueeze(1).expand(batch, num_obj, q_dim) 99 | 100 | x = v * q 101 | x = self.lin(x) # batch, num_obj, glimps 102 | x = F.softmax(x, dim=1) 103 | return x 104 | 105 | 106 | def apply_attention(input, attention): 107 | """ 108 | input = batch, dim, num_obj 109 | attention = batch, num_obj, glimps 110 | """ 111 | batch, dim, _ = input.shape 112 | _, _, glimps = attention.shape 113 | x = input @ attention # batch, dim, glimps 114 | assert(x.shape[1] == dim) 115 | assert(x.shape[2] == glimps) 116 | return x.view(batch, -1) 117 | -------------------------------------------------------------------------------- /seada/butd/reuse_modules.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | from torch.nn.utils import weight_norm 8 | from torch.nn.utils.rnn import pack_padded_sequence 9 | 10 | import config 11 | 12 | 13 | class Fusion(nn.Module): 14 | """ Crazy multi-modal fusion: negative squared difference minus relu'd sum 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, x, y): 20 | # found through grad student descent ;) 21 | return - (x - y)**2 + F.relu(x + y) 22 | 23 | class FCNet(nn.Module): 24 | def __init__(self, in_size, out_size, activate=None, drop=0.0): 25 | super(FCNet, self).__init__() 26 | self.lin = weight_norm(nn.Linear(in_size, out_size), dim=None) 27 | 28 | self.drop_value = drop 29 | self.drop = nn.Dropout(drop) 30 | 31 | # in case of using upper character by mistake 32 | self.activate = activate.lower() if (activate is not None) else None 33 | if activate == 'relu': 34 | self.ac_fn = nn.ReLU() 35 | elif activate == 'sigmoid': 36 | self.ac_fn = nn.Sigmoid() 37 | elif activate == 'tanh': 38 | self.ac_fn = nn.Tanh() 39 | 40 | 41 | def forward(self, x): 42 | if self.drop_value > 0: 43 | x = self.drop(x) 44 | 45 | x = self.lin(x) 46 | 47 | if self.activate is not None: 48 | x = self.ac_fn(x) 49 | return x 50 | -------------------------------------------------------------------------------- /seada/butd/word_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from PyTorch's text library. 3 | """ 4 | 5 | import array 6 | import os 7 | import zipfile 8 | 9 | import six 10 | import torch 11 | from six.moves.urllib.request import urlretrieve 12 | from tqdm import tqdm 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.init as init 17 | from torch.autograd import Variable 18 | from torch.nn.utils.rnn import pack_padded_sequence 19 | 20 | import config 21 | from config import qa_path 22 | 23 | class TextProcessor(nn.Module): 24 | def __init__(self, classes, embedding_features, lstm_features, drop=0.0, use_hidden=True, use_tanh=False, only_embed=False): 25 | super(TextProcessor, self).__init__() 26 | self.use_hidden = use_hidden # return last layer hidden, else return all the outputs for each words 27 | self.use_tanh = use_tanh 28 | self.only_embed = only_embed 29 | classes = list(classes) 30 | 31 | self.embed = nn.Embedding(len(classes)+1, embedding_features, padding_idx=len(classes)) 32 | weight_init = torch.from_numpy(np.load(qa_path+'/glove6b_init_300d.npy')) 33 | assert weight_init.shape == (len(classes), embedding_features) 34 | # print('glove weight shape: ', weight_init.shape) 35 | self.embed.weight.data[:len(classes)] = weight_init 36 | # print('word embed shape: ', self.embed.weight.shape) 37 | 38 | self.drop = nn.Dropout(drop) 39 | 40 | if self.use_tanh: 41 | self.tanh = nn.Tanh() 42 | 43 | if not self.only_embed: 44 | self.lstm = nn.GRU(input_size=embedding_features, 45 | hidden_size=lstm_features, 46 | num_layers=1, 47 | batch_first=not use_hidden,) 48 | 49 | def forward(self, q, q_len, pred_from_emb=None): 50 | if pred_from_emb is not None: 51 | self.embedded = pred_from_emb.requires_grad_() 52 | else: 53 | self.embedded = self.embed(q).requires_grad_() 54 | 55 | # embedded = self.embed(q) 56 | embedded = self.drop(self.embedded) 57 | 58 | if self.use_tanh: 59 | embedded = self.tanh(embedded) 60 | 61 | if self.only_embed: 62 | return embedded 63 | 64 | self.lstm.flatten_parameters() 65 | if self.use_hidden: 66 | packed = pack_padded_sequence(embedded, q_len, batch_first=True) 67 | _, hid = self.lstm(packed) 68 | return hid.squeeze(0) 69 | else: 70 | out, _ = self.lstm(embedded) 71 | return out 72 | 73 | 74 | #embed_vecs = obj_edge_vectors(classes, wv_dim=embedding_features) 75 | #self.embed.weight.data = embed_vecs.clone() 76 | def obj_edge_vectors(names, wv_type='glove.6B', wv_dir=qa_path, wv_dim=300): 77 | wv_dict, wv_arr, wv_size = load_word_vectors(wv_dir, wv_type, wv_dim) 78 | 79 | vectors = torch.Tensor(len(names), wv_dim) 80 | vectors.normal_(0,1) 81 | failed_token = [] 82 | for i, token in enumerate(names): 83 | wv_index = wv_dict.get(token, None) 84 | if wv_index is not None: 85 | vectors[i] = wv_arr[wv_index] 86 | else: 87 | # Try the longest word (hopefully won't be a preposition 88 | lw_token = sorted(token.split(' '), key=lambda x: len(x), reverse=True)[0] 89 | #print("{} -> {} ".format(token, lw_token)) 90 | wv_index = wv_dict.get(lw_token, None) 91 | if wv_index is not None: 92 | vectors[i] = wv_arr[wv_index] 93 | else: 94 | failed_token.append(token) 95 | if (len(failed_token) > 0): 96 | print('Num of failed tokens: ', len(failed_token)) 97 | #print(failed_token) 98 | return vectors 99 | 100 | URL = { 101 | 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 102 | 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 103 | 'glove.twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 104 | 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 105 | } 106 | 107 | 108 | def load_word_vectors(root, wv_type, dim): 109 | """Load word vectors from a path, trying .pt, .txt, and .zip extensions.""" 110 | if isinstance(dim, int): 111 | dim = str(dim) + 'd' 112 | fname = os.path.join(root, wv_type + '.' + dim) 113 | if os.path.isfile(fname + '.pt'): 114 | fname_pt = fname + '.pt' 115 | print('loading word vectors from', fname_pt) 116 | return torch.load(fname_pt) 117 | if os.path.isfile(fname + '.txt'): 118 | fname_txt = fname + '.txt' 119 | cm = open(fname_txt, 'rb') 120 | cm = [line for line in cm] 121 | elif os.path.basename(wv_type) in URL: 122 | url = URL[wv_type] 123 | print('downloading word vectors from {}'.format(url)) 124 | filename = os.path.basename(fname) 125 | if not os.path.exists(root): 126 | os.makedirs(root) 127 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t: 128 | fname, _ = urlretrieve(url, fname, reporthook=reporthook(t)) 129 | with zipfile.ZipFile(fname, "r") as zf: 130 | print('extracting word vectors into {}'.format(root)) 131 | zf.extractall(root) 132 | if not os.path.isfile(fname + '.txt'): 133 | raise RuntimeError('no word vectors of requested dimension found') 134 | return load_word_vectors(root, wv_type, dim) 135 | else: 136 | raise RuntimeError('unable to load word vectors') 137 | 138 | wv_tokens, wv_arr, wv_size = [], array.array('d'), None 139 | if cm is not None: 140 | for line in tqdm(range(len(cm)), desc="loading word vectors from {}".format(fname_txt)): 141 | entries = cm[line].strip().split(b' ') 142 | word, entries = entries[0], entries[1:] 143 | if wv_size is None: 144 | wv_size = len(entries) 145 | try: 146 | if isinstance(word, six.binary_type): 147 | word = word.decode('utf-8') 148 | except: 149 | print('non-UTF8 token', repr(word), 'ignored') 150 | continue 151 | wv_arr.extend(float(x) for x in entries) 152 | wv_tokens.append(word) 153 | 154 | wv_dict = {word: i for i, word in enumerate(wv_tokens)} 155 | wv_arr = torch.Tensor(wv_arr).view(-1, wv_size) 156 | ret = (wv_dict, wv_arr, wv_size) 157 | torch.save(ret, fname + '.pt') 158 | return ret 159 | 160 | def reporthook(t): 161 | """https://github.com/tqdm/tqdm""" 162 | last_b = [0] 163 | 164 | def inner(b=1, bsize=1, tsize=None): 165 | """ 166 | b: int, optionala 167 | Number of blocks just transferred [default: 1]. 168 | bsize: int, optional 169 | Size of each block (in tqdm units) [default: 1]. 170 | tsize: int, optional 171 | Total size (in tqdm units). If [default: None] remains unchanged. 172 | """ 173 | if tsize is not None: 174 | t.total = tsize 175 | t.update((b - last_b[0]) * bsize) 176 | last_b[0] = b 177 | return inner 178 | -------------------------------------------------------------------------------- /seada/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path 4 | import re 5 | 6 | import _pickle as cPickle 7 | from PIL import Image 8 | import h5py 9 | import torch 10 | import torch.utils.data as data 11 | import torchvision.transforms as transforms 12 | import numpy as np 13 | 14 | import config 15 | from . import utils 16 | 17 | 18 | preloaded_vocab = None 19 | 20 | 21 | def get_loader(train=False, val=False, test=False, trainval=False, sea=False, frac=1, iq=False, vqacp=False): 22 | """ Returns a data loader for the desired split """ 23 | split = VQA( 24 | utils.path_for(train=train, val=val, test=test, trainval=trainval, question=True, iq=iq, vqacp=vqacp), 25 | utils.path_for(train=train, val=val, test=test, trainval=trainval, answer=True, iq=iq, vqacp=vqacp), 26 | config.preprocessed_trainval_path if not test else config.preprocessed_test_path, 27 | utils.path_for(train=train, val=val, test=test, trainval=trainval, question=True, sea=sea, iq=iq), 28 | answerable_only=train or trainval, 29 | frac=frac, 30 | dummy_answers=test, 31 | ) 32 | loader = torch.utils.data.DataLoader( 33 | split, 34 | batch_size=64 if config.model_type == 'ban' and val else config.batch_size, 35 | shuffle=train or trainval, # only shuffle the data in training 36 | pin_memory=True, 37 | num_workers=config.data_workers, 38 | collate_fn=collate_fn, 39 | ) 40 | return loader 41 | 42 | 43 | def collate_fn(batch): 44 | # put question lengths in descending order so that we can use packed sequences later 45 | batch.sort(key=lambda x: x[-1], reverse=True) 46 | return data.dataloader.default_collate(batch) 47 | 48 | 49 | class VQA(data.Dataset): 50 | """ VQA dataset, open-ended """ 51 | def __init__(self, questions_path, answers_path, image_features_path, questions_adv_path=None, answerable_only=False, frac=1, dummy_answers=False): 52 | super(VQA, self).__init__() 53 | with open(questions_path, 'r') as fd: 54 | questions_json = json.load(fd) 55 | with open(answers_path, 'r') as fd: 56 | answers_json = json.load(fd) 57 | if 'adv' in questions_adv_path: 58 | with open(questions_adv_path, 'r') as fd: 59 | questions_adv_json = json.load(fd) 60 | if preloaded_vocab: 61 | vocab_json = preloaded_vocab 62 | else: 63 | with open(config.vocabulary_path, 'r') as fd: 64 | vocab_json = json.load(fd) 65 | word2idx, idx2word = cPickle.load(open(config.glove_index, 'rb')) 66 | vocab_json['question'] = word2idx 67 | 68 | self.question_ids = [q['question_id'] for q in questions_json['questions']] 69 | 70 | # vocab 71 | self.vocab = vocab_json 72 | self.token_to_index = self.vocab['question'] 73 | self.answer_to_index = self.vocab['answer'] 74 | 75 | # q and a 76 | self.q_id = [q['question_id'] for q in questions_json['questions']] 77 | self.question_str = [q['question'] for q in questions_json['questions']] # for sea 78 | self.questions = list(prepare_questions(questions_json, self.q_id)) 79 | self.questions_adv = None 80 | if 'adv' in questions_adv_path: 81 | self.questions_adv = list(prepare_questions(questions_adv_json, self.q_id)) 82 | self.questions_adv = [self.encode_question(q) for q in self.questions_adv] 83 | self.answers = list(prepare_answers(answers_json, self.q_id)) 84 | self.questions = [self.encode_question(q) for q in self.questions] 85 | self.answers = [self._encode_answers(a) for a in self.answers] 86 | 87 | # v 88 | self.image_features_path = image_features_path 89 | self.coco_id_to_index = self._create_coco_id_to_index() 90 | self.coco_ids = [q['image_id'] for q in questions_json['questions']] 91 | 92 | self.dummy_answers= dummy_answers 93 | 94 | # only use questions that have at least one answer? 95 | self.answerable_only = answerable_only 96 | if self.answerable_only: 97 | self.answerable = self._find_answerable(not self.answerable_only) 98 | self.answerable = self.answerable[:int(len(self.answerable) * frac)] 99 | 100 | @property 101 | def max_question_length(self): 102 | if not hasattr(self, '_max_length'): 103 | data_max_length = max(map(len, self.questions)) 104 | self._max_length = min(config.max_q_length, data_max_length) 105 | return self._max_length 106 | 107 | @property 108 | def num_tokens(self): 109 | return len(self.token_to_index) 110 | 111 | def _create_coco_id_to_index(self): 112 | """ Create a mapping from a COCO image id into the corresponding index into the h5 file """ 113 | with h5py.File(self.image_features_path, 'r') as features_file: 114 | coco_ids = features_file['ids'][()] 115 | coco_id_to_index = {id: i for i, id in enumerate(coco_ids)} 116 | return coco_id_to_index 117 | 118 | def _check_integrity(self, questions, answers): 119 | """ Verify that we are using the correct data """ 120 | qa_pairs = list(zip(questions['questions'], answers['annotations'])) 121 | assert all(q['question_id'] == a['question_id'] for q, a in qa_pairs), 'Questions not aligned with answers' 122 | assert all(q['image_id'] == a['image_id'] for q, a in qa_pairs), 'Image id of question and answer don\'t match' 123 | assert questions['data_type'] == answers['data_type'], 'Mismatched data types' 124 | assert questions['data_subtype'] == answers['data_subtype'], 'Mismatched data subtypes' 125 | 126 | def _find_answerable(self, count=False): 127 | """ Create a list of indices into questions that will have at least one answer that is in the vocab """ 128 | answerable = [] 129 | if count: 130 | number_indices = torch.LongTensor([self.answer_to_index[str(i)] for i in range(0, 8)]) 131 | for i, answers in enumerate(self.answers): 132 | # store the indices of anything that is answerable 133 | if count: 134 | answers = answers[number_indices] 135 | answer_has_index = len(answers.nonzero()) > 0 136 | if answer_has_index: 137 | answerable.append(i) 138 | return answerable 139 | 140 | def encode_question(self, question): 141 | """ Turn a question into a vector of indices and a question length """ 142 | vec = torch.zeros(self.max_question_length).long().fill_(self.num_tokens) 143 | for i, token in enumerate(question): 144 | if i >= self.max_question_length: 145 | break 146 | index = self.token_to_index.get(token, self.num_tokens - 1) 147 | vec[i] = index 148 | return vec, min(len(question), self.max_question_length) 149 | 150 | def _encode_answers(self, answers): 151 | """ Turn an answer into a vector """ 152 | # answer vec will be a vector of answer counts to determine which answers will contribute to the loss. 153 | # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up 154 | # to get the loss that is weighted by how many humans gave that answer 155 | answer_vec = torch.zeros(len(self.answer_to_index)) 156 | for answer in answers: 157 | index = self.answer_to_index.get(answer) 158 | if index is not None: 159 | answer_vec[index] += 1 160 | return answer_vec 161 | 162 | def _load_image(self, image_id): 163 | """ Load an image """ 164 | if not hasattr(self, 'features_file'): 165 | # Loading the h5 file has to be done here and not in __init__ because when the DataLoader 166 | # forks for multiple works, every child would use the same file object and fail 167 | # Having multiple readers using different file objects is fine though, so we just init in here. 168 | self.features_file = h5py.File(self.image_features_path, 'r') 169 | index = self.coco_id_to_index[image_id] 170 | img = self.features_file['features'][index] 171 | boxes = self.features_file['boxes'][index] 172 | widths = self.features_file['widths'][index] 173 | heights = self.features_file['heights'][index] 174 | obj_mask = (img.sum(0) > 0).astype(int) 175 | return torch.from_numpy(img).transpose(0,1), torch.from_numpy(boxes).transpose(0,1), torch.from_numpy(obj_mask), widths, heights 176 | 177 | def __getitem__(self, item): 178 | if self.answerable_only: 179 | item = self.answerable[item] 180 | q, q_length = self.questions[item] 181 | q_adv = 0 182 | q_adv_mask = 0 183 | q_adv_length = 0 184 | if self.questions_adv is not None: 185 | q_adv, q_adv_length = self.questions_adv[item] 186 | q_adv_mask = torch.from_numpy((np.arange(self.max_question_length) < q_adv_length).astype(int)).float() 187 | q_str = self.question_str[item] 188 | q_mask = torch.from_numpy((np.arange(self.max_question_length) < q_length).astype(int)) 189 | if not self.dummy_answers: 190 | a = self.answers[item] 191 | else: 192 | # just return a dummy answer, it's not going to be used anyway 193 | a = 0 194 | image_id = self.coco_ids[item] 195 | q_id = self.q_id[item] 196 | v, b, obj_mask, width, height = self._load_image(image_id) 197 | # since batches are re-ordered for PackedSequence's, the original question order is lost 198 | # we return `item` so that the order of (v, q, a) triples can be restored if desired 199 | # without shuffling in the dataloader, these will be in the order that they appear in the q and a json's. 200 | if config.normalize_box: 201 | assert b.shape[1] == 4 202 | b[:, 0] = b[:, 0] / float(width) 203 | b[:, 1] = b[:, 1] / float(height) 204 | b[:, 2] = b[:, 2] / float(width) 205 | b[:, 3] = b[:, 3] / float(height) 206 | 207 | return v, q, q_adv, q_str, a, b, item, obj_mask.float(), q_mask.float(), q_adv_mask, image_id, q_id, q_adv_length, q_length 208 | 209 | def __len__(self): 210 | if self.answerable_only: 211 | return len(self.answerable) 212 | else: 213 | return len(self.questions) 214 | 215 | 216 | # this is used for normalizing questions 217 | _special_chars = re.compile('[^a-z0-9 ]*') 218 | 219 | # these try to emulate the original normalization scheme for answers 220 | _period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)') 221 | _comma_strip = re.compile(r'(\d)(,)(\d)') 222 | _punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!') 223 | _punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars))) 224 | _punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars)) 225 | 226 | 227 | def prepare_questions(questions_json, q_id): 228 | """ Tokenize and normalize questions from a given question json in the usual VQA format. """ 229 | questions = [q['question'] for q in questions_json['questions']] 230 | #ques_dict = {} 231 | #for q in questions_json['questions']: 232 | # ques_dict[q['question_id']] = q 233 | # questions = [ques_dict[i]['question'] for i in q_id] 234 | for question in questions: 235 | question = question.lower()[:-1] 236 | question = _special_chars.sub('', question) 237 | yield question.split(' ') 238 | 239 | def prepare_questions_from_para(paraphrases): 240 | for paraphrase in paraphrases: 241 | question = paraphrase[0].lower()[:-2] 242 | question = _special_chars.sub('', question) 243 | yield question.split(' ') 244 | 245 | def prepare_answers(answers_json, q_id): 246 | """ Normalize answers from a given answer json in the usual VQA format. """ 247 | answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in answers_json['annotations']] 248 | # new_ans = {} 249 | #for ans_dict in answers_json['annotations']: 250 | # new_ans[ans_dict['question_id']] = ans_dict 251 | #answers = [[a['answer'] for a in new_ans[i]['answers']] for i in q_id] 252 | # The only normalization that is applied to both machine generated answers as well as 253 | # ground truth answers is replacing most punctuation with space (see [0] and [1]). 254 | # Since potential machine generated answers are just taken from most common answers, applying the other 255 | # normalizations is not needed, assuming that the human answers are already normalized. 256 | # [0]: http://visualqa.org/evaluation.html 257 | # [1]: https://github.com/VT-vision-lab/VQA/blob/3849b1eae04a0ffd83f56ad6f70ebd0767e09e0f/PythonEvaluationTools/vqaEvaluation/vqaEval.py#L96 258 | 259 | def process_punctuation(s): 260 | # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour 261 | # this version should be faster since we use re instead of repeated operations on str's 262 | if _punctuation.search(s) is None: 263 | return s 264 | s = _punctuation_with_a_space.sub('', s) 265 | if re.search(_comma_strip, s) is not None: 266 | s = s.replace(',', '') 267 | s = _punctuation.sub(' ', s) 268 | s = _period_strip.sub('', s) 269 | return s.strip() 270 | 271 | for answer_list in answers: 272 | yield list(map(process_punctuation, answer_list)) 273 | 274 | 275 | class CocoImages(data.Dataset): 276 | """ Dataset for MSCOCO images located in a folder on the filesystem """ 277 | def __init__(self, path, transform=None): 278 | super(CocoImages, self).__init__() 279 | self.path = path 280 | self.id_to_filename = self._find_images() 281 | self.sorted_ids = sorted(self.id_to_filename.keys()) # used for deterministic iteration order 282 | print('found {} images in {}'.format(len(self), self.path)) 283 | self.transform = transform 284 | 285 | def _find_images(self): 286 | id_to_filename = {} 287 | for filename in os.listdir(self.path): 288 | if not filename.endswith('.jpg'): 289 | continue 290 | id_and_extension = filename.split('_')[-1] 291 | id = int(id_and_extension.split('.')[0]) 292 | id_to_filename[id] = filename 293 | return id_to_filename 294 | 295 | def __getitem__(self, item): 296 | id = self.sorted_ids[item] 297 | path = os.path.join(self.path, self.id_to_filename[id]) 298 | img = Image.open(path).convert('RGB') 299 | 300 | if self.transform is not None: 301 | img = self.transform(img) 302 | return id, img 303 | 304 | def __len__(self): 305 | return len(self.sorted_ids) 306 | -------------------------------------------------------------------------------- /seada/sea/onmt_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import onmt 5 | import numpy as np 6 | import re 7 | import sys 8 | import torchtext 9 | from torch.autograd import Variable 10 | 11 | from collections import Counter, defaultdict 12 | 13 | PYTHON3 = sys.version_info > (3, 0) 14 | 15 | def repeat(repeat_numbers, tensor): 16 | cat = [] 17 | for i, x in enumerate(repeat_numbers): 18 | if x == 0: 19 | continue 20 | cat.append(tensor[:, i:i+1, :].repeat(1, x, 1)) 21 | return torch.cat(cat, 1) 22 | 23 | 24 | def transform_dec_states(decStates, repeat_numbers): 25 | assert len(repeat_numbers) == decStates._all[0].data.shape[1] 26 | vars = [Variable(repeat(repeat_numbers, e.data)) 27 | for e in decStates._all] 28 | decStates.hidden = tuple(vars[:-1]) 29 | decStates.input_feed = vars[-1] 30 | 31 | def clean_text(text, only_upper=False): 32 | # should there be a str here?` 33 | text = '%s%s' % (text[0].upper(), text[1:]) 34 | if only_upper: 35 | return text 36 | text = text.replace('|', 'UNK') 37 | text = re.sub('(^|\s)-($|\s)', r'\1@-@\2', text) 38 | # text = re.sub(' (n?\'.) ', r'\1 ', text) 39 | # fix apostrophe stuff according to tokenizer 40 | text = re.sub(' (n)(\'.) ', r'\1 \2 ', text) 41 | return text 42 | 43 | class OnmtModel(object): 44 | def __init__(self, model_path, gpu_id=1): 45 | parser = argparse.ArgumentParser(description='translate.py') 46 | parser.add_argument('-model', required=True, 47 | help='Path to model .pt file') 48 | parser.add_argument( 49 | '-src', required=True, 50 | help='Source sequence to decode (one line per sequence)') 51 | parser.add_argument('-src_img_dir', default="", 52 | help='Source image directory') 53 | parser.add_argument('-tgt', 54 | help='True target sequence (optional)') 55 | parser.add_argument('-output', default='pred.txt', 56 | help="""Path to output the predictions (each line will 57 | be the decoded sequence""") 58 | parser.add_argument('-beam_size', type=int, default=5, 59 | help='Beam size') 60 | parser.add_argument('-batch_size', type=int, default=30, 61 | help='Batch size') 62 | parser.add_argument('-max_sent_length', type=int, default=100, 63 | help='Maximum sentence length.') 64 | parser.add_argument('-replace_unk', action="store_true", 65 | help="""Replace the generated UNK tokens with the 66 | source token that had highest attention weight. If 67 | phrase_table is provided, it will lookup the 68 | identified source token and give the corresponding 69 | target token. If it is not provided (or the 70 | identified source token does not exist in the 71 | table) then it will copy the source token""") 72 | parser.add_argument( 73 | '-verbose', action="store_true", 74 | help='Print scores and predictions for each sentence') 75 | parser.add_argument('-attn_debug', action="store_true", 76 | help='Print best attn for each word') 77 | parser.add_argument('-dump_beam', type=str, default="", 78 | help='File to dump beam information to.') 79 | 80 | parser.add_argument('-n_best', type=int, default=1, 81 | help="""If verbose is set, will output the n_best 82 | decoded sentences""") 83 | 84 | parser.add_argument('-gpu', type=int, default=-1, 85 | help="Device to run on") 86 | # options most relevant to summarization 87 | parser.add_argument('-dynamic_dict', action='store_true', 88 | help="Create dynamic dictionaries") 89 | parser.add_argument('-share_vocab', action='store_true', 90 | help="Share source and target vocabulary") 91 | # Alpha and Beta values for Google Length + Coverage penalty 92 | # Described here: https://arxiv.org/pdf/1609.08144.pdf, Section 7 93 | parser.add_argument('-alpha', type=float, default=0.0, 94 | help="""Google NMT length penalty parameter 95 | (higher = longer generation)""") 96 | parser.add_argument('-beta', type=float, default=0.0, 97 | help="""Coverage penalty parameter""") 98 | 99 | opt = parser.parse_args(( '-model %s -src /tmp/a -tgt /tmp/b -output /tmp/c -gpu %d -verbose -beam_size 5 -batch_size 1 -n_best 5 -replace_unk' % (model_path, gpu_id)).split()) # noqa 100 | opt.cuda = opt.gpu > -1 101 | if opt.cuda: 102 | torch.cuda.set_device(opt.gpu) 103 | self.translator = onmt.Translator(opt) 104 | 105 | 106 | def get_init_states(self, sentence): 107 | sentence = clean_text(sentence) 108 | data = ONMTDataset2([sentence], None, self.translator.fields, 109 | None) 110 | opt = self.translator.opt 111 | self.translator.opt.tgt = None 112 | testData = onmt.IO.OrderedIterator( 113 | dataset=data, device=opt.gpu, 114 | batch_size=opt.batch_size, train=False, sort=False, 115 | shuffle=False) 116 | if PYTHON3: 117 | batch = next(testData.__iter__()) 118 | else: 119 | batch = testData.__iter__().next() 120 | _, src_lengths = batch.src 121 | src = onmt.IO.make_features(batch, 'src') 122 | encStates, context = self.translator.model.encoder(src, src_lengths) 123 | decStates = self.translator.model.decoder.init_decoder_state( 124 | src, context, encStates) 125 | src_example = batch.dataset.examples[batch.indices[0].data].src 126 | return encStates, context, decStates, src_example 127 | 128 | def advance_states(self, encStates, context, decStates, new_idxs, 129 | new_sizes): 130 | # new_idxs is a list of new inputs 131 | # new_sizes indicates how duplicates to make of each decStates in the 132 | # previous round 133 | # Returns predict_proba, decStates(updated) 134 | tt = torch.cuda if self.translator.opt.cuda else torch 135 | 136 | 137 | def var(a): return Variable(a) 138 | 139 | def rvar(a, l): return var(a.repeat(1, l, 1)) 140 | current_state = tt.LongTensor(new_idxs) 141 | inp = var(torch.stack([current_state]).t().contiguous().view(1, -1)) 142 | inp = inp.unsqueeze(2) 143 | n_context = rvar(context.data, len(new_idxs)) 144 | transform_dec_states(decStates, new_sizes) 145 | decOut, decStates, attn = self.translator.model.decoder(inp, n_context, 146 | decStates) 147 | decOut = decOut.squeeze(0) 148 | out = self.translator.model.generator.forward(decOut).data 149 | out_np = out.cpu().numpy() 150 | return out_np, decStates, attn 151 | 152 | 153 | def vocab(self): 154 | return self.translator.fields['tgt'].vocab 155 | 156 | def translate(self, sentences, n_best=1, return_from_mapping=False): 157 | # Returns a 2d list (len(sentences), n(best)) of pairs, where each 158 | # is a translation and a score 159 | sentences = [clean_text(x) for x in sentences] 160 | data = ONMTDataset2(sentences, None, self.translator.fields, 161 | None) 162 | opt = self.translator.opt 163 | self.translator.opt.tgt = None 164 | testData = onmt.IO.OrderedIterator( 165 | dataset=data, device=opt.gpu, 166 | batch_size=opt.batch_size, train=False, sort=False, 167 | shuffle=False) 168 | out = [] 169 | scores = [] 170 | mappings = [] 171 | # gold = [] 172 | self.translator.opt.n_best = n_best 173 | prev_beam_size = self.translator.opt.beam_size 174 | vocab = self.translator.fields['tgt'].vocab 175 | if n_best > self.translator.opt.beam_size: 176 | self.translator.opt.beam_size = n_best 177 | for batch in testData: 178 | _, lens = batch.src 179 | # This only works if batch_size is one 180 | 181 | predBatch, goldBatch, predScore, goldScore, attn, src = ( 182 | self.translator.translate(batch, data)) 183 | # This is doing replace_unk 184 | if self.translator.opt.replace_unk: 185 | src_example = batch.dataset.examples[batch.indices[0].data].src 186 | for i, x in enumerate(predBatch): 187 | for j, sentence in enumerate(x): 188 | for k, word in enumerate(sentence): 189 | if word == vocab.itos[onmt.IO.UNK]: 190 | _, maxIndex = attn[i][j][k].max(0) 191 | m = int(maxIndex) 192 | predBatch[i][j][k] = src_example[m] 193 | # print 'ae', word, src_example[m] 194 | if return_from_mapping: 195 | this_mappings = [] 196 | src_example = batch.dataset.examples[batch.indices[0].data].src 197 | for i, x in enumerate(predBatch): 198 | for j, sentence in enumerate(x): 199 | mapping = {} 200 | for k, word in enumerate(sentence): 201 | _, maxIndex = attn[i][j][k].max(0) 202 | m = int(maxIndex) 203 | mapping[k] = src_example[m] 204 | this_mappings.append(mapping) 205 | 206 | mappings.append(this_mappings) 207 | out.extend([[' '.join(x) for x in y] for y in predBatch]) 208 | # print predScore 209 | # print goldScore 210 | scores.extend([x[:self.translator.opt.n_best] for x in predScore]) 211 | # gold.extend([x for x in goldScore]) 212 | self.translator.opt.beam_size = prev_beam_size 213 | if return_from_mapping: 214 | return [list(zip(x, y, z)) for x, y, z in zip(out, scores, mappings)] 215 | return [list(zip(x, y)) for x, y in zip(out, scores)] 216 | 217 | def score(self, original_sentence, other_sentences): 218 | original_sentence = clean_text(original_sentence) 219 | other_sentences = [clean_text(x) for x in other_sentences] 220 | # print(original_sentence, other_sentences) 221 | # print other_sentences 222 | sentences = [original_sentence] * len(other_sentences) 223 | self.translator.opt.tgt = 'yes' 224 | data = ONMTDataset2(sentences, other_sentences, self.translator.fields, 225 | None) 226 | opt = self.translator.opt 227 | testData = onmt.IO.OrderedIterator( 228 | dataset=data, device=opt.gpu, 229 | batch_size=opt.batch_size, train=False, sort=False, 230 | shuffle=False) 231 | gold = [] 232 | # print(original_sentence, other_sentences) 233 | for batch in testData: 234 | # print('a') 235 | scores = self.translator._runTarget(batch, data) 236 | gold.extend([x for x in scores.cpu().numpy()[0]]) 237 | return np.array(gold) 238 | 239 | 240 | def extractFeatures(tokens): 241 | "Given a list of token separate out words and features (if any)." 242 | words = [] 243 | features = [] 244 | numFeatures = None 245 | 246 | for t in range(len(tokens)): 247 | field = tokens[t].split(u"|") 248 | word = field[0] 249 | if len(word) > 0: 250 | words.append(word) 251 | if numFeatures is None: 252 | numFeatures = len(field) - 1 253 | else: 254 | assert (len(field) - 1 == numFeatures), \ 255 | "all words must have the same number of features" 256 | 257 | if len(field) > 1: 258 | for i in range(1, len(field)): 259 | if len(features) <= i-1: 260 | features.append([]) 261 | features[i - 1].append(field[i]) 262 | assert (len(features[i - 1]) == len(words)) 263 | return words, features, numFeatures if numFeatures else 0 264 | 265 | 266 | class ONMTDataset2(torchtext.data.Dataset): 267 | """Defines a dataset for machine translation.""" 268 | 269 | @staticmethod 270 | def sort_key(ex): 271 | "Sort in reverse size order" 272 | return -len(ex.src) 273 | 274 | def __init__(self, src_path, tgt_path, fields, opt, 275 | src_img_dir=None, **kwargs): 276 | "Create a TranslationDataset given paths and fields." 277 | if src_img_dir: 278 | self.type_ = "img" 279 | else: 280 | self.type_ = "text" 281 | 282 | examples = [] 283 | src_words = [] 284 | self.src_vocabs = [] 285 | for i, src_line in enumerate(src_path): 286 | src_line = src_line.split() 287 | # if len(src_line) == 0: 288 | # skip[i] = True 289 | # continue 290 | if self.type_ == "text": 291 | # Check truncation condition. 292 | if opt is not None and opt.src_seq_length_trunc != 0: 293 | src_line = src_line[:opt.src_seq_length_trunc] 294 | src, src_feats, _ = extractFeatures(src_line) 295 | d = {"src": src, "indices": i} 296 | self.nfeatures = len(src_feats) 297 | for j, v in enumerate(src_feats): 298 | d["src_feat_"+str(j)] = v 299 | examples.append(d) 300 | src_words.append(src) 301 | 302 | # Create dynamic dictionaries 303 | if opt is None or opt.dynamic_dict: 304 | # a temp vocab of a single source example 305 | src_vocab = torchtext.vocab.Vocab(Counter(src)) 306 | 307 | # mapping source tokens to indices in the dynamic dict 308 | src_map = torch.LongTensor(len(src)).fill_(0) 309 | for j, w in enumerate(src): 310 | src_map[j] = src_vocab.stoi[w] 311 | 312 | self.src_vocabs.append(src_vocab) 313 | examples[i]["src_map"] = src_map 314 | 315 | if tgt_path is not None: 316 | for i, tgt_line in enumerate(tgt_path): 317 | # if i in skip: 318 | # continue 319 | tgt_line = tgt_line.split() 320 | 321 | # Check truncation condition. 322 | if opt is not None and opt.tgt_seq_length_trunc != 0: 323 | tgt_line = tgt_line[:opt.tgt_seq_length_trunc] 324 | 325 | tgt, _, _ = extractFeatures(tgt_line) 326 | examples[i]["tgt"] = tgt 327 | 328 | if opt is None or opt.dynamic_dict: 329 | src_vocab = self.src_vocabs[i] 330 | # Map target tokens to indices in the dynamic dict 331 | mask = torch.LongTensor(len(tgt)+2).fill_(0) 332 | for j in range(len(tgt)): 333 | mask[j+1] = src_vocab.stoi[tgt[j]] 334 | examples[i]["alignment"] = mask 335 | assert i + 1 == len(examples), "Len src and tgt do not match" 336 | keys = examples[0].keys() 337 | fields = [(k, fields[k]) for k in keys] 338 | examples = list([torchtext.data.Example.fromlist([ex[k] for k in keys], 339 | fields) 340 | for ex in examples]) 341 | 342 | def filter_pred(example): 343 | return 0 < len(example.src) <= opt.src_seq_length \ 344 | and 0 < len(example.tgt) <= opt.tgt_seq_length 345 | 346 | super(ONMTDataset2, self).__init__(examples, fields, 347 | filter_pred if opt is not None 348 | else None) 349 | 350 | def __getstate__(self): 351 | return self.__dict__ 352 | 353 | def __setstate__(self, d): 354 | self.__dict__.update(d) 355 | 356 | def __reduce_ex__(self, proto): 357 | "This is a hack. Something is broken with torch pickle." 358 | return super(ONMTDataset2, self).__reduce_ex__() 359 | 360 | def collapseCopyScores(self, scores, batch, tgt_vocab): 361 | """Given scores from an expanded dictionary 362 | corresponeding to a batch, sums together copies, 363 | with a dictionary word when it is ambigious. 364 | """ 365 | offset = len(tgt_vocab) 366 | for b in range(batch.batch_size): 367 | index = batch.indices.data[b] 368 | src_vocab = self.src_vocabs[index] 369 | for i in range(1, len(src_vocab)): 370 | sw = src_vocab.itos[i] 371 | ti = tgt_vocab.stoi[sw] 372 | if ti != 0: 373 | scores[:, b, ti] += scores[:, b, offset + i] 374 | scores[:, b, offset + i].fill_(1e-20) 375 | return scores 376 | 377 | @staticmethod 378 | def load_fields(vocab): 379 | vocab = dict(vocab) 380 | fields = ONMTDataset2.get_fields( 381 | len(ONMTDataset2.collect_features(vocab))) 382 | for k, v in vocab.items(): 383 | # Hack. Can't pickle defaultdict :( 384 | v.stoi = defaultdict(lambda: 0, v.stoi) 385 | fields[k].vocab = v 386 | return fields 387 | 388 | @staticmethod 389 | def save_vocab(fields): 390 | vocab = [] 391 | for k, f in fields.items(): 392 | if 'vocab' in f.__dict__: 393 | f.vocab.stoi = dict(f.vocab.stoi) 394 | vocab.append((k, f.vocab)) 395 | return vocab 396 | 397 | @staticmethod 398 | def collect_features(fields): 399 | feats = [] 400 | j = 0 401 | while True: 402 | key = "src_feat_" + str(j) 403 | if key not in fields: 404 | break 405 | feats.append(key) 406 | j += 1 407 | return feats 408 | 409 | @staticmethod 410 | def get_fields(nFeatures=0): 411 | fields = {} 412 | fields["src"] = torchtext.data.Field( 413 | pad_token=PAD_WORD, 414 | include_lengths=True) 415 | 416 | # fields = [("src_img", torchtext.data.Field( 417 | # include_lengths=True))] 418 | 419 | for j in range(nFeatures): 420 | fields["src_feat_"+str(j)] = \ 421 | torchtext.data.Field(pad_token=PAD_WORD) 422 | 423 | fields["tgt"] = torchtext.data.Field( 424 | init_token=BOS_WORD, eos_token=EOS_WORD, 425 | pad_token=PAD_WORD) 426 | 427 | def make_src(data, _): 428 | src_size = max([t.size(0) for t in data]) 429 | src_vocab_size = max([t.max() for t in data]) + 1 430 | alignment = torch.FloatTensor(src_size, len(data), 431 | src_vocab_size).fill_(0) 432 | for i in range(len(data)): 433 | for j, t in enumerate(data[i]): 434 | alignment[j, i, t] = 1 435 | return alignment 436 | 437 | fields["src_map"] = torchtext.data.Field( 438 | use_vocab=False, tensor_type=torch.FloatTensor, 439 | postprocessing=make_src, sequential=False) 440 | 441 | def make_tgt(data, _): 442 | tgt_size = max([t.size(0) for t in data]) 443 | alignment = torch.LongTensor(tgt_size, len(data)).fill_(0) 444 | for i in range(len(data)): 445 | alignment[:data[i].size(0), i] = data[i] 446 | return alignment 447 | 448 | fields["alignment"] = torchtext.data.Field( 449 | use_vocab=False, tensor_type=torch.LongTensor, 450 | postprocessing=make_tgt, sequential=False) 451 | 452 | fields["indices"] = torchtext.data.Field( 453 | use_vocab=False, tensor_type=torch.LongTensor, 454 | sequential=False) 455 | 456 | return fields 457 | 458 | @staticmethod 459 | def build_vocab(train, opt): 460 | fields = train.fields 461 | fields["src"].build_vocab(train, max_size=opt.src_vocab_size, 462 | min_freq=opt.src_words_min_frequency) 463 | for j in range(train.nfeatures): 464 | fields["src_feat_" + str(j)].build_vocab(train) 465 | fields["tgt"].build_vocab(train, max_size=opt.tgt_vocab_size, 466 | min_freq=opt.tgt_words_min_frequency) 467 | 468 | # Merge the input and output vocabularies. 469 | if opt.share_vocab: 470 | # `tgt_vocab_size` is ignored when sharing vocabularies 471 | merged_vocab = merge_vocabs( 472 | [fields["src"].vocab, fields["tgt"].vocab], 473 | vocab_size=opt.src_vocab_size) 474 | fields["src"].vocab = merged_vocab 475 | fields["tgt"].vocab = merged_vocab 476 | -------------------------------------------------------------------------------- /seada/sea/paraphrase_scorer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import os 4 | import copy 5 | import numpy as np 6 | from . import onmt_model 7 | import onmt 8 | import collections 9 | import operator 10 | import editdistance 11 | import sys 12 | import itertools 13 | 14 | PYTHON3 = sys.version_info > (3, 0) 15 | if PYTHON3: 16 | from itertools import zip_longest as zip_longest 17 | else: 18 | from itertools import izip_longest as zip_longest 19 | 20 | # DEFAULT_TO_PATHS = ['/home/marcotcr/OpenNMT-py/trained_models/english_french_model_acc_70.61_ppl_3.73_e13.pt', '/home/marcotcr/OpenNMT-py/trained_models/english_german_model_acc_58.34_ppl_7.82_e13.pt', '/home/marcotcr/OpenNMT-py/trained_models/english_portuguese_model_acc_70.90_ppl_4.28_e13.pt'] 21 | # DEFAULT_BACK_PATHS = ['/home/marcotcr/OpenNMT-py/trained_models/french_english_model_acc_68.83_ppl_4.43_e13.pt', '/home/marcotcr/OpenNMT-py/trained_models/german_english_model_acc_57.23_ppl_10.00_e13.pt', '/home/marcotcr/OpenNMT-py/trained_models/portuguese_english_model_acc_69.78_ppl_5.05_e13.pt'] 22 | DEFAULT_TO_PATHS = ['seada/sea/translation_models/english_french_model_acc_71.05_ppl_3.71_e13.pt', 23 | 'seada/sea/translation_models/english_portuguese_model_acc_70.75_ppl_4.32_e13.pt'] 24 | DEFAULT_BACK_PATHS = ['seada/sea/translation_models/french_english_model_acc_68.51_ppl_4.43_e13.pt', 25 | 'seada/sea/translation_models/portuguese_english_model_acc_69.93_ppl_5.04_e13.pt'] 26 | 27 | def choose_forward_translation(sentence, to_translator, back_translator, n=5): 28 | # chooses the to_translation that gives the best back_score to 29 | # sentence given back_translation 30 | translations = to_translator.translate([sentence], n_best=n, 31 | return_from_mapping=True)[0] 32 | mappings = [x[2] for x in translations if x[0]] 33 | translations = [x[0] for x in translations if x[0]] 34 | # translations = [x[0] for x in 35 | # to_translator.translate([sentence], n_best=n)[0] if x[0]] 36 | scores = [back_translator.score(x, [sentence])[0] for x in translations] 37 | return translations[np.argmax(scores)], mappings[np.argmax(scores)] 38 | 39 | def normalize_ll(x): 40 | # normalizes vector of log likelihoods 41 | max_ = x.max() 42 | b = np.exp(x - max_) 43 | return b / b.sum() 44 | 45 | def largest_indices(ary, n): 46 | """Returns the n largest indices from a numpy array.""" 47 | flat = ary.flatten() 48 | if n > flat.shape[0]: 49 | indices = np.array(range(flat.shape[0]), dtype='int') 50 | return np.unravel_index(indices, ary.shape) 51 | indices = np.argpartition(flat, -n)[-n:] 52 | indices = indices[np.argsort(-flat[indices])] 53 | return np.unravel_index(indices, ary.shape) 54 | 55 | 56 | class ParaphraseScorer(object): 57 | def __init__(self, 58 | to_paths=DEFAULT_TO_PATHS, 59 | back_paths=DEFAULT_BACK_PATHS, 60 | gpu_id=1): 61 | print('GPU ID', gpu_id) 62 | self.to_translators = [] 63 | # self.to_scorers = [] 64 | self.back_translators = [] 65 | for f in to_paths: 66 | translator = onmt_model.OnmtModel(f, gpu_id) 67 | self.to_translators.append(translator) 68 | # self.to_scorers.append(translator) 69 | for f in back_paths: 70 | translator = onmt_model.OnmtModel(f, gpu_id) 71 | self.back_translators.append(translator) 72 | self.build_common_vocabs() 73 | self.last = None 74 | 75 | def build_common_vocabs(self): 76 | self.global_itos = [] 77 | self.global_stoi = {} 78 | self.vocab_mappers = [] 79 | self.back_vocab_mappers = [] 80 | self.vocab_unks = [] 81 | back_vocab_mappers = [] 82 | for t in self.back_translators: 83 | vocab = t.translator.fields['tgt'].vocab 84 | mapper = [] 85 | back_mapper = {} 86 | for i, w in enumerate(vocab.itos): 87 | if w not in self.global_stoi: 88 | self.global_stoi[w] = len(self.global_stoi) 89 | self.global_itos.append(w) 90 | mapper.append(self.global_stoi[w]) 91 | back_mapper[self.global_stoi[w]] = i 92 | self.vocab_mappers.append(np.array(mapper)) 93 | back_vocab_mappers.append(back_mapper) 94 | for t, m, back_mapper in zip(self.back_translators, self.vocab_mappers, 95 | back_vocab_mappers): 96 | unks = np.array( 97 | list(set(range(len(self.global_itos))).difference(m))) 98 | for u in unks: 99 | back_mapper[u] = onmt.IO.UNK 100 | bm = np.zeros(len(self.global_itos), dtype=int) 101 | for b, v in back_mapper.items(): 102 | bm[b] = v 103 | self.back_vocab_mappers.append(bm) 104 | self.vocab_unks.append(unks) 105 | 106 | 107 | def nearby_distribution(self, sentence, weight_by_edit_distance=False, **kwargs): 108 | paraphrases = self.generate_paraphrases(sentence, **kwargs) 109 | if not paraphrases: 110 | return paraphrases 111 | others = [x[0] for x in paraphrases] 112 | n_scores = normalize_ll(np.array([x[1] for x in paraphrases])) 113 | if weight_by_edit_distance: 114 | return self.weight_by_edit_distance(sentence, list(zip(others, n_scores))) 115 | return sorted(zip(others, n_scores), key=lambda x:x[1], reverse=True) 116 | 117 | def weight_by_edit_distance(self, sentence, distribution): 118 | # Distribution is a list of (text, weight) tuples, unnormalized 119 | others = [x[0] for x in distribution] 120 | n_scores = np.array([x[1] for x in distribution]) 121 | import editdistance 122 | orig = onmt_model.clean_text(sentence) 123 | orig_score = self.score_sentences(orig, [orig])[0] 124 | orig = orig.split() 125 | # print(orig_score) 126 | n_scores = np.minimum(0, n_scores - orig_score) 127 | # n_scores = n_scores - orig_score 128 | distances = np.array([editdistance.eval(orig, x.split()) for x in others]) 129 | 130 | logkernel = lambda d, k: .5 * -(d**2) / (k**2) 131 | 132 | # This is equivalent to multiplying the prediction probability by the exponential kernel on the distance 133 | n_scores = n_scores + logkernel(distances, 3) 134 | # zeros = np.where(distances == 0)[0] 135 | # print(n_scores) 136 | # print(distances[np.argsort(distances)]) 137 | # print(distances) 138 | # print(n_scores) 139 | # w = np.log2(distances + 1) 140 | # w = (distances + 1) ** (1./2) 141 | # print(n_scores.argmax()) 142 | # print(w) 143 | # n_scores = w * n_scores 144 | # This is equivalent to dividing the predict proba by the distance + 1 145 | # n_scores = n_scores - np.log(distances + 1) 146 | # print(distances[n_scores.argmax()]) 147 | # n_scores[zeros] = -99999999 148 | # print(distances[n_scores.argmax()]) 149 | 150 | n_scores = normalize_ll(n_scores) 151 | # print(n_scores.argmax()) 152 | 153 | # TODO: achar funcao pra weight by edit distance 154 | # n_scores = n_scores / (distances + 1) 155 | # n_scores = n_scores / n_scores.sum() 156 | return sorted(zip(others, n_scores), key=lambda x: x[1], reverse=True) 157 | 158 | def suggest_next(self, words, idx, in_between=False, run_through=False, topk=10, threshold=None, 159 | original_sentence=None, only_local_score=False): 160 | # TODO: This is outdated 161 | 162 | to_add = -10000 163 | memoized_stuff = self.last == original_sentence and original_sentence is not None 164 | # print('suggest_next', words[:idx], memoized_stuff) 165 | if not memoized_stuff: 166 | self.last = original_sentence 167 | self.memoized = {} 168 | self.memoized['translation'] = [] 169 | sentence = (' '.join(words) if original_sentence is None 170 | else original_sentence) 171 | words_after = words[idx:] if in_between else words[idx + 1:] 172 | words = words[:idx] 173 | print(words) 174 | print(words_after) 175 | orig_ids = np.array([self.global_stoi[onmt.IO.BOS_WORD]] + [self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in words]) 176 | global_scores = np.zeros(len(self.global_itos)) 177 | last_scores = np.zeros(len(self.global_itos)) 178 | unk_scores = [] 179 | if threshold: 180 | threshold *= len(self.back_translators) 181 | attns = [] 182 | src_examples = [] 183 | dec_states = [] 184 | enc_states = [] 185 | contexts = [] 186 | mappings = [] 187 | for k, (to, back, mapper, back_mapper, unks) in enumerate( 188 | zip(self.to_translators, self.back_translators, 189 | self.vocab_mappers, self.back_vocab_mappers, 190 | self.vocab_unks)): 191 | if memoized_stuff: 192 | translation, mapping = self.memoized['translation'][k] 193 | mappings.append(mapping) 194 | else: 195 | translation, mapping = choose_forward_translation(sentence, to, back, 196 | n=5) 197 | mappings.append(mapping) 198 | self.memoized['translation'].append((translation, mapping)) 199 | encStates, context, decStates, src_example = ( 200 | back.get_init_states(translation)) 201 | src_examples.append(src_example) 202 | # print() 203 | # print(k) 204 | a = 0 205 | for i, n in zip(orig_ids, orig_ids[1:]): 206 | idx = int(back_mapper[i]) 207 | n = int(back_mapper[n]) 208 | out, decStates, attn = back.advance_states(encStates, context, 209 | decStates, [idx], [1]) 210 | attenz = attn['std'].data[0].cpu().numpy() 211 | chosen = np.argmax(attenz, axis=1) 212 | for r, ch in enumerate(chosen): 213 | ch = mapping[ch] 214 | if ch in back.vocab().stoi: 215 | # print("YOO") 216 | ind = back.vocab().stoi[ch] 217 | # print("prev", out[r, ind]) 218 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 219 | # print("aft", out[r, ind]) 220 | elif ch in self.global_stoi: 221 | # print(ch) 222 | ind = self.global_stoi[ch] 223 | global_scores[ind] -= to_add 224 | global_scores[mapper] += out[0][n] 225 | a += out[0][n] 226 | # print(n, out[0][n]) 227 | if unks.shape[0]: 228 | global_scores[unks] += to_add + out[0, n] 229 | # print(np.argsort(out[0])[-5:]) 230 | # print( 'g1', global_scores[63441]) 231 | print('a', a) 232 | idx = int(back_mapper[orig_ids[-1]]) 233 | out, decStates, attn = back.advance_states(encStates, context, 234 | decStates, [idx], [1]) 235 | attenz = attn['std'].data[0].cpu().numpy() 236 | chosen = np.argmax(attenz, axis=1) 237 | for r, ch in enumerate(chosen): 238 | ch = mapping[ch] 239 | if ch in back.vocab().stoi: 240 | # print("YOO") 241 | ind = back.vocab().stoi[ch] 242 | # print("prev", out[r, ind]) 243 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 244 | # print("aft", out[r, ind]) 245 | elif ch in self.global_stoi: 246 | # print(ch) 247 | ind = self.global_stoi[ch] 248 | global_scores[ind] -= to_add 249 | last_scores[ind] -= to_add 250 | if unks.shape[0]: 251 | global_scores[unks] += to_add + out[0, onmt.IO.UNK] 252 | last_scores[unks] += to_add + out[0, onmt.IO.UNK] 253 | global_scores[mapper] += out[0] 254 | last_scores[mapper] += out[0] 255 | unk_scores.append(out[0, onmt.IO.UNK]) 256 | attns.append(attn) 257 | dec_states.append(decStates) 258 | contexts.append(context) 259 | enc_states.append(encStates) 260 | # print( 'g2', global_scores[63441]) 261 | unk_scores = normalize_ll(np.array(unk_scores)) 262 | new_unk_scores = collections.defaultdict(lambda: 0) 263 | for x, src, mapping, score_weight in zip(attns, src_examples, mappings, unk_scores): 264 | # print(src) 265 | attn = x['std'].data[0][0] 266 | # TODO: Should we only allow unks here? We are 267 | # currently weighting based on the original score, but 268 | # this makes it so one always chooses the unk. 269 | for zidx, (word, score) in enumerate(zip(src, attn)): 270 | word = mapping[zidx] 271 | new_unk_scores[word] += score * score_weight 272 | # print(sorted(new_unk_scores.items(), key=lambda x:x[1], reverse=True)) 273 | new_unk = max(new_unk_scores.items(), 274 | key=operator.itemgetter(1))[0] 275 | # if new_unk in self.global_stoi: 276 | # global_scores[onmt.IO.UNK] = global_scores[self.global_stoi[new_unk]] 277 | # last_scores[onmt.IO.UNK] = last_scores[self.global_stoi[new_unk]] 278 | picked = largest_indices(global_scores, topk)[0] 279 | # if new_unk in self.global_stoi and onmt.IO.UNK in picked: 280 | # picked[picked == onmt.IO.UNK] = self.global_stoi[new_unk] 281 | if threshold: 282 | to_keep = np.where(global_scores[picked] >= threshold)[0] 283 | to_delete = np.where(global_scores[picked] < threshold)[0] 284 | picked = picked[to_keep] 285 | print(picked.shape) 286 | topk = picked.shape[0] 287 | if not topk: 288 | return [] 289 | # global_scores = np.repeat(global_scores[np.newaxis, :], len(picked), 290 | # axis=0) 291 | # print(last_scores[picked]) 292 | # print( 'g', global_scores[picked]) 293 | if run_through: 294 | orig_ids = np.array([self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in words_after] + 295 | [self.global_stoi[onmt.IO.EOS_WORD]]) 296 | for to, back, mapper, back_mapper, encStates, context, decStates, mapping in zip( 297 | self.to_translators, self.back_translators, self.vocab_mappers, self.back_vocab_mappers, 298 | enc_states, contexts, dec_states, mappings): 299 | if not picked.shape[0]: 300 | break 301 | idx = [int(back_mapper[x]) for x in picked] 302 | # print(idx) 303 | # print(idx) 304 | # print([self.global_itos[x] for x in picked]) 305 | # idx = int(back_mapper[x]) 306 | n = int(back_mapper[orig_ids[0]]) 307 | # print(n, back.vocab().itos[n]) 308 | out, decStates, attn = back.advance_states(encStates, context, 309 | decStates, idx, [len(idx)]) 310 | 311 | attenz = attn['std'].data[0].cpu().numpy() 312 | chosen = np.argmax(attenz, axis=1) 313 | for r, ch in enumerate(chosen): 314 | ch = mapping[ch] 315 | if ch in back.vocab().stoi: 316 | # print("YOO") 317 | ind = back.vocab().stoi[ch] 318 | # print("prev", out[r, ind]) 319 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 320 | # print(n, out[:, n]) 321 | global_scores[picked] += out[:, n] 322 | sizes = [1 for _ in range(topk)] 323 | if threshold: 324 | to_keep = np.where(global_scores[picked] >= threshold)[0] 325 | to_delete = np.where(global_scores[picked] < threshold)[0] 326 | picked = picked[to_keep] 327 | print(picked.shape) 328 | for x in to_delete: 329 | sizes[x] = 0 330 | topk = picked.shape[0] 331 | if not topk: 332 | break 333 | # global_scores[back_mapper] += out[:, n] 334 | for i, next_ in zip(orig_ids, orig_ids[1:]): 335 | idx = [int(back_mapper[i]) for _ in range(topk)] 336 | n = int(back_mapper[next_]) 337 | out, decStates, attn = back.advance_states(encStates, context, 338 | decStates, idx, sizes) 339 | attenz = attn['std'].data[0].cpu().numpy() 340 | chosen = np.argmax(attenz, axis=1) 341 | for r, ch in enumerate(chosen): 342 | ch = mapping[ch] 343 | if ch in back.vocab().stoi: 344 | # print("YOO") 345 | ind = back.vocab().stoi[ch] 346 | # print("prev", out[r, ind]) 347 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 348 | # print("aft", out[r, ind]) 349 | # print(np.argsort(out[0])[-5:]) 350 | global_scores[picked] += out[:, n] 351 | # print(n, out[:, n]) 352 | sizes = [1 for _ in range(topk)] 353 | if threshold: 354 | to_keep = np.where(global_scores[picked] >= threshold)[0] 355 | to_delete = np.where(global_scores[picked] < threshold)[0] 356 | # print(picked.shape) 357 | picked = picked[to_keep] 358 | # print(picked.shape) 359 | topk = picked.shape[0] 360 | for x in to_delete: 361 | sizes[x] = 0 362 | # print(topk) 363 | if not topk: 364 | break 365 | global_scores /= len(self.back_translators) 366 | last_scores /= len(self.back_translators) 367 | # TODO: there may be duplicates because of new_Unk 368 | if only_local_score: 369 | ret = [(self.global_itos[z], last_scores[z]) if z != onmt.IO.UNK else (new_unk, last_scores[z]) for z in picked if self.global_itos[z] != onmt.IO.EOS_WORD] 370 | else: 371 | ret = [(self.global_itos[z], global_scores[z]) if z != onmt.IO.UNK else (new_unk, global_scores[z]) for z in picked if self.global_itos[z] != onmt.IO.EOS_WORD] 372 | return sorted(ret, key=lambda x: x[1], reverse=True) 373 | return global_scores, new_unk 374 | print() 375 | print(list(reversed([self.global_itos[x] for x in np.argsort(global_scores)[-100:]]))) 376 | pass 377 | def suggest_in_between(self, words, idxs_middle, topk=10, threshold=None, 378 | original_sentence=None, max_inserts=4, ignore_set=set(), 379 | return_full_texts=False, orig_score=0, verbose=False): 380 | # TODO: This is outdated 381 | 382 | run_through = True 383 | to_add = -10000 384 | memoized_stuff = self.last == original_sentence and original_sentence is not None 385 | # print('suggest_next', words[:idx], memoized_stuff) 386 | if not memoized_stuff: 387 | self.last = original_sentence 388 | self.memoized = {} 389 | self.memoized['translation'] = [] 390 | sentence = (' '.join(words) if original_sentence is None 391 | else original_sentence) 392 | words_after = words[idxs_middle[-1] + 1:] 393 | words_between = words[idxs_middle[0]:idxs_middle[1] + 1] 394 | words = words[:idxs_middle[0]] 395 | words_before = words 396 | # print(words) 397 | # print(words_between) 398 | # print(words_after) 399 | max_iters = max_inserts + idxs_middle[1] - idxs_middle[0] + 1 400 | out_scores = {} 401 | orig_ids = np.array([self.global_stoi[onmt.IO.BOS_WORD]] + [self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in words]) 402 | after_ids = np.array([self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in words_after] + 403 | [self.global_stoi[onmt.IO.EOS_WORD]]) 404 | mid_ids = np.array([self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in words_between]) 405 | unk_scores = [] 406 | if threshold: 407 | orig_threshold = threshold 408 | attns = [] 409 | src_examples = [] 410 | decoder_states = [] 411 | encoder_states = [] 412 | contexts = [] 413 | mappings = [] 414 | prev_scores = 0 415 | feed_original = 0 416 | in_between = 0 417 | mid_score = 0 418 | for k, (to, back, mapper, back_mapper, unks) in enumerate( 419 | zip(self.to_translators, self.back_translators, 420 | self.vocab_mappers, self.back_vocab_mappers, 421 | self.vocab_unks)): 422 | if memoized_stuff: 423 | translation, mapping = self.memoized['translation'][k] 424 | mappings.append(mapping) 425 | else: 426 | translation, mapping = choose_forward_translation(sentence, to, back, 427 | n=5) 428 | mappings.append(mapping) 429 | self.memoized['translation'].append((translation, mapping)) 430 | encStates, context, decStates, src_example = ( 431 | back.get_init_states(translation)) 432 | src_examples.append(src_example) 433 | # print() 434 | # Feed in the original input 435 | tz = time.time() 436 | for i, n in zip(orig_ids, orig_ids[1:]): 437 | idx = int(back_mapper[i]) 438 | n = int(back_mapper[n]) 439 | out, decStates, attn = back.advance_states(encStates, context, 440 | decStates, [idx], [1]) 441 | attenz = attn['std'].data[0].cpu().numpy() 442 | chosen = np.argmax(attenz, axis=1) 443 | for r, ch in enumerate(chosen): 444 | ch = mapping[ch] 445 | if ch in back.vocab().stoi: 446 | # print("YOO") 447 | ind = back.vocab().stoi[ch] 448 | # print("prev", out[r, ind]) 449 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 450 | # print("aft", out[r, ind]) 451 | prev_scores += out[0][n] 452 | mid_score += prev_scores 453 | feed_original += time.time() - tz 454 | decoder_states.append(decStates) 455 | contexts.append(context) 456 | encoder_states.append(encStates) 457 | # print("MID IDS", mid_ids) 458 | onmt_model.transform_dec_states(decStates, [1]) 459 | decStates = copy.deepcopy(decStates) 460 | for i, n in zip([orig_ids[-1]] + list(mid_ids), list(mid_ids) + [after_ids[0]]): 461 | # print('mid', i, n) 462 | idx = int(back_mapper[i]) 463 | n = int(back_mapper[n]) 464 | out, decStates, attn = back.advance_states(encStates, context, 465 | decStates, [idx], [1]) 466 | attenz = attn['std'].data[0].cpu().numpy() 467 | chosen = np.argmax(attenz, axis=1) 468 | for r, ch in enumerate(chosen): 469 | ch = mapping[ch] 470 | if ch in back.vocab().stoi: 471 | # print("YOO") 472 | ind = back.vocab().stoi[ch] 473 | # print("prev", out[r, ind]) 474 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 475 | # print("aft", out[r, ind]) 476 | mid_score += out[0][n] 477 | # print("INcreasing mid") 478 | prev = [[]] 479 | prev_scores = [prev_scores / float(len(self.back_translators))] 480 | mid_score = mid_score / float(len(self.back_translators)) 481 | if verbose: 482 | print('MID', mid_score) 483 | if threshold: 484 | threshold = mid_score + threshold 485 | # print(prev_scores) 486 | prev_unks = [[]] 487 | new_sizes = [1] 488 | idxs = [orig_ids[-1]] 489 | current_iter = 0 490 | # print(list(reversed([(self.global_itos[x], global_scores[0][x]) for x in np.argsort(global_scores[0])[-10:]]))) 491 | going_after = 0 492 | while prev and current_iter < max_iters + 1: 493 | if verbose: 494 | print('iter', current_iter, topk) 495 | current_iter += 1 496 | global_scores = np.zeros((len(prev), (len(self.global_itos)))) 497 | all_stuff = zip( 498 | self.back_translators, self.vocab_mappers, 499 | self.back_vocab_mappers, self.vocab_unks, contexts, 500 | decoder_states, encoder_states, src_examples, mappings) 501 | new_decoder_states = [] 502 | new_attns = [] 503 | unk_scores = [] 504 | tz = time.time() 505 | for (b, mapper, back_mapper, unks, context, 506 | decStates, encStates, srcz, mapping) in all_stuff: 507 | idx = [int(back_mapper[i]) for i in idxs] 508 | out, decStates, attn = b.advance_states( 509 | encStates, context, decStates, idx, new_sizes) 510 | new_decoder_states.append(decStates) 511 | new_attns.append(attn) 512 | attenz = attn['std'].data[0].cpu().numpy() 513 | chosen = np.argmax(attenz, axis=1) 514 | for r, ch in enumerate(chosen): 515 | ch = mapping[ch] 516 | if ch in b.vocab().stoi: 517 | ind = b.vocab().stoi[ch] 518 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 519 | elif ch in self.global_stoi: 520 | ind = self.global_stoi[ch] 521 | global_scores[r, ind] -= to_add 522 | unk_scores.append(out[:, onmt.IO.UNK]) 523 | global_scores[:, mapper] += out 524 | if unks.shape[0]: 525 | global_scores[:, unks] += to_add + out[:, onmt.IO.UNK][:, np.newaxis] 526 | decoder_states = new_decoder_states 527 | global_scores /= float(len(self.back_translators)) 528 | unk_scores = [normalize_ll(x) for x in np.array(unk_scores).T] 529 | 530 | new_prev = [] 531 | new_prev_unks = [] 532 | new_prev_scores = [] 533 | new_sizes = [] 534 | new_origins = [] 535 | idxs = [] 536 | new_scores = global_scores + np.array(prev_scores)[:, np.newaxis] 537 | # best = new_scores.max() 538 | # if threshold: 539 | # threshold = mid_score + orig_threshold 540 | # threshold = best + orig_threshold 541 | # threshold = orig_score + orig_threshold 542 | # print('best', best) 543 | # print('new thresh', threshold) 544 | # print(threshold == best + orig_threshold) 545 | if threshold: 546 | # print(threshold) 547 | where = np.where(new_scores > threshold) 548 | if topk: 549 | largest = largest_indices(new_scores[where], topk)[0] 550 | where = (where[0][largest], where[1][largest]) 551 | else: 552 | where = largest_indices(new_scores, topk) 553 | # print('best', new_scores[where[0][0], where[1][0]], new_scores.max()) 554 | tmp = np.argsort(where[0]) 555 | where = (where[0][tmp], where[1][tmp]) 556 | # print(where) 557 | new_this_round = [] 558 | new_origins_this_round = [] 559 | to_add = time.time() - tz 560 | in_between += time.time() - tz 561 | if verbose: 562 | print('in', to_add, in_between, threshold) 563 | print(where[0].shape) 564 | for i, j in zip(*where): 565 | if j == after_ids[0]: 566 | words = [self.global_itos[x] if x != onmt.IO.UNK 567 | else prev_unks[i][k] 568 | for k, x in enumerate(prev[i], start=0)] 569 | new_full = ' '.join(words_before + words + words_after) 570 | new = ' '.join(words) 571 | if return_full_texts: 572 | new = new_full 573 | if new_full in ignore_set: 574 | continue 575 | # return 576 | if new not in out_scores or new_scores[i, j] > out_scores[new]: 577 | out_scores[new] = new_scores[i, j] 578 | new_this_round.append(new) 579 | new_origins_this_round.append(i) 580 | # if topk: 581 | # topk -= 1 582 | continue 583 | if j == self.global_stoi[onmt.IO.EOS_WORD]: 584 | continue 585 | new_origins.append(i) 586 | new_unk = '' 587 | if j == onmt.IO.UNK: 588 | new_unk_scores = collections.defaultdict(lambda: 0) 589 | for x, src, mapping, score_weight in zip(new_attns, src_examples, mappings, unk_scores[i]): 590 | attn = x['std'].data[0][i] 591 | for zidx, (word, score) in enumerate(zip(src, attn)): 592 | word = mapping[zidx] 593 | new_unk_scores[word] += score * score_weight 594 | new_unk = max(new_unk_scores.items(), 595 | key=operator.itemgetter(1))[0] 596 | # print (' '.join(self.global_itos[x] for x in prev[i][1:])) 597 | new_prev.append(prev[i] + [j]) 598 | new_prev_unks.append(prev_unks[i] + [new_unk]) 599 | new_prev_scores.append(new_scores[i, j]) 600 | # print(i, j, new_scores[i,j]) 601 | idxs.append(j) 602 | # print('newog', new_origins_this_round) 603 | # print(new_sizes) 604 | # print('idxs') 605 | # print(idxs) 606 | # for i, p in enumerate(prev): 607 | # print(i, end= ' ') 608 | # print([self.global_itos[x] for x in p], end=' ') 609 | # print(list(reversed([(self.global_itos[x], new_scores[i][x]) for x in np.argsort(new_scores[i])[-10:]]))) 610 | new_sizes = np.bincount(new_origins, minlength=len(prev)) 611 | new_sizes = [int(x) for x in new_sizes] 612 | nsizes_this_round = np.bincount(new_origins_this_round, minlength=len(prev)) 613 | nsizes_this_round = [int(x) for x in nsizes_this_round] 614 | # global_scores = np.zeros((len(prev), (len(self.global_itos)))) 615 | zaaa = time.time() 616 | ndec_states = copy.deepcopy(decoder_states) 617 | all_stuff = zip( 618 | self.back_translators, self.vocab_mappers, 619 | self.back_vocab_mappers, self.vocab_unks, contexts, 620 | ndec_states, encoder_states, mappings) 621 | if len(new_this_round): 622 | # print(out_scores) 623 | for (b, mapper, back_mapper, unks, context, 624 | decStates, encStates, mapping) in all_stuff: 625 | nsizes = nsizes_this_round 626 | # print('new b') 627 | for i, next_ in zip(after_ids, after_ids[1:]): 628 | # print(self.global_itos[i], self.global_itos[next_]) 629 | idx = [int(back_mapper[i]) for _ in new_this_round] 630 | # print(len(nsizes_this_round)) 631 | # print(len(idx), sum(nsizes_this_round)) 632 | # print(nsizes) 633 | n = int(back_mapper[next_]) 634 | # decStates = copy.deepcopy(decStates) 635 | out, decStates, attn = b.advance_states( 636 | encStates, context, decStates, idx, nsizes) 637 | attenz = attn['std'].data[0].cpu().numpy() 638 | chosen = np.argmax(attenz, axis=1) 639 | for r, ch in enumerate(chosen): 640 | ch = mapping[ch] 641 | if ch in b.vocab().stoi: 642 | ind = b.vocab().stoi[ch] 643 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 644 | nsizes = [1 for _ in new_this_round] 645 | for r in range(out.shape[0]): 646 | out_scores[new_this_round[r]] += out[r, n] / float(len(self.back_translators)) 647 | # print('ae') 648 | # print(nsizes) 649 | going_after += time.time() - zaaa 650 | 651 | prev = new_prev 652 | prev_unks = new_prev_unks 653 | # print('prev', prev) 654 | prev_scores = new_prev_scores 655 | # print("HIHFSD", prev_scores[2]) 656 | 657 | # new_sizes = [] 658 | # idxs = [] 659 | # return [] 660 | if threshold: 661 | threshold = orig_threshold + orig_score 662 | if verbose: 663 | print('first ', feed_original ) 664 | print('between ', in_between) 665 | print('going after', going_after) 666 | print('total after', feed_original + in_between + going_after) 667 | # return [x for x in sorted(out_scores.items(), key=lambda x: x[1], reverse=True)] 668 | # threshold = -99999999 669 | return [x for x in sorted(out_scores.items(), key=lambda x: x[1], reverse=True) if x[1] > threshold] 670 | return [] 671 | key_order = list(out_scores.keys()) 672 | best = -9999999 673 | for dec_idx, (to, back, mapper, back_mapper, encStates, context, mapping) in enumerate(zip( 674 | self.to_translators, self.back_translators, self.vocab_mappers, 675 | self.back_vocab_mappers, encoder_states, contexts, mappings)): 676 | for i, next_ in zip(after_ids[1:], after_ids[2:]): 677 | idx = [int(back_mapper[i])] 678 | n = int(back_mapper[next_]) 679 | for key in key_order: 680 | decStates = out_dec_states[key][dec_idx] 681 | new_sizes = out_new_sizes[key] 682 | out, decStates, attn = back.advance_states(encStates, context, 683 | decStates, idx, new_sizes) 684 | attenz = attn['std'].data[0].cpu().numpy() 685 | chosen = np.argmax(attenz, axis=1) 686 | for r, ch in enumerate(chosen): 687 | ch = mapping[ch] 688 | if ch in back.vocab().stoi: 689 | # print("YOO") 690 | ind = back.vocab().stoi[ch] 691 | # print("prev", out[r, ind]) 692 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 693 | out_scores[key] += out[0, n] 694 | best = max(out[0, n], best) 695 | out_dec_states[key][dec_idx] = decStates 696 | if threshold: 697 | threshold = best + orig_threshold 698 | key_order = [k for k, v in out_scores if v > threshold] 699 | 700 | print(sorted(out_scores.items(), key=lambda x: x[1], reverse=True)) 701 | return [] 702 | 703 | if run_through: 704 | orig_ids = np.array([self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in words_after] + 705 | [self.global_stoi[onmt.IO.EOS_WORD]]) 706 | for to, back, mapper, back_mapper, encStates, context, decStates, mapping in zip( 707 | self.to_translators, self.back_translators, self.vocab_mappers, self.back_vocab_mappers, 708 | enc_states, contexts, dec_states, mappings): 709 | if not picked.shape[0]: 710 | break 711 | idx = [int(back_mapper[x]) for x in picked] 712 | # print(idx) 713 | # print(idx) 714 | # print([self.global_itos[x] for x in picked]) 715 | # idx = int(back_mapper[x]) 716 | n = int(back_mapper[orig_ids[0]]) 717 | # print(n, back.vocab().itos[n]) 718 | out, decStates, attn = back.advance_states(encStates, context, 719 | decStates, idx, [len(idx)]) 720 | 721 | attenz = attn['std'].data[0].cpu().numpy() 722 | chosen = np.argmax(attenz, axis=1) 723 | for r, ch in enumerate(chosen): 724 | ch = mapping[ch] 725 | if ch in back.vocab().stoi: 726 | # print("YOO") 727 | ind = back.vocab().stoi[ch] 728 | # print("prev", out[r, ind]) 729 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 730 | # print(n, out[:, n]) 731 | global_scores[picked] += out[:, n] 732 | sizes = [1 for _ in range(topk)] 733 | if threshold: 734 | to_keep = np.where(global_scores[picked] >= threshold)[0] 735 | to_delete = np.where(global_scores[picked] < threshold)[0] 736 | picked = picked[to_keep] 737 | print(picked.shape) 738 | for x in to_delete: 739 | sizes[x] = 0 740 | topk = picked.shape[0] 741 | if not topk: 742 | break 743 | # global_scores[back_mapper] += out[:, n] 744 | for i, next_ in zip(orig_ids, orig_ids[1:]): 745 | idx = [int(back_mapper[i]) for _ in range(topk)] 746 | n = int(back_mapper[next_]) 747 | out, decStates, attn = back.advance_states(encStates, context, 748 | decStates, idx, sizes) 749 | attenz = attn['std'].data[0].cpu().numpy() 750 | chosen = np.argmax(attenz, axis=1) 751 | for r, ch in enumerate(chosen): 752 | ch = mapping[ch] 753 | if ch in back.vocab().stoi: 754 | # print("YOO") 755 | ind = back.vocab().stoi[ch] 756 | # print("prev", out[r, ind]) 757 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 758 | # print("aft", out[r, ind]) 759 | # print(np.argsort(out[0])[-5:]) 760 | global_scores[picked] += out[:, n] 761 | # print(n, out[:, n]) 762 | sizes = [1 for _ in range(topk)] 763 | if threshold: 764 | to_keep = np.where(global_scores[picked] >= threshold)[0] 765 | to_delete = np.where(global_scores[picked] < threshold)[0] 766 | # print(picked.shape) 767 | picked = picked[to_keep] 768 | # print(picked.shape) 769 | topk = picked.shape[0] 770 | for x in to_delete: 771 | sizes[x] = 0 772 | # print(topk) 773 | if not topk: 774 | break 775 | global_scores /= len(self.back_translators) 776 | last_scores /= len(self.back_translators) 777 | # TODO: there may be duplicates because of new_Unk 778 | ret = [(self.global_itos[z], global_scores[z]) if z != onmt.IO.UNK else (new_unk, global_scores[z]) for z in picked if self.global_itos[z] != onmt.IO.EOS_WORD] 779 | return sorted(ret, key=lambda x: x[1], reverse=True) 780 | return global_scores, new_unk 781 | print() 782 | print(list(reversed([self.global_itos[x] for x in np.argsort(global_scores)[-100:]]))) 783 | pass 784 | def generate_paraphrases(self, sentence, topk=10, threshold=None, edit_distance_cutoff=None, penalize_unks=True, frequent_ngrams=None): 785 | # returns a list of (sentence, score). 786 | assert threshold or topk 787 | encoder_states = [] 788 | contexts = [] 789 | decoder_states = [] 790 | new_sizes = [] 791 | src_examples = [] 792 | PROFILING = False 793 | mappings = [] 794 | to_add = -10000 if penalize_unks else 0 795 | 796 | for to, back in zip(self.to_translators, self.back_translators): 797 | translation, mapping = choose_forward_translation(sentence, to, back, n=5) 798 | mappings.append(mapping) 799 | encStates, context, decStates, src_example = back.get_init_states(translation) 800 | src_examples.append(src_example) 801 | encoder_states.append(encStates) 802 | contexts.append(context) 803 | decoder_states.append(decStates) 804 | new_sizes.append([1]) 805 | orig_score = self.score_sentences(sentence, [sentence])[0] 806 | if threshold: 807 | threshold = threshold + orig_score 808 | 809 | # Always include original sentence in this todo: no!!! 810 | orig = onmt_model.clean_text(sentence) 811 | output = {} 812 | orig_ids = np.array([self.global_stoi[onmt.IO.BOS_WORD]] + [self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in orig.split()]) 813 | orig_words = [onmt.IO.BOS_WORD] + orig.split() 814 | orig_itoi = {} 815 | orig_stoi = {} 816 | for i, w in zip(orig_ids, orig_words): 817 | # if i not in orig_itoi: 818 | # idx = len(orig_itoi) 819 | # orig_itoi[i] = idx 820 | if w not in orig_stoi: 821 | idx = len(orig_stoi) 822 | orig_stoi[w] = idx 823 | if i not in orig_itoi: 824 | orig_itoi[i] = idx 825 | # print(sorted([(x, k) for x, k in orig_stoi.items()])) 826 | # print() 827 | # print(sorted([(self.global_itos[x], k) for x, k in orig_itoi.items()])) 828 | not_in_sentence = np.array( 829 | list(set(self.global_stoi.values()).difference( 830 | set(list(orig_itoi.keys()) + [onmt.IO.UNK])))) 831 | mapped_orig = [orig_stoi[x] for x in orig_words] 832 | if frequent_ngrams is not None: 833 | import difflib 834 | new_f = set() 835 | new_f.add(tuple()) 836 | for f, v in frequent_ngrams.items(): 837 | for t in v: 838 | new_f.add(tuple(sorted([orig_stoi[x] for x in t]))) 839 | prev = [[self.global_stoi[onmt.IO.BOS_WORD]]] 840 | prev_scores = [0] 841 | prev_distance_rep = [[orig_itoi[prev[0][0]]]] 842 | idxs = [self.global_stoi[onmt.IO.BOS_WORD]] 843 | new_sizes = [1] 844 | prev_unks = [['']] 845 | import time 846 | while prev: 847 | orig_time = time.time() 848 | global_scores = np.zeros((len(prev), (len(self.global_itos)))) 849 | # print(global_scores.shape) 850 | all_stuff = zip( 851 | self.back_translators, self.vocab_mappers, 852 | self.back_vocab_mappers, self.vocab_unks, contexts, 853 | decoder_states, encoder_states, src_examples, mappings) 854 | new_decoder_states = [] 855 | new_attns = [] 856 | unk_scores = [] 857 | # print() 858 | for (b, mapper, back_mapper, unks, context, 859 | decStates, encStates, srcz, mapping) in all_stuff: 860 | idx = [int(back_mapper[i]) for i in idxs] 861 | out, decStates, attn = b.advance_states( 862 | encStates, context, decStates, idx, new_sizes) 863 | # print(list(reversed([(b.vocab().itos[x], out[0, x]) for x in np.argsort(out[0])[-5:]]))) 864 | new_decoder_states.append(decStates) 865 | new_attns.append(attn) 866 | attenz = attn['std'].data[0].cpu().numpy() 867 | chosen = np.argmax(attenz, axis=1) 868 | for r, ch in enumerate(chosen): 869 | ch = mapping[ch] 870 | if ch in b.vocab().stoi: 871 | # print("YOO") 872 | ind = b.vocab().stoi[ch] 873 | # print("prev", out[r, ind]) 874 | out[r, ind] = max(out[r, ind], out[r, onmt.IO.UNK]) 875 | # print("aft", out[r, ind]) 876 | elif ch in self.global_stoi: 877 | # print(ch) 878 | ind = self.global_stoi[ch] 879 | global_scores[r, ind] -= to_add 880 | # if ch == 'giraffes': 881 | # print(ind in unks, 30027 in unks, ind ==30027) 882 | # break 883 | # print(list(reversed([(b.vocab().itos[x], out[0, x]) for x in np.argsort(out[0])[-5:]]))) 884 | # print('ya', [mapping[bb] for bb in chosen]) 885 | # print('ya', attenz) 886 | # print('ya', np.argmax(attenz, axis=1)) 887 | # print(out[:, onmt.IO.UNK]) 888 | # print("AEEE", 30027 in unks) 889 | unk_scores.append(out[:, onmt.IO.UNK]) 890 | global_scores[:, mapper] += out 891 | # print(global_scores[:, 30027]) 892 | if unks.shape[0]: 893 | # global_scores[:, unks] += out[:, onmt.IO.UNK][:, np.newaxis] 894 | global_scores[:, unks] += to_add + out[:, onmt.IO.UNK][:, np.newaxis] 895 | # print(global_scores[:, 30027]) 896 | # print() 897 | if PROFILING: 898 | print(time.time() - orig_time, 'decoding') 899 | orig_time = time.time() 900 | decoder_states = new_decoder_states 901 | global_scores /= float(len(self.back_translators)) 902 | # TODO: Is this right? 903 | unk_scores = [normalize_ll(x) for x in np.array(unk_scores).T] 904 | if PROFILING: 905 | print(time.time() - orig_time, 'normalizing unk scoers') 906 | orig_time = time.time() 907 | # print(unk_scores) 908 | # print(global_scores[0, 0], global_scores[0, 7109], global_scores.max()) 909 | # print(sorted([(self.global_itos[x], global_scores[0, x]) for x in np.argpartition(global_scores[0], -5)[-5:]], key=lambda x:x[1], reverse=True)) 910 | # break 911 | # for b, back_mapper in zip(self.back_translators, self.back_vocab_mappers): 912 | # print([(b.vocab().itos[back_mapper[x]], global_scores[x]) for x in np.argsort(global_scores)[-5:]]) 913 | 914 | new_prev = [] 915 | new_prev_distance_rep = [] 916 | new_prev_unks = [] 917 | new_prev_scores = [] 918 | new_sizes = [] 919 | new_origins = [] 920 | idxs = [] 921 | if PROFILING: 922 | print(time.time() - orig_time, 'before adding scores') 923 | orig_time = time.time() 924 | new_scores = global_scores + np.array(prev_scores)[:, np.newaxis] 925 | if PROFILING: 926 | print(time.time() - orig_time, 'adding scores') 927 | orig_time = time.time() 928 | # print(new_scores.shape) 929 | def get_possibles(opcodes): 930 | possibles = [tuple()] 931 | for tag, i1, i2, j1, j2 in opcodes: 932 | if tag == 'equal': 933 | continue 934 | if tag == 'insert': 935 | cha = range(j1, j2) 936 | if len(cha) == 2: 937 | possibles.append([i1 - 1]) 938 | possibles.append([i1]) 939 | if len(cha) > 2: 940 | possibles = [] 941 | break 942 | if len(cha) == 1: 943 | possibles.append([i1 - 1, i1]) 944 | if tag == 'replace': 945 | for i1, j1 in zip_longest(range(i1, i2), range(j1, j2)): 946 | if i1 is None: 947 | i1 = i2# - 1 948 | possibles.append([i1 - 1, i1]) 949 | elif j1 is None: 950 | possibles.append([i1]) 951 | else: 952 | possibles.append([i1]) 953 | if tag == 'delete': 954 | for i1 in range(i1, i2): 955 | possibles.append([i1]) 956 | if len(possibles) > 1: 957 | # print(possibles) 958 | possibles.pop(0) 959 | # print(possibles) 960 | return possibles 961 | if frequent_ngrams is not None: 962 | for i, p_rep in enumerate(prev_distance_rep): 963 | for idx, v in orig_itoi.items(): 964 | # I'm ignoring UNKs here and letting them be fixed in the next iteration 965 | if idx == onmt.IO.UNK: 966 | continue 967 | candidate = p_rep + [v] 968 | # import difflib 969 | a = difflib.SequenceMatcher(a = mapped_orig[:len(candidate)], b=candidate) 970 | possibles = get_possibles(a.get_opcodes()) 971 | if len(possibles) == 1 and possibles[0] == tuple(): 972 | continue 973 | if not np.any([x in new_f for x in itertools.product(*possibles)]): 974 | # if distance > edit_distance_cutoff: 975 | # pass 976 | # print (possibles) 977 | # print("not allowing", [orig_words[x] if x != -1 else 'unk' for x in candidate]) 978 | new_scores[i, idx] = -100000 979 | candidate = p_rep + [-1] 980 | a = difflib.SequenceMatcher(a = mapped_orig[:len(candidate)], b=candidate) 981 | possibles = get_possibles(a.get_opcodes()) 982 | if not np.any([x in new_f for x in itertools.product(*possibles)]): 983 | new_scores[i, not_in_sentence] = -10000 984 | if edit_distance_cutoff is not None: 985 | for i, p_rep in enumerate(prev_distance_rep): 986 | for idx, v in orig_itoi.items(): 987 | # I'm ignoring UNKs here and letting them be fixed in the next iteration 988 | if idx == onmt.IO.UNK: 989 | continue 990 | candidate = p_rep + [v] 991 | distance = editdistance.eval(candidate, mapped_orig[:len(candidate)]) 992 | 993 | if distance > edit_distance_cutoff: 994 | new_scores[i, idx] = -100000 995 | candidate = p_rep + [-1] 996 | distance = editdistance.eval(candidate, mapped_orig[:len(candidate)]) 997 | if distance > edit_distance_cutoff: 998 | new_scores[i, not_in_sentence] = -10000 999 | if PROFILING: 1000 | print(time.time() - orig_time, 'edit distance cutoff') 1001 | orig_time = time.time() 1002 | if threshold: 1003 | where = np.where(new_scores > threshold) 1004 | if PROFILING: 1005 | print(time.time() - orig_time, 'thresholding') 1006 | orig_time = time.time() 1007 | if topk: 1008 | # print(where) 1009 | # print(new_scores[where]) 1010 | # print(threshold) 1011 | # print(new_scores[where].shape) 1012 | largest = largest_indices(new_scores[where], topk)[0] 1013 | where = (where[0][largest], where[1][largest]) 1014 | # print(where) 1015 | # print(new_scores[where]) 1016 | # print(where) 1017 | else: 1018 | # print(new_scores.shape) 1019 | where = largest_indices(new_scores, topk) 1020 | 1021 | if PROFILING: 1022 | print(time.time() - orig_time, 'topk') 1023 | orig_time = time.time() 1024 | tmp = np.argsort(where[0]) 1025 | where = (where[0][tmp], where[1][tmp]) 1026 | # TODO: Is this right? 1027 | if (edit_distance_cutoff is not None and 1028 | len(prev[0]) < len(orig_ids) and 1029 | orig_ids[len(prev[0])] not in where[1][where[0] == 0]): 1030 | where = (np.hstack(([0], where[0])), 1031 | np.hstack(([orig_ids[len(prev[0])]], where[1]))) 1032 | # print(where[0].shape) 1033 | # print(where) 1034 | # Where needs to be sorted by i, since idxs must be in order of 1035 | # where stuff came from 1036 | for i, j in zip(*where): 1037 | if j == self.global_stoi[onmt.IO.EOS_WORD]: 1038 | words = [self.global_itos[x] if x != onmt.IO.UNK 1039 | else prev_unks[i][k] 1040 | for k, x in enumerate(prev[i][1:], start=1)] 1041 | new = ' '.join(words) 1042 | if new not in output: 1043 | output[new] = new_scores[i, j] 1044 | else: 1045 | output[new] = max(output[new], new_scores[i, j]) 1046 | # if topk: 1047 | # topk -= 1 1048 | continue 1049 | new_origins.append(i) 1050 | # print (' '.join(self.global_itos[x] for x in prev[i][1:] + [j])) 1051 | new_unk = '' 1052 | if j == onmt.IO.UNK: 1053 | # print(i, j, new_attns[0]['std'].data.shape) 1054 | new_unk_scores = collections.defaultdict(lambda: 0) 1055 | for x, src, mapping, score_weight in zip(new_attns, src_examples, mappings, unk_scores[i]): 1056 | # print(src) 1057 | attn = x['std'].data[0][i] 1058 | # TODO: Should we only allow unks here? We are 1059 | # currently weighting based on the original score, but 1060 | # this makes it so one always chooses the unk. 1061 | for zidx, (word, score) in enumerate(zip(src, attn)): 1062 | # if b.vocab().stoi[word] == onmt.IO.UNK: 1063 | word = mapping[zidx] 1064 | new_unk_scores[word] += score * score_weight 1065 | # print(word, score, score_weight ) 1066 | # print(sorted(new_unk_scores.items(), key=lambda x:x[1], reverse=True)) 1067 | new_unk = max(new_unk_scores.items(), 1068 | key=operator.itemgetter(1))[0] 1069 | # _, max_index = attn.max(0) 1070 | # # print(new_attns[0]['std'].data[0][i] + new_attns[1]['std'].data[0][i]) 1071 | # max_index = int(max_index[0]) 1072 | # # print(src_examples[0][max_index]) 1073 | # # print(len(src_examples[0])) 1074 | # new_unk = src_examples[0][max_index] 1075 | # print (' '.join(self.global_itos[x] for x in prev[i][1:]) + ' ' + new_j) 1076 | # new_unk = 'unk' 1077 | 1078 | if edit_distance_cutoff is not None: 1079 | distance_rep = orig_itoi[j] if j in orig_itoi else -1 1080 | if j == onmt.IO.UNK: 1081 | distance_rep = (orig_stoi[new_unk] if new_unk in orig_stoi 1082 | else -1) 1083 | # print('d', distance_rep) 1084 | new_prev_distance_rep.append(prev_distance_rep[i] + 1085 | [distance_rep]) 1086 | new_prev.append(prev[i] + [j]) 1087 | new_prev_unks.append(prev_unks[i] + [new_unk]) 1088 | new_prev_scores.append(new_scores[i, j]) 1089 | idxs.append(j) 1090 | # print(idxs) 1091 | 1092 | # print(mapped_orig[:len(prev[0])+ 1]) 1093 | # print(new_prev_distance_rep) 1094 | if PROFILING: 1095 | print(time.time() - orig_time, 'processing where') 1096 | orig_time = time.time() 1097 | new_sizes = np.bincount(new_origins, minlength=len(prev)) 1098 | new_sizes = [int(x) for x in new_sizes] 1099 | prev = new_prev 1100 | prev_unks = new_prev_unks 1101 | prev_distance_rep = new_prev_distance_rep 1102 | # print('prev', prev) 1103 | prev_scores = new_prev_scores 1104 | if topk and len(output) == topk: 1105 | break 1106 | # for z, unks, dr in zip(prev, prev_unks, prev_distance_rep): 1107 | for z, s, unks in zip(prev, prev_scores, prev_unks): 1108 | # import editdistance 1109 | # # TODO: Must ignore unks here - it's fine if I think it's an unk if the text is the same 1110 | # z = np.array(z) 1111 | # non_unk = np.where(z != onmt.IO.UNK)[0] 1112 | # # print(list(zip(z[non_unk], orig_ids[non_unk]))) 1113 | # distance = editdistance.eval(z[non_unk], orig_ids[non_unk]) 1114 | # bla = list(zip(*([(unks[i], orig_words[i]) for i in range(len(z)) if z[i] == onmt.IO.UNK]))) 1115 | # unk_distance = 0 1116 | # if bla: 1117 | # unks1, unks2 = bla 1118 | # unk_distance = editdistance.eval(unks1, unks2) 1119 | words = [self.global_itos[x] if x != onmt.IO.UNK else 'UNK'+ unks[k] for k, x in enumerate(z)] 1120 | # print(words, s, prev_unks) 1121 | # d2 = editdistance.eval(dr, mapped_orig[:len(dr)]) 1122 | # # print(words, distance, unk_distance, 'd2', d2, dr, mapped_orig[:len(dr)]) 1123 | # print() 1124 | if PROFILING: 1125 | print(time.time() - orig_time, 'rest') 1126 | orig_time = time.time() 1127 | 1128 | return sorted(output.items(), key=lambda x:x[1], reverse=True) 1129 | # print 'generate_paraphrases', n 1130 | # all_generated = [] 1131 | # for to, back in zip(self.to_translators, self.back_translators): 1132 | # translation = choose_forward_translation(sentence, to, back, n=5) 1133 | # all_generated.extend(back.translate([translation], n_best=n)[0]) 1134 | # all_generated = list(set([x[0].encode('ascii', 'ignore').decode() for x in all_generated if x[0]])) 1135 | # scores = self.score_sentences(sentence, all_generated) 1136 | # return sorted(zip(all_generated, scores), key=lambda x: x[1], 1137 | # reverse=True) 1138 | 1139 | def test_translators(self, sentence): 1140 | print('original:', sentence) 1141 | print() 1142 | for to, back in zip(self.to_translators, self.back_translators): 1143 | # translation = to.translate([sentence], n_best=1)[0][0][0] 1144 | translation, mapping = choose_forward_translation(sentence, to, back, n=5) 1145 | print(translation) 1146 | b = back.translate([translation], n_best=1)[0][0] 1147 | print(b) 1148 | print(sentence) 1149 | print('score_original:', back.score(translation, [sentence])) 1150 | print() 1151 | 1152 | def score_sentences(self, original_sentence, other_sentences, relative_to_original=False, verbose=False): 1153 | memoized_stuff = self.last == original_sentence 1154 | if relative_to_original: 1155 | other_sentences = [original_sentence] + other_sentences 1156 | all_scores = [] 1157 | if not memoized_stuff: 1158 | self.last = original_sentence 1159 | self.memoized = {} 1160 | self.memoized['translation'] = [] 1161 | if verbose: 1162 | score_to_print = [] 1163 | for k, (to, back, mapper, back_mapper, unks) in enumerate( 1164 | zip(self.to_translators, self.back_translators, 1165 | self.vocab_mappers, self.back_vocab_mappers, 1166 | self.vocab_unks)): 1167 | if memoized_stuff: 1168 | translation, mapping = self.memoized['translation'][k] 1169 | else: 1170 | translation, mapping = choose_forward_translation(original_sentence, to, back, 1171 | n=5) 1172 | self.memoized['translation'].append((translation, mapping)) 1173 | this_scores = [] 1174 | if verbose: 1175 | scorezz = [] 1176 | for s in other_sentences: 1177 | s = onmt_model.clean_text(s) 1178 | orig_ids = np.array([self.global_stoi[onmt.IO.BOS_WORD]] + [self.global_stoi[x] if x in self.global_stoi else onmt.IO.UNK for x in s.split()] + [self.global_stoi[onmt.IO.EOS_WORD]]) 1179 | score = 0. 1180 | encStates, context, decStates, src_example = ( 1181 | back.get_init_states(translation)) 1182 | for i, n in zip(orig_ids, orig_ids[1:]): 1183 | idx = int(back_mapper[i]) 1184 | n = int(back_mapper[n]) 1185 | out, decStates, attn = back.advance_states(encStates, context, 1186 | decStates, [idx], [1]) 1187 | attenz = attn['std'].data[0].cpu().numpy() 1188 | ch = np.argmax(attenz, axis=1)[0] 1189 | ch = mapping[ch] 1190 | if ch in back.vocab().stoi: 1191 | ind = back.vocab().stoi[ch] 1192 | out[0, ind] = max(out[0, ind], out[0, onmt.IO.UNK]) 1193 | score += out[0][n] 1194 | if verbose: 1195 | scorezz.append( (self.global_itos[n], out[0][n])) 1196 | this_scores.append(score) 1197 | if verbose: 1198 | scorezz.append(('\n', score)) 1199 | if verbose: 1200 | score_to_print.append(scorezz) 1201 | all_scores.append(this_scores) 1202 | scores = np.mean(all_scores, axis=0) 1203 | if verbose: 1204 | for z in zip(*score_to_print): 1205 | print('%-10s'% z[0][0], end=' ') 1206 | for a in z: 1207 | print('%.2f' % a[1], end=' ') 1208 | print() 1209 | if relative_to_original: 1210 | scores = (scores - scores[0])[1:] 1211 | return scores 1212 | def score_sentences_old(self, original_sentence, other_sentences, 1213 | relative_to_original=False): 1214 | # returns a numpy array of scores, one for each sentence in 1215 | # other_sentences 1216 | memoized_stuff = self.last == original_sentence 1217 | if not memoized_stuff: 1218 | self.last = original_sentence 1219 | self.memoized = {} 1220 | self.memoized['translation'] = [] 1221 | all_scores = [] 1222 | if relative_to_original: 1223 | other_sentences = [original_sentence] + other_sentences 1224 | for k, (to, back) in enumerate(zip(self.to_translators, self.back_translators)): 1225 | if memoized_stuff: 1226 | translation, mapping = self.memoized['translation'][k] 1227 | else: 1228 | translation, mapping = choose_forward_translation(original_sentence, to, back, 1229 | n=5) 1230 | self.memoized['translation'].append((translation, mapping)) 1231 | # if I want to pivot over multiple translations, this is how to do it: 1232 | # trs = to.translate([original_sentence], n_best=5)[0] 1233 | # translations = [x[0] for x in trs] 1234 | # weights = np.array([x[1] for x in trs]) 1235 | # # normalizing using exp-normalize trick 1236 | # weights = np.exp(weights - weights.max()) 1237 | # weights = weights / weights.sum() 1238 | # # print 'w',weights 1239 | # temp_scores = [] 1240 | # for t in translations: 1241 | # scores = back.score(t, other_sentences) 1242 | # temp_scores.append(scores) 1243 | # scores = np.average(temp_scores, axis=0, weights=weights) 1244 | 1245 | scores = back.score(translation, other_sentences) 1246 | all_scores.append(scores) 1247 | scores = np.mean(all_scores, axis=0) 1248 | if relative_to_original: 1249 | scores = (scores - scores[0])[1:] 1250 | return scores 1251 | 1252 | def weighted_scores(self, original_sentence, other_sentences, 1253 | relative_to_original=False): 1254 | scores = self.score_sentences(original_sentence, other_sentences) 1255 | self_scores = [] 1256 | for s in other_sentences: 1257 | self_scores.append(self.score_sentences(s, [s])[0]) 1258 | self_scores = np.array(self_scores) 1259 | elementwise_max = np.maximum(scores, self_scores) 1260 | n_scores = np.exp(scores - elementwise_max) 1261 | n_self_scores = np.exp(self_scores - elementwise_max) 1262 | return np.log(n_scores / (n_scores + n_self_scores)) 1263 | -------------------------------------------------------------------------------- /seada/sea/replace_rules.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import difflib 3 | import itertools 4 | import numpy as np 5 | import re 6 | import enum 7 | import collections 8 | import copy 9 | import sys 10 | 11 | PYTHON3 = sys.version_info > (3, 0) 12 | if PYTHON3: 13 | from itertools import zip_longest as zip_longest 14 | unicode = lambda x: x 15 | else: 16 | from itertools import izip_longest as zip_longest 17 | 18 | # add_after(x, a): z x b -> z x a b 19 | # remove_after(x, a): z x a -> z x 20 | # replace_after(x, y, r): a x y -> a x r 21 | # add before (x, a): z x -> z a x 22 | # remove_before(x, y): y x -> x 23 | # replace_before(x, y, r) : y x z -> r x z 24 | # replace (x, r): a x -> a r 25 | # remove(x): a x y -> a y 26 | # add between (x, y, a): z x y -> z x a y 27 | # remove_between(x, y, z): x z y a -> x y a 28 | # replace_between(x, y, z, r): x z y a -> x r y a 29 | 30 | def clean_text(text, only_upper=False): 31 | # should there be a str here?` 32 | text = '%s%s' % (text[0].upper(), text[1:]) 33 | if only_upper: 34 | return text 35 | text = text.replace('|', 'UNK') 36 | text = re.sub('(^|\s)-($|\s)', r'\1@-@\2', text) 37 | # text = re.sub(' (n?\'.) ', r'\1 ', text) 38 | # fix apostrophe stuff according to tokenizer 39 | text = re.sub(' (n)(\'.) ', r'\1 \2 ', text) 40 | return text 41 | 42 | 43 | def largest_indices(ary, n): 44 | """Returns the n largest indices from a numpy array.""" 45 | flat = ary.flatten() 46 | if n > flat.shape[0]: 47 | indices = np.array(range(flat.shape[0]), dtype='int') 48 | return np.unravel_index(indices, ary.shape) 49 | indices = np.argpartition(flat, -n)[-n:] 50 | indices = indices[np.argsort(-flat[indices])] 51 | return np.unravel_index(indices, ary.shape) 52 | 53 | class OpToken: 54 | def __init__(self, type_, value): 55 | self.type = type_ 56 | self.value = value 57 | 58 | def test(self, token): 59 | if self.type == 'text': 60 | return token.text == self.value 61 | if self.type == 'pos': 62 | return token.pos == self.value 63 | if self.type == 'tag': 64 | return token.tag == self.value 65 | 66 | def hash(self): 67 | return self.type + '_' + self.value 68 | 69 | Token = collections.namedtuple('Token', ['text', 'pos', 'tag']) 70 | 71 | def capitalize(text): 72 | if len(text) == 0: 73 | return text 74 | if len(text) == 1: 75 | return text.upper() 76 | else: 77 | return '%s%s' % (text[0].upper(), text[1:]) 78 | 79 | class Tokenizer: 80 | def __init__(self, nlp): 81 | self.nlp = nlp 82 | 83 | def tokenize(self, texts): 84 | ret = [] 85 | processed = self.nlp.pipe(texts) 86 | for text, p in zip(texts, processed): 87 | token_sequence = [Token(x.text, x.pos_, x.tag_) for x in p] 88 | ret.append(token_sequence) 89 | return ret 90 | def tokenize_text(self, texts): 91 | return [' '.join([a.text for a in x]) for x in self.nlp.tokenizer.pipe(texts)] 92 | 93 | def clean_for_model(self, texts): 94 | fn = lambda x: re.sub(r'\s+', ' ', re.sub(r'\s\'(\w{1, 3})', r"'\1", x).replace('@-@', '-').strip()) 95 | return self.tokenize_text([fn(capitalize(x)) for x in texts]) 96 | 97 | def clean_for_humans(self, texts): 98 | return [re.sub("\s(n')", r'\1', re.sub(r'\s\'(\w)', r"'\1", capitalize(x))) for x in texts] 99 | 100 | 101 | 102 | 103 | class ReplaceRule: 104 | def __init__(self, op_sequence, replace_sequence): 105 | self.op_sequence = op_sequence 106 | self.replace_sequence = replace_sequence 107 | 108 | def apply(self, token_sequence, status_only=False, return_position=False, fix_apostrophe=True): 109 | # Returns (status, [new_texts]), where status can be False (doesn't apply) 110 | token_sequence = [Token('', '', '')] + token_sequence + [Token('', '', '')] 111 | match_idx = 0 112 | size_seq = len(self.op_sequence) 113 | matched = -1 114 | matched_pos = collections.defaultdict(lambda: []) 115 | for i, t in enumerate(token_sequence): 116 | if self.op_sequence[match_idx].test(t): 117 | if self.op_sequence[match_idx].type == 'pos': 118 | matched_pos[t.pos].append(token_sequence[i].text) 119 | if self.op_sequence[match_idx].type == 'tag': 120 | matched_pos[t.tag].append(token_sequence[i].text) 121 | match_idx += 1 122 | if match_idx == size_seq: 123 | matched = i 124 | break 125 | else: 126 | match_idx = 0 127 | matched_pos = collections.defaultdict(lambda: []) 128 | status = matched > 0 129 | if status_only: 130 | return status 131 | if not status: 132 | if return_position: 133 | return status, '', -1 134 | return status, '' 135 | match_start = matched - size_seq + 1 136 | t_before = [x.text for x in token_sequence[1:match_start]] 137 | t_after = [x.text for x in token_sequence[matched + 1:-1]] 138 | t_mid = [] 139 | for x in self.replace_sequence: 140 | if x.type == 'text': 141 | t_mid.append(x.value) 142 | else: 143 | text = matched_pos[x.value].pop(0) 144 | t_mid.append(text) 145 | # t_mid = [x.text for x in self.replace_sequence] 146 | ret_text = ' '.join(t_before + t_mid + t_after) 147 | if fix_apostrophe: 148 | ret_text = ret_text.replace(' \'', '\'') 149 | if return_position: 150 | return True, ret_text, match_start - 1 151 | return True, ret_text 152 | 153 | def apply_to_texts(self, token_sequences, idxs_only=False, fix_apostrophe=True): 154 | # returns (idxs, new_texts), where 155 | # idxs is the indices where rule applies, and texts is the results text 156 | idxs = [] 157 | new_texts = [] 158 | for i, token_seq in enumerate(token_sequences): 159 | status, ntext = self.apply(token_seq, status_only=idxs_only, fix_apostrophe=fix_apostrophe) 160 | if status: 161 | idxs.append(i) 162 | new_texts.append(ntext) 163 | return np.array(idxs), new_texts 164 | 165 | def hash(self): 166 | return ' '.join([op.hash() for op in self.op_sequence]) + ' -> ' + ' '.join([op.hash() for op in self.replace_sequence]) 167 | 168 | class TextToReplaceRules: 169 | def __init__(self, nlp, from_dataset, flip_dataset=[], min_freq=.01, min_flip=0.01, ngram_size=4): 170 | if len(flip_dataset) != 0: 171 | assert len(from_dataset) == len(flip_dataset) 172 | self.tokenizer = Tokenizer(nlp) 173 | self.min_freq = min_freq * len(from_dataset) 174 | self.min_flip = min_flip * len(flip_dataset) 175 | self.ngram_size = ngram_size 176 | 177 | self.ngram_freq = collections.defaultdict(lambda: 0.) 178 | token_sequences = self.tokenizer.tokenize(from_dataset) 179 | ngram_idxs = collections.defaultdict(lambda: []) 180 | for i, s in enumerate(token_sequences): 181 | positions = self.get_positions(s, ngram_size) 182 | for p in positions: 183 | self.ngram_freq[p] += 1 184 | ngram_idxs[p].append(i) 185 | all_ngrams = list(self.ngram_freq.keys()) 186 | self.ngram_idxs = {} 187 | for ngram in all_ngrams: 188 | self.ngram_idxs[ngram] = set(ngram_idxs[ngram]) 189 | 190 | self.ngram_flip_freq = collections.defaultdict(lambda: 0.) 191 | for i, others in enumerate(flip_dataset): 192 | if i % 1000 == 0: 193 | print(i) 194 | token_sequences = self.tokenizer.tokenize(others) 195 | ngrams_flipped = set() 196 | for s in token_sequences: 197 | positions = self.get_positions(s, ngram_size) 198 | for p in positions: 199 | ngrams_flipped.add(p) 200 | for n in ngrams_flipped: 201 | self.ngram_flip_freq[n] += 1 202 | 203 | 204 | # def is_ngram_frequent(self, ngram): 205 | 206 | def is_param_ngram_frequent(self, ngram, flip=False): 207 | if type(ngram) != list: 208 | ngram = tuple([ngram.hash()]) 209 | else: 210 | ngram = tuple([x.hash() for x in ngram]) 211 | # TODO: This won't work for word groups 212 | if flip: 213 | return self.ngram_flip_freq[ngram] >= self.min_flip 214 | else: 215 | return self.ngram_freq[ngram] >= self.min_freq 216 | 217 | def get_rule_idxs(self, rule): 218 | ngram = rule.op_sequence 219 | ngram = tuple([x.hash() for x in ngram]) 220 | return self.ngram_idxs[ngram] 221 | 222 | 223 | def get_positions(self, tokenized_sentence, ngram_size): 224 | def get_params(token): 225 | return (OpToken('text', token.text), 226 | OpToken('pos', token.pos), 227 | OpToken('tag', token.tag)) 228 | 229 | positions = {} 230 | prev = Token('', '', '') 231 | for i, current in enumerate(tokenized_sentence + [Token('', '', '')]): 232 | for j in range(0, ngram_size): 233 | if i - j < 0: 234 | continue 235 | to_consider = tokenized_sentence[i - j:i + 1] 236 | tokens = [get_params(x) for x in to_consider] 237 | ngrams = [tuple([y.hash() for y in x]) for x in itertools.product(*tokens)] 238 | for ngram in ngrams: 239 | if ngram == tuple(): 240 | continue 241 | positions.setdefault(ngram, i) 242 | return positions 243 | 244 | 245 | def compute_rules(self, sentence, others, use_words=True, use_pos=True, use_tags=False, max_rule_length=3): 246 | # print(sentence) 247 | # print() 248 | # print('\n\n'.join(others)) 249 | # if require_all is false, assume rule independence 250 | def get_params(token): 251 | to_ret = [OpToken('text', token.text)] 252 | if use_pos: 253 | to_ret.append(OpToken('pos', token.pos)) 254 | if use_tags: 255 | to_ret.append(OpToken('tag', token.tag)) 256 | return tuple(to_ret) 257 | 258 | # sentence = clean_text(sentence, only_upper=False) 259 | # others = [clean_text(x) if x else x for x in others] 260 | doc = self.tokenizer.tokenize([unicode(sentence)])[0] 261 | other_docs = self.tokenizer.tokenize([unicode(x) for x in others]) 262 | # print([x.text for x in doc], [x.text for x in other_docs[0]]) 263 | # positions = self.get_positions(doc, self.ngram_size) 264 | 265 | # doc = self.nlp(sentence) 266 | sentence = [x.text for x in doc] 267 | # others = [x.split() for x in others] 268 | # other_docs = list(self.nlp.pipe(others)) 269 | others = [[x.text for x in d] for d in other_docs] 270 | # fns = [] 271 | # if use_words: 272 | # fns.append(self.word_rep_fn) 273 | # if use_pos: 274 | # fns.append(self.pos_rep_fn) 275 | all_rules = [] 276 | n_doc = [Token('', '', '')] + doc + [Token('', '', '')] 277 | for other, other_doc in zip(others, other_docs): 278 | n_other = [Token('', '', '')] + other_doc + [Token('', '', '')] 279 | matcher = difflib.SequenceMatcher(a=sentence, b=other) 280 | ops = ([x for x in matcher.get_opcodes() if x[0] != 'equal']) 281 | if len(ops) == 0: 282 | all_rules.append([]) 283 | continue 284 | start = ops[0][1] + 1 285 | end = ops[-1][2] + 1 286 | start_o = ops[0][3] + 1 287 | end_o = ops[-1][4] + 1 288 | reps = [n_doc[start:end], n_doc[start -1:end], n_doc[start: end + 1], n_doc[start - 1: end + 1]] 289 | withs = [ n_other[start_o: end_o], n_other[start_o - 1: end_o], n_other[start_o: end_o + 1], n_other[start_o - 1: end_o + 1]] 290 | # new = doc[:start - 1] + other_doc[start_o - 1:end_o - 1] + doc[end - 1:] 291 | # if ' '.join([x.text for x in new]) != ' '.join([x.text for x in other_doc]): 292 | # print 'ERROR' 293 | # quit() 294 | # print ' '.join(sentence), ' '.join(other) 295 | # print number_ops 296 | # for tag, i1, i2, j1, j2 in matcher.get_opcodes(): 297 | # si1, si2, sj1, sj2 = i1, i2, j1, j2 298 | # if tag == 'equal': 299 | # continue 300 | # # reps_prev = [fn(prev, is_start=is_start) for fn in fns] 301 | # print('{:7} a[{}:{}] --> b[{}:{}] {!r:>8} --> {!r}'.format(tag, si1, si2, sj1, sj2, sentence[si1:si2], other[sj1:sj2])) 302 | # print() 303 | def check_pos(ops1, ops2): 304 | counter = collections.Counter() 305 | for o in ops2: 306 | if o.type != 'text': 307 | counter[o.value] += 1 308 | for o in ops1: 309 | if o.type != 'text': 310 | counter[o.value] -= 1 311 | most_common = counter.most_common(1) 312 | if len(most_common) == 0 or most_common[0][1] <= 0: 313 | return True 314 | return False 315 | 316 | 317 | 318 | this_rules = [] 319 | other_sentence = ' '.join(other) 320 | for rep, withe in zip(reps, withs): 321 | if len(rep) > self.ngram_size or len(rep) == 0: 322 | continue 323 | tokens = [get_params(x) for x in rep] 324 | ngrams = [[y for y in x] for x in itertools.product(*tokens)] 325 | tokens_o = [get_params(x) if x in rep else (OpToken('text', x.text),) for x in withe] 326 | ngrams_o = [[y for y in x] for x in itertools.product(*tokens_o)] 327 | 328 | # print(ngrams) 329 | # print 330 | # print(ngrams_o) 331 | frequent = [x for x in ngrams if self.is_param_ngram_frequent(x)] 332 | # print 'frequent other' 333 | frequent_other = [x for x in ngrams_o if self.is_param_ngram_frequent(x, flip=True) or x == []] 334 | rules = [ReplaceRule(a, b) for a, b in (itertools.product(frequent, frequent_other)) if check_pos(a, b)] 335 | # if len(rules): 336 | # print() 337 | # print('rep(', ' '.join([x.text for x in rep]), ',', ' '.join([x.text for x in withe]), ')') 338 | # print('frequent') 339 | # print('\n'.join([r.hash() for r in rules])) 340 | # print("YO") 341 | # print(rules[0].apply(doc)[1], other_sentence) 342 | rules = [r for r in rules if r.apply(doc, fix_apostrophe=False)[1] == other_sentence] 343 | # print('\n'.join([r.hash() for r in rules])) 344 | # for r in rules: 345 | # if r.apply(doc)[1] != ' '.join(other): 346 | # print 'ERROR' 347 | # print r.hash() 348 | # print doc 349 | # print r.apply(doc)[1] 350 | # print ' '.join(other) 351 | # return r 352 | ngrams = [[y for y in x] for x in itertools.product(*tokens)] 353 | # for fr, to in zip(ng 354 | # print [list(y.hash() for y in x) for x in ngrams if self.is_param_ngram_frequent(x)] 355 | # print 356 | # print 357 | this_rules.extend(rules) 358 | all_rules.append(this_rules) 359 | return all_rules 360 | # a = difflib.SequenceMatcher(a=t1.split(), b=t2.split()) 361 | # for tag, i1, i2, j1, j2 in a.get_opcodes(): 362 | # if tag != 'equal': 363 | # print('{:7} a[{}:{}] --> b[{}:{}] {!r:>8} --> {!r}'.format(tag, i1, i2, j1, j2, t1.split()[i1:i2], t2.split()[j1:j2])) 364 | -------------------------------------------------------------------------------- /seada/sea/translation_models/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pth 4 | *.json 5 | *.jpg 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | cache/ 33 | 34 | # Code Editors 35 | .vscode 36 | .idea 37 | 38 | # Code linters 39 | .mypy_caches 40 | *.model 41 | *.tsv 42 | *.txt 43 | -------------------------------------------------------------------------------- /seada/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.transforms as transforms 8 | 9 | import config 10 | 11 | def process_answer(answer): 12 | """ 13 | follow Bilinear Attention Networks 14 | and https://github.com/hengyuan-hu/bottom-up-attention-vqa 15 | """ 16 | answer = answer.float() * 0.3 17 | answer = torch.clamp(answer, 0, 1) 18 | return answer 19 | 20 | def batch_accuracy(logits, labels): 21 | """ 22 | follow Bilinear Attention Networks https://github.com/jnhwkim/ban-vqa.git 23 | """ 24 | logits = torch.max(logits, 1)[1].data # argmax 25 | one_hots = torch.zeros(*labels.size()).cuda() 26 | one_hots.scatter_(1, logits.view(-1, 1), 1) 27 | scores = (one_hots * labels) 28 | 29 | return scores.sum(1), one_hots 30 | 31 | def calculate_loss(answer, pred, method): 32 | """ 33 | answer = [batch, 3129] 34 | pred = [batch, 3129] 35 | """ 36 | if method == 'binary_cross_entropy_with_logits': 37 | loss = F.binary_cross_entropy_with_logits(pred, answer) * config.max_answers 38 | elif method == 'soft_cross_entropy': 39 | nll = -F.log_softmax(pred, dim=1) 40 | loss = (nll * answer).sum(dim=1).mean() # this is worse than binary_cross_entropy_with_logits 41 | elif method == 'KL_divergence': 42 | pred = F.softmax(pred, dim=1) 43 | kl = ((answer / (pred + 1e-12)) + 1e-12).log() 44 | loss = (kl * answer).sum(1).mean() 45 | elif method == 'multi_label_soft_margin': 46 | loss = F.multilabel_soft_margin_loss(pred, answer) 47 | else: 48 | print('Error, pls define loss function') 49 | return loss 50 | 51 | 52 | def path_for(train=False, val=False, test=False, question=False, trainval=False, answer=False, vqacp=False, sea=False, eda=False, iq=False): 53 | assert train + val + test + trainval == 1 54 | assert question + answer == 1 55 | if not vqacp: 56 | if train: 57 | split = 'train2014' 58 | elif val: 59 | split = 'val2014' 60 | elif trainval: 61 | split = 'trainval2014' 62 | else: 63 | split = config.test_split 64 | 65 | if question: 66 | fmt = 'v2_{0}_{1}_{2}_questions.json' 67 | if sea: 68 | fmt = 'v2_{0}_{1}_{2}_questions_adv.json' 69 | if eda: 70 | fmt = 'v2_{0}_{1}_{2}_questions_eda.json' 71 | if iq: 72 | fmt = 'v2_{0}_{1}_{2}_questions_iq.json' 73 | else: 74 | if test: 75 | # just load validation data in the test=answer=True case, will be ignored anyway 76 | split = 'val2014' 77 | if eda: 78 | fmt = 'v2_{1}_{2}_annotations_eda.json' 79 | elif iq: 80 | fmt = 'v2_{1}_{2}_annotations_iq.json' 81 | else: 82 | fmt = 'v2_{1}_{2}_annotations.json' 83 | s = fmt.format(config.task, config.dataset, split) 84 | else: 85 | if train: 86 | split = 'train' 87 | elif val: 88 | split = 'test' 89 | else: 90 | raise ValueError 91 | 92 | if question: 93 | fmt = 'vqacp/vqacp_v2_{0}_questions.json' 94 | if sea: 95 | fmt = 'vqacp/vqacp_v2_{0}_questions_adv.json' 96 | if eda: 97 | fmt = 'vqacp/vqacp_v2_{0}_questions_eda.json' 98 | if iq: 99 | fmt = 'vqacp/vqacp_v2_{0}_questions_iq.json' 100 | else: 101 | if test: 102 | # just load validation data in the test=answer=True case, will be ignored anyway 103 | split = 'test' 104 | if eda: 105 | fmt = 'vqacp/vqacp_v2_{0}_annotations_eda.json' 106 | elif iq: 107 | fmt = 'vqacp/vqacp_v2_{0}_annotations_iq.json' 108 | else: 109 | fmt = 'vqacp/vqacp_v2_{0}_annotations.json' 110 | 111 | s = fmt.format(split) 112 | return os.path.join(config.qa_path, s) 113 | 114 | 115 | def print_lr(optimizer, prefix, epoch): 116 | all_rl = [] 117 | for p in optimizer.param_groups: 118 | all_rl.append(p['lr']) 119 | print('{} E{:03d}:'.format(prefix, epoch), ' Learning Rate: ', set(all_rl)) 120 | 121 | def set_lr(optimizer, value): 122 | for p in optimizer.param_groups: 123 | p['lr'] = value 124 | 125 | def decay_lr(optimizer, rate): 126 | for p in optimizer.param_groups: 127 | p['lr'] *= rate 128 | 129 | 130 | def print_grad(named_parameters): 131 | """ 132 | visualize grad 133 | """ 134 | 135 | total_norm = 0 136 | param_to_norm = {} 137 | param_to_shape = {} 138 | for n, p in named_parameters: 139 | if p.grad is not None: 140 | param_norm = p.grad.data.norm(2) 141 | total_norm += param_norm ** 2 142 | param_to_norm[n] = param_norm 143 | param_to_shape[n] = p.size() 144 | 145 | total_norm = total_norm ** (1. / 2) 146 | 147 | print('---Total norm {:.3f} -----------------'.format(total_norm)) 148 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]): 149 | print("{:<50s}: {:.3f}, ({})".format(name, norm, param_to_shape[name])) 150 | print('-------------------------------', flush=True) 151 | 152 | return total_norm 153 | 154 | 155 | def where(cond, x, y): 156 | """ 157 | code from : 158 | https://discuss.pytorch.org/t/how-can-i-do-the-operation-the-same-as-np-where/1329/8 159 | """ 160 | cond = cond.float() 161 | return (cond*x) + ((1-cond)*y) 162 | 163 | 164 | class Tracker: 165 | """ Keep track of results over time, while having access to monitors to display information about them. """ 166 | def __init__(self): 167 | self.data = {} 168 | 169 | def track(self, name, *monitors): 170 | """ Track a set of results with given monitors under some name (e.g. 'val_acc'). 171 | When appending to the returned list storage, use the monitors to retrieve useful information. 172 | """ 173 | l = Tracker.ListStorage(monitors) 174 | self.data.setdefault(name, []).append(l) 175 | return l 176 | 177 | def to_dict(self): 178 | # turn list storages into regular lists 179 | return {k: list(map(list, v)) for k, v in self.data.items()} 180 | 181 | 182 | class ListStorage: 183 | """ Storage of data points that updates the given monitors """ 184 | def __init__(self, monitors=[]): 185 | self.data = [] 186 | self.monitors = monitors 187 | for monitor in self.monitors: 188 | setattr(self, monitor.name, monitor) 189 | 190 | def append(self, item): 191 | for monitor in self.monitors: 192 | monitor.update(item) 193 | self.data.append(item) 194 | 195 | def __iter__(self): 196 | return iter(self.data) 197 | 198 | class MeanMonitor: 199 | """ Take the mean over the given values """ 200 | name = 'mean' 201 | 202 | def __init__(self): 203 | self.n = 0 204 | self.total = 0 205 | 206 | def update(self, value): 207 | self.total += value 208 | self.n += 1 209 | 210 | @property 211 | def value(self): 212 | return self.total / self.n 213 | 214 | class MovingMeanMonitor: 215 | """ Take an exponentially moving mean over the given values """ 216 | name = 'mean' 217 | 218 | def __init__(self, momentum=0.9): 219 | self.momentum = momentum 220 | self.first = True 221 | self.value = None 222 | 223 | def update(self, value): 224 | if self.first: 225 | self.value = value 226 | self.first = False 227 | else: 228 | m = self.momentum 229 | self.value = m * self.value + (1 - m) * value 230 | -------------------------------------------------------------------------------- /sort_para.py: -------------------------------------------------------------------------------- 1 | import json 2 | ori_q = json.load(open('data/v2_OpenEnded_mscoco_train2014_questions.json', 'r')) 3 | q_adv = json.load(open('data/v2_OpenEnded_mscoco_train2014_questions_adv07.json', 'r')) 4 | 5 | ques_dict = {} 6 | for q in q_adv['questions']: 7 | ques_dict[q['question_id']] = q 8 | sorted_q = [] 9 | for q in ori_q['questions']: 10 | if q['question_id'] in ques_dict.keys(): 11 | sorted_q.append(ques_dict[q['question_id']]) 12 | print('#'*10) 13 | print('Original: %s' % q['question']) 14 | print('Adv: %s' % ques_dict[q['question_id']]['question']) 15 | else: 16 | sorted_q.append(q) 17 | sorted_q = {'questions': sorted_q} 18 | with open('data/v2_OpenEnded_mscoco_train2014_questions_adv7.json', 'w') as f: 19 | json.dump(sorted_q, f) 20 | 21 | print(len(sorted_q['questions'])) 22 | -------------------------------------------------------------------------------- /view-log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import matplotlib; matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def main(): 8 | path = sys.argv[1] 9 | results = torch.load(path) 10 | 11 | val_acc = torch.FloatTensor(results['tracker']['val_acc']) 12 | val_acc = val_acc.mean(dim=1).numpy() 13 | for i, v in enumerate(val_acc): 14 | print(i, v) 15 | 16 | plt.figure() 17 | plt.plot(val_acc) 18 | plt.savefig('val_acc.png') 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | --------------------------------------------------------------------------------