├── scRNN ├── README.md ├── model.py ├── corrector.py ├── utils.py └── train.py ├── docker ├── Dockerfile └── environment.yml ├── codalab ├── construct_cluster.sh ├── reconstruct_clusters.sh ├── construct_agglom_cluster.sh ├── typo_corrector.sh ├── no_attk_run_experiment.sh ├── run_experiment_intprm.sh └── run_experiment.sh ├── LICENSE ├── preprocess_tc.py ├── reconstruct_clusters.py ├── augmentor.py ├── utils.py ├── preprocess_vocab.py ├── edit_dist_utils.py ├── README.md ├── environment.yml ├── word_embedding_model_runners.py ├── recoverer.py ├── construct_clusters.py ├── attacks.py ├── utils_glue.py ├── agglom_clusters.py ├── run_glue.py └── transformers.py /scRNN/README.md: -------------------------------------------------------------------------------- 1 | All code for the scRNN was adapted from https://github.com/danishpruthi/Adversarial-Misspellings. 2 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM infrarift/ubuntu16-py36:latest 2 | 3 | WORKDIR /app 4 | 5 | # Create the environment 6 | COPY environment.yml . 7 | 8 | RUN conda env create -f environment.yml 9 | 10 | # Make RUN commands use the new environment: 11 | SHELL ["conda", "run", "-n", "atenv", "/bin/bash", "-c"] -------------------------------------------------------------------------------- /codalab/construct_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to construct connected component clusters 3 | 4 | # Activate at-env conda environment 5 | source activate erik-cert 6 | # Missing package? 7 | 8 | # Set required environment variables 9 | export CLUSTERER_PATH=$HOME/clusterers 10 | 11 | # Make directory CLUSTERER_PATH 12 | mkdir $CLUSTERER_PATH 13 | 14 | # Needed for CodaLab as roben is mounted one level down 15 | cd roben || exit 1 16 | 17 | # Make clusterer 18 | echo 'Constructing connected component clusters. Attack type: '$1 19 | python construct_clusters.py --vocab_size 100000 --perturb_type $1 --data_dir $HOME/data --output_dir $CLUSTERER_PATH 20 | -------------------------------------------------------------------------------- /codalab/reconstruct_clusters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to reconstruct agglomerative clusters 3 | 4 | # Activate at-env conda environment 5 | source activate atenv 6 | # Missing package? 7 | pip install query 8 | # conda list 9 | 10 | # Set required environment variables 11 | # Make directory where clusterers will be stored 12 | mkdir $HOME/clusterers 13 | 14 | # Needed for CodaLab as roben is mounted one level down 15 | cd roben || exit 1 16 | 17 | echo 'Reconstructing...' 18 | python reconstruct_clusters.py --save_path $HOME/clusterers/vocab100000_ed1_gamma0.3.pkl --file_paths $HOME/agglom1/clusterers/vocab100000_ed1_gamma0.3/job0outof2 $HOME/agglom2/clusterers/vocab100000_ed1_gamma0.3/job1outof2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Erik Jones 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /codalab/construct_agglom_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to construct agglomerative clusters 3 | 4 | echo 'job_id: '$1 5 | echo 'gamma: '$2 6 | 7 | # Activate at-env conda environment 8 | source activate atenv 9 | # Missing package? 10 | pip install query 11 | # conda list 12 | 13 | # Set required environment variables 14 | # Make directory where clusterers will be stored 15 | mkdir $HOME/clusterers 16 | 17 | # Needed for CodaLab as roben is mounted one level down 18 | cd roben || exit 1 19 | 20 | # Make clusterer 21 | echo 'Constructing a connected component clusters' 22 | python construct_clusters.py --vocab_size 100000 --perturb_type ed1 --data_dir $HOME/data --output_dir $HOME/clusterers 23 | 24 | #Now set the clusterer path 25 | export CLUSTERER_PATH=$HOME/clusterers/vocab100000_ed1.pkl 26 | 27 | #Make directory where partial agglomerative clusters will be saved 28 | mkdir $HOME/clusterers/vocab100000_ed1_gamma$2 29 | 30 | # We will now construct our more complicated clusters, the agglomerative clusters 31 | echo 'Building agglomerative clusters (one of two jobs)...' 32 | python agglom_clusters.py --gamma $2 --clusterer_path $CLUSTERER_PATH --job_id $1 --num_jobs 2 33 | -------------------------------------------------------------------------------- /codalab/typo_corrector.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script to run typo corrector training 3 | # Possible tasks are RTE, MRPC, SST-2, QNLI, MNLI, QQP 4 | 5 | echo 'Running typo corrector training for task: '$1 6 | 7 | # Activate at-env conda environment 8 | source activate atenv 9 | # Missing package? 10 | pip install query 11 | # conda list 12 | 13 | # Set required environment variables 14 | export TASK_NAME=$1 15 | export CLUSTERER_PATH=$HOME/clusterers/vocab100000_ed1.pkl 16 | export GLUE_DIR=$HOME/data/glue_data 17 | export TC_DIR=$HOME/tc_data 18 | 19 | # Make directory TC_DIR 20 | mkdir $TC_DIR 21 | mkdir $TC_DIR/glue_tc_preprocessed 22 | 23 | # Needed for CodaLab as roben is mounted one level down 24 | cd roben || exit 1 25 | 26 | # Store preprocessed data, vocabularies, and models 27 | echo 'Storing preprocessed data, vocabularies, and models...' 28 | python preprocess_tc.py --glue_dir $GLUE_DIR --save_dir $TC_DIR/glue_tc_preprocessed 29 | 30 | # Change directory to scRNN 31 | cd scRNN || exit 1 32 | 33 | # Train a typo-corrector based on random perturbations to the task data 34 | echo 'Training a typo-corrector based on random perturbations to the data...' 35 | python train.py --task-name $TASK_NAME --preprocessed_glue_dir $TC_DIR/glue_tc_preprocessed --tc_dir $TC_DIR -------------------------------------------------------------------------------- /preprocess_tc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from tqdm import tqdm 4 | 5 | from utils_glue import PROCESSORS 6 | import argparse 7 | 8 | 9 | #Script used to convert data to something the scRNN can train... 10 | 11 | def preprocess_for_typo_corrector(task, glue_data_dir, tc_preprocess_data_dir): 12 | print("Peprocessing for {}".format(task)) 13 | task_data_dir = os.path.join(glue_data_dir, task) 14 | task = task.lower() 15 | processor = PROCESSORS[task]() 16 | train_examples = processor.get_train_examples(task_data_dir) 17 | has_b = train_examples[0].text_b is not None 18 | example_dicts = [] 19 | for example in tqdm(train_examples): 20 | example_dict = {} 21 | example_dict['text_a'] = example.text_a 22 | if has_b: 23 | example_dict['text_b'] = example.text_b 24 | example_dicts.append(example_dict) 25 | data = pd.DataFrame(example_dicts) 26 | save_name = os.path.join(tc_preprocess_data_dir, '{}_train_preprocessed.tsv'.format(task)) 27 | if not os.path.exists(tc_preprocess_data_dir): 28 | os.mkdir(tc_preprocess_data_dir) 29 | data.to_csv(save_name, sep = '\t') 30 | print ("Saved at {}".format(save_name)) 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--glue_dir', type = str, default = 'data/glue_data', 35 | help = 'Directory where glue data is stored') 36 | parser.add_argument('--save_dir', type = str, default = 'data/glue_tc_preprocessed', 37 | help = 'Directory where preprocessed glue data will be stored.') 38 | return parser.parse_args() 39 | 40 | if __name__ == '__main__': 41 | args = parse_args() 42 | glue_data_dir = args.glue_dir 43 | tc_preprocess_data_dir = args.save_dir 44 | tasks = ['SST-2', 'MRPC', 'QQP', 'MNLI', 'QNLI', 'RTE'] 45 | for task in tasks: 46 | preprocess_for_typo_corrector(task, glue_data_dir, tc_preprocess_data_dir) 47 | -------------------------------------------------------------------------------- /codalab/no_attk_run_experiment.sh: -------------------------------------------------------------------------------- 1 | # Script to run standard training + attack in CodaLab 2 | # Use: bash run_experiment.sh task_name recoverer augmentor [clusterer_path/tc_dir] 3 | # Fourth argument is optional: use tc_dir when recoverer is scrnn, clusterer_path when recoverer starts with clust. 4 | 5 | # Activate at-env conda environment 6 | source activate erik-cert 7 | 8 | # Missing package? 9 | 10 | # Set required environment variables 11 | export TASK_NAME=$1 12 | export RECOVERER=$2 13 | export AUGMENTOR=$3 14 | export DO_ROBUST='' 15 | export ATTACK_TYPE='ed1' 16 | 17 | # Add --do_robust to compute robust accuracy when using clusters to defend 18 | if [ "$RECOVERER" = 'clust-rep' ]; then 19 | export DO_ROBUST='--do_robust' 20 | elif [ "$RECOVERER" = 'clust-intprm' ]; then 21 | export DO_ROBUST='--do_robust' 22 | export ATTACK_TYPE='intprm' 23 | fi 24 | 25 | export TC_DIR=$HOME/tc_data 26 | # WARNING: to change the typo-corrector directory, recoverer should be scrnn (otherwise doesn't make a difference.) 27 | if [ "$#" = 4 ] && [ "$RECOVERER" = 'scrnn' ]; then 28 | export TC_DIR=$HOME/$4 29 | fi 30 | 31 | export CLUSTERER_PATH=$HOME/clusterers/vocab100000_ed1.pkl 32 | # When the fourth argument is present, change the cluster path to $HOME/$5 33 | # WARNING: to change clusterer path, recoverer can't be scrnn (wouldn't make a difference anyways') 34 | if [ "$#" = 4 ] && [ "$RECOVERER" != 'scrnn' ]; then 35 | export CLUSTERER_PATH=$HOME/$4 36 | fi 37 | 38 | export GLUE_DIR=$HOME/data/glue_data 39 | 40 | # Needed for CodaLab as roben is mounted one level down 41 | cd roben || exit 1 42 | 43 | # Training + attacking 44 | echo 'Training and then attacking...' 45 | python run_glue.py --log_stdout_only --tc_dir $TC_DIR --task_name $TASK_NAME --do_lower_case --do_train --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir $HOME/model_output/$TASK_NAME --overwrite_output_dir --save_results --save_dir $HOME/train --recoverer $RECOVERER --augmentor $AUGMENTOR --run_test $DO_ROBUST --clusterer_path $CLUSTERER_PATH 46 | -------------------------------------------------------------------------------- /codalab/run_experiment_intprm.sh: -------------------------------------------------------------------------------- 1 | # Script to run standard training + attack in CodaLab 2 | # Use: bash run_experiment.sh task_name recoverer augmentor [clusterer_path/tc_dir] 3 | # Fourth argument is optional: use tc_dir when recoverer is scrnn, clusterer_path when recoverer starts with clust. 4 | 5 | # Activate at-env conda environment 6 | source activate erik-cert 7 | 8 | # Missing package? 9 | 10 | # Set required environment variables 11 | export TASK_NAME=$1 12 | export RECOVERER=$2 13 | export AUGMENTOR=$3 14 | export DO_ROBUST='' 15 | export ATTACK_TYPE='intprm' 16 | export DO_ATTACK='true' 17 | 18 | # Add --do_robust to compute robust accuracy when using clusters to defend 19 | if [ "$RECOVERER" = 'clust-intprm' ]; then 20 | export DO_ROBUST='--do_robust' 21 | #Attack does not change performance, since can sort interior of string... 22 | export DO_ATTACK='false' 23 | fi 24 | 25 | export CLUSTERER_PATH=$HOME/$4 26 | # When the fourth argument is present, change the cluster path to $HOME/$5 27 | # WARNING: to change clusterer path, recoverer can't be scrnn (wouldn't make a difference anyways') 28 | export GLUE_DIR=$HOME/data/glue_data 29 | 30 | # Needed for CodaLab as roben is mounted one level down 31 | cd roben_intprm || exit 1 32 | 33 | # Training + attacking 34 | echo 'Training and then attacking...' 35 | python run_glue.py --log_stdout_only --task_name $TASK_NAME --do_lower_case --do_train --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir $HOME/model_output/$TASK_NAME --overwrite_output_dir --save_results --save_dir $HOME/train --recoverer $RECOVERER --augmentor $AUGMENTOR --run_test $DO_ROBUST --clusterer_path $CLUSTERER_PATH 36 | if [ "$DO_ATTACK" = 'true' ]; then 37 | python run_glue.py --log_stdout_only --task_name $TASK_NAME --do_lower_case --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir $HOME/model_output/$TASK_NAME --save_results --save_dir $HOME/attack --recoverer $RECOVERER --augmentor $AUGMENTOR --run_test --clusterer_path $CLUSTERER_PATH --model_name_or_path $HOME/model_output/$TASK_NAME --attack --new_attack --attacker beam-search --beam_width 5 --attack_name RandomPerturbationAttack --attack_type $ATTACK_TYPE 38 | fi 39 | -------------------------------------------------------------------------------- /reconstruct_clusters.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import argparse 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--save_path', type = str, required = True, 8 | help = 'Directory where all the partial clusterers are stored...') 9 | parser.add_argument('--file_paths', nargs = '+', 10 | help = 'Files that will be combined') 11 | return parser.parse_args() 12 | 13 | def reconstruct_clusters(save_path, file_paths): 14 | num_clusters_added = 0 15 | clusters = {} 16 | word2cluster = {} 17 | cluster2representative = {} 18 | typo2cluster = {} 19 | word2freq = {} 20 | 21 | for file_path in file_paths: 22 | with open(file_path, 'rb') as f: 23 | job_clusterer_dict = pickle.load(f) 24 | 25 | new_clusters = job_clusterer_dict['cluster'] 26 | word2newcluster = job_clusterer_dict['word2cluster'] 27 | newcluster2representative = job_clusterer_dict['cluster2representative'] 28 | newtypo2cluster = job_clusterer_dict['typo2cluster'] 29 | newword2freq = job_clusterer_dict['word2freq'] 30 | 31 | for cluster_id in new_clusters: 32 | clusters[cluster_id + num_clusters_added] = new_clusters[cluster_id] 33 | cluster2representative[cluster_id + num_clusters_added] = newcluster2representative[cluster_id] 34 | 35 | for word in word2newcluster: 36 | word2cluster[word] = word2newcluster[word] + num_clusters_added 37 | word2freq[word] = newword2freq[word] 38 | 39 | for typo in newtypo2cluster: 40 | assert typo not in typo2cluster 41 | typo2cluster[typo] = newtypo2cluster[typo] + num_clusters_added 42 | 43 | num_clusters_added += len(new_clusters) 44 | 45 | save_dict = {'cluster': clusters, 'word2cluster': word2cluster, 46 | 'cluster2representative': cluster2representative, 'typo2cluster': typo2cluster, 'word2freq': word2freq} 47 | for key in save_dict: 48 | if key not in ['typo2cluster', 'word2freq']: 49 | print(key) 50 | print(save_dict[key]) 51 | 52 | with open(save_path, 'wb') as f: 53 | pickle.dump(save_dict, f) 54 | print("Saved!") 55 | 56 | if __name__ == '__main__': 57 | args = parse_args() 58 | reconstruct_clusters(args.save_path, args.file_paths) 59 | 60 | 61 | -------------------------------------------------------------------------------- /codalab/run_experiment.sh: -------------------------------------------------------------------------------- 1 | # Script to run standard training + attack in CodaLab 2 | # Use: bash run_experiment.sh task_name recoverer augmentor [clusterer_path/tc_dir] 3 | # Fourth argument is optional: use tc_dir when recoverer is scrnn, clusterer_path when recoverer starts with clust. 4 | 5 | # Activate at-env conda environment 6 | source activate erik-cert 7 | 8 | # Missing package? 9 | 10 | # Set required environment variables 11 | export TASK_NAME=$1 12 | export RECOVERER=$2 13 | export AUGMENTOR=$3 14 | export DO_ROBUST='' 15 | export ATTACK_TYPE='ed1' 16 | 17 | # Add --do_robust to compute robust accuracy when using clusters to defend 18 | if [ "$RECOVERER" = 'clust-rep' ]; then 19 | export DO_ROBUST='--do_robust' 20 | elif [ "$RECOVERER" = 'clust-intprm' ]; then 21 | export DO_ROBUST='--do_robust' 22 | export ATTACK_TYPE='intprm' 23 | fi 24 | 25 | export TC_DIR=$HOME/tc_data 26 | # WARNING: to change the typo-corrector directory, recoverer should be scrnn (otherwise doesn't make a difference.) 27 | if [ "$#" = 4 ] && [ "$RECOVERER" = 'scrnn' ]; then 28 | export TC_DIR=$HOME/$4 29 | fi 30 | 31 | export CLUSTERER_PATH=$HOME/clusterers/vocab100000_ed1.pkl 32 | # When the fourth argument is present, change the cluster path to $HOME/$5 33 | # WARNING: to change clusterer path, recoverer can't be scrnn (wouldn't make a difference anyways') 34 | if [ "$#" = 4 ] && [ "$RECOVERER" != 'scrnn' ]; then 35 | export CLUSTERER_PATH=$HOME/$4 36 | fi 37 | 38 | export GLUE_DIR=$HOME/data/glue_data 39 | 40 | # Needed for CodaLab as roben is mounted one level down 41 | cd roben || exit 1 42 | 43 | # Training + attacking 44 | echo 'Training and then attacking...' 45 | python run_glue.py --log_stdout_only --tc_dir $TC_DIR --task_name $TASK_NAME --do_lower_case --do_train --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir $HOME/model_output/$TASK_NAME --overwrite_output_dir --save_results --save_dir $HOME/train --recoverer $RECOVERER --augmentor $AUGMENTOR --run_test $DO_ROBUST --clusterer_path $CLUSTERER_PATH 46 | python run_glue.py --log_stdout_only --tc_dir $TC_DIR --task_name $TASK_NAME --do_lower_case --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir $HOME/model_output/$TASK_NAME --save_results --save_dir $HOME/attack --recoverer $RECOVERER --augmentor $AUGMENTOR --run_test --clusterer_path $CLUSTERER_PATH --model_name_or_path $HOME/model_output/$TASK_NAME --attack --new_attack --attacker beam-search --beam_width 5 --attack_name LongDeleteShortAll --attack_type $ATTACK_TYPE 47 | -------------------------------------------------------------------------------- /augmentor.py: -------------------------------------------------------------------------------- 1 | from random import sample 2 | 3 | from utils_glue import InputExample 4 | from attacks import ED1AttackSurface 5 | 6 | class Augmentor(): 7 | def __init__(self, attack_surface = None): 8 | if attack_surface is None: 9 | attack_surface = ED1AttackSurface() 10 | self.attack_surface = attack_surface 11 | 12 | def augment_dataset(self, dataset): 13 | augmented_examples = [] 14 | for example in dataset: 15 | augmented = self._augment_example(example) 16 | augmented_examples.extend(augmented) 17 | return augmented_examples 18 | 19 | def _augment_example(self, example): 20 | #Should return a list, to allow for multiple 21 | raise NotImplementedError 22 | 23 | class IdentityAugmentor(Augmentor): 24 | def _augment_example(self, example): 25 | return [example] 26 | 27 | class HalfAugmentor(Augmentor): 28 | #New training dataset is double the size, with half normal and half randomly augmented... 29 | def _augment_example(self, example): 30 | tokens = example.text_a.split() 31 | a_len = len(tokens) 32 | if example.text_b: 33 | tokens.extend(example.text_b.split()) 34 | augmented_version = [] 35 | for token in tokens: 36 | possible_perturbations = self.attack_surface.get_perturbations(token) 37 | augmented_version.append(sample(possible_perturbations, 1)[0]) 38 | augmented_a = augmented_version[:a_len] 39 | a_aug = ' '.join(augmented_a) 40 | b_aug = None 41 | if example.text_b: 42 | augmented_b = augmented_version[a_len:] 43 | b_aug = ' '.join(augmented_b) 44 | augmented_example = InputExample('{}-AUG'.format(example.guid), a_aug, b_aug, example.label) 45 | return [example, augmented_example] 46 | 47 | class KAugmentor(Augmentor): 48 | #TODO, should allow changing of k outside... 49 | def _augment_example(self, example, k = 4): 50 | tokens = example.text_a.split() 51 | a_len = len(tokens) 52 | if example.text_b: 53 | tokens.extend(example.text_b.split()) 54 | augmented_examples = [] 55 | for i in range(k): 56 | augmented_version = [] 57 | for token in tokens: 58 | possible_perturbations = self.attack_surface.get_perturbations(token) 59 | augmented_version.append(sample(possible_perturbations, 1)[0]) 60 | augmented_a = augmented_version[:a_len] 61 | a_aug = ' '.join(augmented_a) 62 | b_aug = None 63 | if example.text_b: 64 | augmented_b = augmented_version[a_len:] 65 | b_aug = ' '.join(augmented_b) 66 | augmented_example = InputExample('{}-AUG{}'.format(example.guid, i), a_aug, b_aug, example.label) 67 | augmented_examples.append(augmented_example) 68 | return [example, *augmented_examples] 69 | 70 | 71 | 72 | AUGMENTORS = {'identity': IdentityAugmentor, 'half-aug': HalfAugmentor, 'k-aug': KAugmentor} 73 | 74 | -------------------------------------------------------------------------------- /scRNN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from allennlp.modules.elmo import Elmo 5 | 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | 8 | 9 | 10 | class ScRNN(nn.Module): 11 | def __init__(self, char_vocab_size, hdim, output_dim): 12 | super(ScRNN, self).__init__() 13 | """ layers """ 14 | self.lstm = nn.LSTM(3*char_vocab_size, hdim, 1, batch_first=True, 15 | bidirectional=True) 16 | self.linear = nn.Linear(2*hdim, output_dim) 17 | 18 | 19 | 20 | """ size(inp) --> BATCH_SIZE x MAX_SEQ_LEN x EMB_DIM 21 | """ 22 | def forward(self, inp, lens): 23 | packed_input = pack_padded_sequence(inp, lens, batch_first=True) 24 | packed_output, _ = self.lstm(packed_input) 25 | h, _ = pad_packed_sequence(packed_output, batch_first=True) 26 | out = self.linear(h) # out is batch_size x max_seq_len x class_size 27 | out = out.transpose(dim0=1, dim1=2) 28 | return out # out is batch_size x class_size x max_seq_len 29 | 30 | 31 | 32 | class ElmoScRNN(nn.Module): 33 | def __init__(self, char_vocab_size, hdim, output_dim): 34 | super(ElmoScRNN, self).__init__() 35 | self.elmo_hdim = 1024 36 | self.weight_file = "../elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" 37 | self.options_file = "../elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_options.json" 38 | """ layers """ 39 | self.elmo = Elmo(self.options_file, self.weight_file, 1) 40 | self.lstm = nn.LSTM(3*char_vocab_size, hdim, 1, batch_first=True, 41 | bidirectional=True) 42 | self.linear = nn.Linear(2*hdim + self.elmo_hdim, output_dim) 43 | 44 | 45 | 46 | #TODO: go away from the assumption that the batch size is 1. 47 | """ size(inp) --> BATCH_SIZE x MAX_SEQ_LEN x EMB_DIM 48 | """ 49 | def forward(self, inp, elmo_inp, lens): 50 | packed_input = pack_padded_sequence(inp, lens, batch_first=True) 51 | packed_output, _ = self.lstm(packed_input) 52 | h, _ = pad_packed_sequence(packed_output, batch_first=True) # h is BATCH_SIZE x MAX_SEQ_LEN x hdim 53 | h_e = self.elmo(elmo_inp)['elmo_representations'][0] # h_e is BATCH_SIZE X MAX_SEQ_LEN x 1024 54 | 55 | h = torch.cat((h, h_e), 2) # concat along the last dim 56 | 57 | out = self.linear(h) # out is batch_size x max_seq_len x class_size 58 | out = out.transpose(dim0=1, dim1=2) 59 | return out # out is batch_size x class_size x max_seq_len 60 | 61 | 62 | 63 | """ 64 | This is a vanilla model, which takes the ELMO representations for each token, 65 | and tries to reconstruct each word using the (oft manipulated) word 66 | """ 67 | class ElmoRNN(nn.Module): 68 | def __init__(self, output_dim): 69 | super(ElmoRNN, self).__init__() 70 | self.elmo_hdim = 1024 71 | self.weight_file = "../elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" 72 | self.options_file = "../elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_options.json" 73 | """ layers """ 74 | self.elmo = Elmo(self.options_file, self.weight_file, 1) 75 | self.linear = nn.Linear(self.elmo_hdim, output_dim) 76 | 77 | 78 | """ size(inp) --> BATCH_SIZE x MAX_SEQ_LEN x EMB_DIM 79 | """ 80 | def forward(self, inp): 81 | h = self.elmo(inp)['elmo_representations'][0] # h_e is BATCH_SIZE X MAX_SEQ_LEN x 1024 82 | 83 | out = self.linear(h) # out is batch_size x max_seq_len x class_size 84 | out = out.transpose(dim0=1, dim1=2) # flip the second and the third dimensions 85 | return out # out is batch_size x class_size x max_seq_len 86 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | #OOV_CLUSTER = '' 4 | OOV_CLUSTER = -1 #Chnaged for 5 | OOV_TOKEN = '' 6 | 7 | 8 | class ModelRunner(object): 9 | """Object that can run a model on a given dataset.""" 10 | def __init__(self, recoverer, output_mode, label_list, output_dir, device): 11 | self.recoverer = recoverer 12 | self.output_mode = output_mode 13 | self.label_list = label_list 14 | self.output_dir = output_dir 15 | self.device = device 16 | 17 | def train(self, train_data, args): 18 | """Given already-recovered data, train the model.""" 19 | raise NotImplementedError 20 | 21 | def query(self, examples, batch_size, do_evaluate=True, return_logits=False, 22 | do_recover=True, use_tqdm=True): 23 | """Run the recoverer on raw data and query the model on examples.""" 24 | raise NotImplementedError 25 | 26 | class Clustering(object): 27 | """Object representing an assignment of words to clusters. 28 | 29 | Provides some utilities for dealing with words, typos, and clusters. 30 | """ 31 | def __init__(self, clusterer_dict, max_num_possibilities=None, passthrough=False): 32 | self.cluster2elements = clusterer_dict['cluster'] 33 | self.word2cluster = clusterer_dict['word2cluster'] 34 | self.cluster2representative = clusterer_dict['cluster2representative'] 35 | self.word2freq = clusterer_dict['word2freq'] 36 | self.typo2cluster = clusterer_dict['typo2cluster'] 37 | if max_num_possibilities: 38 | self.cluster2elements = self.filter_possibilities(max_num_possibilities) 39 | 40 | def filter_possibilities(self, max_num_possibilities): 41 | filtered_cluster2elements = {} 42 | for cluster in self.cluster2elements: 43 | elements = self.cluster2elements[cluster] 44 | frequency_list = [(elem, self.word2freq[elem]) for elem in elements] 45 | frequency_list.sort(key = lambda x: x[1], reverse = True) 46 | filtered_elements = [pair[0] for pair in frequency_list[:max_num_possibilities]] 47 | filtered_cluster2elements[cluster] = filtered_elements 48 | return filtered_cluster2elements 49 | 50 | @classmethod 51 | def from_pickle(cls, path, **kwargs): 52 | with open(path, 'rb') as f: 53 | clusterer_dict = pickle.load(f) 54 | return cls(clusterer_dict, **kwargs) 55 | 56 | def get_words(self, cluster): 57 | if cluster == OOV_CLUSTER: 58 | return [OOV_TOKEN] 59 | return self.cluster2elements[cluster] 60 | 61 | def in_vocab(self, word): 62 | return word in self.word2cluster 63 | 64 | def get_cluster(self, word): 65 | """Get cluster of a word, or OOV_CLUSTER if out of vocabulary.""" 66 | word = word.lower() 67 | if word in self.word2cluster: 68 | return self.word2cluster[word] 69 | return OOV_CLUSTER 70 | 71 | def get_rep(self, cluster): 72 | """Get representative for a cluster.""" 73 | if cluster == OOV_CLUSTER: 74 | return OOV_TOKEN 75 | return self.cluster2representative[cluster] 76 | 77 | def get_freq(self, word): 78 | return self.word2freq[word] 79 | 80 | def map_token(self, token, remap_vocab=True, passthrough = False): 81 | """Map a token (possibly a typo) to a cluster. 82 | 83 | Args: 84 | token: a token, possibly a typo 85 | remap_vocab: if False, always map vocab words to themselves, 86 | because perturbing vocab words has been disallowed. 87 | passthrough: Allow OOV to go to downstream model... 88 | """ 89 | token = token.lower() 90 | if token in self.word2cluster and not remap_vocab: 91 | return self.get_cluster(token) 92 | if token in self.typo2cluster: 93 | return self.typo2cluster[token] 94 | if passthrough: 95 | return token 96 | return OOV_CLUSTER 97 | 98 | 99 | def pkl_save(obj, filename): 100 | with open(filename, 'wb') as f: 101 | pickle.dump(obj, f) 102 | 103 | def pkl_load(filename): 104 | with open(filename, 'rb') as f: 105 | obj = pickle.load(f) 106 | return obj 107 | 108 | -------------------------------------------------------------------------------- /scRNN/corrector.py: -------------------------------------------------------------------------------- 1 | """ class using Semi Character RNNs as a defense mechanism 2 | ScRNN paper: https://arxiv.org/abs/1608.02214 3 | """ 4 | import os 5 | from scRNN import utils 6 | from scRNN.utils import * 7 | from scRNN.model import ScRNN 8 | 9 | # torch related imports 10 | import torch 11 | from torch import nn 12 | from torch.autograd import Variable 13 | 14 | # elmo related imports 15 | from allennlp.modules.elmo import batch_to_ids 16 | 17 | class ScRNNChecker(object): 18 | def __init__(self, tc_dir, task_name='sst-2', vocab_size=9999,\ 19 | vocab_size_bg=78470, use_background=False, unk_output=False, \ 20 | use_elmo=False, use_elmo_bg=False): 21 | # TODO: causes problem - lower was causing problems 22 | task_name = task_name.upper() 23 | #MODEL_PATH = PWD + "/model_dumps/scrnn_TASK_NAME={}_VOCAB_SIZE=9999_REP_LIST=_REP_PROBS=".format(task_name) 24 | MODEL_PATH = os.path.join(tc_dir, 'model_dumps', 'scrnn_TASK_NAME={}'.format(task_name)) 25 | 26 | self.vocab_size_bg = vocab_size_bg 27 | self.vocab_size = vocab_size 28 | self.unk_output = unk_output 29 | 30 | # path to vocabs 31 | w2i_PATH = os.path.join(tc_dir, 'vocab', '{}w2i_{}.p'.format(task_name, vocab_size)) 32 | i2w_PATH = os.path.join(tc_dir, 'vocab', '{}i2w_{}.p'.format(task_name, vocab_size)) 33 | CHAR_VOCAB_PATH = os.path.join(tc_dir, 'vocab', '{}CHAR_VOCAB_ {}.p'.format(task_name, vocab_size)) 34 | 35 | set_word_limit(vocab_size, task_name) 36 | 37 | _, _, char_vocab = load_vocab_dicts(w2i_PATH, i2w_PATH, CHAR_VOCAB_PATH) 38 | print("Number of characters: ", len(char_vocab)) 39 | model = ScRNN(len(char_vocab), 50, 10000) 40 | model.load_state_dict(torch.load(MODEL_PATH)) 41 | self.model = model 42 | self.predicted_unks = 0.0 43 | self.predicted_unks_in_vocab = 0.0 44 | self.total_predictions = 0.0 45 | self.use_background = use_background 46 | self.use_elmo = use_elmo 47 | self.use_elmo_bg = use_elmo_bg 48 | print("Made it to the desired location!") 49 | return 50 | 51 | 52 | def correct_string(self, line): 53 | line = line.lower() 54 | Xtype = torch.FloatTensor 55 | ytype = torch.LongTensor 56 | is_cuda = torch.cuda.is_available() 57 | 58 | if is_cuda: 59 | self.model.cuda() 60 | Xtype = torch.cuda.FloatTensor 61 | ytype = torch.cuda.LongTensor 62 | if self.use_background: self.model_bg.cuda() 63 | 64 | X, _ = get_line_representation(line) 65 | tx = Variable(torch.from_numpy(np.array([X]))).type(Xtype) 66 | 67 | if self.use_elmo or self.use_elmo_bg: 68 | tx_elmo = Variable(batch_to_ids([line.split()])).type(ytype) 69 | 70 | 71 | SEQ_LEN = len(line.split()) 72 | 73 | if self.use_elmo: 74 | ty_pred = self.model(tx, tx_elmo, [SEQ_LEN]) 75 | else: 76 | ty_pred = self.model(tx, [SEQ_LEN]) 77 | 78 | y_pred = ty_pred.detach().cpu().numpy() 79 | y_pred = y_pred[0] # ypred now is NUM_CLASSES x SEQ_LEN 80 | 81 | if self.use_background: 82 | if self.use_elmo_bg: 83 | ty_pred_bg = self.model_bg(tx, tx_elmo, [SEQ_LEN]) 84 | else: 85 | ty_pred_bg = self.model_bg(tx, [SEQ_LEN]) 86 | y_pred_bg = ty_pred_bg.detach().cpu().numpy() 87 | y_pred_bg = y_pred_bg[0] 88 | 89 | output_words = [] 90 | 91 | self.total_predictions += SEQ_LEN 92 | 93 | for idx in range(SEQ_LEN): 94 | pred_idx = np.argmax(y_pred[:, idx]) 95 | if pred_idx == utils.WORD_LIMIT: 96 | word = line.split()[idx] 97 | if self.use_background: 98 | pred_idx_bg = np.argmax(y_pred_bg[:, idx]) 99 | if pred_idx_bg != self.vocab_size_bg: 100 | word = utils.i2w_bg[pred_idx_bg] 101 | if self.unk_output: 102 | word = "a" # choose a sentiment neutral word 103 | output_words.append(word) 104 | self.predicted_unks += 1.0 105 | if word in utils.w2i: 106 | self.predicted_unks_in_vocab += 1.0 107 | else: 108 | output_words.append(utils.i2w[pred_idx]) 109 | 110 | return " ".join(output_words) 111 | 112 | def reset_counters(self): 113 | self.predicted_unks = 0.0 114 | self.total_predictions = 0.0 115 | 116 | 117 | def report_statistics(self): 118 | print ("Total number of words predicted by background model = %0.2f " %(100. * self.predicted_unks/self.total_predictions)) 119 | print ("Total number of in vocab words predicted by background model = %0.2f " %(100. * self.predicted_unks_in_vocab/self.total_predictions)) 120 | -------------------------------------------------------------------------------- /preprocess_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import itertools 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | 7 | from utils import pkl_save, pkl_load 8 | from edit_dist_utils import get_all_edit_dist_one, get_all_internal_permutations, get_sorted_word 9 | 10 | TOY_VOCAB = ['cat', 'bat', 'car', 'bar', 'airplane!!!'] 11 | GLOVE_PATH = 'data/glove/glove.6B.50d.txt' 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--vocab-type', choices=['lm', 'glove'], default = 'glove', 16 | help = 'Where to get the vocab from') 17 | parser.add_argument('--save-root', type = str, default = '', 18 | help = 'Name used to format vocab preprocessing output') 19 | parser.add_argument('--filetype', type = int, default = 1111, 20 | help = 'insert, delete, substitute, swap') 21 | parser.add_argument('--perturb_type', default = 'ed1', type = str, help = 'Type of perturbation to make dict for.') 22 | return parser.parse_args() 23 | 24 | 25 | def load_glove_vocab(glove_path, num_lines = 400000): 26 | print("Reading GloVe vectors from {}...".format(glove_path)) 27 | vocab = [] 28 | with open(glove_path) as f: 29 | for i, line in tqdm(enumerate(f), total=num_lines): 30 | toks = line.strip().split(' ') 31 | word = toks[0] 32 | vocab.append(word) 33 | return vocab 34 | 35 | def vocab_from_lm(lm): 36 | print("Possible vocab size: ", len(lm.word_to_idx)) 37 | vocab = list(lm.word_to_idx) 38 | vocab = [word for word in vocab if word.isalpha() and word == word.lower()] 39 | print("Vocab size after flitering: ", len(vocab)) 40 | return vocab 41 | 42 | def preprocess_neighbors_intprm(vocab): 43 | neighbor_trans_map = None 44 | sorted2word = defaultdict(set) 45 | vocab = [word.lower() for word in vocab] 46 | print("Grouping by sorted word") 47 | for word in tqdm(vocab): 48 | sorted_word = get_sorted_word(word) 49 | sorted2word[sorted_word].add(word) 50 | 51 | neighbor_trans_map = None 52 | print("Constructing edges...") 53 | neighbor_map = defaultdict(set) 54 | for sorted_word in tqdm(sorted2word): 55 | permutations = itertools.permutations(sorted2word[sorted_word], r = 2) 56 | for src, dest in permutations: 57 | neighbor_map[src].add(dest) 58 | #Allow self-edges 59 | for src in sorted2word[sorted_word]: 60 | neighbor_map[src].add(src) 61 | return sorted2word, neighbor_map, neighbor_trans_map 62 | 63 | 64 | 65 | def preprocess_neighbors(vocab, filetype = 1111, sub_restrict = None): 66 | #For efficiency, assume edit distance 1 is symmetric. Not true for certain filetypes, so perturbations act accordingly... 67 | typo2vocab = defaultdict(set) 68 | print("Making typo dict...") 69 | for word in tqdm(vocab): 70 | perturbations = get_all_edit_dist_one(word, filetype = filetype, sub_restrict = sub_restrict) 71 | 72 | for typo in perturbations: 73 | typo2vocab[typo].add(word) 74 | 75 | print("Constructing edges...") 76 | neighbor_map = defaultdict(set) 77 | neighbor_trans_map = defaultdict(set) 78 | for typo in tqdm(typo2vocab): 79 | permutations = itertools.permutations(typo2vocab[typo], r = 2) 80 | for src, dest in permutations: 81 | neighbor_map[src].add(dest) 82 | neighbor_trans_map[(src, dest)].add(typo) 83 | #Allow self-edges 84 | for src in typo2vocab[typo]: 85 | neighbor_map[src].add(src) 86 | neighbor_trans_map[(src, src)].add(src) 87 | return typo2vocab, neighbor_map, neighbor_trans_map 88 | 89 | 90 | 91 | def preprocess_vocab(args): 92 | if args.vocab_type == 'glove': 93 | vocab = load_glove_vocab(GLOVE_PATH) 94 | elif args.vocab_type == 'lm': 95 | query_handler = load_language_model() 96 | vocab = vocab_from_lm(query_handler) 97 | else: 98 | raise ValueError("Invalid vocab type of {}".format(args.vocab_type)) 99 | print("Vocab sample: ", vocab[300:500]) 100 | sub_dict = None 101 | if args.modify_end: 102 | print("Modifying the end...") 103 | if args.perturb_type == 'ed1': 104 | typo2vocab, ed2_neighbors, neighbor_trans_map = preprocess_neighbors(vocab, filetype = args.filetype, 105 | sub_restrict = sub_dict) 106 | elif args.perturb_type == 'intprm': 107 | typo2vocab, ed2_neighbors, neighbor_trans_map = preprocess_neighbors_intprm(vocab) 108 | else: 109 | raise NotImplementedError 110 | pkl_save(ed2_neighbors, 'ed2_neighbors{}pt{}.pkl'.format(args.save_root, args.perturb_type)) 111 | print("Saved ed2") 112 | 113 | def get_neighbors(args): 114 | neighbor_path = 'ed2_neighbors{}pt{}.pkl'.format(args.save_root, args.perturb_type) 115 | neighbor_dict = pkl_load(neighbor_path) 116 | while True: 117 | print('broad' in neighbor_dict['bold']) 118 | inpt = input("Enter a word: ") 119 | inpt = inpt.lower() 120 | if inpt not in neighbor_dict: 121 | print("Word not preprocessed...") 122 | else: 123 | print("Neighbors for {}:".format(inpt)) 124 | print(neighbor_dict[inpt]) 125 | 126 | if __name__ == '__main__': 127 | args = parse_args() 128 | get_neighbors(args) 129 | -------------------------------------------------------------------------------- /edit_dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Edit distance utils... 3 | """ 4 | from collections import defaultdict 5 | import numpy as np 6 | import random 7 | import string 8 | from itertools import permutations 9 | 10 | def process_filetype(filetype): 11 | insert = (filetype // 1000) % 2 == 1 12 | delete = (filetype // 100) % 2 == 1 13 | substitute = (filetype // 10) % 2 == 1 14 | swap = filetype % 2 == 1 15 | return insert, delete, substitute, swap 16 | 17 | def get_all_edit_dist_one(word, filetype = 1111, sub_restrict = None): 18 | """ 19 | Allowable edit_dist_one perturbations: 20 | 1. Insert any lowercase characer at any position other than the start 21 | 2. Delete any character other than the first one 22 | 3. Substitute any lowercase character for any other lowercase letter other than the start 23 | 4. Swap adjacent characters 24 | We also include the original word. Filetype determines which of the allowable perturbations to use. 25 | """ 26 | insert, delete, substitute, swap = process_filetype(filetype) 27 | #last_mod_pos is last thing you could insert before 28 | last_mod_pos = len(word) - 1 29 | ed1 = set() 30 | for pos in range(1, last_mod_pos + 1): #can add letters at the end 31 | if delete and pos < last_mod_pos: 32 | deletion = word[:pos] + word[pos + 1:] 33 | ed1.add(deletion) 34 | if swap and pos < last_mod_pos - 1: 35 | #swapping thing at pos with thing at pos + 1 36 | swaped = word[:pos] + word[pos + 1] + word[pos] + word[pos + 2:] 37 | ed1.add(swaped) 38 | for letter in string.ascii_lowercase: 39 | if insert: 40 | #Insert right after pos - 1 41 | insertion = word[:pos] + letter + word[pos:] 42 | ed1.add(insertion) 43 | can_substitute = sub_restrict is None or letter in sub_restrict[word[pos]] 44 | if substitute and pos < last_mod_pos and can_substitute: 45 | substitution = word[:pos] + letter + word[pos + 1:] 46 | ed1.add(substitution) 47 | #Include original word 48 | ed1.add(word) 49 | return ed1 50 | 51 | def get_all_internal_permutations(word): 52 | if len(word) > 10: 53 | return set([word]) 54 | first_char = word[0] 55 | last_char = word[-1] 56 | internal_chars = word[1:-1] 57 | internal_permutations = set() 58 | for int_perm in permutations(internal_chars): 59 | int_perm_str = ''.join(int_perm) 60 | perm = '{}{}{}'.format(first_char, int_perm_str, last_char) 61 | internal_permutations.add(perm) 62 | return internal_permutations 63 | 64 | def sample_random_internal_permutations(word, n_perts = 5): 65 | #We try swapping everything with the second character... 66 | if len(word) < 4: 67 | return set([word]) 68 | #iterate through positions between second and last 69 | perturbations = set() 70 | start = word[0] 71 | end = word[-1] 72 | middle = word[1:-1] 73 | for _ in range(n_perts): 74 | middle_list = list(middle) 75 | random.shuffle(middle_list) 76 | mixed_up_middle = ''.join(middle_list) 77 | perturbations.add('{}{}{}'.format(start, mixed_up_middle, end)) 78 | return perturbations 79 | 80 | def get_sorted_word(word): 81 | if len(word) < 3: 82 | sorted_word = word 83 | else: 84 | sorted_word = '{}{}{}'.format(word[0], ''.join(sorted(word[1:-1])), word[-1]) 85 | return sorted_word 86 | 87 | def get_sorted_word_set(word): 88 | if len(word) < 3: 89 | sorted_word = word 90 | else: 91 | sorted_word = '{}{}{}'.format(word[0], ''.join(sorted(word[1:-1])), word[-1]) 92 | return set([sorted_word]) 93 | 94 | 95 | #Used to create agglomerative clusters. 96 | def preprocess_ed1_neighbors(vocab, sub_restrict = None, filetype = 1111): 97 | vocab = set([word.lower() for word in vocab]) 98 | typo2words = defaultdict(set) 99 | for word in vocab: 100 | ed1_typos = get_all_edit_dist_one(word, filetype = filetype, sub_restrict = sub_restrict) 101 | for typo in ed1_typos: 102 | typo2words[typo].add(word) 103 | 104 | word2neighbors = defaultdict(set) 105 | for typo in typo2words: 106 | for word in typo2words[typo]: 107 | word2neighbors[word] = word2neighbors[word].union(typo2words[typo]) 108 | return word2neighbors 109 | 110 | #Used to create agglomerative clusters. 111 | def ed1_neighbors_mat(vocab, sub_restrict = None, filetype = 1111): 112 | vocab = [word.lower() for word in vocab] 113 | word2idx = dict([(word, i) for i, word in enumerate(vocab)]) 114 | word2neighbors = preprocess_ed1_neighbors(vocab, sub_restrict = sub_restrict, filetype = filetype) 115 | edges = set() 116 | for word in word2neighbors: 117 | for neighbor in word2neighbors[word]: 118 | edge = [word, neighbor] 119 | edge.sort() 120 | edge = tuple(edge) 121 | edges.add(edge) 122 | edge_mat = np.zeros((len(vocab), len(vocab)), dtype = int) 123 | for edge in edges: 124 | vtx1, vtx2 = edge 125 | idx1, idx2 = word2idx[vtx1], word2idx[vtx2] 126 | edge_mat[idx1][idx2] = 1 127 | edge_mat[idx2][idx1] = 1 128 | return edge_mat 129 | 130 | 131 | 132 | if __name__ == '__main__': 133 | while True: 134 | word = input("Enter a word: ") 135 | print("Total number of possible perturbations: {}".format(len(get_all_edit_dist_one(word)))) 136 | 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for the following paper: 2 | > [Robust Encodings: A Framework for Combating Adversarial Typos](https://arxiv.org/abs/2005.01229) 3 | > 4 | > Erik Jones, Robin Jia, Aditi Raghunathan, and Percy Liang 5 | > 6 | > Association for Computational Linguistics (ACL), 2020 7 | 8 | ## Cluster Embeddings 9 | We will run experiments for six tasks: RTE, MRPC, SST-2, QNLI, MNLI, QQP. These are used as arguments whenever task name (or mrpc in the following code, which is used as an example) comes up. Data is available on codalab. 10 | ### Standard training 11 | The core element of our defense is a "clusterer" object, which we use to map tokens to a series of representatives, before inputting into a normal model. To create a clusterer, we use two different data sources: 12 | * Embeddings used to filter vocab words: ```data/glove/glove.6b.50d.txt``` 13 | * Word frequencies: ```data/COCA/coca-1grams.json``` 14 | Given these files, to make a clusterer, run: 15 | ```python construct_clusters.py --vocab_size 100000 --perturb_type ed1``` 16 | This will form a clusterer object with path ```clusterers/vocab100000_ed1.pkl```, which will be used in future experiments. 17 | 18 | Now, lots of the following code is adapted from an older version of https://github.com/huggingface/transformers. Data can be found there. We will first fine-tune and save uncased BERT on the MRPC task. To do so, we set the following variables: 19 | ``` 20 | export TASK_NAME=MRPC 21 | export CLUSTERER_PATH=clusterers/vocab100000_ed1.pkl 22 | export GLUE_DIR=data/glue_data 23 | ``` 24 | Where the data from MRPC is stored in ```glue_data``` With these variables set, we run: 25 | ``` 26 | python run_glue.py --task_name $TASK_NAME --do_lower_case --do_train --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir model_output/$TASK_NAME --overwrite_output_dir --seed_output_dir --save_results --save_dir codalab --recoverer identity --augmentor identity --run_test 27 | ``` 28 | This gives us a normally trained model, which will get saved at model_output/MRPC_XXXXXX where XXXXXX is a random six digit number (this is the ```--seed_output_dir``` argument. Information (including clean accuracy which we report, and future attack statistics) will be stored in results/codalab/MRPC_XXXXXX.json. To attack this model, we run: 29 | ``` 30 | python run_glue.py --task_name $TASK_NAME --do_lower_case --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir model_output/$TASK_NAME_XXXXXX --save_results --save_dir codalab --recoverer identity --augmentor identity --run_test --model_name_or_path model_output/MRPC_XXXXXX --attack --new_attack --attacker beam-search --beam_width 5 --attack_name LongDeleteShortAll --attack_type ed1 31 | ``` 32 | There are a lot of arguments here. ```attack``` means an adversary is searching for a typo, and ```new_attack``` says to avoid a cache. ```attacker``` determines the style of heuristic attack, and ```attack_name``` gives the type of token-level peturbation space used for the attack. This is all the information we need for the identity. 33 | ### Data Augmentation 34 | To run this experiment with data augmentation, repeat both runs of python run_glue.py, but with the flag ```--augmentor k-aug```. 35 | 36 | ### Typo Corrector 37 | We'll now replicate the entire typo corrector training process, utilizing the new environment variable: 38 | ``` 39 | $TC_DIR=$HOME/tc_data 40 | ``` 41 | This will have to be made if it does not exist, but it will store preprocessed data, vocabularies, and models. First, we run: 42 | ``` 43 | preprocess_tc.py --glue_dir $GLUE_DIR --save_dir $TC_DIR/glue_tc_preprocessed 44 | ``` 45 | This converts convert the data in ```$GLUE_DIR``` into the correct format to train the typo corrector. This saves in ``` 46 | $TC_DIR/glue_tc_preprocessed 47 | ```. Next, cd to `scRNN`, and run: 48 | ``` 49 | python train.py --task_name mrpc --preprocessed_glue_dir $TC_DIR/glue_tc_preprocessed --tc_dir $TC_DIR 50 | ``` 51 | This trains a typo-corrector based on random perturbations to the MRPC data. The typo corrector is saved at `$TC_DIR/model_dumps` and the associated vocab (necessary) is saved at `TC_DIR/vocab` (both will likely have to be premade in codalab. Now, we can repeat the original run except with ```--recoverer scrnn``` and ```tc_dir $TC_DIR```. 52 | 53 | ### Connected Component Clusters. 54 | Finally, we're done with the baselines! To try using clusters as a defense, we use: 55 | ``` 56 | python run_glue.py --task_name $TASK_NAME --do_lower_case --do_train --do_eval --data_dir $GLUE_DIR/$TASK_NAME --output_dir model_output/$TASK_NAME --overwrite_output_dir --seed_output_dir --save_results --save_dir codalab --recoverer clust-rep --clusterer_path $CLUSTERER_PATH --augmentor identity --run_test --do_robust 57 | ``` 58 | Here, we include ```clusterer_path``` to load the mapping, and ```do_robust``` to compute the actual robust accuracy. 59 | 60 | ### Agglomerative Clusters. 61 | We will now construct our more complicated clusters, the agglomerative clusters. To leverage existing connected components for computational constraints, we parellelize. To do so, first make the directory where the two partial clusteres will be stored: `$clusterers/vocab100000_ed1_gamma0.3$`. Once the directory is made, run, in parallel: 62 | ``` 63 | python agglom_clusters.py --gamma 0.3 --clusterer_path $CLUSTERER_PATH --job_id 0 --num_jobs 2 64 | python agglom_clusters.py --gamma 0.3 --clusterer_path $CLUSTERER_PATH --job_id 1 --num_jobs 2 65 | ``` 66 | This will save two partial clusterers. To combine them (after both jobs are complete) run: 67 | ``` 68 | python reconstruct_clusterers.py --clusterer_dir clusterers/vocab100000_ed1_gamma0.3 69 | ``` 70 | This will save the clusterer at ```clusterers/vocab100000_ed1_gamma0.3.pkl```. Finally, run the identical commands as connected component clusters, but first use ```export CLUSTERER_PATH=clusterers/vocab100000_ed1_gamma0.3.pkl``` to run. Other value of gamma (only needed for SST-2) are loaded from premade saved files (from exactly this process) in ```saved_clusterers```. 71 | 72 | ### Internal permutation experiments 73 | Much of the code remains the same for internal permutations. Just use ```--perturb_type intprm``` when constructing the clusters, ```--attack_type intprm``` when using an internal permutation attack, and ```--recoverer clust-intprm``` to use an internal permutation recoverer. 74 | 75 | -------------------------------------------------------------------------------- /docker/environment.yml: -------------------------------------------------------------------------------- 1 | # Copied from the root level environment.yml 2 | name: atenv 3 | channels: 4 | - pytorch 5 | - anaconda 6 | - conda-forge 7 | - soumith 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _tflow_select=2.1.0=gpu 12 | - absl-py=0.7.0=py36_0 13 | - astor=0.7.1=py36_0 14 | - backcall=0.1.0=py36_0 15 | - blas=1.0=mkl 16 | - bleach=3.1.0=py36_0 17 | - c-ares=1.15.0=h7b6447c_1 18 | - ca-certificates=2020.1.1=0 19 | - certifi=2019.11.28=py36_0 20 | - cffi=1.12.1=py36h2e261b9_0 21 | - cudatoolkit=9.0=h13b8566_0 22 | - cudnn=7.1.2=cuda9.0_0 23 | - cupti=9.0.176=0 24 | - cycler=0.10.0=py_1 25 | - dbus=1.13.6=h746ee38_0 26 | - defusedxml=0.6.0=py_0 27 | - entrypoints=0.3=py36_0 28 | - expat=2.2.6=he6710b0_0 29 | - fontconfig=2.13.1=he4413a7_1000 30 | - freetype=2.9.1=h8a8886c_1 31 | - gast=0.2.2=py36_0 32 | - gettext=0.19.8.1=hc5be6a0_1002 33 | - glib=2.56.2=had28632_1001 34 | - gmp=6.1.2=h6c8ec71_1 35 | - grpcio=1.27.2=py36hf8bcb03_0 36 | - gst-plugins-base=1.14.0=hbbd80ab_1 37 | - gstreamer=1.14.0=hb453b48_1 38 | - h5py=2.9.0=py36h7918eee_0 39 | - hdf5=1.10.4=hb1b8bf9_0 40 | - icu=58.2=hf484d3e_1000 41 | - importlib_metadata=1.5.0=py36_0 42 | - intel-openmp=2019.1=144 43 | - ipykernel=5.1.4=py36h39e3cac_0 44 | - ipython=7.13.0=py36h5ca1d4c_0 45 | - ipython_genutils=0.2.0=py36_0 46 | - ipywidgets=7.5.1=py_0 47 | - jedi=0.16.0=py36_0 48 | - jpeg=9b=h024ee3a_2 49 | - jupyter=1.0.0=py36_7 50 | - jupyter_client=6.0.0=py_0 51 | - jupyter_console=6.1.0=py_0 52 | - jupyter_core=4.6.1=py36_0 53 | - keras-applications=1.0.6=py36_0 54 | - keras-preprocessing=1.0.5=py36_0 55 | - kiwisolver=1.1.0=py36he6710b0_0 56 | - libedit=3.1.20181209=hc058e9b_0 57 | - libffi=3.2.1=hd88cf55_4 58 | - libgcc-ng=8.2.0=hdf63c60_1 59 | - libgfortran-ng=7.3.0=hdf63c60_0 60 | - libiconv=1.15=h516909a_1005 61 | - libpng=1.6.36=hbc83047_0 62 | - libprotobuf=3.6.1=hd408876_0 63 | - libsodium=1.0.16=h1bed415_0 64 | - libstdcxx-ng=8.2.0=hdf63c60_1 65 | - libtiff=4.0.10=h2733197_2 66 | - libuuid=2.32.1=h14c3975_1000 67 | - libxcb=1.13=h14c3975_1002 68 | - libxml2=2.9.9=hea5a465_1 69 | - markdown=3.0.1=py36_0 70 | - markupsafe=1.1.1=py36h7b6447c_0 71 | - matplotlib=3.0.3=py36h5429711_0 72 | - mistune=0.8.4=py36h7b6447c_0 73 | - mkl=2019.1=144 74 | - mkl_fft=1.0.10=py36ha843d7b_0 75 | - mkl_random=1.0.2=py36hd81dba3_0 76 | - nbconvert=5.6.1=py36_0 77 | - nbformat=5.0.4=py_0 78 | - ncurses=6.1=he6710b0_1 79 | - ninja=1.8.2=py36h6bb024c_1 80 | - nltk=3.4.1=py36_0 81 | - notebook=6.0.3=py36_0 82 | - numpy=1.16.2=py36h7e9f1db_0 83 | - numpy-base=1.16.2=py36hde5b4d6_0 84 | - olefile=0.46=py36_0 85 | - openssl=1.1.1e=h7b6447c_0 86 | - pandas=0.24.2=py36he6710b0_0 87 | - pandoc=2.2.3.2=0 88 | - pandocfilters=1.4.2=py36_1 89 | - parso=0.6.2=py_0 90 | - pcre=8.43=he6710b0_0 91 | - pexpect=4.8.0=py36_0 92 | - pickleshare=0.7.5=py36_0 93 | - pillow=5.4.1=py36h34e0f95_0 94 | - pip=19.0.3=py36_0 95 | - prometheus_client=0.7.1=py_0 96 | - prompt_toolkit=3.0.3=py_0 97 | - protobuf=3.6.1=py36he6710b0_0 98 | - pthread-stubs=0.4=h14c3975_1001 99 | - ptyprocess=0.6.0=py36_0 100 | - pycparser=2.19=py36_0 101 | - pyparsing=2.4.2=py_0 102 | - pyqt=5.9.2=py36h05f1152_2 103 | - python=3.6.8=h0371630_0 104 | - python-dateutil=2.8.0=py36_0 105 | - pytorch=1.0.1=py3.6_cuda9.0.176_cudnn7.4.2_2 106 | - pytz=2019.2=py_0 107 | - pyzmq=18.1.1=py36he6710b0_0 108 | - qt=5.9.7=h5867ecd_1 109 | - qtconsole=4.7.1=py_0 110 | - qtpy=1.9.0=py_0 111 | - readline=7.0=h7b6447c_5 112 | - scipy=1.2.1=py36h7c811a0_0 113 | - send2trash=1.5.0=py36_0 114 | - setuptools=40.8.0=py36_0 115 | - sip=4.19.8=py36hf484d3e_1000 116 | - six=1.12.0=py36_0 117 | - spacy=2.1.4=py36hc9558a2_0 118 | - sqlite=3.26.0=h7b6447c_0 119 | - tensorboard=1.12.2=py36he6710b0_0 120 | - tensorflow=1.12.0=gpu_py36he68c306_0 121 | - tensorflow-base=1.12.0=gpu_py36h8e0ae2d_0 122 | - tensorflow-gpu=1.12.0=h0d30ee6_0 123 | - termcolor=1.1.0=py36_1 124 | - terminado=0.8.3=py36_0 125 | - testpath=0.4.4=py_0 126 | - tk=8.6.8=hbc83047_0 127 | - torchfile=0.1.0=py_0 128 | - torchvision=0.2.1=py_2 129 | - tornado=6.0.3=py36h7b6447c_0 130 | - tqdm=4.32.1=py_0 131 | - traitlets=4.3.3=py36_0 132 | - webencodings=0.5.1=py36_1 133 | - wheel=0.33.1=py36_0 134 | - widgetsnbextension=3.5.1=py36_0 135 | - xorg-libxau=1.0.9=h14c3975_0 136 | - xorg-libxdmcp=1.1.3=h516909a_0 137 | - xz=5.2.4=h14c3975_4 138 | - zeromq=4.3.1=he6710b0_3 139 | - zlib=1.2.11=h7b6447c_3 140 | - zstd=1.3.7=h0b5b093_0 141 | - pip: 142 | - alabaster==0.7.12 143 | - allennlp==0.8.4 144 | - atomicwrites==1.3.0 145 | - attrs==19.1.0 146 | - awscli==1.16.194 147 | - babel==2.7.0 148 | - blis==0.2.4 149 | - boto3==1.9.184 150 | - botocore==1.12.184 151 | - chardet==3.0.4 152 | - click==7.0 153 | - colorama==0.3.9 154 | - conllu==0.11 155 | - cvxpy==1.0.25 156 | - cymem==2.0.2 157 | - cython==0.29.12 158 | - decorator==4.4.0 159 | - dill==0.3.0 160 | - docutils==0.14 161 | - dynet==2.1 162 | - ecos==2.0.7.post1 163 | - editdistance==0.5.3 164 | # Can't find the following in pypi 165 | # - en-core-web-sm==2.1.0 166 | - flaky==3.6.0 167 | - flask==1.1.1 168 | - flask-cors==3.0.8 169 | - ftfy==5.5.1 170 | - future==0.17.1 171 | - gevent==1.4.0 172 | - greenlet==0.4.15 173 | - idna==2.8 174 | - imagesize==1.1.0 175 | - importlib-metadata==0.18 176 | - itsdangerous==1.1.0 177 | - jinja2==2.10.1 178 | - jmespath==0.9.4 179 | - joblib==0.13.2 180 | - jsondiff==1.2.0 181 | - jsonnet==0.13.0 182 | - jsonpickle==1.2 183 | - jsonschema==3.0.1 184 | - more-itertools==7.1.0 185 | - multiprocess==0.70.8 186 | - murmurhash==1.0.2 187 | - networkx==2.3 188 | - numpydoc==0.9.1 189 | - osqp==0.5.0 190 | - overrides==1.9 191 | - packaging==19.0 192 | - parsimonious==0.8.1 193 | - plac==0.9.6 194 | - pluggy==0.12.0 195 | - preshed==2.0.1 196 | - py==1.8.0 197 | - pyasn1==0.4.5 198 | - pygments==2.4.2 199 | - pymaxflow==1.2.12 200 | - pyrsistent==0.15.3 201 | - pytest==5.0.1 202 | - pytorch-pretrained-bert==0.6.2 203 | - pytorch-transformers==1.0.0 204 | - pyyaml==5.1 205 | - query==0.1.4 206 | - regex==2019.6.8 207 | - requests==2.22.0 208 | - responses==0.10.6 209 | - rsa==3.4.2 210 | - s3transfer==0.2.1 211 | - scikit-learn==0.21.2 212 | - scs==2.1.1-2 213 | - sentencepiece==0.1.82 214 | - snowballstemmer==1.9.0 215 | - sphinx==2.1.2 216 | - sphinxcontrib-applehelp==1.0.1 217 | - sphinxcontrib-devhelp==1.0.1 218 | - sphinxcontrib-htmlhelp==1.0.2 219 | - sphinxcontrib-jsmath==1.0.1 220 | - sphinxcontrib-qthelp==1.0.2 221 | - sphinxcontrib-serializinghtml==1.1.3 222 | - sqlparse==0.3.0 223 | - srsly==0.0.7 224 | - tensorboardx==1.8 225 | - torch==0.4.1 226 | - thinc==7.0.4 227 | - unidecode==1.1.1 228 | - urllib3==1.25.3 229 | - wasabi==0.2.2 230 | - wcwidth==0.1.7 231 | - werkzeug==0.15.4 232 | - word2number==1.1 233 | - zipp==0.5.2 234 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: erik-cert 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - soumith 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _tflow_select=2.1.0=gpu 11 | - absl-py=0.7.0=py36_0 12 | - astor=0.7.1=py36_0 13 | - backcall=0.1.0=py36_0 14 | - blas=1.0=mkl 15 | - bleach=3.1.0=py36_0 16 | - c-ares=1.15.0=h7b6447c_1 17 | - ca-certificates=2020.1.1=0 18 | - certifi=2019.11.28=py36_0 19 | - cffi=1.12.1=py36h2e261b9_0 20 | - cudatoolkit=9.0=h13b8566_0 21 | - cudnn=7.1.2=cuda9.0_0 22 | - cupti=9.0.176=0 23 | - cycler=0.10.0=py_1 24 | - dbus=1.13.6=h746ee38_0 25 | - defusedxml=0.6.0=py_0 26 | - entrypoints=0.3=py36_0 27 | - expat=2.2.6=he6710b0_0 28 | - fontconfig=2.13.1=he4413a7_1000 29 | - freetype=2.9.1=h8a8886c_1 30 | - gast=0.2.2=py36_0 31 | - gettext=0.19.8.1=hc5be6a0_1002 32 | - glib=2.56.2=had28632_1001 33 | - gmp=6.1.2=h6c8ec71_1 34 | - grpcio=1.27.2=py36hf8bcb03_0 35 | - gst-plugins-base=1.14.0=hbbd80ab_1 36 | - gstreamer=1.14.0=hb453b48_1 37 | - h5py=2.9.0=py36h7918eee_0 38 | - hdf5=1.10.4=hb1b8bf9_0 39 | - icu=58.2=hf484d3e_1000 40 | - importlib_metadata=1.5.0=py36_0 41 | - intel-openmp=2019.1=144 42 | - ipykernel=5.1.4=py36h39e3cac_0 43 | - ipython=7.13.0=py36h5ca1d4c_0 44 | - ipython_genutils=0.2.0=py36_0 45 | - ipywidgets=7.5.1=py_0 46 | - jedi=0.16.0=py36_0 47 | - jpeg=9b=h024ee3a_2 48 | - jupyter=1.0.0=py36_7 49 | - jupyter_client=6.0.0=py_0 50 | - jupyter_console=6.1.0=py_0 51 | - jupyter_core=4.6.1=py36_0 52 | - keras-applications=1.0.6=py36_0 53 | - keras-preprocessing=1.0.5=py36_0 54 | - kiwisolver=1.1.0=py36he6710b0_0 55 | - libedit=3.1.20181209=hc058e9b_0 56 | - libffi=3.2.1=hd88cf55_4 57 | - libgcc-ng=8.2.0=hdf63c60_1 58 | - libgfortran-ng=7.3.0=hdf63c60_0 59 | - libiconv=1.15=h516909a_1005 60 | - libpng=1.6.36=hbc83047_0 61 | - libprotobuf=3.6.1=hd408876_0 62 | - libsodium=1.0.16=h1bed415_0 63 | - libstdcxx-ng=8.2.0=hdf63c60_1 64 | - libtiff=4.0.10=h2733197_2 65 | - libuuid=2.32.1=h14c3975_1000 66 | - libxcb=1.13=h14c3975_1002 67 | - libxml2=2.9.9=hea5a465_1 68 | - markdown=3.0.1=py36_0 69 | - markupsafe=1.1.1=py36h7b6447c_0 70 | - matplotlib=3.0.3=py36h5429711_0 71 | - mistune=0.8.4=py36h7b6447c_0 72 | - mkl=2019.1=144 73 | - mkl_fft=1.0.10=py36ha843d7b_0 74 | - mkl_random=1.0.2=py36hd81dba3_0 75 | - nbconvert=5.6.1=py36_0 76 | - nbformat=5.0.4=py_0 77 | - ncurses=6.1=he6710b0_1 78 | - ninja=1.8.2=py36h6bb024c_1 79 | - nltk=3.4.1=py36_0 80 | - notebook=6.0.3=py36_0 81 | - numpy=1.16.2=py36h7e9f1db_0 82 | - numpy-base=1.16.2=py36hde5b4d6_0 83 | - olefile=0.46=py36_0 84 | - openssl=1.1.1e=h7b6447c_0 85 | - pandas=0.24.2=py36he6710b0_0 86 | - pandoc=2.2.3.2=0 87 | - pandocfilters=1.4.2=py36_1 88 | - parso=0.6.2=py_0 89 | - pcre=8.43=he6710b0_0 90 | - pexpect=4.8.0=py36_0 91 | - pickleshare=0.7.5=py36_0 92 | - pillow=5.4.1=py36h34e0f95_0 93 | - pip=19.0.3=py36_0 94 | - prometheus_client=0.7.1=py_0 95 | - prompt_toolkit=3.0.3=py_0 96 | - protobuf=3.6.1=py36he6710b0_0 97 | - pthread-stubs=0.4=h14c3975_1001 98 | - ptyprocess=0.6.0=py36_0 99 | - pycparser=2.19=py36_0 100 | - pyparsing=2.4.2=py_0 101 | - pyqt=5.9.2=py36h05f1152_2 102 | - python=3.6.8=h0371630_0 103 | - python-dateutil=2.8.0=py36_0 104 | - pytorch=1.0.1=py3.6_cuda9.0.176_cudnn7.4.2_2 105 | - pytz=2019.2=py_0 106 | - pyzmq=18.1.1=py36he6710b0_0 107 | - qt=5.9.7=h5867ecd_1 108 | - qtconsole=4.7.1=py_0 109 | - qtpy=1.9.0=py_0 110 | - readline=7.0=h7b6447c_5 111 | - scipy=1.2.1=py36h7c811a0_0 112 | - send2trash=1.5.0=py36_0 113 | - setuptools=40.8.0=py36_0 114 | - sip=4.19.8=py36hf484d3e_1000 115 | - sqlite=3.26.0=h7b6447c_0 116 | - tensorboard=1.12.2=py36he6710b0_0 117 | - tensorflow=1.12.0=gpu_py36he68c306_0 118 | - tensorflow-base=1.12.0=gpu_py36h8e0ae2d_0 119 | - tensorflow-gpu=1.12.0=h0d30ee6_0 120 | - termcolor=1.1.0=py36_1 121 | - terminado=0.8.3=py36_0 122 | - testpath=0.4.4=py_0 123 | - tk=8.6.8=hbc83047_0 124 | - torchfile=0.1.0=py_0 125 | - torchvision=0.2.1=py_2 126 | - tornado=6.0.3=py36h7b6447c_0 127 | - tqdm=4.32.1=py_0 128 | - traitlets=4.3.3=py36_0 129 | - webencodings=0.5.1=py36_1 130 | - wheel=0.33.1=py36_0 131 | - widgetsnbextension=3.5.1=py36_0 132 | - xorg-libxau=1.0.9=h14c3975_0 133 | - xorg-libxdmcp=1.1.3=h516909a_0 134 | - xz=5.2.4=h14c3975_4 135 | - zeromq=4.3.1=he6710b0_3 136 | - zlib=1.2.11=h7b6447c_3 137 | - zstd=1.3.7=h0b5b093_0 138 | - pip: 139 | - alabaster==0.7.12 140 | - allennlp==0.8.4 141 | - argcomplete==1.9.4 142 | - argh==0.26.2 143 | - atomicwrites==1.3.0 144 | - attrs==19.1.0 145 | - awscli==1.16.194 146 | - babel==2.7.0 147 | - blis==0.2.4 148 | - boto3==1.9.184 149 | - botocore==1.12.184 150 | - bottle==0.12.9 151 | - chardet==3.0.4 152 | - click==7.0 153 | - codalab==0.5.13 154 | - colorama==0.3.9 155 | - conllu==0.11 156 | - cvxpy==1.0.25 157 | - cymem==2.0.2 158 | - cython==0.29.12 159 | - decorator==4.4.0 160 | - diffimg==0.2.3 161 | - dill==0.3.0 162 | - docker==3.7.0 163 | - docker-pycreds==0.4.0 164 | - docutils==0.14 165 | - dynet==2.1 166 | - ecos==2.0.7.post1 167 | - editdistance==0.5.3 168 | - flaky==3.6.0 169 | - flask==1.1.1 170 | - flask-cors==3.0.8 171 | - ftfy==5.5.1 172 | - fusepy==2.0.4 173 | - future==0.17.1 174 | - gevent==1.4.0 175 | - greenlet==0.4.15 176 | - idna==2.8 177 | - imagesize==1.1.0 178 | - importlib-metadata==0.18 179 | - itsdangerous==1.1.0 180 | - jinja2==2.10.1 181 | - jmespath==0.9.4 182 | - joblib==0.13.2 183 | - jsondiff==1.2.0 184 | - jsonnet==0.13.0 185 | - jsonpickle==1.2 186 | - jsonschema==3.0.1 187 | - marshmallow==2.15.1 188 | - marshmallow-jsonapi==0.15.1 189 | - more-itertools==7.1.0 190 | - multiprocess==0.70.8 191 | - murmurhash==1.0.2 192 | - networkx==2.3 193 | - numpydoc==0.9.1 194 | - osqp==0.5.0 195 | - overrides==1.9 196 | - packaging==19.0 197 | - parsimonious==0.8.1 198 | - pathtools==0.1.2 199 | - plac==0.9.6 200 | - pluggy==0.12.0 201 | - preshed==2.0.1 202 | - psutil==5.6.6 203 | - py==1.8.0 204 | - pyasn1==0.4.5 205 | - pygments==2.4.2 206 | - pymaxflow==1.2.12 207 | - pyrsistent==0.15.3 208 | - pytest==5.0.1 209 | - pytorch-pretrained-bert==0.6.2 210 | - pytorch-transformers==1.0.0 211 | - pyyaml==5.1 212 | - regex==2019.6.8 213 | - requests==2.22.0 214 | - responses==0.10.6 215 | - rsa==3.4.2 216 | - s3transfer==0.2.1 217 | - scikit-learn==0.21.2 218 | - scs==2.1.1-2 219 | - selenium==3.141.0 220 | - sentencepiece==0.1.82 221 | - six==1.11.0 222 | - snowballstemmer==1.9.0 223 | - spacy==2.1.4 224 | - sphinx==2.1.2 225 | - sphinxcontrib-applehelp==1.0.1 226 | - sphinxcontrib-devhelp==1.0.1 227 | - sphinxcontrib-htmlhelp==1.0.2 228 | - sphinxcontrib-jsmath==1.0.1 229 | - sphinxcontrib-qthelp==1.0.2 230 | - sphinxcontrib-serializinghtml==1.1.3 231 | - sqlalchemy==1.3.0 232 | - sqlparse==0.3.0 233 | - srsly==0.0.7 234 | - tensorboardx==1.8 235 | - thinc==7.0.4 236 | - unidecode==1.1.1 237 | - urllib3==1.25.3 238 | - wasabi==0.2.2 239 | - watchdog==0.8.3 240 | - wcwidth==0.1.7 241 | - websocket-client==0.57.0 242 | - werkzeug==0.15.4 243 | - word2number==1.1 244 | - zipp==0.5.2 245 | prefix: /sailhome/erjones/.conda/envs/erik-cert 246 | 247 | -------------------------------------------------------------------------------- /word_embedding_model_runners.py: -------------------------------------------------------------------------------- 1 | class TransformerRunner(ModelRunner): 2 | def __init__(self, recoverer, output_mode, label_list, output_dir, device, task_name, 3 | model_type, model_name_or_path, do_lower_case, max_seq_length): 4 | super(TransformerRunner, self).__init__(recoverer, output_mode, label_list, output_dir, device) 5 | self.task_name = task_name 6 | self.model_type = model_type 7 | self.do_lower_case = do_lower_case 8 | self.max_seq_length = max_seq_length 9 | config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type] 10 | self.model_class = model_class 11 | self.tokenizer_class = tokenizer_class 12 | config = config_class.from_pretrained(model_name_or_path, num_labels=len(label_list), 13 | finetuning_task=task_name) 14 | self.tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=do_lower_case) 15 | self.model = model_class.from_pretrained( 16 | model_name_or_path, from_tf=bool('.ckpt' in model_name_or_path), config=config) 17 | self.model.to(device) 18 | 19 | def _prep_examples(self, examples, verbose=False): 20 | features = convert_examples_to_features( 21 | examples, self.label_list, self.max_seq_length, self.tokenizer, self.output_mode, 22 | cls_token_at_end=bool(self.model_type in ['xlnet']), # xlnet has a cls token at the end 23 | cls_token=self.tokenizer.cls_token, 24 | sep_token=self.tokenizer.sep_token, 25 | cls_token_segment_id=2 if self.model_type in ['xlnet'] else 0, 26 | pad_on_left=bool(self.model_type in ['xlnet']), # pad on the left for xlnet 27 | pad_token_segment_id=4 if self.model_type in ['xlnet'] else 0, 28 | verbose=verbose) 29 | 30 | # Convert to Tensors and build dataset 31 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 32 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 33 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 34 | if self.output_mode == "classification": 35 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 36 | elif self.output_mode == "regression": 37 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) 38 | all_text_ids = torch.tensor([f.example_idx for f in features], dtype = torch.long) 39 | 40 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_text_ids) 41 | return dataset 42 | 43 | def train(self, train_data, args): 44 | print("Preparing examples.") 45 | train_dataset = self._prep_examples(train_data, verbose=args.verbose) 46 | print("Starting training.") 47 | global_step, tr_loss, train_results = train(args, train_dataset, self.model, self.tokenizer) 48 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 49 | 50 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 51 | # They can then be reloaded using `from_pretrained()` 52 | logger.info("Saving model checkpoint to %s", self.output_dir) 53 | model_to_save = model.module if hasattr(self.model, 'module') else self.model # Take care of distributed/parallel training 54 | model_to_save.save_pretrained(self.output_dir) 55 | self.tokenizer.save_pretrained(self.output_dir) 56 | torch.save(args, os.path.join(self.output_dir, 'training_args.bin')) 57 | 58 | # Reload model 59 | self.load(self.output_dir, self.device) 60 | print("Finished training.") 61 | 62 | def load(self, output_dir, device): 63 | self.model = self.model_class.from_pretrained(output_dir) 64 | self.tokenizer = self.tokenizer_class.from_pretrained(output_dir) 65 | self.model.to(self.device) 66 | 67 | def query(self, examples, batch_size, do_evaluate=True, return_logits=False, 68 | do_recover=True, use_tqdm=True): 69 | if do_recover: 70 | examples = [self.recoverer.recover_example(x) for x in examples] 71 | dataset = self._prep_examples(examples) 72 | eval_sampler = SequentialSampler(dataset) # Makes sure order is correct 73 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=batch_size) 74 | 75 | # Eval! 76 | logger.info("***** Querying model *****") 77 | logger.info(" Num examples = %d", len(examples)) 78 | logger.info(" Batch size = %d", batch_size) 79 | eval_loss = 0.0 80 | nb_eval_steps = 0 81 | preds = None 82 | out_label_ids = None 83 | example_idxs = None 84 | self.model.eval() 85 | if use_tqdm: 86 | eval_dataloader = tqdm(eval_dataloader, desc="Querying") 87 | for batch in eval_dataloader: 88 | batch = tuple(t.to(self.device) for t in batch) 89 | 90 | with torch.no_grad(): 91 | inputs = {'input_ids': batch[0], 92 | 'attention_mask': batch[1], 93 | 'token_type_ids': batch[2] if self.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids 94 | 'labels': batch[3]} 95 | outputs = self.model(**inputs) 96 | inputs['example_idxs'] = batch[4] 97 | tmp_eval_loss, logits = outputs[:2] 98 | 99 | eval_loss += tmp_eval_loss.mean().item() 100 | 101 | nb_eval_steps += 1 102 | if preds is None: 103 | preds = logits.detach().cpu().numpy() 104 | out_label_ids = inputs['labels'].detach().cpu().numpy() 105 | example_idxs = inputs['example_idxs'].detach().cpu().numpy() 106 | else: 107 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 108 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 109 | example_idxs = np.append(example_idxs, inputs['example_idxs'].detach().cpu().numpy(), axis = 0) 110 | 111 | eval_loss = eval_loss / nb_eval_steps 112 | logger.info(' eval_loss = %.6f', eval_loss) 113 | incorrect_example_indices = None 114 | if self.output_mode == "classification": 115 | pred_argmax = np.argmax(preds, axis=1) 116 | pred_labels = [self.label_list[pred_argmax[i]] for i in range(len(examples))] 117 | incorrect_example_indices = set(example_idxs[np.not_equal(pred_argmax, out_label_ids)]) 118 | 119 | elif self.output_mode == "regression": 120 | preds = np.squeeze(preds) 121 | 122 | if do_evaluate: 123 | result = compute_metrics(self.task_name, pred_argmax, out_label_ids) 124 | output_eval_file = os.path.join(self.output_dir, "eval-{}.txt".format(self.task_name)) 125 | #print("Possible predictions: ", set(list(preds))) 126 | #priny("Model predictions: mean: {}, max: {}, min: {}".format(preds.mean(), preds.max(), preds.min())) 127 | with open(output_eval_file, "w") as writer: 128 | logger.info("***** Eval results *****") 129 | for key in sorted(result.keys()): 130 | logger.info(" %s = %s", key, str(result[key])) 131 | writer.write("%s = %s\n" % (key, str(result[key]))) 132 | 133 | if return_logits: 134 | return preds 135 | else: 136 | return pred_labels 137 | -------------------------------------------------------------------------------- /recoverer.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import itertools 3 | import json 4 | import numpy as np 5 | import os 6 | import pickle 7 | import random 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from scRNN.corrector import ScRNNChecker 12 | from utils import OOV_CLUSTER, OOV_TOKEN 13 | from utils_glue import InputExample 14 | from edit_dist_utils import get_sorted_word 15 | 16 | class Recoverer(object): 17 | """Clean up a possibly typo-ed string.""" 18 | def __init__(self, cache_dir): 19 | self.cache_dir = cache_dir 20 | self.cache = {} 21 | self.name = None # Subclasses should set this 22 | 23 | def _cache_path(self): 24 | return os.path.join(self.cache_dir, 'recoveryCache-{}.json'.format(self.name)) 25 | 26 | def load_cache(self): 27 | path = self._cache_path() 28 | if os.path.exists(path): 29 | with open(self._cache_path()) as f: 30 | self.cache = json.load(f) 31 | print('Recoverer: loaded {} values from cache.'.format(len(self.cache))) 32 | else: 33 | print('Recoverer: no cache at {}.'.format(path)) 34 | 35 | def save_cache(self, save = False): 36 | if save: 37 | cache_path = self._cache_path() 38 | print('Recoverer: saving {} cached values to {} .'.format(len(self.cache), cache_path)) 39 | with open(cache_path, 'w') as f: 40 | json.dump(self.cache, f) 41 | 42 | 43 | def recover(self, text): 44 | """Recover |text| to a new string. 45 | 46 | Used at test time to preprocess possibly typo-ed input. 47 | """ 48 | if text in self.cache: 49 | return self.cache[text] 50 | recovered = self._recover(text) 51 | self.cache[text] = recovered 52 | return recovered 53 | 54 | def _recover(self, text): 55 | """Actually do the recovery for self.recover().""" 56 | raise NotImplementedError 57 | 58 | def get_possible_recoveries(self, text, attack_surface, max_num, analyze_res_attacks = False, ret_ball_stats = False): 59 | """For a clean string, return list of possible recovered strings, or None if too many. 60 | 61 | Used at certification time to exactly compute robust accuracy. 62 | 63 | Returns tuple (list_of_possibilities, num_possibilities) 64 | where list_of_possibilities is None if num_possibilities > max_num. 65 | """ 66 | pass 67 | 68 | def recover_example(self, example): 69 | """Recover an InputExample |example| to a new InputExample. 70 | 71 | Used at test time to preprocess possibly typo-ed input. 72 | """ 73 | tokens = example.text_a.split() 74 | a_len = len(tokens) 75 | if example.text_b: 76 | tokens.extend(example.text_b.split()) 77 | recovered_tokens = self.recover(' '.join(tokens)).split() 78 | a_new = ' '.join(recovered_tokens[:a_len]) 79 | if example.text_b: 80 | b_new = ' '.join(recovered_tokens[a_len:]) 81 | else: 82 | b_new = None 83 | return InputExample(example.guid, a_new, b_new, example.label) 84 | 85 | def get_possible_examples(self, example, attack_surface, max_num, analyze_res_attacks = False): 86 | """For a clean InputExample, return list of InputExample's you could recover to. 87 | 88 | Used at certification time to exactly compute robust accuracy. 89 | """ 90 | tokens = example.text_a.split() 91 | a_len = len(tokens) 92 | if example.text_b: 93 | tokens.extend(example.text_b.split()) 94 | possibilities, num_poss, perturb_counts = self.get_possible_recoveries(' '.join(tokens), attack_surface, max_num, 95 | analyze_res_attacks = analyze_res_attacks) 96 | if perturb_counts is not None: 97 | assert len(perturb_counts) == len(possibilities) 98 | if not possibilities: 99 | return (None, num_poss) 100 | out = [] 101 | example_num = 0 102 | for i in range(len(possibilities)): 103 | poss = possibilities[i] 104 | poss_tokens = poss.split() 105 | a = ' '.join(poss_tokens[:a_len]) 106 | if example.text_b: 107 | b = ' '.join(poss_tokens[a_len:]) 108 | else: 109 | b = None 110 | if not analyze_res_attacks: 111 | poss_guid = '{}-{}'.format(example.guid, example_num) 112 | else: 113 | poss_guid = '{}-{}-{}'.format(example.guid, example_num, perturb_counts[i]) 114 | out.append(InputExample('{}-{}'.format(poss_guid, example_num), a, b, example.label)) 115 | example_num += 1 116 | return (out, len(out)) 117 | 118 | 119 | class IdentityRecoverer(Recoverer): 120 | def __init__(self, cache_dir): 121 | super(IdentityRecoverer, self).__init__(cache_dir) 122 | self.name = 'IdentityRecoverer' 123 | 124 | def recover(self, text): 125 | """Override self.recover() rather than self._recover() to avoid cache.""" 126 | return text 127 | 128 | class ClusterRecoverer(Recoverer): 129 | def __init__(self, cache_dir, clustering): 130 | super(ClusterRecoverer, self).__init__(cache_dir) 131 | self.clustering = clustering 132 | self.passthrough = False 133 | 134 | def get_possible_recoveries(self, text, attack_surface, max_num, analyze_res_attacks = False, ret_ball_stats = False): 135 | tokens = text.split() 136 | possibilities = [] 137 | perturb_counts = [] 138 | standard_clustering = np.array([self.clustering.map_token(token) for token in tokens]) 139 | for token in tokens: 140 | cur_perturb = attack_surface.get_perturbations(token) 141 | perturb_counts.append(len(cur_perturb)) 142 | poss_clusters = set() 143 | for pert in cur_perturb: 144 | clust_id = self.clustering.map_token(pert) 145 | poss_clusters.add(clust_id) 146 | possibilities.append(sorted(poss_clusters, key=str)) # sort for deterministic order 147 | if ret_ball_stats: 148 | return [len(pos_clusters) for pos_clusters in possibilities], perturb_counts 149 | num_pos = reduce(lambda x, y: x * y, [len(x) for x in possibilities]) 150 | if num_pos > max_num: 151 | return (None, num_pos, None) 152 | poss_recoveries = [] 153 | perturb_counts = None 154 | if analyze_res_attacks: 155 | perturb_counts = [] 156 | num_zero = 0 157 | for clust_seq in itertools.product(*possibilities): 158 | if analyze_res_attacks: 159 | #print("Stand: ", standard_clustering) 160 | #print("Seq: ", clust_seq) 161 | #print("Lengths: {}, {}".format(len(standard_clustering), len(clust_seq))) 162 | #print("Types: {}, {}".format(type(np.array(clust_seq)[0]), type(standard_clustering[0]))) 163 | #print("Comparison: ", np.array(clust_seq) != standard_clustering) 164 | #print("Inv comparison: ", np.array(clust_seq) == standard_clustering) 165 | num_different = (np.array(clust_seq) != standard_clustering).sum() 166 | if num_different == 0: 167 | num_zero += 1 168 | #print(num_different) 169 | perturb_counts.append(num_different) 170 | poss_recoveries.append(self._recover_from_clusters(clust_seq)) 171 | assert num_zero == 1 or not analyze_res_attacks 172 | return (poss_recoveries, len(poss_recoveries), perturb_counts) 173 | 174 | def _recover(self, text): 175 | tokens = text.split() 176 | clust_ids = [self.clustering.map_token(w, passthrough = self.passthrough) for w in tokens] 177 | return self._recover_from_clusters(clust_ids) 178 | 179 | def _recover_from_clusters(self, clust_ids): 180 | raise NotImplementedError 181 | 182 | 183 | class ClusterRepRecoverer(ClusterRecoverer): 184 | def _recover_from_clusters(self, clust_ids): 185 | tokens = [] 186 | for c in clust_ids: 187 | if c == OOV_CLUSTER: 188 | tokens.append('[MASK]') 189 | else: 190 | tokens.append(self.clustering.get_rep(c)) 191 | return ' '.join(tokens) 192 | 193 | class ClusterIntprmRecoverer(ClusterRepRecoverer): 194 | def get_possible_recoveries(self, text, attack_surface, max_num, analyze_res_attacks = False, ret_ball_stats = False): 195 | if analyze_res_attacks: 196 | raise NotImplementedError 197 | tokens = text.split() 198 | clusters = [] 199 | for token in tokens: 200 | token_key = get_sorted_word(token) #Adversary can't modify sorted word, since attack is internal perturbations 201 | clust_id = self.clustering.map_token(token_key) 202 | clusters.append(clust_id) 203 | recovery = self._recover_from_clusters(clusters) 204 | return ([recovery], 1, None) #One possibility, and no perturb_counts since we don't analyze resticted attack yet. 205 | 206 | def _recover(self, text): 207 | tokens = text.split() 208 | clust_ids = [self.clustering.map_token(get_sorted_word(w), passthrough = False) for w in tokens] 209 | return self._recover_from_clusters(clust_ids) 210 | 211 | class ScRNNRecoverer(Recoverer): 212 | def __init__(self, cache_dir, tc_dir, task_name): 213 | super(ScRNNRecoverer, self).__init__(cache_dir) 214 | self.checker = ScRNNChecker(tc_dir, unk_output=True, task_name=task_name) 215 | 216 | def _recover(self, text): 217 | return self.checker.correct_string(text) 218 | 219 | 220 | 221 | RECOVERERS = { 222 | 'identity': IdentityRecoverer, 223 | 'clust-rep': ClusterRepRecoverer, 224 | 'clust-intprm': ClusterIntprmRecoverer, 225 | 'scrnn': ScRNNRecoverer, 226 | } 227 | -------------------------------------------------------------------------------- /construct_clusters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from scipy.sparse import csr_matrix 4 | from scipy.sparse.csgraph import connected_components 5 | import os 6 | import json 7 | from collections import defaultdict 8 | import string 9 | import argparse 10 | 11 | import matplotlib as mpl 12 | import matplotlib.pyplot as plt 13 | import networkx as nx 14 | 15 | from edit_dist_utils import get_all_edit_dist_one, get_sorted_word 16 | from preprocess_vocab import preprocess_neighbors, preprocess_neighbors_intprm 17 | from utils import pkl_load, pkl_save 18 | 19 | 20 | RELATIVE_GLOVE_PATH = 'glove/glove.6B.50d.txt' 21 | RELATIVE_COCA_PATH = 'COCA/coca-1grams.json' 22 | 23 | #Read in frequencies from COCA 24 | def read_coca_freq(coca_path, sort = True): 25 | with open(coca_path, 'r', encoding="ISO-8859-1") as f: 26 | coca_freq_dict = json.load(f) 27 | frequencies = [(elem.split('_')[0], int(coca_freq_dict[elem])) for elem in coca_freq_dict] 28 | if sort: 29 | frequencies.sort(key = lambda x: x[1], reverse = True) 30 | frequencies = process_duplicates(frequencies) 31 | return frequencies 32 | 33 | def process_duplicates(frequencies): 34 | frequency_dict = {} 35 | duplicates = set() 36 | for elem, freq in frequencies: 37 | if elem in frequency_dict: 38 | duplicates.add(elem) 39 | frequency_dict[elem] += int(freq) 40 | else: 41 | frequency_dict[elem] = int(freq) 42 | frequencies = [(elem, frequency_dict[elem]) for elem in frequency_dict] 43 | frequencies.sort(key = lambda x: x[1], reverse = True) 44 | return frequencies 45 | 46 | def get_glove_vocab(glove_path, num_lines = 400000): 47 | print("Reading GloVe vectors from {}...".format(glove_path)) 48 | embedding_map = {} 49 | glove_vocab = set() 50 | with open(glove_path, encoding = 'utf-8', mode = "r") as f: 51 | for i, line in tqdm(enumerate(f), total=num_lines): 52 | toks = line.strip().split(' ') 53 | word = toks[0] 54 | glove_vocab.add(word) 55 | return glove_vocab 56 | 57 | 58 | class Clusterer(): 59 | def __init__(self, vocab_size = 100000, perturb_type = 'ed1'): 60 | #max_num_ret should be greater than number of initial vertices in the graph 61 | print("Getting word frequencies...") 62 | self.num_verts = vocab_size 63 | self.frequencies = read_coca_freq(os.path.join(args.data_dir, RELATIVE_COCA_PATH)) 64 | self.word2freq = self._get_word2freq(self.frequencies) 65 | 66 | #verify frequencies are sorted 67 | assert self._is_sorted([float(f[1]) for f in self.frequencies], reverse = True) 68 | self.vertices = [] 69 | self.edges = [] 70 | 71 | #constraining vocab to be in glove 72 | glove_vocab = get_glove_vocab(os.path.join(args.data_dir, RELATIVE_GLOVE_PATH)) 73 | self.frequencies = [elem for elem in self.frequencies if elem[0] in glove_vocab] 74 | 75 | self.word2cluster = {} 76 | self.cluster2representative = {} 77 | self.clusters = {} 78 | self.perturb_type = perturb_type 79 | 80 | def _is_sorted(self, lst, reverse = False): 81 | start, end, incr = 0, len(lst), 1 82 | if reverse: 83 | start, end, incr = len(lst) - 1, -1, -1 84 | prev_elem = float('-inf') 85 | for i in range(start, end, incr): 86 | elem = lst[i] 87 | if elem < prev_elem: 88 | return False 89 | prev_elem = elem 90 | return True 91 | 92 | def _get_word2freq(self, frequencies): 93 | word2freq = defaultdict(lambda: 0) 94 | for word, freq in frequencies: 95 | word2freq[word] += int(freq) 96 | return word2freq 97 | 98 | def construct_graph(self, perturb_type = 'ed1'): 99 | """ 100 | Form a graph with nodes given by vocabulary in the constructor 101 | and edges between words that share a perturbation (using perturb_type) 102 | """ 103 | 104 | self.perturb_type = perturb_type 105 | #Constrain vertex set. 106 | self.vertices = list([self.frequencies[i][0] for i in range(self.num_verts)]) 107 | self.word2freq = dict([(vtx, self.word2freq[vtx]) for vtx in self.vertices]) 108 | 109 | 110 | if perturb_type == 'ed1': 111 | typo2words, neighbor_map, _ = preprocess_neighbors(self.vertices) 112 | elif perturb_type == 'intprm': 113 | typo2words, neighbor_map, _ = preprocess_neighbors_intprm(self.vertices) 114 | else: 115 | raise ValueError("Unsupported perturbation type") 116 | self.typo2words = typo2words 117 | 118 | print("Computing edges...") 119 | self.edges = self._filter_edges(neighbor_map) 120 | print("Generating edge matrix...") 121 | self.edge_mat = self._edges_to_matrix(self.vertices, self.edges) 122 | print("Finished constructing the graph") 123 | 124 | def _filter_edges(self, neighbor_map): 125 | """ 126 | neighbor_map 127 | """ 128 | possible_edges = set() 129 | rejected_edges = set() 130 | similarities = [] 131 | for vtx in neighbor_map: 132 | for vtx2 in neighbor_map[vtx]: 133 | if vtx == vtx2: 134 | continue 135 | vtx_pair = [vtx, vtx2] 136 | vtx_pair.sort() #Graph is undirected, assume attack surface is symmetric 137 | vtx_pair = tuple(vtx_pair) 138 | if vtx_pair in possible_edges or vtx_pair in rejected_edges: 139 | continue 140 | else: 141 | possible_edges.add(vtx_pair) 142 | return possible_edges 143 | 144 | def _edges_to_matrix(self, vertices, edges): 145 | #exclusive_edges = [edge for edge in edges if edge[0] != edge[1]] 146 | #print("Num vertices: {}, edges: {}".format(vertices, exclusive_edges)) 147 | edge_mat = np.zeros(shape = (len(vertices), len(vertices)), dtype = bool) 148 | vert2idx = dict([(vertices[i], i) for i in range(len(vertices))]) 149 | #for edge in tqdm(edges): 150 | for edge in edges: 151 | vert1, vert2 = edge 152 | if vert1 == vert2: 153 | continue 154 | id1, id2 = vert2idx[vert1], vert2idx[vert2] 155 | edge_mat[id1][id2] = 1 156 | edge_mat[id2][id1] = 1 157 | return edge_mat 158 | 159 | 160 | def construct_clusters(self): 161 | if self.edge_mat is None: 162 | raise ValueError("Graph must already be computed...") 163 | self.cluster2elements = defaultdict(list) 164 | graph = csr_matrix(self.edge_mat) 165 | n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True) 166 | self.num_clusters = n_components 167 | for i in range(labels.shape[0]): 168 | label = labels[i] 169 | if label not in self.cluster2representative: 170 | representative = self.vertices[i] #Assumes vertices in sorted order of freq 171 | #print("Setting {} to representative".format(self.vertices[i])) 172 | self.cluster2representative[label] = representative 173 | self.clusters[label] = [] 174 | #print("Adding {} to {}'s cluster".format(self.vertices[i], self.cluster2representative[label])) 175 | self.word2cluster[self.vertices[i]] = label 176 | self.clusters[label].append(self.vertices[i]) 177 | self.typo2cluster = self._get_typo2cluster() 178 | 179 | def _get_typo2cluster(self): 180 | typo2cluster = {} 181 | typo2word = self._get_typo2word(self.vertices, self.word2freq) 182 | for typo in tqdm(typo2word): 183 | typo2cluster[typo] = self.word2cluster[typo2word[typo]] 184 | return typo2cluster 185 | 186 | def _get_typo2word(self, words, word2freq): 187 | typo2word = {} 188 | print("Getting typo2word") 189 | for typo in tqdm(self.typo2words): 190 | possible_words = self.typo2words[typo] 191 | typo_word_freq_list = [(word, word2freq[word]) for word in possible_words] 192 | typo_word_freq_list.sort(key = lambda x: x[1], reverse = True) 193 | most_frequent_word = typo_word_freq_list[0][0] 194 | typo2word[typo] = most_frequent_word 195 | 196 | #Word always recovers to it's own cluster 197 | for word in words: 198 | typo2word[word] = word 199 | 200 | return typo2word 201 | 202 | def save_clusterer(vocab_size = 100000, perturb_type = 'ed1', save_dir = 'clusterers', 203 | check_perturb_size = False): 204 | 205 | filename = 'vocab{}_{}.pkl'.format(vocab_size, perturb_type) 206 | if not os.path.isdir(save_dir): 207 | os.makedirs(save_dir) 208 | save_path = os.path.join(save_dir, filename) 209 | print("Will save at: {}".format(save_path)) 210 | 211 | #Initializing clusterer 212 | clusterer = Clusterer(vocab_size = vocab_size) 213 | #Initializing the graph 214 | clusterer.construct_graph(perturb_type = perturb_type) 215 | #Creating clusters. 216 | clusterer.construct_clusters() 217 | 218 | #Option to analyze number of perturbations, etc. 219 | #if check_perturb_size: 220 | # get_vocab_statistics(clusterer.vertices) 221 | # return 222 | save_dict = {'cluster': clusterer.clusters, 223 | 'word2cluster': clusterer.word2cluster, 224 | 'cluster2representative': clusterer.cluster2representative, 225 | 'word2freq': clusterer.word2freq, 226 | 'typo2cluster': clusterer.typo2cluster} 227 | 228 | print("Saving everything at: ", save_path) 229 | pkl_save(save_dict, save_path) 230 | print("Number of clusters: {}, vocab size: {}".format(len(clusterer.clusters), vocab_size)) 231 | 232 | def get_vocab_statistics(vertices): 233 | #vertices correspond to words in the vocabulary. 234 | #prints stats on the number of perturbations each word has. 235 | num_perturbations = [] 236 | print("Total number of vertices: ", len(vertices)) 237 | for vtx in tqdm(vertices): 238 | num_perturbations.append(len(get_all_edit_dist_one(vtx))) 239 | num_perturbations = np.array(num_perturbations) 240 | print("Mean: {} Min: {} Max: {}".format(num_perturbations.mean(), num_perturbations.min(), num_perturbations.max())) 241 | 242 | 243 | def parse_args(): 244 | parser = argparse.ArgumentParser() 245 | parser.add_argument("--data_dir", default=None, type=str, required=False, 246 | help="The input data dir.") 247 | parser.add_argument("--output_dir", default=None, type=str, required=False, 248 | help="The output dir for the clusterer.") 249 | parser.add_argument('--vocab_size', type = int, default = 100000, 250 | help = 'Size of the vocabulary used to make the clusters.') 251 | parser.add_argument('--perturb_type', choices=['ed1', 'intprm'], type = str, 252 | help = 'type of perturbation used to define clusters') 253 | return parser.parse_args() 254 | 255 | if __name__ == '__main__': 256 | print("Starting the run...") 257 | args = parse_args() 258 | save_clusterer(vocab_size = args.vocab_size, perturb_type = args.perturb_type, save_dir=args.output_dir) 259 | 260 | 261 | 262 | -------------------------------------------------------------------------------- /scRNN/utils.py: -------------------------------------------------------------------------------- 1 | """ helper functions for 2 | - data loading 3 | - representation building 4 | - vocabulary loading 5 | """ 6 | 7 | from collections import defaultdict 8 | import numpy as np 9 | import pandas as pd 10 | import pickle 11 | import os 12 | import random 13 | from random import shuffle 14 | 15 | CHAR_VOCAB = [] 16 | CHAR_VOCAB_BG = [] 17 | w2i = defaultdict(lambda: 0.0) 18 | w2i_bg = defaultdict(lambda: 0.0) 19 | i2w = defaultdict(lambda: "UNK") 20 | i2w_bg = defaultdict(lambda: "UNK") 21 | 22 | #TODO: think of an open vocabulary system 23 | WORD_LIMIT = 9999 # remaining 1 for (this is inclusive of UNK) 24 | task_name = "" 25 | TARGET_PAD_IDX = -1 26 | INPUT_PAD_IDX = 0 27 | 28 | keyboard_mappings = None 29 | 30 | def set_word_limit(word_limit, task=""): 31 | global WORD_LIMIT 32 | global task_name 33 | WORD_LIMIT = word_limit 34 | task_name = task 35 | 36 | 37 | def get_lines(filename, glue = False): 38 | if glue: 39 | return get_lines_glue(filename) #Hacky but less destructive to existing code... 40 | f = open(filename) 41 | lines = f.readlines() 42 | if "|||" in lines[0]: 43 | # remove the tag 44 | clean_lines = [line.split("|||")[1].strip().lower() for line in lines] 45 | else: 46 | clean_lines = [line.strip().lower() for line in lines] 47 | return clean_lines 48 | 49 | def get_lines_glue(glue_filename): 50 | data = pd.read_csv(glue_filename, sep = '\t') 51 | cols = list(data) 52 | lines = data['text_a'].tolist() 53 | if 'text_b' in cols: 54 | lines.extend(data['text_b'].tolist()) 55 | lines = [line.strip().lower() for line in lines if not isinstance(line, float)] 56 | return lines 57 | 58 | 59 | 60 | def create_vocab(filename, tc_dir, background_train=False, cv_path="", glue = False, char_limit = 70): 61 | global w2i, i2w, CHAR_VOCAB 62 | char2count = defaultdict(lambda: 0) 63 | lines = get_lines(filename, glue = glue) 64 | for line in lines: 65 | for word in line.split(): 66 | 67 | # add all its char in vocab 68 | for char in word: 69 | char2count[char] += 1 70 | if char not in CHAR_VOCAB: 71 | CHAR_VOCAB.append(char) 72 | 73 | w2i[word] += 1.0 74 | 75 | if background_train: 76 | CHAR_VOCAB = pickle.load(open(cv_path, 'rb')) 77 | word_list = sorted(w2i.items(), key=lambda x:x[1], reverse=True) 78 | word_list = word_list[:WORD_LIMIT] # only need top few words 79 | if char_limit is not None: 80 | char_freq_list = [(char, char2count[char]) for char in char2count] 81 | char_freq_list.sort(key = lambda x: x[1], reverse = True) 82 | CHAR_VOCAB = [char for (char, freq) in char_freq_list[:char_limit]] 83 | 84 | 85 | # remaining words are UNKs ... sorry! 86 | w2i = defaultdict(lambda: WORD_LIMIT) # default id is UNK ID 87 | w2i[''] = INPUT_PAD_IDX # INPUT_PAD_IDX is 0 88 | i2w[INPUT_PAD_IDX] = '' 89 | for idx in range(WORD_LIMIT-1): 90 | w2i[word_list[idx][0]] = idx+1 91 | i2w[idx+1] = word_list[idx][0] 92 | 93 | if not os.path.exists(os.path.join(tc_dir, 'vocab')): 94 | os.makedirs(os.path.join(tc_dir, 'vocab')) 95 | pickle.dump(dict(w2i), open(tc_dir + "/vocab/" + task_name + "w2i_" + str(WORD_LIMIT) + ".p", 'wb')) 96 | pickle.dump(dict(i2w), open(tc_dir + "/vocab/" + task_name + "i2w_" + str(WORD_LIMIT) + ".p", 'wb')) # don't think its needed 97 | pickle.dump(CHAR_VOCAB, open(tc_dir + "/vocab/" + task_name + "CHAR_VOCAB_ " + str(WORD_LIMIT) + ".p", 'wb')) 98 | return 99 | 100 | 101 | def load_vocab_dicts(wi_path, iw_path, cv_path, use_background=False): 102 | wi = pickle.load(open(wi_path, 'rb')) 103 | iw = pickle.load(open(iw_path, 'rb')) 104 | cv = pickle.load(open(cv_path, 'rb')) 105 | if use_background: 106 | convert_vocab_dicts_bg(wi, iw, cv) 107 | else: 108 | convert_vocab_dicts(wi, iw, cv) 109 | return wi, iw, cv 110 | 111 | """ converts vocabulary dictionaries into defaultdicts 112 | """ 113 | def convert_vocab_dicts(wi, iw, cv): 114 | global w2i, i2w, CHAR_VOCAB 115 | CHAR_VOCAB = cv 116 | w2i = defaultdict(lambda: WORD_LIMIT) 117 | for w in wi: 118 | w2i[w] = wi[w] 119 | 120 | for i in iw: 121 | i2w[i] = iw[i] 122 | return 123 | 124 | def convert_vocab_dicts_bg(wi, iw, cv): 125 | global w2i_bg, i2w_bg, CHAR_VOCAB_BG 126 | CHAR_VOCAB_BG = cv 127 | w2i_bg = defaultdict(lambda: WORD_LIMIT) 128 | for w in wi: 129 | w2i_bg[w] = wi[w] 130 | 131 | for i in iw: 132 | i2w_bg[i] = iw[i] 133 | return 134 | 135 | 136 | def get_target_representation(line): 137 | return [w2i[word] for word in line.split()] 138 | 139 | def pad_input_sequence(X, max_len): 140 | assert (len(X) <= max_len) 141 | while len(X) != max_len: 142 | X.append([INPUT_PAD_IDX for _ in range(len(X[0]))]) 143 | return X 144 | 145 | def pad_target_sequence(y, max_len): 146 | assert (len(y) <= max_len) 147 | while len(y) != max_len: 148 | y.append(TARGET_PAD_IDX) 149 | return y 150 | 151 | def get_batched_input_data(lines, batch_size, rep_list=['swap'], probs=[1.0]): 152 | #shuffle(lines) 153 | if len(rep_list) == 0 or 'all' in rep_list: 154 | rep_list = ['swap', 'drop', 'add', 'key'] 155 | probs=np.ones(len(rep_list)) / len(rep_list) 156 | total_len = len(lines) 157 | output = [] 158 | for batch_start in range(0, len(lines) - batch_size, batch_size): 159 | 160 | input_lines = [] 161 | modified_lines = [] 162 | X = [] 163 | y = [] 164 | lens = [] 165 | max_len = max([len(line.split()) \ 166 | for line in lines[batch_start: batch_start + batch_size]]) 167 | 168 | for line in lines[batch_start: batch_start + batch_size]: 169 | X_i, modified_line_i = get_line_representation(line, rep_list, probs) 170 | assert (len(line.split()) == len(modified_line_i.split())) 171 | y_i = get_target_representation(line) 172 | # pad X_i, and y_i 173 | X_i = pad_input_sequence(X_i, max_len) 174 | y_i = pad_target_sequence(y_i, max_len) 175 | # append input lines, modified lines, X_i, y_i, lens 176 | input_lines.append(line) 177 | modified_lines.append(modified_line_i) 178 | X.append(X_i) 179 | y.append(y_i) 180 | lens.append(len(modified_line_i.split())) 181 | 182 | output.append((input_lines, modified_lines, np.array(X), np.array(y), lens)) 183 | return output 184 | 185 | def get_line_representation(line, rep_list=['swap'], probs=[1.0]): 186 | rep = [] 187 | modified_words = [] 188 | for word in line.split(): 189 | rep_type = np.random.choice(rep_list, 1, p=probs)[0] 190 | if 'swap' in rep_type: 191 | word_rep, new_word = get_swap_word_representation(word) 192 | elif 'drop' in rep_type: 193 | word_rep, new_word = get_drop_word_representation(word, 1.0) 194 | elif 'add' in rep_type: 195 | word_rep, new_word = get_add_word_representation(word) 196 | elif 'key' in rep_type: 197 | word_rep, new_word = get_keyboard_word_representation(word) 198 | elif 'none' in rep_type or 'normal' in rep_type: 199 | word_rep, _ = get_swap_word_representation(word) 200 | new_word = word 201 | else: 202 | #TODO: give a more ceremonious error... 203 | raise NotImplementedError 204 | rep.append(word_rep) 205 | modified_words.append(new_word) 206 | return rep, " ".join(modified_words) 207 | 208 | 209 | """ word representation from individual chars 210 | one hot (first char) + bag of chars (middle chars) + one hot (last char) 211 | """ 212 | def get_swap_word_representation(word): 213 | 214 | # dirty case 215 | if len(word) == 1 or len(word) == 2: 216 | rep = one_hot(word[0]) + zero_vector() + one_hot(word[-1]) 217 | return rep, word 218 | 219 | rep = one_hot(word[0]) + bag_of_chars(word[1:-1]) + one_hot(word[-1]) 220 | if len(word) > 3: 221 | idx = random.randint(1, len(word)-3) 222 | word = word[:idx] + word[idx + 1] + word[idx] + word[idx+2:] 223 | 224 | return rep, word 225 | 226 | 227 | 228 | 229 | """ word representation from individual chars (except that one of the internal 230 | chars might be dropped with a probability prob 231 | """ 232 | def get_drop_word_representation(word, prob=0.5): 233 | p = random.random() 234 | if len(word) >= 5 and p < prob: 235 | idx = random.randint(1, len(word)-2) 236 | word = word[:idx] + word[idx+1:] 237 | rep, _ = get_swap_word_representation(word) # don't care about the returned word 238 | elif p > prob: 239 | rep, word = get_swap_word_representation(word) 240 | else: 241 | rep, _ = get_swap_word_representation(word) # don't care about the returned word 242 | return rep, word 243 | 244 | 245 | def get_add_word_representation(word): 246 | if len(word) >= 3: 247 | idx = random.randint(1, len(word)-1) 248 | random_char = _get_random_char() 249 | word = word[:idx] + random_char + word[idx:] 250 | rep, _ = get_swap_word_representation(word) # don't care about the returned word 251 | else: 252 | rep, _ = get_swap_word_representation(word) # don't care about the returned word 253 | return rep, word 254 | 255 | def get_keyboard_word_representation(word): 256 | if len(word) >=3: 257 | idx = random.randint(1, len(word)-2) 258 | keyboard_neighbor = _get_keyboard_neighbor(word[idx]) 259 | word = word[:idx] + keyboard_neighbor + word[idx+1:] 260 | rep, _ = get_swap_word_representation(word) # don't care about the returned word 261 | else: 262 | rep, _ = get_swap_word_representation(word) # don't care about the returned word 263 | return rep, word 264 | 265 | 266 | """ word representation from bag of chars 267 | """ 268 | def get_boc_word_representation(word): 269 | return zero_vector() + bag_of_chars(word) + zero_vector() 270 | 271 | 272 | def one_hot(char): 273 | return [1.0 if ch == char else 0.0 for ch in CHAR_VOCAB] 274 | 275 | 276 | def bag_of_chars(chars): 277 | return [float(chars.count(ch)) for ch in CHAR_VOCAB] 278 | 279 | 280 | def zero_vector(): 281 | return [0.0 for _ in CHAR_VOCAB] 282 | 283 | 284 | #TODO: is that all the characters we need?? 285 | def _get_random_char(): 286 | alphabets = "abcdefghijklmnopqrstuvwxyz" 287 | alphabets = [i for i in alphabets] 288 | return np.random.choice(alphabets, 1)[0] 289 | 290 | 291 | def _get_keyboard_neighbor(ch): 292 | global keyboard_mappings 293 | if keyboard_mappings is None or len(keyboard_mappings) != 26: 294 | keyboard_mappings = defaultdict(lambda: []) 295 | keyboard = ["qwertyuiop", "asdfghjkl*", "zxcvbnm***"] 296 | row = len(keyboard) 297 | col = len(keyboard[0]) 298 | 299 | dx = [-1, 1, 0, 0] 300 | dy = [0, 0, -1, 1] 301 | 302 | for i in range(row): 303 | for j in range(col): 304 | for k in range(4): 305 | x_, y_ = i + dx[k], j + dy[k] 306 | if (x_ >= 0 and x_ < row) and (y_ >= 0 and y_ < col): 307 | if keyboard[x_][y_] == '*': continue 308 | if keyboard[i][j] == '*': continue 309 | keyboard_mappings[keyboard[i][j]].append(keyboard[x_][y_]) 310 | 311 | if ch not in keyboard_mappings: return ch 312 | return np.random.choice(keyboard_mappings[ch], 1)[0] 313 | -------------------------------------------------------------------------------- /attacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, TensorDataset 3 | 4 | import numpy as np 5 | from nltk import word_tokenize 6 | import json 7 | import os 8 | from tqdm import tqdm 9 | from random import sample 10 | 11 | from edit_dist_utils import get_all_edit_dist_one, sample_random_internal_permutations 12 | from utils_glue import InputExample, PROCESSORS, OUTPUT_MODES 13 | 14 | class AttackSurface(object): 15 | def get_perturbations(self, word): 16 | raise NotImplementedError 17 | 18 | 19 | class ED1AttackSurface(AttackSurface): 20 | def get_perturbations(self, word): 21 | return get_all_edit_dist_one(word) 22 | 23 | 24 | class Attacker(): 25 | #Going to use attack to cache mapping from clean examples to attack, then run normally in their script. 26 | def __init__(self, attack_name, task, model_runner, save_dir, args, max_num_words = None, perturbations_per_word = 4): 27 | ATTACK2CLASS = {'DeleteSubFirstAttack': DeleteSubFirstAttack, 'DeleteAttack': DeleteAttack, 28 | 'RandomPerturbationAttack': RandomPerturbationAttack, 'LongDeleteShortAll': LongDeleteShortAll} 29 | if attack_name not in ATTACK2CLASS: 30 | raise ValueError("Invalid attack name: {}".format(attack_name)) 31 | attack = ATTACK2CLASS[attack_name]() 32 | if attack_name == 'RandomPerturbationAttack': 33 | #TODO, should do something better in constructor, but this should work 34 | attack.perturbations_per_word = perturbations_per_word 35 | if args.attack_type.lower() == 'intprm': 36 | attack.attack_type = 'intprm' 37 | elif args.attack_type.lower() == 'intprm': 38 | raise NotImplementedError 39 | 40 | self.label_map = {label : i for i, label in enumerate(model_runner.label_list)} 41 | self.task = task 42 | self.model_runner = model_runner 43 | self.attack = attack 44 | self.args = args 45 | self.max_num_words = max_num_words 46 | self.save_dir = args.attack_save_dir 47 | self.total_count = 0 48 | self.attacked_count = 0 49 | 50 | def attack_dataset(self, dataset): #, force_new = False): 51 | adv_dataset = [] 52 | for example in tqdm(dataset): 53 | num_a_words = len(example.text_a.split()) 54 | 55 | perturbed_example = self._attack_example(example, max_num_words = self.max_num_words, verbose = False) 56 | adv_dataset.append(perturbed_example) 57 | if self.total_count % 100 == 0: 58 | print("Performance so far: successfully attacked {}/{} total = {}".format(self.attacked_count, self.total_count, self.attacked_count / self.total_count)) 59 | 60 | return adv_dataset 61 | 62 | def _example_to_words(self, example): 63 | exists_b = example.text_b is not None 64 | split_a = example.text_a.split() 65 | words = split_a.copy() 66 | if exists_b: 67 | split_b = example.text_b.split() 68 | words.extend(split_b) 69 | return words, len(split_a) 70 | 71 | def _attack_example(self, clean_example, max_num_words = None, max_attack_attempts = 1, verbose = True): 72 | self.total_count += 1 73 | label = clean_example.label 74 | exists_b = clean_example.text_b is not None 75 | words, num_in_a = self._example_to_words(clean_example) 76 | if max_num_words is None or max_num_words > len(words): 77 | max_num_words = len(words) 78 | perturb_word_idxs = np.random.choice(len(words), size = max_num_words, replace = False) 79 | to_be_attacked = [words] 80 | for perturbed_word_idx in perturb_word_idxs: 81 | perturbed_examples = [] 82 | for words in to_be_attacked: 83 | word_to_perturb = words[perturbed_word_idx] 84 | word_perturbations = self.attack.get_perturbations(word_to_perturb) 85 | for prtbd_word in word_perturbations: 86 | og_copy = words.copy() 87 | og_copy[perturbed_word_idx] = prtbd_word 88 | new_guid = '{}-{}'.format(clean_example.guid, len(perturbed_examples)) 89 | if not exists_b: 90 | perturbed_examples.append(InputExample(new_guid, ' '.join(og_copy), label = label)) 91 | else: 92 | perturbed_examples.append(InputExample(new_guid, ' '.join(og_copy[:num_in_a]), label = label, text_b = ' '.join(og_copy[num_in_a:]))) 93 | #Labels should all be the same, sanity check 94 | 95 | preds = self.model_runner.query( 96 | perturbed_examples, self.args.eval_batch_size, do_evaluate=False, 97 | return_logits=True, use_tqdm=False) 98 | worst_performing_indices, found_incorrect_pred = self._process_preds(preds, self.label_map[label]) 99 | if found_incorrect_pred: 100 | assert len(worst_performing_indices) == 1 101 | worst_performing_idx = worst_performing_indices[0] 102 | self.attacked_count += 1 103 | if verbose: 104 | print('') 105 | og_example_str = 'Premise: {}\nHypothesis: {}'.format(clean_example.text_a, '' if not exists_b else clean_example.text_b) 106 | print(og_example_str) 107 | attacked_str = 'Premise: {}\nHypothesis: {}'.format(perturbed_examples[worst_performing_idx].text_a, '' if not exists_b else perturbed_examples[worst_performing_idx].text_b) 108 | print(attacked_str) 109 | print("Original label: {}".format(clean_example.label)) 110 | print("Attacked prediction: {}".format(self.model_runner.label_list[np.argmax(preds, axis = 1)[worst_performing_idx]])) 111 | 112 | return perturbed_examples[worst_performing_idx] 113 | else: 114 | to_be_attacked = [] 115 | for idx in worst_performing_indices: 116 | new_words, _ = self._example_to_words(perturbed_examples[idx]) 117 | to_be_attacked.append(new_words) 118 | #Didn't find a successful attack, but still going to do worst case thing... 119 | if verbose: 120 | print('') 121 | print("Could not attack the following: ") 122 | og_example_str = 'Premise: {}\nHypothesis: {}'.format(clean_example.text_a, '' if not exists_b else clean_example.text_b) 123 | print(og_example_str) 124 | return perturbed_examples[worst_performing_indices[0]] 125 | 126 | def _process_preds(self, preds, label): 127 | #Should return a list of predictions, and whether or not a label is found... 128 | raise NotImplementedError 129 | 130 | class BeamSearchAttacker(Attacker): 131 | def __init__(self, attack_name, task, model_runner, save_dir, args, max_num_words = None): 132 | super(BeamSearchAttacker, self).__init__(attack_name, task, model_runner, save_dir, args, max_num_words = max_num_words) 133 | self.beam_width = args.beam_width 134 | 135 | def _process_preds(self, preds, label): 136 | argmax_preds = np.argmax(preds, axis = 1) 137 | if not (argmax_preds == label).all(): 138 | incorrect_idx = np.where(argmax_preds != label)[0][0] 139 | return [incorrect_idx], True 140 | if preds.shape[0] <= self.beam_width: 141 | return list(range(preds.shape[0])), False 142 | worst_performing_indices = np.argpartition(preds[:, label], self.beam_width)[:self.beam_width] 143 | return list(worst_performing_indices), False 144 | 145 | class GreedyAttacker(Attacker): 146 | def _process_preds(self, preds, label): 147 | #Assumes if a pred changes the prediction, it's the only thing returned. Otherwise, returns list... 148 | argmax_preds = np.argmax(preds, axis = 1) 149 | if not (argmax_preds == label).all(): 150 | incorrect_idx = np.where(argmax_preds != label)[0][0] 151 | return [incorrect_idx], True 152 | worst_performing_idx = np.argmin(preds[:,label]) 153 | return [worst_performing_idx], False 154 | 155 | class Attack(): 156 | def get_perturbations(self, word): 157 | raise NotImplementedError() 158 | 159 | def name(self): 160 | raise NotImplementedError() 161 | 162 | 163 | class LongDeleteShortAll(Attack): 164 | def __init__(self, perturbations_per_word = 4, max_insert_len = 4): 165 | self.cache = {} 166 | self.perturbations_per_word = perturbations_per_word 167 | self.max_insert_len = max_insert_len 168 | 169 | def get_perturbations(self, word): 170 | if word in self.cache: 171 | return self.cache[word] 172 | if len(word) > self.max_insert_len: 173 | perturbations = get_all_edit_dist_one(word, filetype = 100) #Just deletions 174 | else: 175 | perturbations = get_all_edit_dist_one(word) 176 | if len(perturbations) > self.perturbations_per_word: 177 | perturbations = set(sample(perturbations, self.perturbations_per_word)) 178 | self.cache[word] = perturbations 179 | return perturbations 180 | 181 | def name(self): 182 | return 'LongDeleteShortAll' 183 | 184 | class RandomPerturbationAttack(Attack): 185 | def __init__(self, perturbations_per_word = 5, attack_type = 'ed1'): 186 | self.cache = {} 187 | self.perturbations_per_word = perturbations_per_word 188 | self.attack_type = attack_type 189 | 190 | def get_perturbations(self, word): 191 | if word in self.cache: 192 | return self.cache[word] 193 | if self.attack_type == 'ed1': 194 | perturbations = get_all_edit_dist_one(word) 195 | if len(perturbations) > self.perturbations_per_word: 196 | pertubations = set(sample(perturbations, self.perturbations_per_word)) 197 | elif self.attack_type == 'intprm': 198 | perturbations = sample_random_internal_permutations(word, n_perts = self.perturbations_per_word) 199 | else: 200 | raise NotImplementedError("Attack type: {} not implemented yet".format(self.attack_type)) 201 | self.cache[word] = perturbations 202 | return perturbations 203 | 204 | def name(self): 205 | return 'RandomPerturbationAttack' 206 | 207 | 208 | class DeleteSubFirstAttack(Attack): 209 | def __init__(self): 210 | self.cache = {} 211 | 212 | def get_perturbations(self, word): 213 | if len(word) < 3: #Min case where a substitution is possible 214 | return set([word]) 215 | if word in self.cache: 216 | return self.cache[word] 217 | deletions = get_all_edit_dist_one(word, filetype = 100) #Just deletions 218 | substution_heads = get_all_edit_dist_one(word[:3], filetype = 10) #Just substitutions in pos 2 (can sub middle of first three letters) 219 | second_char_substitutions = set([head + word[3:] for head in substution_heads]) 220 | perturbations = deletions.union(second_char_substitutions) 221 | self.cache[word] = perturbations 222 | return perturbations 223 | 224 | def name(self): 225 | return 'DeleteSubFirstAttack' 226 | 227 | class DeleteAttack(Attack): 228 | def __init__(self): 229 | self.cache = {} 230 | 231 | def get_perturbations(self, word): 232 | if len(word) < 3: 233 | return set([word]) 234 | if word in self.cache: 235 | return self.cache[word] 236 | deletions = get_all_edit_dist_one(word, filetype = 100) 237 | self.cache[word] = deletions 238 | return deletions 239 | 240 | def name(self): 241 | return 'DeleteAttack' 242 | 243 | 244 | -------------------------------------------------------------------------------- /scRNN/train.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import pickle 4 | import random 5 | from random import shuffle 6 | import utils 7 | from utils import * #FIXME: should not do this 8 | import argparse 9 | import time 10 | 11 | # torch related imports 12 | import torch 13 | from torch import nn 14 | from torch.autograd import Variable 15 | 16 | # elmo related imports 17 | from allennlp.modules.elmo import batch_to_ids 18 | 19 | # model related imports 20 | from model import ScRNN 21 | from model import ElmoScRNN 22 | from model import ElmoRNN 23 | 24 | 25 | parser = argparse.ArgumentParser( 26 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | 28 | parser.add_argument('--train-rep', dest='train_rep_list', nargs='+', default=[], 29 | help = 'the type of the representation to train from') 30 | 31 | parser.add_argument('--val-rep', dest='val_rep_list', nargs='+', default=[], 32 | help = 'the type of the representation to validate on') 33 | 34 | parser.add_argument('--train-rep-probs', dest='train_rep_probs', nargs='+', default=[], 35 | help = 'the probs of the representation to train from') 36 | 37 | parser.add_argument('--val-rep-probs', dest='val_rep_probs', nargs='+', default=[], 38 | help = 'the probs of the representation to validate on') 39 | 40 | parser.add_argument('--save', dest='save_model', action='store_true') 41 | parser.add_argument('--background', dest='background', action='store_true') 42 | parser.add_argument('--background-train', dest='background_train', action='store_true') 43 | 44 | parser.add_argument('--new-vocab', dest='new_vocab', action='store_true') 45 | parser.add_argument('--model-type', dest='model_type', type=str, default="scrnn", 46 | help="choice between scrnn/elmo/elmo-plus-scrnn") 47 | parser.add_argument('--model-type-bg', dest='model_type_bg', type=str, default="scrnn", 48 | help="choice between scrnn/elmo/elmo-plus-scrnn") 49 | 50 | parser.add_argument('--no-train', dest='need_to_train', action='store_false') 51 | 52 | parser.add_argument('--model-path', dest='model_path', type=str, default="") 53 | parser.add_argument('--model-path-bg', dest='model_path_bg', type=str, default="") 54 | 55 | parser.add_argument('--unk-output', dest='unk_output', action='store_true') 56 | 57 | parser.add_argument('--batch-size', dest='batch_size', type=int, default=32) 58 | parser.add_argument('--num-epochs', dest='num_epochs', type=int, default=100) 59 | parser.add_argument('--vocab-size', dest='vocab_size', type=int, default=9999) 60 | parser.add_argument('--vocab-size-bg', dest='vocab_size_bg', type=int, default=78470) 61 | 62 | # char vocab path for training bg model to share vocab 63 | parser.add_argument('--common-cv-path', dest='common_cv_path', type=str, 64 | default="vocab/CHAR_VOCAB_ 16580.p") 65 | 66 | # train/dev/test files 67 | parser.add_argument('--train-file', dest='train_file', type=str, 68 | default="glue_tc_preprocessed/rte_train_preprocessed.tsv") 69 | parser.add_argument('--dev-file', dest='dev_file', type=str, 70 | default="../../../data/classes/dev.txt") 71 | parser.add_argument('--test-file', dest='test_file', type=str, 72 | default="../../../data/classes/test.txt") 73 | 74 | parser.add_argument('--task-name', dest='task_name', type=str, 75 | default="rte") 76 | parser.add_argument('--preprocessed_glue_dir', default = 'glue_tc_preprocessed', 77 | help = 'directory where preprocessed glue data is stored', type = str) 78 | parser.add_argument('--tc_dir', type = str, default = '.', 79 | help = 'location where vocab and models are to be saved.') 80 | 81 | params = vars(parser.parse_args()) 82 | 83 | # useful variables for representation type and strength... 84 | train_rep_list = params['train_rep_list'] 85 | val_rep_list = params['val_rep_list'] 86 | train_rep_probs = [float(i) for i in params['train_rep_probs']] 87 | val_rep_probs = [float(i) for i in params['val_rep_probs']] 88 | batch_size = params['batch_size'] 89 | model_type = params['model_type'].lower() 90 | model_type_bg = params['model_type_bg'].lower() 91 | vocab_size = params['vocab_size'] 92 | vocab_size_bg = params['vocab_size_bg'] 93 | task_name = params['task_name'] 94 | NUM_EPOCHS = params['num_epochs'] 95 | tc_dir = params['tc_dir'] 96 | set_word_limit(vocab_size, task_name) 97 | WORD_LIMIT = vocab_size 98 | STOP_AFTER = 25 99 | 100 | # shall we save the model? 101 | save = params['save_model'] 102 | 103 | # are we also using a background model? 104 | use_background = params['background'] 105 | 106 | # are we training the background model? 107 | background_train = params['background_train'] 108 | 109 | # paths to important stuff.. 110 | PWD = "/home/danish/git/break-it-build-it/src/defenses/scRNN/" 111 | 112 | # path to vocabs 113 | w2i_PATH = PWD + "vocab/" + task_name + "w2i_" + str(vocab_size) + ".p" 114 | i2w_PATH = PWD + "vocab/" + task_name + "i2w_" + str(vocab_size) + ".p" 115 | CHAR_VOCAB_PATH = PWD + "vocab/" + task_name + "CHAR_VOCAB_ " + str(vocab_size) + ".p" 116 | common_cv_path = params['common_cv_path'] 117 | 118 | # paths to background vocabs 119 | w2i_PATH_BG = PWD + "vocab/" + task_name + "w2i_" + str(vocab_size_bg) + ".p" 120 | i2w_PATH_BG = PWD + "vocab/" + task_name + "i2w_" + str(vocab_size_bg) + ".p" 121 | CHAR_VOCAB_PATH_BG = PWD + "vocab/" + task_name + "CHAR_VOCAB_ " + str(vocab_size_bg) + ".p" 122 | 123 | # model paths 124 | MODEL_PATH = PWD + params['model_path'] 125 | MODEL_PATH_BG = PWD + params['model_path_bg'] 126 | 127 | # train/dev/test files 128 | train_file = params['train_file'] 129 | dev_file = params['dev_file'] 130 | test_file = params['test_file'] 131 | 132 | # sanity check... 133 | print ("--- Parameters ----") 134 | print (params) 135 | 136 | """ 137 | [Takes in predictions (y_preds) in integers, outputs a human readable 138 | output line. In case when the prediction is UNK, it uses the input word as is. 139 | Hence, input_line is also needed to know the corresponding input word.] 140 | """ 141 | def decode_line(input_line, y_preds, use_background, y_preds_bg): 142 | SEQ_LEN = len(input_line.split()) 143 | assert (SEQ_LEN == len(y_preds)) 144 | 145 | predicted_words = [] 146 | for idx in range(SEQ_LEN): 147 | if y_preds[idx] == WORD_LIMIT: 148 | word = input_line.split()[idx] 149 | if use_background: 150 | # the main model predicted unk ...backoff 151 | if y_preds_bg[idx] != vocab_size_bg: 152 | # the backoff model predicted non-unk 153 | word = utils.i2w_bg[y_preds_bg[idx]] 154 | # print ("Input: %s \n Backoff: %s -> %s\n" %(input_line, input_line.split()[idx], word)) 155 | if params['unk_output']: 156 | word = "a" 157 | else: 158 | word = utils.i2w[y_preds[idx]] 159 | predicted_words.append(word) 160 | 161 | return " ".join(predicted_words) 162 | 163 | 164 | """ 165 | [computes the word error rate] 166 | true_lines are what the model should have predicted, whereas 167 | output_lines are what the model ended up predicted 168 | """ 169 | def compute_WER(true_lines, output_lines): 170 | assert (len(true_lines) == len(output_lines)) 171 | size = len(output_lines) 172 | 173 | error = 0.0 174 | total_words = 0.0 175 | 176 | for i in range(size): 177 | true_words = true_lines[i].split() 178 | output_words = output_lines[i].split() 179 | assert (len(true_words) == len(output_words)) 180 | total_words += len(true_words) 181 | for j in range(len(output_words)): 182 | if true_words[j] != output_words[j]: 183 | error += 1.0 184 | 185 | return (100. * error/total_words) 186 | 187 | 188 | 189 | def iterate(model, optimizer, data_lines, need_to_train, rep_list, rep_probs, 190 | desc, iter_count, print_stuff=True, use_background=False, model_bg=None): 191 | data_lines = sorted(data_lines, key = lambda x:len(x.split()), reverse=True) 192 | Xtype = torch.FloatTensor 193 | ytype = torch.LongTensor 194 | criterion = nn.CrossEntropyLoss(size_average=True, ignore_index=TARGET_PAD_IDX) 195 | is_cuda = torch.cuda.is_available() 196 | if is_cuda: 197 | Xtype = torch.cuda.FloatTensor 198 | ytype = torch.cuda.LongTensor 199 | criterion.cuda() 200 | 201 | predicted_lines = [] 202 | true_lines = [] 203 | 204 | total_loss = 0.0 205 | 206 | for input_lines, modified_lines, X, y, lens in get_batched_input_data(data_lines, batch_size, \ 207 | rep_list, rep_probs): 208 | true_lines.extend(input_lines) 209 | tx = Variable(torch.from_numpy(X)).type(Xtype) 210 | ty_true = Variable(torch.from_numpy(y)).type(ytype) 211 | 212 | tokenized_modified_lines = [line.split() for line in modified_lines] 213 | if 'elmo' in model_type or 'elmo' in model_type_bg: 214 | tx_elmo = Variable(batch_to_ids(tokenized_modified_lines)).type(ytype) 215 | 216 | # forward pass 217 | if model_type == 'elmo': 218 | ty_pred = model(tx_elmo) 219 | #TODO: add the cases where the background model might 220 | # be other than an elmo-only model 221 | if use_background and model_type_bg =='elmo': 222 | ty_pred_bg = model_bg(tx_elmo) 223 | elif model_type == 'scrnn': 224 | ty_pred = model(tx, lens) 225 | #TODO: add the cases where the background model might 226 | # be other than an scrnn-only model 227 | if use_background and model_type_bg == 'scrnn': 228 | ty_pred_bg = model_bg(tx, lens) 229 | elif 'elmo' in model_type and 'scrnn' in model_type: 230 | ty_pred = model(tx, tx_elmo, lens) 231 | 232 | #TODO: add the cases where the background model might 233 | # be an elmo-only model 234 | if use_background and 'elmo' in model_type_bg and \ 235 | 'scrnn' in model_type_bg: 236 | ty_pred_bg = model_bg(tx, tx_elmo, lens) 237 | elif use_background and model_type_bg == 'scrnn': 238 | ty_pred_bg = model_bg(tx, lens) 239 | 240 | y_pred = ty_pred.detach().cpu().numpy() 241 | # ypred BATCH_SIZE x NUM_CLASSES x SEQ_LEN 242 | if use_background: 243 | y_pred_bg = ty_pred_bg.detach().cpu().numpy() 244 | 245 | 246 | for idx in range(batch_size): 247 | y_pred_i = [np.argmax(y_pred[idx][:, i]) for i in range(lens[idx])] 248 | 249 | y_pred_bg_i = None 250 | if use_background: 251 | y_pred_bg_i = [np.argmax(y_pred_bg[idx][:, i]) for i in range(lens[idx])] 252 | 253 | predicted_lines.append(decode_line(modified_lines[idx], y_pred_i, use_background, y_pred_bg_i)) 254 | 255 | 256 | # compute loss 257 | loss = criterion(ty_pred, ty_true) 258 | total_loss += loss.item() 259 | 260 | if need_to_train: 261 | # backprop the loss 262 | optimizer.zero_grad() 263 | loss.backward() 264 | optimizer.step() 265 | 266 | WER = compute_WER(true_lines, predicted_lines) 267 | 268 | if print_stuff: 269 | print ("Average %s loss after %d iteration = %0.4f" %(desc, iter_count, 270 | total_loss/len(true_lines))) 271 | print ("Total %s WER after %d iteration = %0.4f" %(desc, iter_count, WER)) 272 | 273 | 274 | return WER 275 | 276 | 277 | def main(glue = False, max_train_lines = 20000, num_val_lines = 500): 278 | train_file = params['train_file'] 279 | if glue: 280 | train_file = os.path.join(params['preprocessed_glue_dir'], '{}_train_preprocessed.tsv'.format(params['task_name'].lower())) 281 | 282 | tc_dir = params['tc_dir'] 283 | train_lines = get_lines(train_file, glue = glue) 284 | random.shuffle(train_lines) 285 | val_cutoff = len(train_lines) - num_val_lines 286 | val_lines = train_lines[val_cutoff:] 287 | train_lines = train_lines[:max_train_lines] 288 | 289 | params['new_vocab'] = True 290 | if params['new_vocab']: 291 | print ("creating new vocabulary") 292 | create_vocab(train_file, params['tc_dir'], background_train, common_cv_path, glue = glue) 293 | else: 294 | print ("loading existing vocabulary") 295 | load_vocab_dicts(w2i_PATH, i2w_PATH, CHAR_VOCAB_PATH) 296 | if use_background: 297 | print ("loading existing background vocabulary") 298 | load_vocab_dicts(w2i_PATH_BG, i2w_PATH_BG, CHAR_VOCAB_PATH_BG, use_background) 299 | 300 | 301 | print ("len of w2i ", len(utils.w2i)) 302 | print ("len of i2w ", len(utils.i2w)) 303 | print ("len of char vocab", len(utils.CHAR_VOCAB)) 304 | 305 | params['need_to_train'] = True 306 | assert model_type == 'scrnn' 307 | model = ScRNN(len(utils.CHAR_VOCAB), 50, WORD_LIMIT + 1) # +1 for UNK 308 | 309 | if not params['need_to_train']: 310 | model = ScRNN(len(utils.CHAR_VOCAB), 50, WORD_LIMIT) 311 | model = torch.load(MODEL_PATH) 312 | model_bg = None 313 | assert not use_background 314 | if use_background: 315 | print ("Loading background model") 316 | model_bg = torch.load(MODEL_PATH_BG) 317 | 318 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) 319 | criterion = nn.CrossEntropyLoss(size_average=True, ignore_index=TARGET_PAD_IDX) 320 | is_cuda = torch.cuda.is_available() 321 | 322 | if is_cuda: 323 | model.cuda() 324 | if use_background: 325 | model_bg.cuda() 326 | 327 | 328 | if params['need_to_train']: 329 | # begin training ... 330 | print (" *** training the model *** ") 331 | best_val_WER = 100.0 332 | last_dumped_idx = 99999 333 | for ITER in range(NUM_EPOCHS): 334 | st_time = time.time() 335 | _ = iterate(model, optimizer, train_lines, True, train_rep_list, 336 | train_rep_probs, 'train', ITER+1) 337 | 338 | curr_val_WER = iterate(model, None, val_lines, False, val_rep_list, 339 | val_rep_probs, 'val', ITER+1) 340 | save = True #TODO 341 | if save: 342 | if not os.path.isdir(os.path.join(tc_dir, 'model_dumps')): 343 | os.makedirs(os.path.join(tc_dir, 'model_dumps')) 344 | model_save_path = os.path.join(tc_dir, 'model_dumps', '{}_TASK_NAME={}'.format(model_type, task_name)) 345 | # check if the val WER improved? 346 | if curr_val_WER < best_val_WER: 347 | last_dumped_idx = ITER+1 348 | best_val_WER = curr_val_WER 349 | # informative names for model dump files 350 | train_rep_names = "_".join(train_rep_list) 351 | train_probs_names = ":".join([str(i) for i in train_rep_probs]) 352 | 353 | print ("Dumping after ", ITER + 1) 354 | model_name = model_type 355 | torch.save(model.state_dict(), model_save_path) 356 | # report the time taken per iteration for train + val + test 357 | # (+ often save) 358 | en_time = time.time() 359 | print ("Time for the iteration %0.1f seconds" %(en_time - st_time)) 360 | 361 | # check if there hasn't been enough progress since last few iters 362 | if ITER > STOP_AFTER + last_dumped_idx: 363 | # i.e it is not improving since 'STOP_AFTER' number of iterations 364 | print ("Aborting since there hasn't been much progress") 365 | break 366 | 367 | else: 368 | # just run the model on validation and test... 369 | #print (" *** running the model on val and test set *** ") 370 | 371 | st_time = time.time() 372 | 373 | val_WER = iterate(model, None, val_lines, False, val_rep_list, 374 | val_rep_probs, 'val', 0, use_background=use_background, model_bg=model_bg) 375 | 376 | # report the time taken per iteration for val + test 377 | en_time = time.time() 378 | print ("Time for the testing process = %0.1f seconds" %(en_time - st_time)) 379 | val_rep_names = " ".join(val_rep_list) 380 | model_name = MODEL_PATH.split("/")[-1] 381 | print (val_rep_names + "\t" + model_name + "\t" + str(val_WER) + "\t"\ 382 | + str(test_WER)) 383 | 384 | 385 | return 386 | 387 | 388 | main(glue = True) 389 | -------------------------------------------------------------------------------- /utils_glue.py: -------------------------------------------------------------------------------- 1 | #Code adapted from https://github.com/huggingface/transformers 2 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """ 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import csv 7 | import logging 8 | import os 9 | import sys 10 | import pickle 11 | import string 12 | import numpy as np 13 | from io import open 14 | import itertools 15 | from tqdm import tqdm 16 | import random 17 | from collections import defaultdict 18 | import json 19 | 20 | from scipy.stats import pearsonr, spearmanr 21 | from sklearn.metrics import matthews_corrcoef, f1_score 22 | 23 | from edit_dist_utils import get_all_edit_dist_one 24 | from scRNN.corrector import ScRNNChecker 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class InputExample(object): 30 | """A single training/test example for simple sequence classification.""" 31 | 32 | def __init__(self, guid, text_a, text_b=None, label=None): 33 | """Constructs a InputExample. 34 | 35 | Args: 36 | guid: Unique id for the example. 37 | text_a: string. The untokenized text of the first sequence. For single 38 | sequence tasks, only this sequence must be specified. 39 | text_b: (Optional) string. The untokenized text of the second sequence. 40 | Only must be specified for sequence pair tasks. 41 | label: (Optional) string. The label of the example. This should be 42 | specified for train and dev examples, but not for test examples. 43 | """ 44 | self.guid = guid 45 | self.text_a = text_a 46 | self.text_b = text_b 47 | self.label = label 48 | 49 | 50 | class DataProcessor(object): 51 | """Base class for data converters for sequence classification data sets.""" 52 | 53 | def get_train_examples(self, data_dir): 54 | """Gets a collection of `InputExample`s for the train set.""" 55 | raise NotImplementedError() 56 | 57 | def get_dev_examples(self, data_dir): 58 | """Gets a collection of `InputExample`s for the dev set.""" 59 | raise NotImplementedError() 60 | 61 | def get_labels(self): 62 | """Gets the list of labels for this data set.""" 63 | raise NotImplementedError() 64 | 65 | @classmethod 66 | def _read_tsv(cls, input_file, quotechar=None): 67 | """Reads a tab separated value file.""" 68 | with open(input_file, "r", encoding="utf-8-sig") as f: 69 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 70 | lines = [] 71 | for line in reader: 72 | if sys.version_info[0] == 2: 73 | line = list(unicode(cell, 'utf-8') for cell in line) 74 | lines.append(line) 75 | return lines 76 | 77 | 78 | class MrpcProcessor(DataProcessor): 79 | """Processor for the MRPC data set (GLUE version).""" 80 | 81 | def get_train_examples(self, data_dir): 82 | """See base class.""" 83 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 84 | return self._create_examples( 85 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 86 | 87 | def get_dev_examples(self, data_dir): 88 | """See base class.""" 89 | return self._create_examples( 90 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 91 | 92 | def get_labels(self): 93 | """See base class.""" 94 | return ["0", "1"] 95 | 96 | def _create_examples(self, lines, set_type): 97 | """Creates examples for the training and dev sets.""" 98 | examples = [] 99 | for (i, line) in enumerate(lines): 100 | if i == 0: 101 | continue 102 | guid = "%s-%s" % (set_type, i) 103 | text_a = line[3] 104 | text_b = line[4] 105 | label = line[0] 106 | examples.append( 107 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 108 | return examples 109 | 110 | 111 | class MnliProcessor(DataProcessor): 112 | """Processor for the MultiNLI data set (GLUE version).""" 113 | 114 | def get_train_examples(self, data_dir): 115 | """See base class.""" 116 | return self._create_examples( 117 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 118 | 119 | def get_dev_examples(self, data_dir): 120 | """See base class.""" 121 | return self._create_examples( 122 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 123 | "dev_matched") 124 | 125 | def get_labels(self): 126 | """See base class.""" 127 | return ["contradiction", "entailment", "neutral"] 128 | 129 | def _create_examples(self, lines, set_type): 130 | """Creates examples for the training and dev sets.""" 131 | examples = [] 132 | for (i, line) in enumerate(lines): 133 | if i == 0: 134 | continue 135 | guid = "%s-%s" % (set_type, line[0]) 136 | text_a = line[8] 137 | text_b = line[9] 138 | label = line[-1] 139 | examples.append( 140 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 141 | return examples 142 | 143 | 144 | class MnliMismatchedProcessor(MnliProcessor): 145 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 146 | 147 | def get_dev_examples(self, data_dir): 148 | """See base class.""" 149 | return self._create_examples( 150 | self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), 151 | "dev_matched") 152 | 153 | 154 | class ColaProcessor(DataProcessor): 155 | """Processor for the CoLA data set (GLUE version).""" 156 | 157 | def get_train_examples(self, data_dir): 158 | """See base class.""" 159 | return self._create_examples( 160 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 161 | 162 | def get_dev_examples(self, data_dir): 163 | """See base class.""" 164 | return self._create_examples( 165 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 166 | 167 | def get_labels(self): 168 | """See base class.""" 169 | return ["0", "1"] 170 | 171 | def _create_examples(self, lines, set_type): 172 | """Creates examples for the training and dev sets.""" 173 | examples = [] 174 | for (i, line) in enumerate(lines): 175 | guid = "%s-%s" % (set_type, i) 176 | text_a = line[3] 177 | label = line[1] 178 | examples.append( 179 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 180 | return examples 181 | 182 | 183 | class Sst2Processor(DataProcessor): 184 | """Processor for the SST-2 data set (GLUE version).""" 185 | 186 | def get_train_examples(self, data_dir): 187 | """See base class.""" 188 | return self._create_examples( 189 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 190 | 191 | def get_dev_examples(self, data_dir): 192 | """See base class.""" 193 | return self._create_examples( 194 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 195 | 196 | def get_labels(self): 197 | """See base class.""" 198 | return ["0", "1"] 199 | 200 | def _create_examples(self, lines, set_type): 201 | """Creates examples for the training and dev sets.""" 202 | examples = [] 203 | for (i, line) in enumerate(lines): 204 | if i == 0: 205 | continue 206 | guid = "%s-%s" % (set_type, i) 207 | text_a = line[0] 208 | label = line[1] 209 | examples.append( 210 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 211 | return examples 212 | 213 | 214 | class StsbProcessor(DataProcessor): 215 | """Processor for the STS-B data set (GLUE version).""" 216 | 217 | def get_train_examples(self, data_dir): 218 | """See base class.""" 219 | return self._create_examples( 220 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 221 | 222 | def get_dev_examples(self, data_dir): 223 | """See base class.""" 224 | return self._create_examples( 225 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 226 | 227 | def get_labels(self): 228 | """See base class.""" 229 | return [None] 230 | 231 | def _create_examples(self, lines, set_type): 232 | """Creates examples for the training and dev sets.""" 233 | examples = [] 234 | for (i, line) in enumerate(lines): 235 | if i == 0: 236 | continue 237 | guid = "%s-%s" % (set_type, line[0]) 238 | text_a = line[7] 239 | text_b = line[8] 240 | label = line[-1] 241 | examples.append( 242 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 243 | return examples 244 | 245 | 246 | class QqpProcessor(DataProcessor): 247 | """Processor for the QQP data set (GLUE version).""" 248 | 249 | def get_train_examples(self, data_dir): 250 | """See base class.""" 251 | return self._create_examples( 252 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 253 | 254 | def get_dev_examples(self, data_dir): 255 | """See base class.""" 256 | return self._create_examples( 257 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 258 | 259 | def get_labels(self): 260 | """See base class.""" 261 | return ["0", "1"] 262 | 263 | def _create_examples(self, lines, set_type): 264 | """Creates examples for the training and dev sets.""" 265 | examples = [] 266 | for (i, line) in enumerate(lines): 267 | if i == 0: 268 | continue 269 | guid = "%s-%s" % (set_type, line[0]) 270 | try: 271 | text_a = line[3] 272 | text_b = line[4] 273 | label = line[5] 274 | except IndexError: 275 | continue 276 | examples.append( 277 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 278 | return examples 279 | 280 | 281 | class QnliProcessor(DataProcessor): 282 | """Processor for the QNLI data set (GLUE version).""" 283 | 284 | def get_train_examples(self, data_dir): 285 | """See base class.""" 286 | return self._create_examples( 287 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 288 | 289 | def get_dev_examples(self, data_dir): 290 | """See base class.""" 291 | return self._create_examples( 292 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 293 | "dev_matched") 294 | 295 | def get_labels(self): 296 | """See base class.""" 297 | return ["entailment", "not_entailment"] 298 | 299 | def _create_examples(self, lines, set_type): 300 | """Creates examples for the training and dev sets.""" 301 | examples = [] 302 | for (i, line) in enumerate(lines): 303 | if i == 0: 304 | continue 305 | guid = "%s-%s" % (set_type, line[0]) 306 | text_a = line[1] 307 | text_b = line[2] 308 | label = line[-1] 309 | examples.append( 310 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 311 | return examples 312 | 313 | 314 | class RteProcessor(DataProcessor): 315 | """Processor for the RTE data set (GLUE version).""" 316 | 317 | def get_train_examples(self, data_dir): 318 | """See base class.""" 319 | return self._create_examples( 320 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 321 | 322 | def get_dev_examples(self, data_dir): 323 | """See base class.""" 324 | return self._create_examples( 325 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 326 | 327 | def get_labels(self): 328 | """See base class.""" 329 | return ["entailment", "not_entailment"] 330 | 331 | def _create_examples(self, lines, set_type): 332 | """Creates examples for the training and dev sets.""" 333 | examples = [] 334 | for (i, line) in enumerate(lines): 335 | if i == 0: 336 | continue 337 | guid = "%s-%s" % (set_type, line[0]) 338 | text_a = line[1] 339 | text_b = line[2] 340 | label = line[-1] 341 | examples.append( 342 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 343 | return examples 344 | 345 | 346 | class WnliProcessor(DataProcessor): 347 | """Processor for the WNLI data set (GLUE version).""" 348 | 349 | def get_train_examples(self, data_dir): 350 | """See base class.""" 351 | return self._create_examples( 352 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 353 | 354 | def get_dev_examples(self, data_dir): 355 | """See base class.""" 356 | return self._create_examples( 357 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 358 | 359 | def get_labels(self): 360 | """See base class.""" 361 | return ["0", "1"] 362 | 363 | def _create_examples(self, lines, set_type): 364 | """Creates examples for the training and dev sets.""" 365 | examples = [] 366 | for (i, line) in enumerate(lines): 367 | if i == 0: 368 | continue 369 | guid = "%s-%s" % (set_type, line[0]) 370 | text_a = line[1] 371 | text_b = line[2] 372 | label = line[-1] 373 | examples.append( 374 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 375 | return examples 376 | 377 | def simple_accuracy(preds, labels): 378 | return (preds == labels).mean() 379 | 380 | 381 | def acc_and_f1(preds, labels): 382 | acc = simple_accuracy(preds, labels) 383 | f1 = f1_score(y_true=labels, y_pred=preds) 384 | return { 385 | "acc": acc, 386 | "f1": f1, 387 | "acc_and_f1": (acc + f1) / 2, 388 | } 389 | 390 | 391 | def pearson_and_spearman(preds, labels): 392 | pearson_corr = pearsonr(preds, labels)[0] 393 | spearman_corr = spearmanr(preds, labels)[0] 394 | return { 395 | "pearson": pearson_corr, 396 | "spearmanr": spearman_corr, 397 | "corr": (pearson_corr + spearman_corr) / 2, 398 | } 399 | 400 | 401 | def compute_metrics(task_name, preds, labels): 402 | assert len(preds) == len(labels) 403 | if task_name == "cola": 404 | return {"mcc": matthews_corrcoef(labels, preds)} 405 | elif task_name == "sst-2": 406 | return {"acc": simple_accuracy(preds, labels)} 407 | elif task_name == "mrpc": 408 | return acc_and_f1(preds, labels) 409 | elif task_name == "sts-b": 410 | return pearson_and_spearman(preds, labels) 411 | elif task_name == "qqp": 412 | return acc_and_f1(preds, labels) 413 | elif task_name == "mnli": 414 | return {"acc": simple_accuracy(preds, labels)} 415 | elif task_name == "mnli-mm": 416 | return {"acc": simple_accuracy(preds, labels)} 417 | elif task_name == "qnli": 418 | return {"acc": simple_accuracy(preds, labels)} 419 | elif task_name == "rte": 420 | return {"acc": simple_accuracy(preds, labels)} 421 | elif task_name == "wnli": 422 | return {"acc": simple_accuracy(preds, labels)} 423 | else: 424 | raise KeyError(task_name) 425 | 426 | PROCESSORS = { 427 | "cola": ColaProcessor, 428 | "mnli": MnliProcessor, 429 | "mnli-mm": MnliMismatchedProcessor, 430 | "mrpc": MrpcProcessor, 431 | "sst-2": Sst2Processor, 432 | "sts-b": StsbProcessor, 433 | "qqp": QqpProcessor, 434 | "qnli": QnliProcessor, 435 | "rte": RteProcessor, 436 | "wnli": WnliProcessor, 437 | } 438 | 439 | OUTPUT_MODES = { 440 | "cola": "classification", 441 | "mnli": "classification", 442 | "mnli-mm": "classification", 443 | "mrpc": "classification", 444 | "sst-2": "classification", 445 | "sts-b": "regression", 446 | "qqp": "classification", 447 | "qnli": "classification", 448 | "rte": "classification", 449 | "wnli": "classification", 450 | } 451 | 452 | GLUE_TASKS_NUM_LABELS = { 453 | "cola": 2, 454 | "mnli": 3, 455 | "mrpc": 2, 456 | "sst-2": 2, 457 | "sts-b": 1, 458 | "qqp": 2, 459 | "qnli": 2, 460 | "rte": 2, 461 | "wnli": 2, 462 | } 463 | 464 | # Task names with standard case 465 | GLUE_TASK_NAMES = ['CoLA', 'MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SNLI', 'SST-2', 'STS-B', 'WNLI'] 466 | -------------------------------------------------------------------------------- /agglom_clusters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from collections import defaultdict 5 | from itertools import combinations 6 | import pickle 7 | import argparse 8 | from edit_dist_utils import get_all_edit_dist_one, ed1_neighbors_mat 9 | from datetime import datetime 10 | import random 11 | import os 12 | 13 | def update_mu(C, f): 14 | weighted_cluster_assignments = C * f.reshape(-1, 1) 15 | cluster_weights = weighted_cluster_assignments.sum(axis = 0) 16 | zero_weights = np.where(cluster_weights == 0)[0] 17 | cluster_weights[zero_weights] = 1 18 | updated_mu = weighted_cluster_assignments / cluster_weights.reshape(1,-1) #Divide component-wise 19 | #Now, we want K x N, since we wnat to get mu_i by calling mu[i] 20 | return updated_mu.T 21 | 22 | def compute_centroid_distances(X, mu): 23 | #Returns an array of shape n_vocab_words x n_clusters. centroid_distances[i][j] is distance of word i to centroid j. 24 | n_vocab_words, vec_size = X.shape 25 | n_clusters, mu_vec_size = mu.shape 26 | assert mu_vec_size == vec_size 27 | centroid_distances = np.zeros(shape = (n_vocab_words, n_clusters)) 28 | for i in range(n_clusters): 29 | centroid = mu[i] 30 | distances = np.linalg.norm(X - centroid, axis = 1) 31 | assert len(distances.shape) == 1 32 | assert distances.shape[0] == n_vocab_words 33 | centroid_distances[:, i] = distances 34 | return centroid_distances 35 | 36 | 37 | def get_typo2cluster(words, word2freq, word2cluster): 38 | typo2cluster = {} 39 | typo2word = get_typo2word(words, word2freq) 40 | for typo in tqdm(typo2word): 41 | typo2cluster[typo] = word2cluster[typo2word[typo]] 42 | return typo2cluster 43 | 44 | def get_typo2word(words, word2freq): 45 | typo2words = defaultdict(list) 46 | for word in words: 47 | word = word.lower() 48 | typos = get_all_edit_dist_one(word) 49 | for typo in typos: 50 | typo2words[typo].append((word, word2freq[word])) 51 | typo2word = {} 52 | for typo in typo2words: 53 | #Hard coding that vocab words always have to map to themselves. 54 | if typo in words: 55 | typo2word[typo] = typo 56 | continue 57 | typo2words[typo].sort(key = lambda x: x[1], reverse = True) 58 | most_frequent_word = typo2words[typo][0][0] 59 | typo2word[typo] = most_frequent_word 60 | return typo2word 61 | 62 | def compute_W(words, word2idx, word2freq): 63 | #TODO, want some way to verify clusterer is only ED2 clusters. 64 | typo2word = get_typo2word(words, word2freq) 65 | word2recoverwords = defaultdict(set) 66 | for word in words: 67 | word = word.lower() 68 | typos = get_all_edit_dist_one(word) 69 | for typo in typos: 70 | recover_word = typo2word[typo] 71 | word2recoverwords[word].add(recover_word) 72 | 73 | n_words = len(words) 74 | W = np.zeros(shape = (n_words, n_words)) 75 | for word in word2recoverwords: 76 | word_idx = word2idx[word] 77 | for recovered_word in word2recoverwords[word]: 78 | recovered_idx = word2idx[recovered_word] 79 | W[word_idx][recovered_idx] = 1 80 | return W 81 | 82 | def apply_objective(X, mu, f, C, W, gamma = 0.5): 83 | n_vocab_words, n_clusters = C.shape 84 | centroid_distances = compute_centroid_distances(X, mu) 85 | A = f.T @ (C * centroid_distances).sum(axis = 1) 86 | 87 | num_words, num_words2 = W.shape 88 | assert num_words == num_words2 89 | assert num_words == n_vocab_words 90 | word_num_clusters = np.zeros(num_words) 91 | for i in range(num_words): 92 | dom_indices = W[i] 93 | relevant_clusters = C * dom_indices.reshape(-1, 1) 94 | num_relevant_clusters = relevant_clusters.max(axis = 0).sum() 95 | word_num_clusters[i] = num_relevant_clusters 96 | B = f.T @ word_num_clusters 97 | return gamma * A + (1 - gamma) * B, word_num_clusters 98 | 99 | def clusters_from_verts(C_verts, C_fixed): 100 | n_verts, n_clusters = C_verts.shape 101 | n_words, n_fixed_clusters = C_fixed.shape 102 | assert n_fixed_clusters == n_clusters 103 | C = np.zeros(shape = C_fixed.shape) 104 | cluster_assignments = C_verts.argmax(axis = 1) 105 | for vert in range(n_verts): 106 | C[C_fixed[:, vert], cluster_assignments[vert]] = 1 107 | return C 108 | 109 | def get_optimal_merge_efficient(X, mu, f, C, W, previous_objective, gamma = 0.5, combs = None): 110 | n_vocab_words, n_clusters = C.shape 111 | nonzero_clusters = np.where(np.count_nonzero(C, axis = 0) != 0)[0] 112 | cluster_freqs = np.array([np.dot(f, C[:, cluster_id]) for cluster_id in range(n_clusters)]) 113 | 114 | current_best_clusters = C.copy() 115 | current_best_change = 0 116 | 117 | found_good_merge = False 118 | current_centroid_distances = compute_centroid_distances(X, mu) 119 | cluster_dominates = [(W * C[:, cluster_id]).max(axis = 1) for cluster_id in range(n_clusters)] 120 | combination_tuple = None 121 | 122 | for combination in combinations(nonzero_clusters, 2): 123 | 124 | cluster1, cluster2 = combination 125 | if combs is not None and (cluster1, cluster2) not in combs: 126 | continue 127 | 128 | cluster1_elems = C[:, cluster1] 129 | cluster2_elems = C[:, cluster2] 130 | 131 | combined_cluster_elems = np.logical_or(cluster1_elems, cluster2_elems) 132 | 133 | cluster1_centroid = mu[cluster1] 134 | cluster2_centroid = mu[cluster2] 135 | 136 | cluster1_freq = cluster_freqs[cluster1] 137 | cluster2_freq = cluster_freqs[cluster2] 138 | 139 | new_mu = (cluster1_centroid * cluster1_freq + cluster2_centroid * cluster2_freq) / (cluster1_freq + cluster2_freq) 140 | 141 | new_centroid_dist = compute_centroid_distances(X, new_mu.reshape(1, -1)) 142 | 143 | new_weighted_centroid_dist = np.dot(f, new_centroid_dist[:, 0] * combined_cluster_elems) 144 | old_weighted_c1_dist = np.dot(f, current_centroid_distances[:, cluster1] * cluster1_elems) 145 | old_weighted_c2_dist = np.dot(f, current_centroid_distances[:, cluster2] * cluster2_elems) 146 | 147 | 148 | A_change = new_weighted_centroid_dist - old_weighted_c1_dist - old_weighted_c2_dist 149 | assert A_change >= 0 150 | 151 | dom_by_both = np.logical_and(cluster_dominates[cluster1], cluster_dominates[cluster2]) 152 | B_change = -np.dot(f, dom_by_both) 153 | assert B_change <= 0 154 | # Gain back ambiguity from things that were previously dominated by original clusters 155 | objective_change = gamma * A_change + (1 - gamma) * B_change 156 | 157 | if objective_change < current_best_change: 158 | current_best_change = objective_change 159 | og_C = C.copy() 160 | combination_tuple = (cluster1, cluster2, cluster1) 161 | C[:,cluster1] = combined_cluster_elems 162 | C[:,cluster2] = np.zeros(n_vocab_words) 163 | current_best_clusters = C.copy() 164 | C = og_C 165 | found_good_merge = True 166 | 167 | return current_best_clusters, previous_objective + current_best_change, found_good_merge, combination_tuple 168 | 169 | 170 | def get_allowable_combinations(edge_mat): 171 | allowable_combinations = set() 172 | num_vertices, num_vertices2 = edge_mat.shape 173 | assert num_vertices == num_vertices2 174 | for vtx in range(num_vertices): 175 | neighbors = np.where(edge_mat[vtx] != 0)[0] 176 | for neighbor in neighbors: 177 | if neighbor != vtx: 178 | allowable_combinations.add((vtx, neighbor)) 179 | return allowable_combinations 180 | 181 | def update_allowable_combinations(combination_tuple, prev_allowable_combinations): 182 | allowable_combinations = set() 183 | c1, c2, comb = combination_tuple 184 | assert comb == c1 or comb == c2 185 | for combination in prev_allowable_combinations: 186 | cluster1, cluster2 = combination 187 | if cluster1 in combination_tuple: 188 | if cluster2 in combination_tuple: 189 | #Two have been combined, don't need to read them 190 | continue 191 | else: 192 | allowable_combinations.add((comb, cluster2)) 193 | elif cluster2 in combination_tuple: 194 | #Implies cluster one is not... 195 | allowable_combinations.add((cluster1, comb)) 196 | else: 197 | allowable_combinations.add((cluster1, cluster2)) 198 | return allowable_combinations 199 | 200 | 201 | def merge_then_ilp(words, word2freq, gamma = 0.5, edge_mat = None, word2idx = None): 202 | if word2idx is None: 203 | word2idx = {word: i for i, word in enumerate(words)} 204 | n_words = len(words) 205 | f = np.zeros(n_words) 206 | for word in words: 207 | f[word2idx[word]] = word2freq[word] 208 | f = f / f.sum() 209 | 210 | X = np.identity(n_words) 211 | C = np.identity(n_words) 212 | W = compute_W(words, word2idx, word2freq) 213 | 214 | current_num_clusters = len(words) 215 | found_good_merge = True 216 | mu = update_mu(C, f) 217 | 218 | best_objective, word_num_clusters = apply_objective(X, mu, f, C, W, gamma = gamma) 219 | allowable_combinations = get_allowable_combinations(edge_mat) 220 | #Update while combining clusters still lowers the objective... 221 | while found_good_merge: 222 | mu = update_mu(C, f) 223 | og_C = C.copy() 224 | C, current_min_objective, found_good_merge, combination_tuple = get_optimal_merge_efficient(X, mu, f, C, W, best_objective, gamma = gamma, combs = allowable_combinations) 225 | if found_good_merge: 226 | allowable_combinations = update_allowable_combinations(combination_tuple, allowable_combinations) 227 | current_num_clusters -= 1 228 | best_objective = current_min_objective 229 | return C, best_objective, word2idx 230 | 231 | def process_ilp_output(new_cluster_assignment, word2idx, word2freq): 232 | 233 | new_clusters = defaultdict(set) 234 | word2newcluster = {} 235 | cluster2newrepresentative = {} 236 | idx2word = dict([(word2idx[word], word) for word in word2idx]) 237 | nonzero_clusters = np.where(new_cluster_assignment.max(axis = 0) != 0)[0] 238 | relevant_cluster_assignments = new_cluster_assignment[:, nonzero_clusters] 239 | new_cluster_assignments = relevant_cluster_assignments.argmax(axis = 1) 240 | assert len(new_cluster_assignments.shape) == 1 241 | for i in range(new_cluster_assignments.shape[0]): 242 | word = idx2word[i] 243 | cluster = new_cluster_assignments[i] 244 | new_clusters[cluster].add(word) 245 | word2newcluster[word] = cluster 246 | for cluster in new_clusters: 247 | word_freq_pairs = [(word, word2freq[word]) for word in new_clusters[cluster]] 248 | word_freq_pairs.sort(key = lambda x: x[1], reverse = True) 249 | cluster2newrepresentative[cluster] = word_freq_pairs[0][0] #Take the most frequent element 250 | return new_clusters, word2newcluster, cluster2newrepresentative 251 | 252 | 253 | def new_cluster_assignments(clusterer_path, gamma = 0.3, 254 | toy = False, save = True, job_num = 0, total_jobs = 1): 255 | print("Loading clusterer dict") 256 | if toy: 257 | clusterer_dict = {'cluster': {0: ['stop', 'step'], 1: ['plain', 'pin', 'pun'], 2: ['ham']}, 258 | 'word2cluster': {'stop': 0, 'step': 0, 'plain': 1, 'pin': 1, 'pun': 1, 'ham': 2}, 259 | 'word2freq': {'stop': 100, 'step': 50, 'plain': 75, 'pin': 15, 'pun': 10, 'ham': 5}, 260 | 'cluster2representative': {0: 'stop', 1: 'plain', 2: 'ham'}} 261 | else: 262 | with open(clusterer_path, 'rb') as f: 263 | clusterer_dict = pickle.load(f) 264 | word2cluster = clusterer_dict['word2cluster'] 265 | clusters = clusterer_dict['cluster'] 266 | word2freq = clusterer_dict['word2freq'] 267 | cluster2representative = clusterer_dict['cluster2representative'] 268 | words = list(word2cluster) 269 | 270 | cluster_id_iter = clusters 271 | if total_jobs > 1: 272 | num_cluster_elems = [(cluster_id, len(clusters[cluster_id])) for cluster_id in clusters] 273 | num_cluster_elems.sort(key = lambda x: x[1], reverse = True) 274 | sorted_cluster_ids = np.array([cluster_id for (cluster_id, n_elems) in num_cluster_elems]) 275 | if total_jobs == 2: 276 | if job_num == 0: 277 | job_cluster_ids = sorted_cluster_ids[:2] 278 | elif job_num == 1: 279 | job_cluster_ids = sorted_cluster_ids[2:] 280 | else: 281 | raise ValueError("Invalid job id for total jobs 2") 282 | else: 283 | job_cluster_ids = sorted_cluster_ids[job_num::total_jobs] 284 | cluster_id_iter = list(job_cluster_ids) 285 | words = [] 286 | for cluster_id in cluster_id_iter: 287 | words.extend(clusters[cluster_id]) 288 | 289 | split_clusters = {} 290 | word2split_cluster = {} 291 | split_cluster2representative = {} 292 | num_clusters_added = 0 293 | 294 | #Will use different word2freqs for each cluster to speed up computation 295 | cluster_word2freqs = defaultdict(dict) 296 | 297 | for word in words: 298 | cluster_id = word2cluster[word] 299 | cluster_word2freqs[cluster_id][word] = word2freq[word] 300 | 301 | print("Starting the preprocessing") 302 | for cluster_id in tqdm(cluster_id_iter, desc = 'Ed2 Clusters'): 303 | cluster_words = clusters[cluster_id] 304 | cluster_word2freq = cluster_word2freqs[cluster_id] 305 | print("Starting preprocessing for cluster {}, which contains {} words: ".format(cluster2representative[cluster_id], len(cluster_words))) 306 | start = datetime.now() 307 | edge_mat = ed1_neighbors_mat(cluster_words) 308 | 309 | new_cluster_assignment, loss, word2idx = merge_then_ilp(cluster_words, cluster_word2freq, gamma = gamma, edge_mat = edge_mat) 310 | print("Fishished preprocessing for cluster {}. Total time: {}: ".format(cluster2representative[cluster_id], str(datetime.now() - start))) 311 | 312 | new_clusters, word2newcluster, newcluster2representative = process_ilp_output(new_cluster_assignment, word2idx, cluster_word2freq) 313 | print("New clusters: ", new_clusters) 314 | for new_local_cluster_id in new_clusters: 315 | new_global_cluster_id = new_local_cluster_id + num_clusters_added 316 | split_clusters[new_global_cluster_id] = new_clusters[new_local_cluster_id] 317 | split_cluster2representative[new_global_cluster_id] = newcluster2representative[new_local_cluster_id] 318 | for word in word2newcluster: 319 | word2split_cluster[word] = word2newcluster[word] + num_clusters_added 320 | num_clusters_added += len(new_clusters) 321 | 322 | print("Getting typo2cluster") 323 | typo2cluster = get_typo2cluster(words, word2freq, word2split_cluster) 324 | 325 | save_dict = {'cluster': split_clusters, 'word2cluster': word2split_cluster, 326 | 'cluster2representative': split_cluster2representative, 327 | 'word2freq': word2freq, 'typo2cluster': typo2cluster} 328 | # print("About to save: ", save_dict) 329 | 330 | 331 | print("Num clusters: ", len(split_clusters)) 332 | print("Num words with clusters: ", len(word2split_cluster)) 333 | print("Num cluster representatives: ", len(split_cluster2representative)) 334 | if total_jobs == 1: 335 | split_clusterer_path = '{}_gamma{}pkl'.format(clusterer_path.strip('.pkl'), gamma) 336 | #if save and not toy: 337 | if save: 338 | print("Saving clusters for gamma = {} at {}".format(gamma, split_clusterer_path)) 339 | with open(split_clusterer_path, 'wb') as f: 340 | pickle.dump(save_dict, f) 341 | print("Saved!") 342 | else: 343 | split_clusterer_path_dir = '{}_gamma{}'.format(clusterer_path.strip('.pkl'), gamma) 344 | #if save and not toy: 345 | if save: 346 | if not os.path.isdir(split_clusterer_path_dir): 347 | os.mkdir(split_clusterer_path_dir) 348 | print("Saving clusters for gamma = {} at {}".format(gamma, split_clusterer_path_dir)) 349 | split_clusterer_path = os.path.join(split_clusterer_path_dir, 'job{}outof{}'.format(job_num, total_jobs)) 350 | with open(split_clusterer_path, 'wb') as f: 351 | pickle.dump(save_dict, f) 352 | print("Saved!") 353 | 354 | def parse_args(): 355 | parser = argparse.ArgumentParser() 356 | parser.add_argument('--gamma', type = float, required = True, 357 | help = 'How to weight the different objective functions') 358 | parser.add_argument('--clusterer_path', type = str, required = True, 359 | help = 'Connected component clusterer that is to be split') 360 | parser.add_argument('--no_save', action = 'store_true', 361 | help = 'Whether or not to avoid saving...') 362 | parser.add_argument('--toy', action = 'store_true', help = 'Use toy clusters for testing') 363 | parser.add_argument('--job_id', default = 0, type = int, help = 'Job number for parallelization') 364 | parser.add_argument('--num_jobs', default = 1, type = int, help ='Total number of parallelization jobs') 365 | 366 | args = parser.parse_args() 367 | return args 368 | 369 | 370 | 371 | if __name__ == '__main__': 372 | #toy() 373 | args = parse_args() 374 | new_cluster_assignments(args.clusterer_path, gamma = args.gamma, toy = args.toy, 375 | save = not args.no_save, job_num = args.job_id, total_jobs = args.num_jobs) 376 | 377 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | #Code adapted from pytorch_transformers: https://github.com/huggingface/transformers 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import argparse 5 | import glob 6 | import logging 7 | import os 8 | import random 9 | import json 10 | import sys 11 | import time 12 | import pickle 13 | from collections import defaultdict 14 | 15 | import numpy as np 16 | import torch 17 | from tqdm import tqdm, trange 18 | 19 | from attacks import GreedyAttacker, BeamSearchAttacker, ED1AttackSurface 20 | from recoverer import ClusterRepRecoverer, IdentityRecoverer, ScRNNRecoverer, ClusterIntprmRecoverer, RECOVERERS 21 | from transformers import TransformerRunner, ALL_MODELS, MODEL_CLASSES 22 | from utils import Clustering 23 | from utils_glue import compute_metrics, GLUE_TASK_NAMES, OUTPUT_MODES, PROCESSORS 24 | from augmentor import AUGMENTORS, IdentityAugmentor, HalfAugmentor, KAugmentor 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def set_seed(args): 31 | random.seed(args.seed) 32 | np.random.seed(args.seed) 33 | torch.manual_seed(args.seed) 34 | if args.n_gpu > 0: 35 | torch.cuda.manual_seed_all(args.seed) 36 | if args.seed_output_dir: 37 | print("Original output dir: ", args.output_dir) 38 | seed = int(time.clock() * 100000) 39 | if args.output_dir.endswith('/'): 40 | args.output_dir = args.output_dir[:-1] 41 | print("Using seed {}".format(seed)) 42 | args.output_dir = args.output_dir + '_' + str(seed) 43 | print("New output dir!: ", args.output_dir) 44 | 45 | 46 | def save_results(args, results): 47 | save_dir = args.save_dir 48 | output_dirname = os.path.basename(os.path.normpath(args.output_dir)) 49 | #Updated for codalab." 50 | filename = 'results.json' 51 | #filename = '{}_{}.json'.format(output_dirname, args.recoverer) 52 | results_dict = {'results': results, 'clusterer_path': args.clusterer_path, 'output_dir': args.output_dir, 'recoverer': args.recoverer, 53 | 'num_epochs': args.num_train_epochs, 'attack_info': [args.attack, args.attacker, args.attack_name, args.beam_width], 54 | 'augmentor': args.augmentor, 'run_test': args.run_test} 55 | results_of_interest = ['acc', 'adv_acc', 'robust_acc'] 56 | for res in results_of_interest: 57 | if res in results: 58 | results_dict[res] = results[res] 59 | print("About to save into: ", save_dir) 60 | if not os.path.exists(save_dir): 61 | os.mkdir(save_dir) 62 | save_path = os.path.join(save_dir, filename) 63 | with open(save_path, 'w') as f: 64 | json.dump(results_dict, f) 65 | 66 | def parse_args(): 67 | parser = argparse.ArgumentParser() 68 | 69 | ## Required parameters 70 | parser.add_argument("--data_dir", default=None, type=str, required=True, 71 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 72 | parser.add_argument("--task_name", default=None, type=str, required=True, 73 | help="The name of the task to train selected in the list: " + ", ".join(GLUE_TASK_NAMES)) 74 | parser.add_argument("--output_dir", default=None, type=str, required=True, 75 | help="The output directory where the model predictions and checkpoints will be written.") 76 | 77 | ## Other parameters 78 | parser.add_argument("--model_type", default='bert', type=str, 79 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 80 | parser.add_argument("--model_name_or_path", type=str, default = 'bert-base-uncased', 81 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 82 | parser.add_argument("--cache_dir", default="", type=str, 83 | help="Where to store the pre-trained models downloaded from s3") 84 | parser.add_argument("--max_seq_length", default=128, type=int, 85 | help="The maximum total input sequence length after tokenization. Sequences longer " 86 | "than this will be truncated, sequences shorter will be padded.") 87 | parser.add_argument("--do_train", action='store_true', 88 | help="Whether to run training.") 89 | parser.add_argument("--do_eval", action='store_true', 90 | help="Whether to run eval on the dev set.") 91 | parser.add_argument("--evaluate_during_training", action='store_true', 92 | help="Rul evaluation during training at each logging step.") 93 | parser.add_argument("--do_lower_case", action='store_true', 94 | help="Set this flag if you are using an uncased model.") 95 | 96 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 97 | help="Batch size per GPU/CPU for training.") 98 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 99 | help="Batch size per GPU/CPU for evaluation.") 100 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 101 | help="Number of updates steps to accumulate before performing a backward/update pass.") 102 | parser.add_argument("--learning_rate", default=2e-5, type=float, 103 | help="The initial learning rate for Adam.") 104 | parser.add_argument("--weight_decay", default=0.0, type=float, 105 | help="Weight deay if we apply some.") 106 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 107 | help="Epsilon for Adam optimizer.") 108 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 109 | help="Max gradient norm.") 110 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 111 | help="Total number of training epochs to perform.") 112 | parser.add_argument("--max_steps", default=-1, type=int, 113 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 114 | parser.add_argument("--warmup_steps", default=0, type=int, 115 | help="Linear warmup over warmup_steps.") 116 | 117 | parser.add_argument('--logging_steps', type=int, default=50, 118 | help="Log every X updates steps.") 119 | parser.add_argument('--log_stdout_only', action='store_true', 120 | help="Whether to log to stdout only") 121 | parser.add_argument('--verbose', action='store_true', help='Log verbosely') 122 | parser.add_argument('--save_steps', type=int, default=float('inf'), 123 | help="Save checkpoint every X updates steps. Will save at the end regardless.") 124 | parser.add_argument("--eval_all_checkpoints", action='store_true', 125 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 126 | parser.add_argument("--no_cuda", action='store_true', 127 | help="Avoid using CUDA when available") 128 | parser.add_argument('--overwrite_output_dir', action='store_true', 129 | help="Overwrite the content of the output directory") 130 | parser.add_argument('--overwrite_cache', action='store_true', 131 | help="Overwrite the cached training and evaluation sets") 132 | parser.add_argument('--seed', type=int, default=42, 133 | help="random seed for initialization") 134 | 135 | parser.add_argument('--fp16', action='store_true', 136 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 137 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 138 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 139 | "See details at https://nvidia.github.io/apex/amp.html") 140 | parser.add_argument('--clusterer_path', type = str, default = None, help = 141 | 'Location of clusterer to be used, if clusterer is not None') 142 | 143 | parser.add_argument('--do_robust', action = 'store_true', 144 | help = 'Compute robust accuracy (only tractable with clusters).') 145 | parser.add_argument('--robust_max_examples', type=int, default=10000, 146 | help='Give up on certifying robustness if more than this many possibilities.') 147 | parser.add_argument('--print_examples', action = 'store_true', 148 | help = 'Whether or not to print the adversarial worst case possibilities') 149 | parser.add_argument('--seed_output_dir', action = 'store_true', 150 | help = 'Whether or not we should append something to the end of the output dir.') 151 | parser.add_argument('--save_results', action = 'store_true', help = 'Additionally save detailed results.') 152 | parser.add_argument('--save_dir', type = str, default = 'full_glue_run', 153 | help = 'Directory to save the results of the run (within larger results_dir)') 154 | parser.add_argument('--recoverer', choices=RECOVERERS, default='identity', 155 | help='Which recovery strategy to use (default: do nothing)') 156 | parser.add_argument('--augmentor', choices=AUGMENTORS, default = 'identity', help = 'How to augment data for training...') 157 | parser.add_argument('--lm_num_possibilities', type = int, default = 10, 158 | help = 'Number of highest frequency options to consider per cluster in LM recovery (for efficiency)') 159 | parser.add_argument('--lm_ngram_size', type = int, default = 2, 160 | help = 'Max size of n-grams in n-gram NgramLMRecoverer.') 161 | parser.add_argument('--error_analysis', action = 'store_true', help = 'Print out error analysis') 162 | parser.add_argument('--tc_dir', type = str, default = 'scRNN', 163 | help = 'Directory where typo-correctors and vocabs are stored (when recoverer is scrnn)') 164 | 165 | #Attack parameters 166 | parser.add_argument('--attack', action = 'store_true', help = 'Attack the clean model') 167 | parser.add_argument('--new_attack', action = 'store_true', help = 'Avoid loading from cached attack if it exists.') 168 | parser.add_argument('--attack_save_dir', type = str, default = 'attack_cache', 169 | help = 'Location where the preprocessed attack files will be stored') 170 | parser.add_argument('--attack_name', type = str, default = 'DeleteAttack', 171 | help = 'Name of the attack.') 172 | parser.add_argument('--attacker', type = str, choices = ['greedy', 'beam-search'], default = 'greedy', 173 | help = 'Type of attack search strategy to use.') 174 | parser.add_argument('--attack_type', type = str, choices = ['ed1', 'intprm'], default = 'ed1', 175 | help = 'Attack with edit distance one typos, or internal perturbations') 176 | parser.add_argument('--beam_width', type = int, default = 5, help = 'Width for beam search if used...') 177 | parser.add_argument('--analyze_res_attacks', action = 'store_true', 178 | help = 'Consider worst-case accuracy for different numbers of perturbations') 179 | parser.add_argument('--save_every_epoch', action = 'store_true', 180 | help = 'Save checkpoints after every epoch, as opposed to just the last epoch...') 181 | parser.add_argument('--run_test', action = 'store_true', 182 | help = 'Evaluate on GLUE dev data, as opposed to a held out fraction of the training set.') 183 | parser.add_argument('--compute_ball_stats', action = 'store_true', help = 'Save the statistics about B_alpha') 184 | parser.add_argument('--compute_pred_stats', action = 'store', help = 'Store predictions for each rep...') 185 | return parser.parse_args() 186 | 187 | 188 | def get_data(args): 189 | # Prepare GLUE task 190 | if args.task_name not in PROCESSORS: 191 | raise ValueError("Task not found: %s" % (args.task_name)) 192 | processor = PROCESSORS[args.task_name]() 193 | augmentor = AUGMENTORS[args.augmentor]() 194 | output_mode = OUTPUT_MODES[args.task_name] 195 | label_list = processor.get_labels() 196 | train_data = processor.get_train_examples(args.data_dir) 197 | if args.run_test: 198 | dev_data = processor.get_dev_examples(args.data_dir) 199 | else: 200 | num_train_examples = int(len(train_data) * 0.8) 201 | dev_data = train_data[num_train_examples:] 202 | train_data = train_data[:num_train_examples] 203 | #Augmenting dataset 204 | train_data = augmentor.augment_dataset(train_data) 205 | args.output_mode = output_mode 206 | print("Train data len: {}, dev data len: {}".format(len(train_data), len(dev_data))) 207 | return train_data, dev_data, label_list 208 | 209 | def get_recoverer(args): 210 | cache_dir = args.output_dir 211 | if args.recoverer == 'identity': 212 | return IdentityRecoverer(cache_dir) 213 | elif args.recoverer == 'scrnn': 214 | return ScRNNRecoverer(cache_dir, args.tc_dir, args.task_name) 215 | elif args.recoverer.startswith('clust'): 216 | clustering = Clustering.from_pickle( 217 | args.clusterer_path, max_num_possibilities=args.lm_num_possibilities) 218 | if args.recoverer == 'clust-rep': 219 | return ClusterRepRecoverer(cache_dir, clustering) 220 | elif args.recoverer == 'clust-intprm': 221 | return ClusterIntprmRecoverer(cache_dir, clustering) 222 | raise ValueError(args.recoverer) 223 | 224 | def get_model_runner(args, recoverer, label_list): 225 | return TransformerRunner( 226 | recoverer, args.output_mode, label_list, args.output_dir, args.device, 227 | args.task_name, args.model_type, args.model_name_or_path, 228 | args.do_lower_case, args.max_seq_length) 229 | 230 | def get_attacker(args, model_runner): 231 | if args.attacker == 'greedy': 232 | return GreedyAttacker(args.attack_name, args.task_name, model_runner, 233 | args.attack_save_dir, args,) 234 | elif args.attacker == 'beam-search': 235 | print("Returning beam search attacker...") 236 | return BeamSearchAttacker(args.attack_name, args.task_name, model_runner, 237 | args.attack_save_dir, args) 238 | else: 239 | raise ValueError(args.attacker) 240 | 241 | def compute_ball_stats(dataset, model_runner, args, robust_max_examples = float('inf')): 242 | assert args.recoverer == 'clust-rep' 243 | assert args.run_test 244 | attack_surface = ED1AttackSurface() 245 | id2sizes = {} 246 | for ex in tqdm(dataset, desc = 'Getting example statistics'): 247 | text = ex.text_a 248 | if ex.text_b: 249 | text = '{} {}'.format(text, ex.text_b) 250 | clust_sizes, perturb_sizes = model_runner.recoverer.get_possible_recoveries( 251 | text, attack_surface, robust_max_examples, ret_ball_stats = True) 252 | id2sizes[ex.guid] = clust_sizes, perturb_sizes 253 | return id2sizes 254 | 255 | def save_ball_stats(args, ball_stats): 256 | task = args.task_name 257 | ids = list(ball_stats) 258 | clusterer_name = os.path.basename(args.clusterer_path) 259 | if not os.path.exists('results/stats'): 260 | os.mkdir('results/stats') 261 | task_dir = os.path.join('results/stats', task) 262 | if not os.path.exists(task_dir): 263 | os.mkdir(task_dir) 264 | stats_dir = os.path.join(task_dir, clusterer_name) 265 | if not os.path.exists(stats_dir): 266 | os.mkdir(stats_dir) 267 | clust_stats_fn = os.path.join(stats_dir, 'ball_stats.txt') 268 | pert_stats_fn = os.path.join(stats_dir, 'perturbation_stats.txt') 269 | clust_lines = [] 270 | perturb_lines = [] 271 | for guid in ids: 272 | clust_sizes, pert_sizes = ball_stats[guid] 273 | clust_size_str = ','.join([str(size) for size in clust_sizes]) 274 | pert_size_str = ','.join([str(size) for size in pert_sizes]) 275 | clust_line = '{}:{}\n'.format(guid, clust_size_str) 276 | pert_line = '{}:{}\n'.format(guid, pert_size_str) 277 | clust_lines.append(clust_line) 278 | perturb_lines.append(pert_line) 279 | with open(clust_stats_fn, 'w') as f: 280 | f.writelines(clust_lines) 281 | with open(pert_stats_fn, 'w') as f: 282 | f.writelines(perturb_lines) 283 | 284 | def evaluate(model_runner, dataset, batch_size, do_robust=False, robust_max_examples=10000, analyze_res_attacks = False): 285 | # Need to run recoverer manually on queries 286 | # Because we're using recoverer.get_possible_examples 287 | all_examples = [model_runner.recoverer.recover_example(ex) for ex in tqdm(dataset, desc='Recovering dev')] 288 | 289 | if do_robust: 290 | num_reps = defaultdict(lambda: 0) 291 | attack_surface = ED1AttackSurface() 292 | id_to_poss_ids = {} 293 | num_poss_list = [] 294 | num_exceed_max = 0 295 | for ex in tqdm(dataset, desc='Getting possible recoveries'): 296 | poss_exs, num_poss = model_runner.recoverer.get_possible_examples( 297 | ex, attack_surface, robust_max_examples, analyze_res_attacks = analyze_res_attacks) 298 | num_reps[num_poss] += 1 299 | num_poss_list.append(num_poss) 300 | if poss_exs: 301 | all_examples.extend(poss_exs) 302 | id_to_poss_ids[ex.guid] = [x.guid for x in poss_exs] 303 | else: 304 | num_exceed_max += 1 305 | median = sorted(num_poss_list)[int(len(num_poss_list) / 2)] 306 | num_poss_under_thresh = [x for x in num_poss_list if x <= robust_max_examples] 307 | avg_under_thresh = sum(num_poss_under_thresh) / len(num_poss_under_thresh) 308 | print('Robust eval: %d median poss/ex; %d/%d exceed max of %d; %.1f avg poss/ex on remainder' % ( 309 | median, num_exceed_max, len(dataset), robust_max_examples, avg_under_thresh)) 310 | num_correct = 0 311 | if do_robust: 312 | num_robust = 0 313 | preds = model_runner.query(all_examples, batch_size, do_evaluate=not do_robust, do_recover=False) 314 | id_to_pred = {all_examples[i].guid: preds[i] for i in range(len(all_examples))} 315 | num_with_incorrect = 0 316 | assert(len(id_to_pred) == len(all_examples)) 317 | if analyze_res_attacks: 318 | maxp2ncorrect = defaultdict(lambda: 0) 319 | for ex in dataset: 320 | pred = id_to_pred[ex.guid] 321 | was_correct = False 322 | if pred == ex.label: 323 | was_correct = True 324 | num_correct += 1 325 | if do_robust: 326 | if ex.guid not in id_to_poss_ids: 327 | continue 328 | poss_ids = id_to_poss_ids[ex.guid] 329 | cur_preds = set([id_to_pred[pid] for pid in poss_ids]) 330 | if analyze_res_attacks: 331 | incorrect_preds = [int(pid.split('-')[3]) for pid in poss_ids if id_to_pred[pid] != ex.label] 332 | zero_count = len([pred for pred in incorrect_preds if pred == 0]) 333 | assert zero_count < 2 334 | if len(incorrect_preds) != 0: 335 | num_with_incorrect += 1 336 | #Only adding things that were wrong initially 337 | min_required_changes = min(incorrect_preds) 338 | if was_correct: 339 | if min_required_changes == 0: 340 | print(incorrect_preds) 341 | print(poss_ids) 342 | print([id_to_pred[poss_id] for poss_id in poss_ids]) 343 | print(ex.guid, id_to_pred[ex.guid]) 344 | print("Label: ", ex.label) 345 | print("Cur preds: ", cur_preds) 346 | assert False 347 | 348 | #Adds things where an incorrect prediction was found... 349 | for n_changes in range(min_required_changes): 350 | maxp2ncorrect[n_changes] += 1 351 | 352 | if all(x == ex.label for x in cur_preds): 353 | num_robust += 1 354 | print('Normal accuracy: %d/%d = %.2f%%' % (num_correct, len(dataset), 100.0 * num_correct / len(dataset))) 355 | results = {'acc': num_correct / len(dataset)} 356 | if do_robust: 357 | print('Robust accuracy: %d/%d = %.2f%%' % (num_robust, len(dataset), 100.0 * num_robust / len(dataset))) 358 | results['robust_acc'] = num_robust / len(dataset) 359 | if analyze_res_attacks: 360 | print(maxp2ncorrect) 361 | for n_changes in maxp2ncorrect: 362 | maxp2ncorrect[n_changes] = (maxp2ncorrect[n_changes] + num_robust) / len(dataset) 363 | maxp2ncorrect[len(maxp2ncorrect)] = num_robust / len(dataset) 364 | maxp2ncorrect = dict(maxp2ncorrect) 365 | results['restricted_acc_dict'] = maxp2ncorrect 366 | print(maxp2ncorrect) 367 | results['avg_under_thresh'] = avg_under_thresh 368 | results['median_under_thresh'] = median 369 | results['num_exceed_max'] = num_exceed_max 370 | #results['num_rep_analysis'] = dict(num_reps) 371 | print(dict(num_reps)) 372 | 373 | return results 374 | 375 | def main(): 376 | args = parse_args() 377 | args.task_name = args.task_name.lower() 378 | if args.save_results: 379 | if not args.do_eval: 380 | raise ValueError("Must evaluate to save results (i.e. use --do_eval)") 381 | 382 | # Setup 383 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 384 | args.n_gpu = torch.cuda.device_count() 385 | args.device = device 386 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 387 | set_seed(args) 388 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 389 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 390 | if not os.path.exists(args.output_dir): 391 | os.makedirs(args.output_dir) 392 | if args.log_stdout_only: 393 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 394 | else: 395 | logging.basicConfig(filename=os.path.join(args.output_dir, 'log.txt'), level=logging.DEBUG) 396 | 397 | # Get data and model 398 | train_data, dev_data, label_list = get_data(args) 399 | recoverer = get_recoverer(args) 400 | model_runner = get_model_runner(args, recoverer, label_list) 401 | logger.info("Training/evaluation parameters %s", args) 402 | 403 | # Run training and evaluation 404 | if args.do_train: 405 | train_recovered = [recoverer.recover_example(x) for x in tqdm(train_data, desc='Recovering train')] 406 | model_runner.train(train_recovered, args) 407 | if args.compute_ball_stats: 408 | ball_stats_dict = compute_ball_stats(dev_data, model_runner, args, robust_max_examples = 10000) 409 | save_ball_stats(args, ball_stats_dict) 410 | if args.do_eval: 411 | results = evaluate(model_runner, dev_data, args.eval_batch_size, 412 | do_robust=args.do_robust, robust_max_examples=args.robust_max_examples, 413 | analyze_res_attacks = args.analyze_res_attacks) 414 | if args.attack: 415 | attacker = get_attacker(args, model_runner) 416 | adv_data = attacker.attack_dataset(dev_data) 417 | print('Running adversarial evaluation.') 418 | adv_results = evaluate(model_runner, adv_data, args.eval_batch_size) 419 | for k, v in adv_results.items(): 420 | results['adv_{}'.format(k)] = v 421 | print('Results: {}'.format(json.dumps(results))) 422 | with open(os.path.join(args.output_dir, 'results.json'), 'w') as f: 423 | json.dump(results, f) 424 | if args.save_results: 425 | save_results(args, results) 426 | #recoverer.save_cache() 427 | 428 | if __name__ == "__main__": 429 | main() 430 | -------------------------------------------------------------------------------- /transformers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Models based on pytorch-transformers repo for GLUE (Bert, XLM, XLNet).""" 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import collections 20 | import glob 21 | import logging 22 | import os 23 | import random 24 | import json 25 | import time 26 | import pickle 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 31 | TensorDataset) 32 | from torch.utils.data.distributed import DistributedSampler 33 | from tensorboardX import SummaryWriter 34 | from tqdm import tqdm, trange 35 | 36 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 37 | BertForSequenceClassification, BertTokenizer, 38 | XLMConfig, XLMForSequenceClassification, 39 | XLMTokenizer, XLNetConfig, 40 | XLNetForSequenceClassification, 41 | XLNetTokenizer) 42 | 43 | from pytorch_transformers import AdamW, WarmupLinearSchedule 44 | 45 | from utils import ModelRunner 46 | from utils_glue import compute_metrics 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) 51 | 52 | MODEL_CLASSES = { 53 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 54 | 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 55 | 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 56 | } 57 | 58 | 59 | class InputFeatures(object): 60 | """A single set of features of data.""" 61 | 62 | def __init__(self, input_ids, input_mask, segment_ids, label_id, og_text = None, input_text = None, example_idx = None): 63 | self.input_ids = input_ids 64 | self.input_mask = input_mask 65 | self.segment_ids = segment_ids 66 | self.label_id = label_id 67 | self.og_text = og_text 68 | self.input_text = input_text 69 | self.example_idx = example_idx 70 | 71 | 72 | class TransformerRunner(ModelRunner): 73 | def __init__(self, recoverer, output_mode, label_list, output_dir, device, task_name, 74 | model_type, model_name_or_path, do_lower_case, max_seq_length): 75 | super(TransformerRunner, self).__init__(recoverer, output_mode, label_list, output_dir, device) 76 | self.task_name = task_name 77 | self.model_type = model_type 78 | self.do_lower_case = do_lower_case 79 | self.max_seq_length = max_seq_length 80 | config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type] 81 | self.model_class = model_class 82 | self.tokenizer_class = tokenizer_class 83 | config = config_class.from_pretrained(model_name_or_path, num_labels=len(label_list), 84 | finetuning_task=task_name) 85 | self.tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=do_lower_case) 86 | self.model = model_class.from_pretrained( 87 | model_name_or_path, from_tf=bool('.ckpt' in model_name_or_path), config=config) 88 | self.model.to(device) 89 | 90 | def _prep_examples(self, examples, verbose=False): 91 | features = convert_examples_to_features( 92 | examples, self.label_list, self.max_seq_length, self.tokenizer, self.output_mode, 93 | cls_token_at_end=bool(self.model_type in ['xlnet']), # xlnet has a cls token at the end 94 | cls_token=self.tokenizer.cls_token, 95 | sep_token=self.tokenizer.sep_token, 96 | cls_token_segment_id=2 if self.model_type in ['xlnet'] else 0, 97 | pad_on_left=bool(self.model_type in ['xlnet']), # pad on the left for xlnet 98 | pad_token_segment_id=4 if self.model_type in ['xlnet'] else 0, 99 | verbose=verbose) 100 | 101 | # Convert to Tensors and build dataset 102 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 103 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 104 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 105 | if self.output_mode == "classification": 106 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 107 | elif self.output_mode == "regression": 108 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) 109 | all_text_ids = torch.tensor([f.example_idx for f in features], dtype = torch.long) 110 | 111 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_text_ids) 112 | return dataset 113 | 114 | def train(self, train_data, args): 115 | print("Preparing examples.") 116 | train_dataset = self._prep_examples(train_data, verbose=args.verbose) 117 | print("Starting training.") 118 | global_step, tr_loss, train_results = train(args, train_dataset, self.model, self.tokenizer) 119 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 120 | 121 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 122 | # They can then be reloaded using `from_pretrained()` 123 | logger.info("Saving model checkpoint to %s", self.output_dir) 124 | model_to_save = model.module if hasattr(self.model, 'module') else self.model # Take care of distributed/parallel training 125 | model_to_save.save_pretrained(self.output_dir) 126 | self.tokenizer.save_pretrained(self.output_dir) 127 | torch.save(args, os.path.join(self.output_dir, 'training_args.bin')) 128 | 129 | # Reload model 130 | self.load(self.output_dir, self.device) 131 | print("Finished training.") 132 | 133 | def load(self, output_dir, device): 134 | self.model = self.model_class.from_pretrained(output_dir) 135 | self.tokenizer = self.tokenizer_class.from_pretrained(output_dir) 136 | self.model.to(self.device) 137 | 138 | def query(self, examples, batch_size, do_evaluate=True, return_logits=False, 139 | do_recover=True, use_tqdm=True): 140 | if do_recover: 141 | examples = [self.recoverer.recover_example(x) for x in examples] 142 | dataset = self._prep_examples(examples) 143 | eval_sampler = SequentialSampler(dataset) # Makes sure order is correct 144 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=batch_size) 145 | 146 | # Eval! 147 | logger.info("***** Querying model *****") 148 | logger.info(" Num examples = %d", len(examples)) 149 | logger.info(" Batch size = %d", batch_size) 150 | eval_loss = 0.0 151 | nb_eval_steps = 0 152 | preds = None 153 | out_label_ids = None 154 | example_idxs = None 155 | self.model.eval() 156 | if use_tqdm: 157 | eval_dataloader = tqdm(eval_dataloader, desc="Querying") 158 | for batch in eval_dataloader: 159 | batch = tuple(t.to(self.device) for t in batch) 160 | 161 | with torch.no_grad(): 162 | inputs = {'input_ids': batch[0], 163 | 'attention_mask': batch[1], 164 | 'token_type_ids': batch[2] if self.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids 165 | 'labels': batch[3]} 166 | outputs = self.model(**inputs) 167 | inputs['example_idxs'] = batch[4] 168 | tmp_eval_loss, logits = outputs[:2] 169 | 170 | eval_loss += tmp_eval_loss.mean().item() 171 | 172 | nb_eval_steps += 1 173 | if preds is None: 174 | preds = logits.detach().cpu().numpy() 175 | out_label_ids = inputs['labels'].detach().cpu().numpy() 176 | example_idxs = inputs['example_idxs'].detach().cpu().numpy() 177 | else: 178 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 179 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 180 | example_idxs = np.append(example_idxs, inputs['example_idxs'].detach().cpu().numpy(), axis = 0) 181 | 182 | eval_loss = eval_loss / nb_eval_steps 183 | logger.info(' eval_loss = %.6f', eval_loss) 184 | incorrect_example_indices = None 185 | if self.output_mode == "classification": 186 | pred_argmax = np.argmax(preds, axis=1) 187 | pred_labels = [self.label_list[pred_argmax[i]] for i in range(len(examples))] 188 | incorrect_example_indices = set(example_idxs[np.not_equal(pred_argmax, out_label_ids)]) 189 | 190 | elif self.output_mode == "regression": 191 | preds = np.squeeze(preds) 192 | 193 | if do_evaluate: 194 | result = compute_metrics(self.task_name, pred_argmax, out_label_ids) 195 | output_eval_file = os.path.join(self.output_dir, "eval-{}.txt".format(self.task_name)) 196 | #print("Possible predictions: ", set(list(preds))) 197 | #priny("Model predictions: mean: {}, max: {}, min: {}".format(preds.mean(), preds.max(), preds.min())) 198 | with open(output_eval_file, "w") as writer: 199 | logger.info("***** Eval results *****") 200 | for key in sorted(result.keys()): 201 | logger.info(" %s = %s", key, str(result[key])) 202 | writer.write("%s = %s\n" % (key, str(result[key]))) 203 | 204 | if return_logits: 205 | return preds 206 | else: 207 | return pred_labels 208 | 209 | 210 | def convert_examples_to_features(examples, label_list, max_seq_length, 211 | tokenizer, output_mode, 212 | cls_token_at_end=False, pad_on_left=False, 213 | cls_token='[CLS]', sep_token='[SEP]', pad_token=0, 214 | sequence_a_segment_id=0, sequence_b_segment_id=1, 215 | cls_token_segment_id=1, pad_token_segment_id=0, 216 | mask_padding_with_zero=True, 217 | verbose = False): 218 | """ Loads a data file into a list of `InputBatch`s 219 | `cls_token_at_end` define the location of the CLS token: 220 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 221 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 222 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 223 | """ 224 | 225 | label_map = {label : i for i, label in enumerate(label_list)} 226 | 227 | features = [] 228 | no_oov_examples = 0 229 | total_examples = 0 230 | label_distribution = collections.defaultdict(lambda: 0) 231 | for (ex_index, example) in enumerate(examples): 232 | total_examples += 1 233 | if ex_index % 10000 == 0 and verbose: 234 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 235 | exists_oov = False 236 | a_text = example.text_a 237 | tokens_a = tokenizer.tokenize(a_text) 238 | tokens_b = None 239 | if example.text_b: 240 | b_text = example.text_b 241 | tokens_b = tokenizer.tokenize(b_text) 242 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 243 | else: 244 | # Account for [CLS] and [SEP] with "- 2" 245 | if len(tokens_a) > max_seq_length - 2: 246 | tokens_a = tokens_a[:(max_seq_length - 2)] 247 | if not exists_oov: 248 | no_oov_examples += 1 249 | 250 | # The convention in BERT is: 251 | # (a) For sequence pairs: 252 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 253 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 254 | # (b) For single sequences: 255 | # tokens: [CLS] the dog is hairy . [SEP] 256 | # type_ids: 0 0 0 0 0 0 0 257 | # 258 | # Where "type_ids" are used to indicate whether this is the first 259 | # sequence or the second sequence. The embedding vectors for `type=0` and 260 | # `type=1` were learned during pre-training and are added to the wordpiece 261 | # embedding vector (and position vector). This is not *strictly* necessary 262 | # since the [SEP] token unambiguously separates the sequences, but it makes 263 | # it easier for the model to learn the concept of sequences. 264 | # 265 | # For classification tasks, the first vector (corresponding to [CLS]) is 266 | # used as as the "sentence vector". Note that this only makes sense because 267 | # the entire model is fine-tuned. 268 | tokens = tokens_a + [sep_token] 269 | segment_ids = [sequence_a_segment_id] * len(tokens) 270 | 271 | if tokens_b: 272 | tokens += tokens_b + [sep_token] 273 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) 274 | 275 | if cls_token_at_end: 276 | tokens = tokens + [cls_token] 277 | segment_ids = segment_ids + [cls_token_segment_id] 278 | else: 279 | tokens = [cls_token] + tokens 280 | segment_ids = [cls_token_segment_id] + segment_ids 281 | 282 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 283 | 284 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 285 | # tokens are attended to. 286 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 287 | 288 | # Zero-pad up to the sequence length. 289 | padding_length = max_seq_length - len(input_ids) 290 | if pad_on_left: 291 | input_ids = ([pad_token] * padding_length) + input_ids 292 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 293 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 294 | else: 295 | input_ids = input_ids + ([pad_token] * padding_length) 296 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 297 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) 298 | 299 | assert len(input_ids) == max_seq_length 300 | assert len(input_mask) == max_seq_length 301 | assert len(segment_ids) == max_seq_length 302 | 303 | if output_mode == "classification": 304 | label_id = label_map[example.label] 305 | elif output_mode == "regression": 306 | label_id = float(example.label) 307 | else: 308 | raise KeyError(output_mode) 309 | 310 | if ex_index < 5 and verbose: 311 | logger.info("*** Example ***") 312 | logger.info("guid: %s" % (example.guid)) 313 | logger.info("tokens: %s" % " ".join( 314 | [str(x) for x in tokens])) 315 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 316 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 317 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 318 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 319 | 320 | original_text = example.text_a if not example.text_b else (example.text_a, example.text_b) 321 | input_text = a_text if not example.text_b else (a_text, b_text) 322 | 323 | features.append( 324 | InputFeatures(input_ids=input_ids, 325 | input_mask=input_mask, 326 | segment_ids=segment_ids, 327 | label_id=label_id, 328 | og_text = original_text, 329 | input_text = input_text, 330 | example_idx = ex_index)) 331 | label_distribution[label_id] += 1 332 | if verbose: 333 | for label in label_distribution: 334 | print("Label: {} Percentage: {}".format(label, label_distribution[label] / total_examples)) 335 | #print("Number of examples without oov: {}/{} = {}".format(no_oov_examples, total_examples, no_oov_examples / total_examples)) 336 | return features 337 | 338 | 339 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 340 | """Truncates a sequence pair in place to the maximum length.""" 341 | 342 | # This is a simple heuristic which will always truncate the longer sequence 343 | # one token at a time. This makes more sense than truncating an equal percent 344 | # of tokens from each, since if one sequence is very short then each token 345 | # that's truncated likely contains more information than a longer sequence. 346 | while True: 347 | total_length = len(tokens_a) + len(tokens_b) 348 | if total_length <= max_length: 349 | break 350 | if len(tokens_a) > len(tokens_b): 351 | tokens_a.pop() 352 | else: 353 | tokens_b.pop() 354 | 355 | def train(args, train_dataset, model, tokenizer): 356 | """ Train the model """ 357 | tb_writer = SummaryWriter(os.path.join(args.output_dir, 'runs')) 358 | 359 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 360 | train_sampler = RandomSampler(train_dataset) 361 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 362 | 363 | if args.max_steps > 0: 364 | t_total = args.max_steps 365 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 366 | else: 367 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 368 | 369 | # Prepare optimizer and schedule (linear warmup and decay) 370 | no_decay = ['bias', 'LayerNorm.weight'] 371 | optimizer_grouped_parameters = [ 372 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 373 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 374 | ] 375 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 376 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 377 | if args.fp16: 378 | try: 379 | from apex import amp 380 | except ImportError: 381 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 382 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 383 | 384 | # multi-gpu training (should be after apex fp16 initialization) 385 | if args.n_gpu > 1: 386 | model = torch.nn.DataParallel(model) 387 | 388 | # Train! 389 | logger.info("***** Running training *****") 390 | logger.info(" Num examples = %d", len(train_dataset)) 391 | logger.info(" Num Epochs = %d", args.num_train_epochs) 392 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 393 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 394 | logger.info(" Total optimization steps = %d", t_total) 395 | 396 | global_step = 0 397 | tr_loss, logging_loss = 0.0, 0.0 398 | model.zero_grad() 399 | #train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 400 | #set_seed(args) # Added here for reproductibility keeping the seed the same... 401 | # TODO(robinjia): does calling set_seed a second time matter? 402 | train_results = {} 403 | for epoch in range(int(args.num_train_epochs)): 404 | preds = None 405 | out_label_ids = None 406 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 407 | for step, batch in enumerate(epoch_iterator): 408 | model.train() 409 | batch = tuple(t.to(args.device) for t in batch) 410 | inputs = {'input_ids': batch[0], 411 | 'attention_mask': batch[1], 412 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids 413 | 'labels': batch[3]} 414 | outputs = model(**inputs) 415 | loss, logits = outputs[:2] # model outputs are always tuple in pytorch-transformers (see doc) 416 | 417 | if preds is None: 418 | preds = logits.detach().cpu().numpy() 419 | out_label_ids = inputs['labels'].detach().cpu().numpy() 420 | else: 421 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 422 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 423 | 424 | 425 | if args.n_gpu > 1: 426 | loss = loss.mean() # mean() to average on multi-gpu parallel training 427 | if args.gradient_accumulation_steps > 1: 428 | loss = loss / args.gradient_accumulation_steps 429 | 430 | if args.fp16: 431 | with amp.scale_loss(loss, optimizer) as scaled_loss: 432 | scaled_loss.backward() 433 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 434 | else: 435 | loss.backward() 436 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 437 | 438 | tr_loss += loss.item() 439 | if (step + 1) % args.gradient_accumulation_steps == 0: 440 | optimizer.step() 441 | scheduler.step() # Update learning rate schedule 442 | model.zero_grad() 443 | global_step += 1 444 | 445 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 446 | # Log metrics 447 | if args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 448 | raise NotImplementedError 449 | # TODO: make evaluation happen below 450 | #results = evaluate(args, model, tokenizer) 451 | #for key, value in results.items(): 452 | # tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 453 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 454 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 455 | logging_loss = tr_loss 456 | 457 | if args.save_steps > 0 and global_step % args.save_steps == 0: 458 | # Save model checkpoint 459 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 460 | if not os.path.exists(output_dir): 461 | os.makedirs(output_dir) 462 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 463 | model_to_save.save_pretrained(output_dir) 464 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 465 | logger.info("Saving model checkpoint to %s", output_dir) 466 | 467 | if args.max_steps > 0 and global_step > args.max_steps: 468 | epoch_iterator.close() 469 | break 470 | if args.save_every_epoch: 471 | output_dir = os.path.join(args.output_dir, 'checkpoint-epoch{}'.format(epoch)) 472 | if not os.path.exists(output_dir): 473 | os.makedirs(output_dir) 474 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 475 | model_to_save.save_pretrained(output_dir) 476 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 477 | logger.info("Saving model checkpoint to %s", output_dir) 478 | if args.output_mode == "classification": 479 | preds = np.argmax(preds, axis=1) 480 | elif args.output_mode == "regression": 481 | preds = np.squeeze(preds) 482 | results = compute_metrics(args.task_name, preds, out_label_ids) 483 | train_results[epoch] = results 484 | print("Train results: ", train_results) 485 | if args.max_steps > 0 and global_step > args.max_steps: 486 | train_iterator.close() 487 | break 488 | 489 | tb_writer.close() 490 | #TODO, hacky but saves more significant restructuring... 491 | args.train_results = train_results 492 | return global_step, tr_loss / global_step, train_results 493 | --------------------------------------------------------------------------------