├── data ├── example.png ├── pivot_analysis_camera_ready.pdf └── stop_words.pkl ├── .gitignore ├── src ├── pivot_classifier.py ├── to_model_format.py ├── nlp_pipeline.py ├── config.py ├── main.py ├── classifiers.py ├── model_analysis.py └── data_utils.py └── readme.md /data/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FranxYao/pivot_analysis/HEAD/data/example.png -------------------------------------------------------------------------------- /data/pivot_analysis_camera_ready.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FranxYao/pivot_analysis/HEAD/data/pivot_analysis_camera_ready.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/.vscode 2 | data/.DS_Store 3 | data/amazon 4 | data/amazon_transfer 5 | data/caption 6 | data/gender 7 | data/gender_transfer 8 | data/paper 9 | data/politics 10 | data/reddit 11 | data/twitter 12 | data/yelp 13 | data/yelp_transfer 14 | pivot_analysis.code-workspace 15 | .DS_Store 16 | src/*.txt 17 | src/*.pyc 18 | src/__pycache__ 19 | outputs/transfer 20 | outputs/prec_recl_outputs 21 | outputs/*.pivot 22 | outputs/*.sent 23 | outputs/*.sent_hard 24 | outputs/prec_recl_outputs.zip 25 | outputs/*.npy 26 | local/ 27 | .vscode/ 28 | data.zip 29 | -------------------------------------------------------------------------------- /src/pivot_classifier.py: -------------------------------------------------------------------------------- 1 | """The pivot classifier 2 | 3 | Yao Fu, Columbia University 4 | yao.fu@columbia.edu 5 | Tue Jun 18, 2019 6 | """ 7 | 8 | import numpy as np 9 | 10 | from sklearn.metrics import classification_report 11 | from tqdm import tqdm 12 | from pprint import pprint 13 | 14 | class PivotClassifier(object): 15 | 16 | def __init__(self, pos_index, neg_index): 17 | """ 18 | Args: 19 | pos_index: the list of positive word index 20 | neg_index: the list of negative word index 21 | """ 22 | # print('pos index:') 23 | # print(pos_index[:20]) 24 | # print('neg index:') 25 | # print(neg_index[:20]) 26 | self.pos = set(pos_index) 27 | self.neg = set(neg_index) 28 | return 29 | 30 | def classify(self, sent): 31 | """Classify a sentence 32 | 33 | Args: 34 | sent: a list of word index 35 | """ 36 | sent_ = set(sent) 37 | # print(sent_) 38 | pos_cnt = len(self.pos & sent_) 39 | neg_cnt = len(self.neg & sent_) 40 | # print(pos_cnt, neg_cnt, pos_cnt > neg_cnt) 41 | if(pos_cnt >= neg_cnt): 42 | # p = float(pos_cnt) / len(sent_) 43 | p = 0 44 | return 1, p 45 | else: 46 | # p = float(neg_cnt) / len(sent_) 47 | p = 0 48 | return 0, p 49 | 50 | def classify_dataset(self, sentences, word2id): 51 | """classify a dataset 52 | 53 | Args: 54 | sentences: the list of sentences, a sentence is a list of words 55 | 56 | Returns: 57 | outputs: a numpy array with value 0=negative, 1=positive 58 | """ 59 | outputs = [] 60 | # for s in tqdm(sentences): 61 | for s in sentences: 62 | s_ = [] 63 | for w in s: 64 | if(w in word2id): s_.append(word2id[w]) 65 | outputs.append(self.classify(s_)[0]) 66 | return np.array(outputs) 67 | 68 | def test(self, sentences, labels, word2id, setname): 69 | """Classify a dataset and measure the performance 70 | 71 | Args: 72 | sentences: the list of sentences, a sentence is a list of word index 73 | labels: the list of labels 74 | """ 75 | labels = np.array(labels) 76 | pred = self.classify_dataset(sentences, word2id) 77 | print(np.sum(pred == 0)) 78 | print(np.sum(pred == 1)) 79 | 80 | # performance 81 | # results = classification_report(labels, pred, output_dict=True) 82 | # pprint(results) 83 | results = np.sum(pred == labels) / float(len(labels)) 84 | print('%s accuracy: %.4f' % (setname, results)) 85 | return results -------------------------------------------------------------------------------- /src/to_model_format.py: -------------------------------------------------------------------------------- 1 | """Convert dataset format to different model requirement""" 2 | 3 | import argparse 4 | 5 | from config import Config 6 | from data_utils import build_vocab 7 | 8 | def add_arguments(config): 9 | parser = argparse.ArgumentParser(description='Command line arguments') 10 | 11 | parser.add_argument('--model', type=str, default='cmu', 12 | help='The model') 13 | parser.add_argument('--dataset', type=str, default='amazon', 14 | help='The model') 15 | return parser.parse_args() 16 | 17 | def main(): 18 | config = Config() 19 | args = add_arguments(config) 20 | 21 | # read sentences 22 | data_path = config.dataset_base_path + args.dataset + '/' 23 | 24 | train_neg = open(data_path + 'train.0').readlines() 25 | train_pos = open(data_path + 'train.1').readlines() 26 | train_labels = [0] * len(train_neg) + [1] * len(train_pos) 27 | train = train_neg + train_pos 28 | train = [s.split() for s in train] 29 | train = [s[:30] if len(s) > 30 else s for s in train] 30 | 31 | dev_neg = open(data_path + 'dev.0').readlines() 32 | dev_pos = open(data_path + 'dev.1').readlines() 33 | dev_labels = [0] * len(dev_neg) + [1] * len(dev_pos) 34 | dev = dev_neg + dev_pos 35 | dev = [s.split() for s in dev] 36 | dev = [s[:30] if len(s) > 30 else s for s in dev] 37 | 38 | test_neg = open(data_path + 'test.0').readlines() 39 | test_pos = open(data_path + 'test.1').readlines() 40 | test_labels = [0] * len(test_neg) + [1] * len(test_pos) 41 | test = test_neg + test_pos 42 | test = [s.split() for s in test] 43 | test = [s[:30] if len(s) > 30 else s for s in test] 44 | 45 | word2id, _, _, _, _, _ = build_vocab(train, filter_stop_words=0) 46 | if(args.model == 'cmu'): 47 | with open(args.dataset + '.train.text', 'w') as fd: 48 | for l in train: fd.write(' '.join(l) + '\n') 49 | with open(args.dataset + '.train.labels', 'w') as fd: 50 | for l in train_labels: fd.write(str(l) + '\n') 51 | 52 | with open(args.dataset + '.dev.text', 'w') as fd: 53 | for l in dev: fd.write(' '.join(l) + '\n') 54 | with open(args.dataset + '.dev.labels', 'w') as fd: 55 | for l in dev_labels: fd.write(str(l) + '\n') 56 | 57 | with open(args.dataset + '.test.text', 'w') as fd: 58 | for l in test: fd.write(' '.join(l) + '\n') 59 | with open(args.dataset + '.test.labels', 'w') as fd: 60 | for l in test_labels: fd.write(str(l) + '\n') 61 | 62 | with open('vocab', 'w') as fd: 63 | for w in word2id: fd.write('%s\n' % w) 64 | return 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /src/nlp_pipeline.py: -------------------------------------------------------------------------------- 1 | """The NLP data cleaning pipeline 2 | Yao Fu, Columbia University 3 | yao.fu@columabia.edu 4 | THU MAY 09TH 2019 5 | """ 6 | import numpy as np 7 | 8 | import nltk 9 | from nltk.corpus import stopwords 10 | from collections import Counter 11 | from tqdm import tqdm 12 | 13 | def normalize( 14 | sentences, word2id, start, end, unk, pad, max_src_len, max_tgt_len): 15 | """Normalize the sentences by the following procedure 16 | - word to index 17 | - add unk 18 | - add start, end 19 | - pad/ cut the sentence length 20 | - record the sentence length 21 | Returns: 22 | sent_normalized: the normalized sentences, a list of (src, tgt) pairs 23 | sent_lens: the sentence length, a list of (src_len, tgt_len) pairs 24 | """ 25 | sent_normalized, sent_lens = [], [] 26 | 27 | def _pad(s, max_len, pad): 28 | s_ = list(s[: max_len]) 29 | lens = len(s_) 30 | for i in range(max_len - lens): 31 | s_.append(pad) 32 | return s_ 33 | 34 | for (s, t) in tqdm(sentences): 35 | s_ = [] 36 | s_.extend([word2id[w] if w in word2id else unk for w in s]) 37 | s_.append(end) 38 | slen = min(len(s) + 1, max_src_len) 39 | s_ = _pad(s_, max_src_len, pad) 40 | 41 | t_ = [start] 42 | t_.extend([word2id[w] if w in word2id else unk for w in t]) 43 | t_.append(end) 44 | tlen = min(len(t) + 1, max_tgt_len) 45 | t_ = _pad(t_, max_tgt_len, pad) 46 | 47 | sent_normalized.append((s_, t_)) 48 | sent_lens.append((slen, tlen)) 49 | 50 | return sent_normalized, sent_lens 51 | 52 | def corpus_statistics(sentences, vocab_size_threshold=5): 53 | """Calculate basic corpus statistics""" 54 | print("Calculating basic corpus statistics .. ") 55 | 56 | # sentence length 57 | sentence_lens = [] 58 | for s in sentences: sentence_lens.append(len(s)) 59 | sent_len_percentile = np.percentile(sentence_lens, [50, 80, 90, 95, 100]) 60 | print("sentence length percentile:") 61 | for i, percent in enumerate([50, 80, 90, 95, 100]): 62 | print('%d: %d' % (percent, sent_len_percentile[i])) 63 | 64 | # vocabulary 65 | vocab = [] 66 | for s in sentences: 67 | vocab.extend(s) 68 | vocab = Counter(vocab) 69 | print("vocabulary size: %d" % len(vocab)) 70 | for th in range(1, vocab_size_threshold + 1): 71 | vocab_truncate = [w for w in vocab if vocab[w] >= th] 72 | print("vocabulary size, occurance >= %d: %d" % (th, len(vocab_truncate))) 73 | return 74 | 75 | def get_vocab(training_set, word2id, id2word, vocab_size_threshold=3): 76 | """Get the vocabulary from the training set""" 77 | vocab = [] 78 | for s in training_set: 79 | vocab.extend(s) 80 | 81 | vocab = Counter(vocab) 82 | print('%d words in total' % len(vocab)) 83 | vocab_truncate = [w for w in vocab if vocab[w] >= vocab_size_threshold] 84 | 85 | i = len(word2id) 86 | for w in vocab_truncate: 87 | word2id[w] = i 88 | id2word[i] = w 89 | i += 1 90 | 91 | assert(len(word2id) == len(id2word)) 92 | print("vocabulary size: %d" % len(word2id)) 93 | return word2id, id2word -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | """The Pivot Analysis Configuration""" 2 | 3 | class Config(object): 4 | dataset_name = 'yelp' # ['yelp', 'amazon', 'caption', 'gender', 'paper', 5 | # 'politics', 'reddit', 'twitter'] 6 | 7 | model = 'cmu' # 'cmu', 'mit' 8 | 9 | show_statistics = True 10 | dataset_base_path = '../data/' 11 | output_path = '../outputs/' 12 | set_to_test = 'test' 13 | test_epoch = '' 14 | 15 | vocab_cnt_thres = 5 16 | vocab_size = -1 17 | pad = '_PAD' 18 | unk = '_UNK' 19 | 20 | is_bigram = False 21 | is_trigram = False 22 | filter_stop_words = 0 # do not filter stop words in building vocabulary 23 | 24 | pivot_thres_cnt = 1 # larger = higher confidence 25 | prec_thres = 0.7 # larger = higher confidence 26 | recl_thres = 0.0 # larger = higher confidence 27 | 28 | classifier = 'none' # 'cnn', 'fc' or 'none' 29 | max_training_case = 80000 30 | max_test_case = 10000 31 | num_epoch = 5 32 | batch_size = 200 33 | 34 | max_slen = {'yelp': 20, 35 | 'amazon': 30, 36 | 'caption': 30, 37 | 'gender': 30, 38 | 'paper': 30, 39 | 'politics': 30, 40 | 'reddit': 100, 41 | 'twitter': 35} 42 | style2id = {'yelp': {0: 'negative', 1: 'positive'}, 43 | 'amazon': {0: 'negative', 1: 'positive'}, 44 | 'caption': {0: 'humorous', 1: 'romantic'}, 45 | 'gender': {0: 'male', 1: 'female'}, 46 | 'paper': {0: 'academic', 1: 'journalism'}, 47 | 'politics': {0: 'democratic', 1: 'republican'}, 48 | 'reddit': {0: 'impolite', 1: 'polite'}, 49 | 'twitter': {0: 'impolite', 1: 'polite'}} 50 | id2style = {'yelp': {'negative': 0, 'positive': 1}, 51 | 'amazon': {'negative': 0, 'positive': 1}, 52 | 'caption': {'humorous': 0, 'romantic': 1}, 53 | 'gender': {'male': 0, 'female': 1}, 54 | 'paper': {'academic': 0, 'journalism': 1}, 55 | 'politics': {'democratic': 0, 'republican': 1}, 56 | 'reddit': {'impolite': 0, 'polite': 1}, 57 | 'twitter': {'impolite': 0, 'polite': 1}} 58 | 59 | 60 | def parse_arg(self, args): 61 | print(args) 62 | # dataset 63 | self.dataset_name = args.dataset 64 | self.model = args.model 65 | self.filter_stop_words = args.filter_stop_words 66 | self.set_to_test = args.set_to_test 67 | self.test_epoch = args.test_epoch 68 | if(args.bigram): self.is_bigram = True 69 | if(args.trigram): self.is_trigram = True 70 | if(args.vocab_cnt_thres != -1): self.vocab_cnt_thres = args.vocab_cnt_thres 71 | if(args.pivot_thres_cnt != -1): self.pivot_thres_cnt = args.pivot_thres_cnt 72 | if(args.prec_thres != -1): self.prec_thres = args.prec_thres 73 | if(args.recl_thres != -1): self.recl_thres = args.recl_thres 74 | if(args.max_training_case != -1): 75 | self.max_training_case = args.max_training_case 76 | if(args.num_epoch != -1): self.num_epoch = args.num_epoch 77 | 78 | 79 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # The Pivot Analysis 2 | 3 | The implementation of paper Yao fu, Hao Zhou, Jiaze Chen, and Lei Li, _Rethinking Text Attribute Transfer: A Lexical Analysis_. INLG 2019 (oral). [link](https://arxiv.org/abs/1909.12335) 4 | 5 | In this paper, we discuss the observation that in many text style transfer datasets and models, only a few style-related words are changed during the transfer process, while the higher-level sentence structures remain unchanged. E.g. to change a negetive sentence "The food is awful" in Yelp to positive, one only need to substitute the word "awful" -> "The food is awesome". 6 | 7 | example 9 | 10 | How can quantitatively identify, measure, and visualize the influence of these words? We propose three algorithms for this purpose: the **pivot word discovery, the pivot classifier**, and **the precision-recall histogram** algorithms. They are all implemented in this repo. 11 | 12 | We gather 8 major style-transfer dataset, standarlize them (so in your future work you could **use them from this repo with minimal modification** :), and analyze the pivot effects in these dataset. All analytical results from the paper can be reproduced and find out in the `outputs/` folder. 13 | 14 | 15 | ## Download the data 16 | The datasets used in the paper are: 17 | * yelp 18 | * amazon 19 | * caption 20 | * gender 21 | * paper 22 | * politics 23 | * reddit 24 | * twitter 25 | 26 | All organized as: train.0, train.1/ dev.0, dev.1/ test.0, test.1. Download from [here](https://drive.google.com/open?id=1ZtDIfHKc_GhNElRwHdDvk7tiCkv5_wJa) 27 | 28 | But note that the caption dataset does not have the right test data (because they made a mistake [in their release](https://github.com/lijuncen/Sentiment-and-Style-Transfer), the positive and negative sentences in the test set are the same). 29 | 30 | Other data are from the corresponding papers, with renaming and re-organization to fit our code. 31 | 32 | ## Run it 33 | 34 | ```bash 35 | mkdir outputs 36 | python main.py --dataset=yelp --pivot_thres_cnt=1 --prec_thres=0.5 --recl_thres=0.0 37 | ``` 38 | 39 | and the outputs would something like: 40 | 41 | ``` 42 | ... 43 | Pivot word discovery: 44 | class 0, 4929 pivots, pivot recall: 0.3348 45 | class 1, 4129 pivots, pivot recall: 0.3435 46 | ... 47 | Pivot classifier: 48 | train accuracy: 0.8401 49 | dev accuracy: 0.8313 50 | test accuracy: 0.8333 51 | ... 52 | output stored in 53 | ../outputs/yelp_1.pivot 54 | ``` 55 | 56 | Sample outputs 57 | ``` 58 | yelp_0.pivot: word/ precision/ recall (negative sentiment) 59 | sadly 0.9924 0.0002 60 | mistaken 0.7778 0.0000 61 | general 0.6285 0.0001 62 | run 0.6795 0.0003 63 | mill 0.6226 0.0000 64 | 65 | yelp_1.pivot: word/ precision/ recall (positive sentiment) 66 | hoagies 0.7903 0.0000 67 | italian 0.7029 0.0004 68 | ton 0.7260 0.0001 69 | really 0.5998 0.0038 70 | worthy 0.6548 0.0000 71 | 72 | yelp_0.sent: (pivot words are annotated with their precision) 73 | ok(0.927) never(0.897) going(0.680) back(0.616) to this place again . 74 | easter(0.786) day(0.502) nothing(0.918) open(0.516) , heard(0.778) about this place figured(0.781) it would ok(0.927) . 75 | 76 | yelp_1.sent: (pivot words are annotated with their precision) 77 | staff(0.791) behind the deli(0.696) counter were super(0.845) nice(0.907) and efficient(0.943) ! 78 | the staff(0.791) are always(0.918) very nice(0.907) and helpful(0.890) . 79 | ``` 80 | 81 | Parameters tunning: 82 | 83 | `prec_thres` gives the confidence of how a word may determine the classification. To find strong pivot words, increase this parameter (e.g. [0.7, 1.0]). To achieve better classification performance, decrease this parameter (e.g. [0.5, 0.7]) 84 | 85 | `recl_thres` and `pivot_thres_cnt` prevents overfitting on single words. To increase confidence of the pivot words, increase them; to increase classification performance, decrease them. 86 | 87 | 88 | ## Contact 89 | Yao Fu, yao.fu@columbia.edu 90 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """The Main Function 2 | 3 | Yao Fu, Columbia University 4 | yao.fu@columbia.edu 5 | Fri Jun 21st 2019 6 | """ 7 | 8 | import argparse 9 | import os 10 | import numpy as np 11 | 12 | from config import Config 13 | from data_utils import Dataset 14 | from classifiers import LogisticClassifier, FFClassifier, CNNClassifier 15 | 16 | def add_arguments(config): 17 | parser = argparse.ArgumentParser(description='Command line arguments') 18 | 19 | parser.add_argument('--dataset', type=str, default=config.dataset_name, 20 | help='The dataset') 21 | parser.add_argument('--model', type=str, default=config.model, 22 | help='The model') 23 | parser.add_argument('--set_to_test', type=str, default=config.set_to_test, 24 | help='The model') 25 | parser.add_argument('--test_epoch', type=str, default=config.test_epoch, 26 | help='The model') 27 | parser.add_argument('--bigram', action='store_true', 28 | help='If use bigram feature') 29 | parser.add_argument('--trigram', action='store_true', 30 | help='If use trigram feature') 31 | parser.add_argument('--vocab_cnt_thres', type=int, default=config.vocab_cnt_thres, 32 | help='The occurrence threshold of the vocabulary') 33 | parser.add_argument('--pivot_thres_cnt', type=float, default=config.pivot_thres_cnt, 34 | help='The occurrence threshold of pivot words') 35 | parser.add_argument('--prec_thres', type=float, default=config.prec_thres, 36 | help='The threshold of precision') 37 | parser.add_argument('--recl_thres', type=float, default=config.recl_thres, 38 | help='The threshold of recall') 39 | parser.add_argument('--filter_stop_words', type=int, default=config.filter_stop_words, 40 | help='If use stop words') 41 | parser.add_argument('--max_training_case', type=int, default=-1, 42 | help='The maximum cases used in training the classifier') 43 | parser.add_argument('--num_epoch', type=int, default=-1, 44 | help='The number of epoches') 45 | parser.add_argument('--classifier', type=str, default=config.classifier, 46 | help='The classifier') 47 | parser.add_argument('--gpu_id', type=str, default='0', 48 | help='The index of gpu') 49 | return parser.parse_args() 50 | 51 | def main(): 52 | config = Config() 53 | args = add_arguments(config) 54 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 55 | 56 | # sample the dataset 57 | config.parse_arg(args) 58 | dset = Dataset(config) 59 | print('------------------------------------------------------------') 60 | print('Pivot word discovery:') 61 | dset.build() 62 | config.vocab_size = len(dset.word2id) 63 | 64 | print('------------------------------------------------------------') 65 | print('Pivot classifier:') 66 | dset.classify() 67 | 68 | print('------------------------------------------------------------') 69 | print('Precision-recall histogram:') 70 | dset.get_prec_recl() 71 | 72 | print('------------------------------------------------------------') 73 | print('Storing the pivot outputs') 74 | dset.store_pivots() 75 | 76 | # the logistic classifier 77 | if(args.classifier == 'ff'): 78 | classifier = FFClassifier(config) 79 | x_train, y_train = dset.to_bow_numpy('train') 80 | classifier.train(x_train, y_train) 81 | 82 | x_dev, y_dev = dset.to_bow_numpy('dev') 83 | classifier.test(x_dev, y_dev) 84 | 85 | x_test, y_test = dset.to_bow_numpy('test') 86 | classifier.test(x_test, y_test) 87 | elif(args.classifier == 'cnn'): 88 | cnn = CNNClassifier(config) 89 | x_train, y_train = dset.to_sent_numpy('train') 90 | cnn.train(x_train, y_train) 91 | 92 | x_dev, y_dev = dset.to_sent_numpy('dev') 93 | cnn.test(x_dev, y_dev) 94 | 95 | x_test, y_test = dset.to_sent_numpy('test') 96 | cnn.test(x_test, y_test) 97 | else: 98 | pass 99 | 100 | # correlation between the pivot words and logistic classifier words 101 | return 102 | 103 | 104 | if __name__ == '__main__': 105 | main() -------------------------------------------------------------------------------- /data/stop_words.pkl: -------------------------------------------------------------------------------- 1 | c__builtin__ 2 | set 3 | p0 4 | ((lp1 5 | Vall 6 | p2 7 | aVshe'll 8 | p3 9 | aVdon't 10 | p4 11 | aVbeing 12 | p5 13 | aVwhen 14 | p6 15 | aVover 16 | p7 17 | aVisnt 18 | p8 19 | aVthrough 20 | p9 21 | aVyourselves 22 | p10 23 | aVits 24 | p11 25 | aVbefore 26 | p12 27 | aV$ 28 | p13 29 | aVhe's 30 | p14 31 | aVwhen's 32 | p15 33 | aV( 34 | p16 35 | aVhad 36 | p17 37 | aV, 38 | p18 39 | aVshould 40 | p19 41 | aVhe'd 42 | p20 43 | aVto 44 | p21 45 | aVonly 46 | p22 47 | aVthere's 48 | p23 49 | aVthose 50 | p24 51 | aVunder 52 | p25 53 | aVhas 54 | p26 55 | aV< 56 | p27 57 | aVhaven't 58 | p28 59 | aV@ 60 | p29 61 | aVthem 62 | p30 63 | aVhis 64 | p31 65 | aVreturn 66 | p32 67 | aVthey'll 68 | p33 69 | aVvery 70 | p34 71 | aVwho's 72 | p35 73 | aVthey'd 74 | p36 75 | aVcannot 76 | p37 77 | aVyou've 78 | p38 79 | aVthey 80 | p39 81 | aVwerent 82 | p40 83 | aVnot 84 | p41 85 | aVduring 86 | p42 87 | aVyourself 88 | p43 89 | aVhim 90 | p44 91 | aVnor 92 | p45 93 | aVwont 94 | p46 95 | aV` 96 | p47 97 | aVwe'll 98 | p48 99 | aVdid 100 | p49 101 | aVi'm 102 | p50 103 | aVthey've 104 | p51 105 | aVshant 106 | p52 107 | aVtheyre 108 | p53 109 | aVthis 110 | p54 111 | aV-lsb- 112 | p55 113 | aVshe 114 | p56 115 | aVeach 116 | p57 117 | aVwon't 118 | p58 119 | aVhavent 120 | p59 121 | aVwhere 122 | p60 123 | aVmustn't 124 | p61 125 | aVisn't 126 | p62 127 | aVi'll 128 | p63 129 | aVwhy's 130 | p64 131 | aVbecause 132 | p65 133 | aVyou'd 134 | p66 135 | aVdoing 136 | p67 137 | aVthere 138 | p68 139 | aVsome 140 | p69 141 | aVwhys 142 | p70 143 | aVwhens 144 | p71 145 | aVup 146 | p72 147 | aVare 148 | p73 149 | aVcant 150 | p74 151 | aVfurther 152 | p75 153 | aV-rsb- 154 | p76 155 | aVourselves 156 | p77 157 | aVout 158 | p78 159 | aV# 160 | p79 161 | aV' 162 | p80 163 | aVheres 164 | p81 165 | aV+ 166 | p82 167 | aV### 168 | p83 169 | aVwhile 170 | p84 171 | aVwasn't 172 | p85 173 | aVdoes 174 | p86 175 | aVshouldn't 176 | p87 177 | aVabove 178 | p88 179 | aVbetween 180 | p89 181 | aV; 182 | p90 183 | aVyoull 184 | p91 185 | aV? 186 | p92 187 | aVought 188 | p93 189 | aVbe 190 | p94 191 | aVwe 192 | p95 193 | aVwho 194 | p96 195 | aVdo 196 | p97 197 | aVyou're 198 | p98 199 | aVwere 200 | p99 201 | aVhere 202 | p100 203 | aVdidnt 204 | p101 205 | aVhadn't 206 | p102 207 | aV[ 208 | p103 209 | aVaren't 210 | p104 211 | aVby 212 | p105 213 | aVboth 214 | p106 215 | aVabout 216 | p107 217 | aVher 218 | p108 219 | aVwouldnt 220 | p109 221 | aVof 222 | p110 223 | aVcould 224 | p111 225 | aVhes 226 | p112 227 | aVagainst 228 | p113 229 | aVi'd 230 | p114 231 | aVweren't 232 | p115 233 | aVwe've 234 | p116 235 | aV{ 236 | p117 237 | aVtheres 238 | p118 239 | aVor 240 | p119 241 | aVcan't 242 | p120 243 | aVthats 244 | p121 245 | aV!! 246 | p122 247 | aVown 248 | p123 249 | aVwhats 250 | p124 251 | aVdont 252 | p125 253 | aVinto 254 | p126 255 | aVyoud 256 | p127 257 | aVwhom 258 | p128 259 | aVdown 260 | p129 261 | aVhers 262 | p130 263 | aVdoesnt 264 | p131 265 | aVcouldn't 266 | p132 267 | aVsuch 268 | p133 269 | aVyouv 270 | p134 271 | aVcouldnt 272 | p135 273 | aVwhos 274 | p136 275 | aVyour 276 | p137 277 | aV!? 278 | p138 279 | aVdoesn't 280 | p139 281 | aVhe 282 | p140 283 | aV" 284 | p141 285 | aVfrom 286 | p142 287 | aVhow's 288 | p143 289 | aV.. 290 | p144 291 | aV& 292 | p145 293 | aVit's 294 | p146 295 | aV* 296 | p147 297 | aVlets 298 | p148 299 | aVbeen 300 | p149 301 | aV. 302 | p150 303 | aVfew 304 | p151 305 | aVtoo 306 | p152 307 | aVthen 308 | p153 309 | aVthemselves 310 | p154 311 | aV: 312 | p155 313 | aVwas 314 | p156 315 | aVuntil 316 | p157 317 | aV> 318 | p158 319 | aV`` 320 | p159 321 | aVhimself 322 | p160 323 | aVwhere's 324 | p161 325 | aVi've 326 | p162 327 | aVwith 328 | p163 329 | aVdidn't 330 | p164 331 | aVwhat's 332 | p165 333 | aVtheyll 334 | p166 335 | aVbut 336 | p167 337 | aV... 338 | p168 339 | aVhadnt 340 | p169 341 | aV-lrb- 342 | p170 343 | aVmustnt 344 | p171 345 | aVherself 346 | p172 347 | aVthan 348 | p173 349 | aVhere's 350 | p174 351 | aV^ 352 | p175 353 | aVme 354 | p176 355 | aVthey're 356 | p177 357 | aVmyself 358 | p178 359 | aVtheyve 360 | p179 361 | aVthese 362 | p180 363 | aVhasn't 364 | p181 365 | aVshes 366 | p182 367 | aVbelow 368 | p183 369 | aVcan 370 | p184 371 | aVtheirs 372 | p185 373 | aVmore 374 | p186 375 | aVmy 376 | p187 377 | aVwouldn't 378 | p188 379 | aVwe'd 380 | p189 381 | aVand 382 | p190 383 | aVwould 384 | p191 385 | aV-rrb- 386 | p192 387 | aVwasnt 388 | p193 389 | aVis 390 | p194 391 | aVam 392 | p195 393 | aVit 394 | p196 395 | aVan 396 | p197 397 | aV'' 398 | p198 399 | aVas 400 | p199 401 | aVitself 402 | p200 403 | aVim 404 | p201 405 | aVat 406 | p202 407 | aVhave 408 | p203 409 | aVin 410 | p204 411 | aVany 412 | p205 413 | aVif 414 | p206 415 | aV! 416 | p207 417 | aVagain 418 | p208 419 | aVhasnt 420 | p209 421 | aV% 422 | p210 423 | aVno 424 | p211 425 | aV) 426 | p212 427 | aVthat 428 | p213 429 | aV- 430 | p214 431 | aVsame 432 | p215 433 | aVhow 434 | p216 435 | aVother 436 | p217 437 | aVwhich 438 | p218 439 | aVwheres 440 | p219 441 | aVyou 442 | p220 443 | aVshan't 444 | p221 445 | aVarent 446 | p222 447 | aV?? 448 | p223 449 | aVshouldnt 450 | p224 451 | aV's 452 | p225 453 | aVour 454 | p226 455 | aVwhy 456 | p227 457 | aVafter 458 | p228 459 | aVwhat 460 | p229 461 | aVlet's 462 | p230 463 | aVmost 464 | p231 465 | aVours 466 | p232 467 | aV'll 468 | p233 469 | aV'm 470 | p234 471 | aVon 472 | p235 473 | aV] 474 | p236 475 | aVhe'll 476 | p237 477 | aV?! 478 | p238 479 | aVa 480 | p239 481 | aVhows 482 | p240 483 | aVoff 484 | p241 485 | aVfor 486 | p242 487 | aVi 488 | p243 489 | aVyoure 490 | p244 491 | aVshe'd 492 | p245 493 | aVyours 494 | p246 495 | aVtheir 496 | p247 497 | aVyou'll 498 | p248 499 | aVso 500 | p249 501 | aVwe're 502 | p250 503 | aVshe's 504 | p251 505 | aVthe 506 | p252 507 | aVthat's 508 | p253 509 | aV} 510 | p254 511 | aVhaving 512 | p255 513 | aVonce 514 | p256 515 | atp257 516 | Rp258 517 | . -------------------------------------------------------------------------------- /src/classifiers.py: -------------------------------------------------------------------------------- 1 | """The Logistic Baseline Calssifier for Pivot Analysis 2 | 3 | Yao Fu, Columbia University 4 | yao.fu@columbia.edu 5 | Sat Jun 22nd 2019 6 | """ 7 | 8 | import numpy as np 9 | import keras.layers as layers 10 | 11 | from sklearn.linear_model import LogisticRegression 12 | from keras.models import Sequential 13 | from keras.layers import Dense, Dropout, Activation, Reshape 14 | from keras.layers import Embedding 15 | from keras.layers import Input, Flatten 16 | from keras.layers import Conv1D, MaxPooling1D, Conv2D, MaxPooling2D, concatenate 17 | from keras import optimizers 18 | from keras.models import Model 19 | 20 | 21 | class LogisticClassifier(object): 22 | """The logistic classifier, the baseline model""" 23 | 24 | def __init__(self, config): 25 | self.model = LogisticRegression(solver='lbfgs') 26 | 27 | self.max_training_case = config.max_training_case 28 | return 29 | 30 | def train(self, train_data, labels): 31 | if(len(train_data) >= self.max_training_case): 32 | sample_id = np.random.choice( 33 | len(train_data), self.max_training_case, False) 34 | train_data = train_data[sample_id] 35 | train_labels = train_labels[sample_id] 36 | self.model.fit(train_data, labels) 37 | return 38 | 39 | def test(self, test_data, labels): 40 | pred = self.model.predict(test_data) 41 | acc = np.sum(pred == labels) / float(len(labels)) 42 | print('accuracy: %.4f' % acc) 43 | return 44 | 45 | def word_saliency(self): 46 | """Word saliency analysis, output strong words, """ 47 | return 48 | 49 | class CNNClassifier(object): 50 | def __init__(self, config): 51 | """ 52 | Convolution neural network model for sentence classification. 53 | Parameters 54 | Sentence CNN by Y.Kim 55 | ---------- 56 | EMBEDDING_DIM: Dimension of the embedding space. 57 | MAX_SEQUENCE_LENGTH: Maximum length of the sentence. 58 | MAX_NB_WORDS: Maximum number of words in the vocabulary. 59 | embeddings_index: A dict containing words and their embeddings. 60 | word_index: A dict containing words and their indices. 61 | labels_index: A dict containing the labels and their indices. 62 | Returns 63 | ------- 64 | compiled keras model 65 | """ 66 | self.batch_size = config.batch_size 67 | self.num_epoch = config.num_epoch 68 | 69 | EMBEDDING_DIM = 300 70 | MAX_SEQUENCE_LENGTH = config.max_slen[config.dataset_name] 71 | # embedding_matrix = np.zeros((config.vocab_size, EMBEDDING_DIM)) 72 | embedding_layer = Embedding(config.vocab_size, 73 | EMBEDDING_DIM, 74 | input_length=MAX_SEQUENCE_LENGTH, 75 | trainable=True) 76 | 77 | sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32') 78 | embedded_sequences = embedding_layer(sequence_input) 79 | 80 | # add first conv filter 81 | embedded_sequences = Reshape( 82 | (MAX_SEQUENCE_LENGTH, EMBEDDING_DIM, 1))(embedded_sequences) 83 | 84 | x = Conv2D(100, (5, EMBEDDING_DIM), activation='relu')(embedded_sequences) 85 | x = MaxPooling2D((MAX_SEQUENCE_LENGTH - 5 + 1, 1))(x) 86 | 87 | # add second conv filter. 88 | y = Conv2D(100, (4, EMBEDDING_DIM), activation='relu')(embedded_sequences) 89 | y = MaxPooling2D((MAX_SEQUENCE_LENGTH - 4 + 1, 1))(y) 90 | 91 | # add third conv filter. 92 | z = Conv2D(100, (3, EMBEDDING_DIM), activation='relu')(embedded_sequences) 93 | z = MaxPooling2D((MAX_SEQUENCE_LENGTH - 3 + 1, 1))(z) 94 | 95 | # concate the conv layers 96 | alpha = concatenate([x,y,z]) 97 | # flatted the pooled features. 98 | alpha = Flatten()(alpha) 99 | 100 | # dropout 101 | alpha = Dropout(0.5)(alpha) 102 | # predictions 103 | preds = Dense(1, activation='sigmoid')(alpha) 104 | 105 | # build model 106 | model = Model(sequence_input, preds) 107 | opt = optimizers.Adam(lr=0.0001) 108 | 109 | model.compile(loss='binary_crossentropy', 110 | optimizer=opt, 111 | metrics=['acc']) 112 | 113 | self.model = model 114 | return 115 | 116 | def train(self, train_data, train_labels): 117 | history = self.model.fit(train_data, train_labels, batch_size=self.batch_size, 118 | epochs=self.num_epoch, verbose=1, validation_split=0.1) 119 | return 120 | 121 | def test(self, test_data, labels): 122 | score = self.model.evaluate(test_data, labels, verbose=1) 123 | acc = score[1] 124 | print('accuracy: %.4f' % acc) 125 | return 126 | 127 | class FFClassifier(object): 128 | """The CNN classifier""" 129 | 130 | def __init__(self, config): 131 | self.num_epoch = config.num_epoch 132 | self.batch_size = config.batch_size 133 | self.max_training_case = config.max_training_case 134 | 135 | model = Sequential() 136 | # model.add(Dense(200, input_shape=(config.vocab_size, ))) 137 | # model.add(Activation('relu')) 138 | # model.add(Dropout(0.2)) 139 | model.add(Dense(1, input_shape=(config.vocab_size, ))) 140 | model.add(Activation('sigmoid')) 141 | model.compile( 142 | loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 143 | 144 | self.model = model 145 | return 146 | 147 | def train(self, train_data, train_labels): 148 | if(len(train_data) >= self.max_training_case): 149 | sample_id = np.random.choice( 150 | len(train_data), self.max_training_case, False) 151 | train_data = train_data[sample_id] 152 | train_labels = train_labels[sample_id] 153 | history = self.model.fit(train_data, train_labels, batch_size=self.batch_size, 154 | epochs=self.num_epoch, verbose=1, validation_split=0.1) 155 | return 156 | 157 | def test(self, test_data, labels): 158 | score = self.model.evaluate(test_data, labels, verbose=1) 159 | acc = score[1] 160 | print('accuracy: %.4f' % acc) 161 | return 162 | 163 | 164 | -------------------------------------------------------------------------------- /src/model_analysis.py: -------------------------------------------------------------------------------- 1 | """Model output analysis 2 | 3 | Yao Fu, Columbia University 4 | yao.fu@columbia.edu 5 | Sat Jul 06th 2019 6 | """ 7 | 8 | import argparse 9 | import os 10 | import numpy as np 11 | 12 | from config import Config 13 | from data_utils import Dataset 14 | from main import add_arguments 15 | from collections import Counter 16 | from editdistance import eval as editdist 17 | 18 | def _sent_raw_to_id(sentences, word2id, setname): 19 | sentences_ = [] 20 | num_unk = 0 21 | total_words = 0 22 | unknown_words = [] 23 | for s in sentences: 24 | s_ = [] 25 | for w in s: 26 | if(w in word2id): s_.append(word2id[w]) 27 | else: 28 | num_unk += 1 29 | s_.append(word2id['_UNK']) 30 | unknown_words.append(w) 31 | total_words += 1 32 | sentences_.append(s_) 33 | ratio = float(num_unk) / total_words 34 | print('%s,%d unk words, %d total, %.4f ratio' % 35 | (setname, num_unk, total_words, ratio)) 36 | unknown_words = Counter(unknown_words) 37 | with open(setname + '_unknown_words.txt', 'w') as fd: 38 | for w, c in unknown_words.most_common(): fd.write('%s %d\n' % (w, c)) 39 | return sentences_ 40 | 41 | def _transfer_in_pivot(src, tsf, pivots): 42 | total_modified = [] 43 | total_in_pivot = [] 44 | sent_lens = [] 45 | for s, t in zip(src, tsf): 46 | s_ = set(s) 47 | t_ = set(t) 48 | sent_lens.append(len(s)) 49 | modified = (s_ | t_) - (s_ & t_) 50 | num_modified = len(modified) 51 | num_in_pivot = len(modified & pivots) 52 | total_modified.append(num_modified) 53 | total_in_pivot.append(num_in_pivot) 54 | total_modified = np.sum(total_modified) 55 | total_in_pivot = np.sum(total_in_pivot) 56 | avg_modified = total_modified / float(2 * len(src)) 57 | ratio = float(total_in_pivot) / total_modified 58 | avg_sent_lens = np.average(sent_lens) 59 | return total_modified, total_in_pivot, avg_modified, ratio, avg_sent_lens 60 | 61 | def _format_sentence(s, id2word, pivots): 62 | s_ = [] 63 | for w in s: 64 | if(w in pivots[0]): s_.append('[0 ' + id2word[w] + ']') 65 | elif(w in pivots[1]): s_.append('[1 ' + id2word[w] + ']') 66 | else: 67 | # print(type(w)) 68 | s_.append(id2word[w]) 69 | return ' '.join(s_) 70 | 71 | def _masked_edit_dist(src, tsf, pivots, id2word, output_path): 72 | distances = [] 73 | i = 0 74 | print('output write to:\n%s' % output_path) 75 | # print(output_path) 76 | fd = open(output_path, 'w') 77 | for s, t in zip(src, tsf): 78 | s_ = set(s) 79 | t_ = set(t) 80 | modified = (s_ | t_) - (s_ & t_) 81 | # s_masked = [w if w not in modified else 0 for w in s] # 0 = '_PAD' 82 | # t_masked = [w if w not in modified else 0 for w in t] # 0 = '_PAD' 83 | 84 | pivot_set = pivots[0] | pivots[1] 85 | s_masked = [w if w not in pivot_set else 0 for w in s] # 0 = '_PAD' 86 | t_masked = [w if w not in pivot_set else 0 for w in t] # 0 = '_PAD' 87 | s_masked_ = ' '.join([str(w) for w in s_masked]) 88 | t_masked_ = ' '.join([str(w) for w in t_masked]) 89 | ed = editdist(s_masked_, t_masked_) 90 | distances.append(ed) 91 | 92 | fd.write('s: %s\n' % _format_sentence(s, id2word, pivots)) 93 | fd.write('t: %s\n' % _format_sentence(t, id2word, pivots)) 94 | # debug 95 | # if(i < 5): 96 | # print('modified:', [id2word[w] for w in modified]) 97 | # print('s:', _format_sentence(s, id2word, pivots)) 98 | # print('t:', _format_sentence(t, id2word, pivots)) 99 | # print('s_masked:', _format_sentence(s_masked, id2word, pivots)) 100 | # print('t_masked:', _format_sentence(t_masked, id2word, pivots)) 101 | # print('ed %d' % ed) 102 | # i += 1 103 | avg_dist = np.average(distances) 104 | distances = Counter(distances) 105 | 106 | dist_distribution = np.zeros(8) 107 | for i in range(8): 108 | if(i < 7): dist_distribution[i] = float(distances[i]) / len(src) 109 | else: dist_distribution[i] = 1 - dist_distribution[: i].sum() 110 | return avg_dist, distances, dist_distribution 111 | 112 | class PivotTransferAnalysis(object): 113 | """Pivot analysis of the transfered dataset""" 114 | 115 | def __init__(self, config): 116 | self.data_base_path = config.dataset_base_path + config.dataset_name +\ 117 | '_transfer/' + config.model + '/' + config.set_to_test + '.' 118 | if(config.model == 'cmu'): 119 | if(config.test_epoch != ''): 120 | self.data_base_path = self.data_base_path + config.test_epoch + '.' 121 | if(config.model == 'mit'): 122 | self.data_base_path_tsf = self.data_base_path + 'epoch' + config.test_epoch + '.' 123 | self.output_path = config.output_path + 'transfer/' + config.model +\ 124 | '_' + config.dataset_name + '_transfer' 125 | return 126 | 127 | def pipeline_w_cmu(self, dset): 128 | src = open(self.data_base_path + 'src').readlines() 129 | src = [s.split() for s in src] 130 | src = _sent_raw_to_id(src, dset.word2id, 'src') 131 | 132 | tsf = open(self.data_base_path + 'tsf').readlines() 133 | tsf = [s.split() for s in tsf] 134 | tsf = _sent_raw_to_id(tsf, dset.word2id, 'tsf') 135 | 136 | pivot_words = set(dset.pivot_words[0]) | set(dset.pivot_words[1]) 137 | 138 | modified, in_pivot, avg_modified, ratio, avg_sent_len =\ 139 | _transfer_in_pivot(src, tsf, pivot_words) 140 | print('%d modified, %d in pivot' % (modified, in_pivot)) 141 | print('%.2f avg sentence length %.2f average modified, %.4f ratio' % 142 | (avg_sent_len, avg_modified, ratio)) 143 | 144 | pivot_words_class = [set(dset.pivot_words[0]), set(dset.pivot_words[1])] 145 | avg_dist, distances, dist_distribution =\ 146 | _masked_edit_dist(src, tsf, pivot_words_class, dset.id2word, self.output_path) 147 | print('%d different distances in total, avg %.2f' % 148 | (len(distances), avg_dist)) 149 | print('distribution:', np.sum(dist_distribution)) 150 | for i, di in enumerate(dist_distribution): 151 | print('%d: %.4f' % (i, di)) 152 | print(distances.most_common(10)) 153 | return 154 | 155 | def pipeline(self, dset): 156 | """Pivot analysis pipeline """ 157 | print('reading data from:\n %s' % self.data_base_path) 158 | # Read the transfered sentences 159 | neg_src = open(self.data_base_path + '0.src').readlines() 160 | neg_src = [s.split() for s in neg_src] 161 | neg_src = _sent_raw_to_id(neg_src, dset.word2id, 'neg_src') 162 | 163 | neg_tsf = open(self.data_base_path_tsf + '0.tsf').readlines() 164 | neg_tsf = [s.split() for s in neg_tsf] 165 | neg_tsf = _sent_raw_to_id(neg_tsf, dset.word2id, 'neg_tsf') 166 | 167 | pos_src = open(self.data_base_path + '1.src').readlines() 168 | pos_src = [s.split() for s in pos_src] 169 | pos_src = _sent_raw_to_id(pos_src, dset.word2id, 'pos_src') 170 | 171 | pos_tsf = open(self.data_base_path_tsf + '1.tsf').readlines() 172 | pos_tsf = [s.split() for s in pos_tsf] 173 | pos_tsf = _sent_raw_to_id(pos_tsf, dset.word2id, 'pos_tsf') 174 | 175 | # calculate how many modified words are pivots 176 | pivot_words = set(dset.pivot_words[0]) | set(dset.pivot_words[1]) 177 | 178 | # neg_modified, neg_in_pivot, neg_avg_modified, neg_ratio =\ 179 | # _transfer_in_pivot(neg_src, neg_tsf, pivot_words) 180 | # print('neg to pos, %d modified, %d in pivot' % (neg_modified, neg_in_pivot)) 181 | # print('%.2f average modified, %.4f ratio' % (neg_avg_modified, neg_ratio)) 182 | 183 | # pos_modified, pos_in_pivot, pos_avg_modified, pos_ratio =\ 184 | # _transfer_in_pivot(pos_src, pos_tsf, pivot_words) 185 | # print('pos to neg, %d modified, %d in pivot' % (pos_modified, pos_in_pivot)) 186 | # print('%.2f average modified, %.4f ratio' % (pos_avg_modified, pos_ratio)) 187 | 188 | modified, in_pivot, avg_modified, ratio, avg_sent_lens =\ 189 | _transfer_in_pivot(neg_src + pos_src, neg_tsf + pos_tsf, pivot_words) 190 | print('%d modified, %d in pivot' % (modified, in_pivot)) 191 | print('%.2f avg len, %.2f average modified, %.4f in pivots' % 192 | (avg_sent_lens, avg_modified, ratio)) 193 | 194 | # mask the modified words, calculate the sentence distances 195 | pivot_words_class = [set(dset.pivot_words[0]), set(dset.pivot_words[1])] 196 | # avg_dist, distances, _ = _masked_edit_dist( 197 | # neg_src, neg_tsf, pivot_words_class, dset.id2word, self.output_path) 198 | # print('neg to pos, %d different distances in total, avg %.2f' % 199 | # (len(distances), avg_dist)) 200 | # print(distances.most_common(10)) 201 | 202 | # avg_dist, distances, _ = _masked_edit_dist( 203 | # pos_src, pos_tsf, pivot_words_class, dset.id2word, self.output_path) 204 | # print('pos to neg, %d different distances in total, avg %.2f' % 205 | # (len(distances), avg_dist)) 206 | # print(distances.most_common(10)) 207 | 208 | avg_dist, distances, dist_distribution = _masked_edit_dist( 209 | neg_src + pos_src, neg_tsf + pos_tsf, pivot_words_class, dset.id2word, self.output_path) 210 | print('%d different distances in total, avg %.2f' % 211 | (len(distances), avg_dist)) 212 | print('distribution:', np.sum(dist_distribution)) 213 | for i, di in enumerate(dist_distribution): 214 | print('%d: %.4f' % (i, di)) 215 | print(distances.most_common(10)) 216 | 217 | # mask the pivot words, calculate the sentence distances 218 | return 219 | 220 | def main(): 221 | config = Config() 222 | args = add_arguments(config) 223 | config.parse_arg(args) 224 | dset = Dataset(config) 225 | dset.build() 226 | # print('debug:') 227 | # print(dset.id2word[1]) 228 | config.vocab_size = len(dset.word2id) 229 | 230 | # read the transfered sentences 231 | transfer_analysis = PivotTransferAnalysis(config) 232 | 233 | if(config.model == 'cmu'): 234 | transfer_analysis.pipeline_w_cmu(dset) 235 | else: 236 | transfer_analysis.pipeline(dset) 237 | return 238 | 239 | if __name__ == '__main__': 240 | main() -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | """Unified Dataset Utilities for Pivot Analysis 2 | 3 | Yao Fu, Columbia University 4 | yao.fu@columbia.edu 5 | Fri Jun 21st 2019 6 | """ 7 | 8 | import pickle 9 | import time 10 | import tqdm 11 | import numpy as np 12 | from nltk.corpus import stopwords 13 | from pivot_classifier import PivotClassifier 14 | from keras.preprocessing.sequence import pad_sequences 15 | from keras.utils.np_utils import to_categorical 16 | 17 | 18 | # STOPWORDS_PATH = '/home/francis/hdd/pivot_analysis/data/stop_words.pkl' 19 | STOPWORDS_PATH = '../data/stop_words.pkl' 20 | STOPWORDS = set(pickle.load(open(STOPWORDS_PATH, 'rb'))) 21 | STOPWORDS = STOPWORDS | set(stopwords.words('english')) 22 | 23 | def upsample(sentences, label_id, target_size): 24 | """Upsample the sentences to the target size 25 | sample size = target size - original size 26 | Sample [sample size] sentences from the given sentence set. 27 | Also modify the labels. 28 | We assume all sentences are of the sample label 29 | 30 | Args: 31 | sentences: a list of the sentences to be upsampled 32 | label_id: the label of the sentences, an integer 33 | target_size: the size to sample to, an integer 34 | 35 | Returns: 36 | sampled_sentences: a list of sampled sentences 37 | labels: a list of extended labels 38 | """ 39 | sample_size = target_size - len(sentences) 40 | sample_id = np.random.choice(len(sentences), sample_size) 41 | 42 | sampled_sentences = list(sentences) 43 | for i in sample_id: sampled_sentences.append(sentences[i]) 44 | labels = [label_id] * len(sampled_sentences) 45 | return sampled_sentences, labels 46 | 47 | def downsample(sentences, label_id, target_size): 48 | """downsample the sentences to the target size 49 | Also modify the labels. 50 | We assume all sentences are of the sample label 51 | 52 | Args: 53 | sentences: a list of the sentences to be upsampled 54 | label_id: the label of the sentences, an integer 55 | target_size: the size to sample to, an integer 56 | 57 | Returns: 58 | sampled_sentences: a list of sampled sentences 59 | labels: a list of extended labels 60 | """ 61 | sample_id = np.random.choice(len(sentences), target_size, False) 62 | 63 | sampled_sentences = [] 64 | for i in sample_id: sampled_sentences.append(sentences[i]) 65 | labels = [label_id] * len(sampled_sentences) 66 | return sampled_sentences, labels 67 | 68 | def read_data(dataset, setname, base_path, balance_method='upsample'): 69 | """Read the data 70 | 71 | Args: 72 | dataset: the name of the dataset 73 | setname: 'train', 'dev' or 'test' 74 | base_path: the base path of the datasets 75 | balance_method: 'upsample', 'downsample' 76 | 77 | Returns: 78 | sentences: a list of sentences. A sentence is a list of words. 79 | labels: a list of labels 80 | """ 81 | print('Reading the %s dataset, %s .. ' % (dataset, setname)) 82 | neg_path = base_path + dataset + '/' + setname + '.0' 83 | pos_path = base_path + dataset + '/' + setname + '.1' 84 | 85 | def _read_file_lines(f_path): 86 | with open(f_path, errors='ignore') as fd: 87 | lines = fd.readlines() 88 | lines_ = [] 89 | for l in lines: 90 | s = l.split() 91 | if(len(s) > 0): lines_.append(s) 92 | return lines_ 93 | 94 | neg_sentences = _read_file_lines(neg_path) 95 | pos_sentences = _read_file_lines(pos_path) 96 | 97 | neg_sent_num = len(neg_sentences) 98 | pos_sent_num = len(pos_sentences) 99 | print('neg sentence num: %d, pos num: %d' % (neg_sent_num, pos_sent_num)) 100 | 101 | if(balance_method == 'upsample'): 102 | if(neg_sent_num < pos_sent_num): 103 | neg_sentences, neg_labels = upsample(neg_sentences, 0, pos_sent_num) 104 | pos_labels = [1] * pos_sent_num 105 | else: 106 | pos_sentences, pos_labels = upsample(pos_sentences, 1, neg_sent_num) 107 | neg_labels = [0] * neg_sent_num 108 | else: 109 | if(neg_sent_num < pos_sent_num): 110 | pos_sentences, pos_labels = downsample(pos_sentences, 1, neg_sent_num) 111 | neg_labels = [0] * neg_sent_num 112 | else: 113 | neg_sentences, neg_labels = downsample(neg_sentences, 0, pos_sent_num) 114 | pos_labels = [1] * pos_sent_num 115 | 116 | sentences = neg_sentences + pos_sentences 117 | labels = np.array(neg_labels + pos_labels) 118 | return sentences, labels 119 | 120 | 121 | def build_vocab(sentences, 122 | is_bigram=False, is_trigram=False, cnt_threshold=5, filter_stop_words=1): 123 | """Build the vocabulary, bigram, and trigram 124 | 125 | Returns: 126 | unigram, bigram, trigram to id, and id to them 127 | 128 | Note: 129 | the sentences is also padded here if its length is less than 3 130 | """ 131 | print("Building the vocabulary ..., filter_stop_words = %d" % filter_stop_words) 132 | start_time = time.time() 133 | unigram_cnt = dict() 134 | bigram_cnt = dict() 135 | trigram_cnt = dict() 136 | 137 | unigram2id = dict() 138 | bigram2id = dict() 139 | trigram2id = dict() 140 | id2unigram = dict() 141 | id2bigram = dict() 142 | id2trigram = dict() 143 | stop_words = STOPWORDS 144 | 145 | for s in sentences: 146 | slen = len(s) 147 | if(slen < 3): 148 | for i in range(3 - slen): s.append("_PAD") 149 | slen = 3 150 | for w in s: 151 | if(w not in unigram_cnt): unigram_cnt[w] = 1 152 | else: unigram_cnt[w] += 1 153 | if(is_bigram): 154 | for i in range(slen - 1): 155 | bigram = s[i] + " " + s[i + 1] 156 | if(bigram not in bigram_cnt): bigram_cnt[bigram] = 1 157 | else: bigram_cnt[bigram] += 1 158 | if(is_trigram): 159 | for i in range(slen - 2): 160 | trigram = s[i] + " " + s[i + 1] + " " + s[i + 2] 161 | if(trigram not in trigram_cnt): trigram_cnt[trigram] = 1 162 | else: trigram_cnt[trigram] += 1 163 | 164 | num_unigram = 2 165 | unigram2id['_PAD'] = 0 166 | id2unigram[0] = '_PAD' 167 | unigram2id['_UNK'] = 1 168 | id2unigram[1] = '_UNK' 169 | # num_unigram = 0 170 | for unigram in unigram_cnt: 171 | # if(filter_stop_words == 1 and unigram in stop_words): continue 172 | # if(unigram in stop_words): continue 173 | if(unigram == "_PAD"): continue 174 | if(unigram_cnt[unigram] < cnt_threshold): continue 175 | unigram2id[unigram] = num_unigram 176 | id2unigram[num_unigram] = unigram 177 | num_unigram += 1 178 | 179 | num_bigram = 0 180 | if(is_bigram): 181 | for bigram in bigram_cnt: 182 | if(bigram == "_PAD _PAD"): continue 183 | if(bigram_cnt[bigram] < cnt_threshold): continue 184 | bigram2id[bigram] = num_bigram 185 | id2bigram[num_bigram] = bigram 186 | num_bigram += 1 187 | else: bigram2id, id2bigram = None, None 188 | 189 | num_trigram = 0 190 | if(is_trigram): 191 | for trigram in trigram_cnt: 192 | if(trigram_cnt[trigram] < cnt_threshold): continue 193 | trigram2id[trigram] = num_trigram 194 | id2trigram[num_trigram] = trigram 195 | num_trigram += 1 196 | else: trigram2id, id2trigram = None, None 197 | 198 | print("%d unigram, %d bigram, %d trigrams in total" % 199 | (num_unigram, num_bigram, num_trigram)) 200 | print("%.2s seconds cost" % (time.time() - start_time)) 201 | return unigram2id, id2unigram, bigram2id, id2bigram, trigram2id, id2trigram 202 | 203 | def build_style_word_sent(sentences, labels, unigram2id, bigram2id, trigram2id, 204 | max_slen, id2style): 205 | """Build the style-word set, style-sentence set 206 | 207 | Args: 208 | sentences: the sentence set, a list of sentences, 209 | a sentence is a list of words 210 | labels: the labels, a list of integers 211 | unigram2id: word to index dictionary 212 | bigram2id: bigram to index dictionary 213 | trigram2id: trigram to index dictionary 214 | max_slen: maximum sentence length. A sentence is a bag of words 215 | id2style: the index to style dictionary 216 | 217 | Returns: 218 | style_sent_unigram: the style-sentence distribution. 219 | A sentence is a bag of unigram 220 | style_sent_bigram: the style-sentence(bigram) distribution 221 | style_sent_trigram: the style-sentence(trigram) distribution 222 | style_unigram: the style-unigram distribution, a 2 * vocab_size matrix 223 | style_bigram: the style-bigram distribution 224 | style_trigram: the style-trigram distribution 225 | """ 226 | print("Building the style-related distributions ... ") 227 | 228 | start_time = time.time() 229 | 230 | style_sent_unigram = [[] for _ in id2style] 231 | style_sent_bigram = [[] for _ in id2style] 232 | style_sent_trigram = [[] for _ in id2style] 233 | 234 | style_unigram = np.zeros([len(id2style), len(unigram2id)]) 235 | 236 | if(bigram2id is not None): 237 | style_bigram = np.zeros([len(id2style), len(bigram2id)]) 238 | else: style_bigram = None 239 | 240 | if(trigram2id is not None): 241 | style_trigram = np.zeros([len(id2style), len(trigram2id)]) 242 | else: style_trigram = None 243 | 244 | num_sentences = len(labels) 245 | 246 | for i in tqdm.tqdm(range(num_sentences)): 247 | s = sentences[i] 248 | lb = labels[i] 249 | 250 | s_unigram = [] 251 | s_bigram = [] 252 | s_trigram = [] 253 | slen = len(s) 254 | for j in range(slen): 255 | w = s[j] 256 | if(w in unigram2id): 257 | wid = unigram2id[w] 258 | s_unigram.append(wid) 259 | style_unigram[lb][wid] += 1 260 | s_unigram = set(s_unigram) 261 | 262 | if(bigram2id is not None): 263 | for j in range(slen - 1): 264 | bigram = s[j] + " " + s[j + 1] 265 | if(bigram in bigram2id): 266 | bid = bigram2id[bigram] 267 | s_bigram.append(bid) 268 | style_bigram[lb][bid] += 1 269 | s_bigram = set(s_bigram) 270 | 271 | if(trigram2id is not None): 272 | for j in range(slen - 2): 273 | trigram = s[j] + " " + s[j + 1] + " " + s[j + 2] 274 | if(trigram in trigram2id): 275 | tid = trigram2id[trigram] 276 | s_trigram.append(tid) 277 | style_trigram[lb][tid] += 1 278 | s_trigram = set(s_trigram) 279 | 280 | style_sent_unigram[lb].append(s_unigram) 281 | style_sent_bigram[lb].append(s_bigram) 282 | style_sent_trigram[lb].append(s_trigram) 283 | 284 | # pad sentence bag-of-words to maximum length 285 | style_sent_unigram = _set_to_array( 286 | style_sent_unigram, num_sentences, max_slen, len(id2style)) 287 | 288 | if(bigram2id is not None): 289 | style_sent_bigram = _set_to_array( 290 | style_sent_bigram, num_sentences, max_slen, len(id2style)) 291 | else: style_sent_bigram = None 292 | 293 | if(trigram2id is not None): 294 | style_sent_trigram = _set_to_array( 295 | style_sent_trigram, num_sentences, max_slen, len(id2style)) 296 | else: style_sent_trigram = None 297 | 298 | print("%.2s seconds cost" % (time.time() - start_time)) 299 | # print(style_unigram[0][0], style_unigram[1][0]) 300 | return (style_sent_unigram, style_sent_bigram, style_sent_trigram, 301 | style_unigram, style_bigram, style_trigram) 302 | 303 | def _set_to_array(style_sent, num_sentences, max_slen, num_style): 304 | sent_array = np.zeros([num_sentences, max_slen]).astype(np.int) - 1 305 | sid = 0 306 | for i in range(num_style): 307 | # print(len(style_sent[i])) 308 | for s in style_sent[i]: 309 | for wi, w in enumerate(s): 310 | sent_array[sid][wi] = w 311 | if(wi == max_slen - 1): break 312 | sid += 1 313 | assert(sid == num_sentences) 314 | return sent_array 315 | 316 | def filter_dist(P, id2word, threshold=20): 317 | """filter the numbers less than [threshold] in the distribution and create the new 318 | id2word dictionary, return the chosen index 319 | 320 | Args: 321 | P: style-word the distribution 322 | id2word: the index to word dictionary 323 | threshold: the filter threshold 324 | 325 | Returns: 326 | P_ret: the new style-word distribution 327 | id2word_new: the new id2word 328 | index_chosen: the chosen index 329 | """ 330 | index_chosen = np.unique(np.where(P > threshold)[1]) 331 | rows = [] 332 | columns = [] 333 | for i in range(P.shape[0]): 334 | rows.append([i] * len(index_chosen)) 335 | columns.append(index_chosen) 336 | P_ret = np.array(P[rows, columns]) 337 | id2word_new = dict() 338 | filtered2prev = dict() 339 | for i, j in enumerate(index_chosen): 340 | id2word_new[i] = id2word[j] 341 | filtered2prev[i] = j 342 | return P_ret, id2word_new, index_chosen, filtered2prev 343 | 344 | def prec_recl_f1_dist(style_words): 345 | """get the precision, recall, and f1 distribution of words 346 | 347 | Args: 348 | style_words: the vocabulary distribution given a style 349 | 350 | Returns: 351 | prec_m: the precision matrix, prec_m[i][j] = what precision we can get if we 352 | use word j to predict style i 353 | recl_m: the recall matrix, recl_m[i][j] = what recall we can get if we use 354 | word j to predict style i 355 | f1_m: the f1 matrix, calculated as the harmonic mean of precision and recall 356 | """ 357 | prec_m = np.zeros(style_words.shape) 358 | recl_m = np.zeros(style_words.shape) 359 | f1_m = np.zeros(style_words.shape) 360 | for i in range(style_words.shape[0]): 361 | for j in range(style_words.shape[1]): 362 | if(np.sum(style_words.T[j]) != 0): 363 | prec_m[i][j] = style_words[i][j] / np.sum(style_words.T[j]) 364 | else: 365 | prec_m[i][j] = 0 366 | if(np.sum(style_words[i]) != 0): 367 | recl_m[i][j] = style_words[i][j] / np.sum(style_words[i]) 368 | else: 369 | recl_m[i][j] = 0 370 | f1_m[i][j] = 2 * prec_m[i][j] * recl_m[i][j] 371 | if(prec_m[i][j] + recl_m[i][j] != 0): 372 | f1_m[i][j] /= prec_m[i][j] + recl_m[i][j] 373 | else: 374 | f1_m[i][j] = 0 375 | return prec_m, recl_m, f1_m 376 | 377 | def get_pivot_words(prec, recl, filtered2prev, stop_words, 378 | prec_thres=0.7, recl_thres=0.): 379 | """Mine the high precision words 380 | 381 | Args: 382 | prec: the precision matrix 383 | recl: the recall matrix 384 | filtered2prev: filtered index to previous word index mapping 385 | """ 386 | pivot_words = [[], []] 387 | pivot_prec_recl = [[], []] 388 | pivot_prec = [{}, {}] 389 | for si in range(2): 390 | words = np.where(prec[si] > prec_thres) 391 | total_recl = 0. 392 | for wi in words[0]: 393 | if(recl[si][wi] > recl_thres): 394 | wid = filtered2prev[wi] 395 | if(wid in stop_words): continue 396 | pivot_words[si].append(wid) 397 | pivot_prec[si][wid] = prec[si][wi] 398 | # pivot_words[si].append(wi) 399 | pivot_prec_recl[si].append((prec[si][wi], recl[si][wi])) 400 | total_recl += recl[si][wi] 401 | print('class %d, %d pivots, pivot recall: %.4f' % 402 | (si, len(pivot_words[si]), total_recl)) 403 | return pivot_words, pivot_prec_recl, pivot_prec 404 | 405 | def get_pivot_range(pivots, lower, upper): 406 | """Get the pivot words within a range""" 407 | p = set([w for w in pivots if pivots[w] >= lower and pivots[w] < upper]) 408 | return p 409 | 410 | class Dataset(object): 411 | def __init__(self, config): 412 | self.name = config.dataset_name 413 | self.show_statistics = config.show_statistics 414 | self.base_path = config.dataset_base_path 415 | self.is_bigram = config.is_bigram 416 | self.is_trigram = config.is_trigram 417 | self.max_slen = config.max_slen[self.name] 418 | self.style2id = config.style2id[self.name] 419 | self.id2style = config.id2style[self.name] 420 | self.threshold_cnt = config.pivot_thres_cnt 421 | self.prec_thres = config.prec_thres 422 | self.recl_thres = config.recl_thres 423 | self.vocab_cnt_thres = config.vocab_cnt_thres 424 | self.filter_stop_words = config.filter_stop_words 425 | self.max_training_case = config.max_training_case 426 | self.max_test_case = config.max_test_case 427 | 428 | self.output_path = config.output_path 429 | 430 | self.word2id = None 431 | self.id2word = None 432 | self.sentences = {'train': None, 'dev': None, 'test': None} 433 | self.labels = {'train': None, 'dev': None, 'test': None} 434 | self.pivot_words = None 435 | self.pivot_prec = None 436 | self.pivot_words_prec_recl = None 437 | 438 | self.pivot_classifier = None 439 | self.prec_recl = None 440 | return 441 | 442 | def build(self): 443 | """Build the dataset 444 | 445 | * read the sentences 446 | * nlp-pipeline the sentences 447 | * build the word-class matrix 448 | * extract the pivot words 449 | * build the pivot classifier 450 | """ 451 | 452 | ## Read the dataset 453 | balance_method = 'downsample' if self.name in ['reddit', 'twitter']\ 454 | else 'upsample' 455 | train_sentences, train_labels = read_data( 456 | self.name, 'train', self.base_path, balance_method) 457 | dev_sentences, dev_labels = read_data( 458 | self.name, 'dev', self.base_path, balance_method) 459 | test_sentences, test_labels = read_data( 460 | self.name, 'test', self.base_path, balance_method) 461 | 462 | # Note: sentences are lists of words. Words are not converted to index at 463 | # This stage 464 | self.sentences = { 465 | 'train': train_sentences, 'dev': dev_sentences, 'test': test_sentences} 466 | self.labels = { 467 | 'train': train_labels, 'dev': dev_labels, 'test': test_labels} 468 | 469 | ## Dataset statistics 470 | if(self.show_statistics): pass # TBC 471 | 472 | ## Piovt analysis pipeline 473 | # word to index 474 | (word2id, id2word, bigram2id, id2bigram, trigram2id, id2trigram) =\ 475 | build_vocab(train_sentences, self.is_bigram, self.is_trigram, 476 | self.vocab_cnt_thres, self.filter_stop_words) 477 | self.word2id, self.id2word = word2id, id2word 478 | 479 | # build style-word distribution 480 | (style_sent_unigram, style_sent_bigram, style_sent_trigram, 481 | style_words, style_bigram, style_trigram) = build_style_word_sent( 482 | train_sentences, train_labels, word2id, bigram2id, trigram2id, 483 | self.max_slen, self.id2style) 484 | 485 | # filter words with small occurrance 486 | style_words_filtered, id2word_filtered, index_chosen, filtered2prev = \ 487 | filter_dist(style_words, id2word, self.threshold_cnt) 488 | 489 | # the precsion and recall of each words 490 | prec, recl, f1 = prec_recl_f1_dist(style_words_filtered) 491 | self.prec = prec 492 | 493 | # pivot words are those with high precision 494 | stop_words = set(word2id[w] for w in STOPWORDS if w in word2id) 495 | self.pivot_words, self.pivot_words_prec_recl, self.pivot_prec =\ 496 | get_pivot_words( 497 | prec, recl, filtered2prev, stop_words, self.prec_thres, self.recl_thres) 498 | 499 | self.pivot_classifier = PivotClassifier( 500 | self.pivot_words[1], self.pivot_words[0]) 501 | return 502 | 503 | def get_prec_recl(self, bins=10): 504 | """Get the precision-recall distribution""" 505 | prec_recl = np.zeros([2, bins]) 506 | bin_range = 100. / bins 507 | for si in range(2): 508 | for j in range(bins): 509 | if(j < 5): continue 510 | lower = 0.01 * float(j) * bin_range 511 | upper = 0.01 * float(j + 1) * bin_range 512 | range_pivot_words = get_pivot_range(self.pivot_prec[si], lower, upper) 513 | pivot_recl = self.classify_w_pivot_list('train', range_pivot_words, si) 514 | print('class %d, prec lower %.3f, upper %.3f, %d pivot words, %.4f recall' 515 | % (si, lower, upper, len(range_pivot_words), pivot_recl)) 516 | 517 | print('%.4f recall' % pivot_recl) 518 | prec_recl[si][j] = pivot_recl 519 | 520 | self.prec_recl = prec_recl 521 | print('The precision-recall matrix:') 522 | print(self.prec_recl) 523 | 524 | out_path = self.output_path + self.name + '_prec_recl' 525 | print('Store to:') 526 | print(out_path) 527 | np.save(out_path, prec_recl) 528 | return prec_recl 529 | 530 | def store_pivots(self): 531 | """Store the pivot words, label sentences, and the precision-recall 532 | histogram""" 533 | pivot_prec = {0: {}, 1: {}} 534 | for s in [0, 1]: 535 | out_path = self.output_path + self.name + '_%d.pivot' % s 536 | print('output stored in\n%s' % out_path) 537 | with open(out_path, 'w') as fd: 538 | for w, (p, r) in zip( 539 | self.pivot_words[s], self.pivot_words_prec_recl[s]): 540 | pivot_prec[s][w] = p 541 | fd.write('%s\t\t\t%.4f\t%.4f\n' % (self.id2word[w], p, r)) 542 | 543 | # write down the sentences 544 | fd = {0: open(self.output_path + self.name + '_0.sent', 'w'), 545 | 1: open(self.output_path + self.name + '_1.sent', 'w')} 546 | fd_hard = { 0: open(self.output_path + self.name + '_0.sent_hard', 'w'), 547 | 1: open(self.output_path + self.name + '_1.sent_hard', 'w')} 548 | sent_out = [0, 0] 549 | print(np.sum(self.labels['dev'] == 0), np.sum(self.labels['dev'] == 1)) 550 | for s, l in zip(self.sentences['dev'], self.labels['dev']): 551 | s_ = [self.word2id[w] if w in self.word2id else self.word2id['_UNK'] 552 | for w in s] 553 | s_num_pivots = 0 554 | s_out = [] 555 | for w, wid in zip(s, s_): 556 | if(wid in pivot_prec[l]): 557 | s_out.append(w + '(%.3f)' % pivot_prec[l][wid]) 558 | s_num_pivots += 1 559 | else: s_out.append(w) 560 | if(s_num_pivots >= 3): 561 | fd[l].write(' '.join(s_out) + '\n') 562 | sent_out[l] += 1 563 | if(s_num_pivots == 0): 564 | fd_hard[l].write(' '.join(s_out) + '\n') 565 | print('%d negative sentences written, %d positive' % 566 | (sent_out[0], sent_out[1])) 567 | 568 | # Store the precision-recall histogram 569 | # TBC 570 | return 571 | 572 | def classify_w_pivot_list(self, setname, pivot_list, s): 573 | """Classify the sentences with a given list of pivot words 574 | 575 | Args: 576 | setname: 'train', 'dev' or 'test' 577 | pivot_list: the list of pivot words 578 | s: the style label, 0 or 1 579 | """ 580 | recl = 0 581 | for i in range(len(self.sentences[setname])): 582 | x = self.sentences[setname][i] 583 | y = self.labels[setname][i] 584 | x = set(self.word2id[w] for w in x if w in self.word2id) 585 | if(len(set(x) & pivot_list) >= 1 and s == y): recl += 1 586 | recl = float(recl) / np.sum(self.labels[setname] == s) 587 | return recl 588 | 589 | def classify(self): 590 | """Test the pivot classifier""" 591 | self.pivot_classifier.test( 592 | self.sentences['train'], self.labels['train'], self.word2id, 'train') 593 | self.pivot_classifier.test( 594 | self.sentences['dev'], self.labels['dev'], self.word2id, 'dev') 595 | self.pivot_classifier.test( 596 | self.sentences['test'], self.labels['test'], self.word2id, 'test') 597 | return 598 | 599 | def to_bow_numpy(self, setname): 600 | """Raw data to bag of words numpy representation""" 601 | if( (setname == 'train' and 602 | len(self.sentences[setname]) > self.max_training_case) or 603 | (setname in ['dev', 'test'] and 604 | len(self.sentences[setname]) > self.max_test_case)): 605 | 606 | if(setname == 'train'): max_num_case = self.max_training_case 607 | else: max_num_case = self.max_test_case 608 | 609 | sample_id = np.random.choice( 610 | len(self.sentences[setname]), max_num_case, False) 611 | data = np.zeros([max_num_case, len(self.word2id)]) 612 | else: 613 | sample_id = range(len(self.sentences[setname])) 614 | data = np.zeros([len(self.sentences[setname]), len(self.word2id)]) 615 | si = 0 616 | for sid in sample_id: 617 | s = self.sentences[setname][sid] 618 | for w in s: 619 | if w in self.word2id: 620 | wid = self.word2id[w] 621 | data[si][wid] = 1 622 | si += 1 623 | labels = self.labels[setname][sample_id] 624 | return data, labels 625 | 626 | def to_sent_numpy(self, setname): 627 | """Raw data to numpy representation, a sentence is a list of word index""" 628 | 629 | if( (setname == 'train' and 630 | len(self.sentences[setname]) > self.max_training_case) or 631 | (setname in ['dev', 'test'] and 632 | len(self.sentences[setname]) > self.max_test_case)): 633 | 634 | if(setname == 'train'): max_num_case = self.max_training_case 635 | else: max_num_case = self.max_test_case 636 | 637 | sample_id = np.random.choice( 638 | len(self.sentences[setname]), max_num_case, False) 639 | else: 640 | sample_id = range(len(self.sentences[setname])) 641 | 642 | data = [] 643 | for sid in sample_id: 644 | s = self.sentences[setname][sid] 645 | s_ = [] 646 | for w in s: 647 | if w in self.word2id: wid = self.word2id[w] 648 | else: wid = self.word2id['_UNK'] 649 | s_.append(wid) 650 | data.append(s_) 651 | data = pad_sequences(data, maxlen=self.max_slen) 652 | labels = self.labels[setname][sample_id] 653 | return data, labels 654 | 655 | --------------------------------------------------------------------------------