├── _config.yml ├── src ├── wikipedia │ ├── __init__.py │ ├── __init__.pyc │ ├── wikipedia.pyc │ ├── util.py │ ├── exceptions.py │ ├── wikipedia-standard.py │ └── wikipedia.py ├── preprocess_wikipedia.py ├── baseline_ESA_emotion.py ├── baseline_ESA_situation.py ├── baseline_ESA_yahoo.py ├── baseline_word2vec.py ├── load_data_ESA.py ├── jsonlines.py ├── ESA.py ├── preprocess_situation.py ├── preprocess_emotion.py ├── preprocess_yahoo.py └── demo.py └── README.md /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /src/wikipedia/__init__.py: -------------------------------------------------------------------------------- 1 | from .wikipedia import * 2 | from .exceptions import * 3 | 4 | __version__ = (1, 4, 0) 5 | -------------------------------------------------------------------------------- /src/wikipedia/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CogComp/Benchmarking-Zero-shot-Text-Classification/HEAD/src/wikipedia/__init__.pyc -------------------------------------------------------------------------------- /src/wikipedia/wikipedia.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CogComp/Benchmarking-Zero-shot-Text-Classification/HEAD/src/wikipedia/wikipedia.pyc -------------------------------------------------------------------------------- /src/wikipedia/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals 2 | 3 | import sys 4 | import functools 5 | 6 | def debug(fn): 7 | def wrapper(*args, **kwargs): 8 | print(fn.__name__, 'called!') 9 | print(sorted(args), tuple(sorted(kwargs.items()))) 10 | res = fn(*args, **kwargs) 11 | print(res) 12 | return res 13 | return wrapper 14 | 15 | 16 | class cache(object): 17 | 18 | def __init__(self, fn): 19 | self.fn = fn 20 | self._cache = {} 21 | functools.update_wrapper(self, fn) 22 | 23 | def __call__(self, *args, **kwargs): 24 | key = str(args) + str(kwargs) 25 | if key in self._cache: 26 | ret = self._cache[key] 27 | else: 28 | ret = self._cache[key] = self.fn(*args, **kwargs) 29 | 30 | return ret 31 | 32 | def clear_cache(self): 33 | self._cache = {} 34 | 35 | 36 | # from http://stackoverflow.com/questions/3627793/best-output-type-and-encoding-practices-for-repr-functions 37 | def stdout_encode(u, default='UTF8'): 38 | encoding = sys.stdout.encoding or default 39 | if sys.version_info > (3, 0): 40 | return u.encode(encoding).decode(encoding) 41 | return u.encode(encoding) 42 | -------------------------------------------------------------------------------- /src/preprocess_wikipedia.py: -------------------------------------------------------------------------------- 1 | from wikipedia import WikipediaPage 2 | import json 3 | import codecs 4 | 5 | def build_wiki_category_dataset(): 6 | readfile = codecs.open('/export/home/Dataset/wikipedia/parsed_output/tokenized_wiki/tokenized_wiki.txt', 'r', 'utf-8') 7 | writefile = codecs.open('/export/home/Dataset/wikipedia/parsed_output/tokenized_wiki/tokenized_wiki2categories.txt', 'w', 'utf-8') 8 | co = 0 9 | for line in readfile: 10 | try: 11 | line_dic = json.loads(line) 12 | except ValueError: 13 | continue 14 | 15 | try: 16 | # title = line_dic.get('title') 17 | title_id = line_dic.get('id') 18 | article = WikipediaPage(pageid=title_id) 19 | except AttributeError: 20 | continue 21 | type_list = article.categories 22 | # print(type_list) 23 | line_dic['categories'] = type_list 24 | writefile.write(json.dumps(line_dic)+'\n') 25 | co+=1 26 | if co % 5 == 0: 27 | print(co) 28 | if co == 100000: 29 | break 30 | writefile.close() 31 | readfile.close() 32 | print('over') 33 | 34 | 35 | if __name__ == '__main__': 36 | build_wiki_category_dataset() 37 | -------------------------------------------------------------------------------- /src/wikipedia/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global wikipedia exception and warning classes. 3 | """ 4 | 5 | import sys 6 | 7 | 8 | ODD_ERROR_MESSAGE = "This shouldn't happen. Please report on GitHub: github.com/goldsmith/Wikipedia" 9 | 10 | 11 | class WikipediaException(Exception): 12 | """Base Wikipedia exception class.""" 13 | 14 | def __init__(self, error): 15 | self.error = error 16 | 17 | def __unicode__(self): 18 | return "An unknown error occured: \"{0}\". Please report it on GitHub!".format(self.error) 19 | 20 | if sys.version_info > (3, 0): 21 | def __str__(self): 22 | return self.__unicode__() 23 | 24 | else: 25 | def __str__(self): 26 | return self.__unicode__().encode('utf8') 27 | 28 | 29 | class PageError(WikipediaException): 30 | """Exception raised when no Wikipedia matched a query.""" 31 | 32 | def __init__(self, pageid=None, *args): 33 | if pageid: 34 | self.pageid = pageid 35 | else: 36 | self.title = args[0] 37 | 38 | def __unicode__(self): 39 | if hasattr(self, 'title'): 40 | return u"\"{0}\" does not match any pages. Try another query!".format(self.title) 41 | else: 42 | return u"Page id \"{0}\" does not match any pages. Try another id!".format(self.pageid) 43 | 44 | 45 | class DisambiguationError(WikipediaException): 46 | """ 47 | Exception raised when a page resolves to a Disambiguation page. 48 | 49 | The `options` property contains a list of titles 50 | of Wikipedia pages that the query may refer to. 51 | 52 | .. note:: `options` does not include titles that do not link to a valid Wikipedia page. 53 | """ 54 | 55 | def __init__(self, title, may_refer_to): 56 | self.title = title 57 | self.options = may_refer_to 58 | 59 | def __unicode__(self): 60 | return u"\"{0}\" may refer to: \n{1}".format(self.title, '\n'.join(self.options)) 61 | 62 | 63 | class RedirectError(WikipediaException): 64 | """Exception raised when a page title unexpectedly resolves to a redirect.""" 65 | 66 | def __init__(self, title): 67 | self.title = title 68 | 69 | def __unicode__(self): 70 | return u"\"{0}\" resulted in a redirect. Set the redirect property to True to allow automatic redirects.".format(self.title) 71 | 72 | 73 | class HTTPTimeoutError(WikipediaException): 74 | """Exception raised when a request to the Mediawiki servers times out.""" 75 | 76 | def __init__(self, query): 77 | self.query = query 78 | 79 | def __unicode__(self): 80 | return u"Searching for \"{0}\" resulted in a timeout. Try again in a few seconds, and make sure you have rate limiting set to True.".format(self.query) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking Zero-shot Text Classification: Datasets, Evaluation and Entailment Approach 2 | 3 | ## Description 4 | Zero-shot text classification (0SHOT-TC) is a challenging NLU problem to which little attention has been paid by the research community. 0SHOT-TC aims to associate an appropriate label with a piece of text, irrespective of the text domain and the aspect (e.g., topic, emotion, event, etc.) described by the label. And there are only a few articles studying 0SHOT-TC, all focusing only on topical categorization which, we argue, is just the tip of the iceberg in 0SHOT-TC. In addition, the chaotic experiments in literature make no uniform comparison, which blurs the progress. This work benchmarks the 0SHOT-TC problem by providing unified datasets, standardized evaluations, and state-of-the-art baselines. Our contributions include: i) The datasets we provide facilitate studying 0SHOT-TC relative to conceptually different and diverse aspects: the “topic” aspect includes “sports” and “politics” as labels; the “emotion” aspect includes “joy” and “anger”; the “situation” aspect includes “medical assistance” and “water shortage”. ii) We extend the existing evaluation setup (labelpartially-unseen) – given a dataset, train on some labels, test on all labels – to include a more challenging yet realistic evaluation label-fully-unseen 0SHOT-TC, aiming at classifying text snippets without seeing task specific training data at all. iii) We unify the 0SHOT-TC of diverse aspects within a textual entailment formulation and study it this way. 5 | Hi, this repository contains the code and the data for the EMNLP2019 paper "Benchmarking Zero-shot Text Classification: Datasets, Evaluation and Entailment Approach" 6 | 7 | ## Datasets 8 | Dataset for "topic detection", "emotion detection" and "situation detection" 9 | - https://drive.google.com/open?id=1qGmyEVD19ruvLLz9J0QGV7rsZPFEz2Az 10 | 11 | Wikipedia data and three pretrained entailment models (RTE, MNLI, FEVER) 12 | - https://drive.google.com/file/d/1ILCQR_y-OSTdgkz45LP7JsHcelEsvoIn/view?usp=sharing 13 | 14 | 15 | ## Requirement 16 | - Pytorch 17 | - Transformer (pytorch): https://github.com/huggingface/transformers 18 | - GPU 19 | 20 | ## Usage 21 | To rerun the code (take "baseline_wiki_based_emotion.py" as an example): 22 | 23 | CUDA_VISIBLE_DEVICES=1 python -u baseline_wiki_based_emotion.py --task_name rte --do_train --do_lower_case --bert_model bert-base-uncased --max_seq_length 128 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --data_dir '' --output_dir '' 24 | 25 | The following is a very important step before running: 26 | Since our code was written in "pytorch-transformer" -- the old verion of Huggingface Transformer 27 | 1) Update "pytorch-transformer" to "transformer" before running the code. For example let's change the following: 28 | 29 | from pytorch_transformers.file_utils import PYTORCH_TRANSFORMERS_CACHE 30 | from pytorch_transformers.modeling_bert import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME 31 | from pytorch_transformers.tokenization_bert import BertTokenizer 32 | from pytorch_transformers.optimization import AdamW 33 | 34 | To: 35 | 36 | from transformers.file_utils import PYTORCH_TRANSFORMERS_CACHE 37 | from transformers.modeling_bert import BertForSequenceClassification 38 | from transformers.tokenization_bert import BertTokenizer 39 | from transformers.optimization import AdamW 40 | 41 | 2) The new Transformer's function "BertForSequenceClassification" has parameter order slightly different with the prior "pytorch_transformer". Therefore please change the following: 42 | 43 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, 44 | position_ids=None, head_mask=None, inputs_embeds=None, labels=None): 45 | To: 46 | 47 | def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, 48 | position_ids=None, head_mask=None, inputs_embeds=None, labels=None): 49 | 50 | As the token_ids and mask are exchanged, we need to change the input order (only for "token_type_ids" and "attention_mask") when ever we call the model. For example, let's change following: 51 | 52 | logits = model(input_ids, input_mask,segment_ids, labels=None) 53 | 54 | 55 | To: 56 | 57 | logits = model(input_ids, segment_ids, input_mask, labels=None) 58 | 59 | 60 | ## Citation 61 | For code and data: 62 | 63 | @inproceedings{yinroth2019zeroshot, 64 | title={Benchmarking Zero-shot Text Classification: Datasets, Evaluation and Entailment Approach}, 65 | author={Wenpeng Yin, Jamaal Hay and Dan Roth}, 66 | booktitle={{EMNLP}}, 67 | url = {https://arxiv.org/abs/1909.00161}, 68 | year={2019} 69 | } 70 | 71 | ## Contacts 72 | 73 | For any questions : mr.yinwenpeng@gmail.com 74 | -------------------------------------------------------------------------------- /src/baseline_ESA_emotion.py: -------------------------------------------------------------------------------- 1 | import time 2 | from load_data_ESA import load_emotion_and_labelnames 3 | from ESA import load_ESA_sparse_matrix, divide_sparseMatrix_by_list_row_wise, multiply_sparseMatrix_by_list_row_wise 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | from sklearn.metrics.pairwise import euclidean_distances 6 | from scipy.sparse import vstack 7 | import numpy as np 8 | from operator import itemgetter 9 | from scipy.special import softmax 10 | from preprocess_emotion import emotion_f1_given_goldlist_and_predlist 11 | 12 | 13 | 14 | all_texts, all_labels, all_word2DF, labelnames = load_emotion_and_labelnames() 15 | ESA_sparse_matrix = load_ESA_sparse_matrix().tocsr() 16 | # ESA_sparse_matrix_2_dict = {} 17 | # 18 | # def ESA_sparse_matrix_into_dict(): 19 | # global ESA_sparse_matrix_2_dict 20 | # for i in range(ESA_sparse_matrix.shape[0]): 21 | # ESA_sparse_matrix_2_dict[i] = ESA_sparse_matrix.getrow(i) 22 | # print('ESA_sparse_matrix_into_dict succeed') 23 | 24 | 25 | def text_idlist_2_ESAVector(idlist, text_bool): 26 | # sub_matrix = ESA_sparse_matrix[idlist,:] 27 | # return sub_matrix.mean(axis=0) 28 | # matrix_list = [] 29 | # for id in idlist: 30 | # matrix_list.append(ESA_sparse_matrix_2_dict.get(id)) 31 | # stack_matrix = vstack(matrix_list) 32 | # return stack_matrix.mean(axis=0) 33 | # print('idlist:', idlist) 34 | if text_bool: 35 | sub_matrix = ESA_sparse_matrix[idlist,:] 36 | # myvalues = list(itemgetter(*idlist)(all_word2DF)) 37 | myvalues = [all_word2DF.get(id) for id in idlist] 38 | weighted_matrix = divide_sparseMatrix_by_list_row_wise(sub_matrix, myvalues) 39 | return weighted_matrix.sum(axis=0) 40 | else: #label names 41 | sub_matrix = ESA_sparse_matrix[idlist,:] 42 | return sub_matrix.sum(axis=0) 43 | 44 | def text_idlist_2_ESAVector_attention(idlist, text_bool, label_veclist): 45 | 46 | 47 | sub_matrix = ESA_sparse_matrix[idlist,:] 48 | result_veclist = [] 49 | for vec in label_veclist: 50 | cos = cosine_similarity(sub_matrix, vec) 51 | att_matrix = multiply_sparseMatrix_by_list_row_wise(sub_matrix, softmax(cos)) 52 | result_veclist.append(att_matrix.sum(axis=0)) 53 | 54 | return result_veclist 55 | 56 | # if text_bool: 57 | # sub_matrix = ESA_sparse_matrix[idlist,:] 58 | # myvalues = [all_word2DF.get(id) for id in idlist] 59 | # weighted_matrix = divide_sparseMatrix_by_list_row_wise(sub_matrix, myvalues) 60 | # return weighted_matrix.sum(axis=0) 61 | # else: #label names 62 | # sub_matrix = ESA_sparse_matrix[idlist,:] 63 | # return sub_matrix.sum(axis=0) 64 | 65 | def ESA_cosine(): 66 | origin_type_list = ['sadness', 'joy', 'anger', 'disgust', 'fear', 'surprise', 'shame', 'guilt', 'love'] 67 | label_veclist = [] 68 | for i in range(len(labelnames)): 69 | labelname_idlist = labelnames[i] 70 | '''label rep is sum up all word ESA vectors''' 71 | label_veclist.append(text_idlist_2_ESAVector(labelname_idlist, False)) 72 | # print(label_veclist) 73 | print('all labelnames are in vec succeed') 74 | labels = all_labels[0] 75 | # print('all_labels:', labels[:10]) 76 | sample_size = len(labels) 77 | print('total test size:', sample_size) 78 | hit_size = 0 79 | co=0 80 | start_time = time.time() 81 | pred_type_list = [] 82 | gold_type_list = [] 83 | for sample_index in range(sample_size): 84 | text_idlist = all_texts[0][sample_index] 85 | '''text rep is weighted sum up of ESA vectors''' 86 | text_vec = text_idlist_2_ESAVector(text_idlist, True) 87 | cos_array=cosine_similarity(text_vec, np.vstack(label_veclist)) 88 | list_cosine = list(cos_array[0]) 89 | # print('list_cosine:', list_cosine) 90 | max_prob = -100.0 91 | max_index = -1 92 | for i in range(len(list_cosine)): 93 | if list_cosine[i] > max_prob: 94 | max_prob = list_cosine[i] 95 | max_index = i 96 | if max_index ==-1 or max_prob < 0.01: 97 | pred_type = 'noemo' 98 | else: 99 | pred_type = origin_type_list[max_index] 100 | # print('pred_type:', pred_type) 101 | # print('pred_type_list_i:', pred_type_list_i) 102 | # print('all_labels:', labels[:10]) 103 | # print('gold_type_list_i:',labels[i], i ) 104 | pred_type_list.append(pred_type) 105 | gold_type_list.append(labels[sample_index]) 106 | 107 | v0, v1, all_f1 = emotion_f1_given_goldlist_and_predlist(gold_type_list, pred_type_list, set(['sadness', 'anger', 'fear', 'shame', 'love']), set(['joy', 'disgust', 'surprise', 'guilt'])) 108 | # seen_f1_v1, unseen_f1_v1, all_f1_v1 = situation_f1_given_goldlist_and_predlist(gold_type_list, pred_type_list, set(['evac','utils', 'shelter','food', 'terrorism'])) 109 | 110 | co+=1 111 | print(co, '...', v0, v1, all_f1) 112 | if co%10==0: 113 | spend_time = (time.time()-start_time)/60.0 114 | print('\t\t\t\t\t',spend_time, 'mins') 115 | 116 | print('over.') 117 | 118 | 119 | 120 | if __name__ == '__main__': 121 | ESA_cosine() 122 | -------------------------------------------------------------------------------- /src/baseline_ESA_situation.py: -------------------------------------------------------------------------------- 1 | import time 2 | from load_data_ESA import load_situation_and_labelnames 3 | from ESA import load_ESA_sparse_matrix, divide_sparseMatrix_by_list_row_wise, multiply_sparseMatrix_by_list_row_wise 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | from sklearn.metrics.pairwise import euclidean_distances 6 | from scipy.sparse import vstack 7 | import numpy as np 8 | from operator import itemgetter 9 | from scipy.special import softmax 10 | from preprocess_situation import situation_f1_given_goldlist_and_predlist 11 | 12 | 13 | 14 | all_texts, all_labels, all_word2DF, labelnames = load_situation_and_labelnames() 15 | ESA_sparse_matrix = load_ESA_sparse_matrix().tocsr() 16 | # ESA_sparse_matrix_2_dict = {} 17 | # 18 | # def ESA_sparse_matrix_into_dict(): 19 | # global ESA_sparse_matrix_2_dict 20 | # for i in range(ESA_sparse_matrix.shape[0]): 21 | # ESA_sparse_matrix_2_dict[i] = ESA_sparse_matrix.getrow(i) 22 | # print('ESA_sparse_matrix_into_dict succeed') 23 | 24 | 25 | def text_idlist_2_ESAVector(idlist, text_bool): 26 | # sub_matrix = ESA_sparse_matrix[idlist,:] 27 | # return sub_matrix.mean(axis=0) 28 | # matrix_list = [] 29 | # for id in idlist: 30 | # matrix_list.append(ESA_sparse_matrix_2_dict.get(id)) 31 | # stack_matrix = vstack(matrix_list) 32 | # return stack_matrix.mean(axis=0) 33 | # print('idlist:', idlist) 34 | if text_bool: 35 | sub_matrix = ESA_sparse_matrix[idlist,:] 36 | # myvalues = list(itemgetter(*idlist)(all_word2DF)) 37 | myvalues = [all_word2DF.get(id) for id in idlist] 38 | weighted_matrix = divide_sparseMatrix_by_list_row_wise(sub_matrix, myvalues) 39 | return weighted_matrix.sum(axis=0) 40 | else: #label names 41 | sub_matrix = ESA_sparse_matrix[idlist,:] 42 | return sub_matrix.sum(axis=0) 43 | 44 | def text_idlist_2_ESAVector_attention(idlist, text_bool, label_veclist): 45 | 46 | 47 | sub_matrix = ESA_sparse_matrix[idlist,:] 48 | result_veclist = [] 49 | for vec in label_veclist: 50 | cos = cosine_similarity(sub_matrix, vec) 51 | att_matrix = multiply_sparseMatrix_by_list_row_wise(sub_matrix, softmax(cos)) 52 | result_veclist.append(att_matrix.sum(axis=0)) 53 | 54 | return result_veclist 55 | 56 | # if text_bool: 57 | # sub_matrix = ESA_sparse_matrix[idlist,:] 58 | # myvalues = [all_word2DF.get(id) for id in idlist] 59 | # weighted_matrix = divide_sparseMatrix_by_list_row_wise(sub_matrix, myvalues) 60 | # return weighted_matrix.sum(axis=0) 61 | # else: #label names 62 | # sub_matrix = ESA_sparse_matrix[idlist,:] 63 | # return sub_matrix.sum(axis=0) 64 | 65 | def ESA_cosine(): 66 | origin_type_list = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange'] 67 | label_veclist = [] 68 | for i in range(len(labelnames)): 69 | labelname_idlist = labelnames[i] 70 | '''label rep is sum up all word ESA vectors''' 71 | label_veclist.append(text_idlist_2_ESAVector(labelname_idlist, False)) 72 | # print(label_veclist) 73 | print('all labelnames are in vec succeed') 74 | labels = all_labels[0] 75 | # print('all_labels:', labels[:10]) 76 | sample_size = len(labels) 77 | print('total test size:', sample_size) 78 | hit_size = 0 79 | co=0 80 | start_time = time.time() 81 | pred_type_list = [] 82 | gold_type_list = [] 83 | for sample_index in range(sample_size): 84 | text_idlist = all_texts[0][sample_index] 85 | '''text rep is weighted sum up of ESA vectors''' 86 | text_vec = text_idlist_2_ESAVector(text_idlist, True) 87 | cos_array=cosine_similarity(text_vec, np.vstack(label_veclist)) 88 | list_cosine = list(cos_array[0]) 89 | pred_type_list_i = [] 90 | for i in range(len(list_cosine)): 91 | if list_cosine[i] > 0.03: 92 | pred_type_list_i.append(origin_type_list[i]) 93 | if len(pred_type_list_i) == 0: 94 | pred_type_list_i.append('out-of-domain') 95 | # print('pred_type_list_i:', pred_type_list_i) 96 | # print('all_labels:', labels[:10]) 97 | # print('gold_type_list_i:',labels[i], i ) 98 | pred_type_list.append(pred_type_list_i) 99 | gold_type_list.append(labels[sample_index]) 100 | 101 | v0, v1, all_f1 = situation_f1_given_goldlist_and_predlist(gold_type_list, pred_type_list, set(['search','infra','water','med','crimeviolence', 'regimechange']), set(['evac','utils', 'shelter','food', 'terrorism'])) 102 | # seen_f1_v1, unseen_f1_v1, all_f1_v1 = situation_f1_given_goldlist_and_predlist(gold_type_list, pred_type_list, set(['evac','utils', 'shelter','food', 'terrorism'])) 103 | 104 | co+=1 105 | print(co, '...', v0, v1, all_f1) 106 | if co%10==0: 107 | spend_time = (time.time()-start_time)/60.0 108 | print('\t\t\t\t\t',spend_time, 'mins') 109 | 110 | print('over.') 111 | 112 | def ESA_cosine_attention(): 113 | '''not used finally''' 114 | label_veclist = [] 115 | for i in range(len(labelnames)): 116 | labelname_idlist = labelnames[i] 117 | label_veclist.append(text_idlist_2_ESAVector(labelname_idlist, False)) 118 | # print(label_veclist) 119 | print('all labelnames are in vec succeed') 120 | labels = all_labels[0] 121 | sample_size = len(labels) 122 | print('total test size:', sample_size) 123 | hit_size = 0 124 | co=0 125 | start_time = time.time() 126 | for i in range(sample_size): 127 | text_idlist = all_texts[0][i] 128 | text_veclist = text_idlist_2_ESAVector_attention(text_idlist, True, label_veclist) 129 | cos_array=cosine_similarity(np.vstack(text_veclist), np.vstack(label_veclist)) 130 | print('cos_array:',cos_array) 131 | print('diagonal:', cos_array.diagonal(), 'ground truth:', labels[i]) 132 | # exit(0) 133 | max_id = np.argmax(cos_array.diagonal()) 134 | if max_id == labels[i]: 135 | hit_size+=1 136 | co+=1 137 | print(co, '...', hit_size/sample_size, hit_size/co) 138 | if co%10==0: 139 | spend_time = (time.time()-start_time)/60.0 140 | print('\t\t\t\t\t',spend_time, 'mins') 141 | acc = hit_size/sample_size 142 | print('acc:', acc) 143 | 144 | if __name__ == '__main__': 145 | ESA_cosine() 146 | -------------------------------------------------------------------------------- /src/baseline_ESA_yahoo.py: -------------------------------------------------------------------------------- 1 | import time 2 | from load_data_ESA import load_yahoo_and_labelnames 3 | from ESA import load_ESA_sparse_matrix, divide_sparseMatrix_by_list_row_wise, multiply_sparseMatrix_by_list_row_wise 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | from sklearn.metrics.pairwise import euclidean_distances 6 | from scipy.sparse import vstack 7 | import numpy as np 8 | from operator import itemgetter 9 | from scipy.special import softmax 10 | 11 | all_texts, all_labels, all_word2DF, labelnames = load_yahoo_and_labelnames() 12 | ESA_sparse_matrix = load_ESA_sparse_matrix().tocsr() 13 | # ESA_sparse_matrix_2_dict = {} 14 | # 15 | # def ESA_sparse_matrix_into_dict(): 16 | # global ESA_sparse_matrix_2_dict 17 | # for i in range(ESA_sparse_matrix.shape[0]): 18 | # ESA_sparse_matrix_2_dict[i] = ESA_sparse_matrix.getrow(i) 19 | # print('ESA_sparse_matrix_into_dict succeed') 20 | 21 | 22 | def text_idlist_2_ESAVector(idlist, text_bool): 23 | # sub_matrix = ESA_sparse_matrix[idlist,:] 24 | # return sub_matrix.mean(axis=0) 25 | # matrix_list = [] 26 | # for id in idlist: 27 | # matrix_list.append(ESA_sparse_matrix_2_dict.get(id)) 28 | # stack_matrix = vstack(matrix_list) 29 | # return stack_matrix.mean(axis=0) 30 | # print('idlist:', idlist) 31 | if text_bool: 32 | sub_matrix = ESA_sparse_matrix[idlist,:] 33 | # myvalues = list(itemgetter(*idlist)(all_word2DF)) 34 | myvalues = [all_word2DF.get(id) for id in idlist] 35 | weighted_matrix = divide_sparseMatrix_by_list_row_wise(sub_matrix, myvalues) 36 | return weighted_matrix.sum(axis=0) 37 | else: #label names 38 | sub_matrix = ESA_sparse_matrix[idlist,:] 39 | return sub_matrix.sum(axis=0) 40 | 41 | def text_idlist_2_ESAVector_attention(idlist, text_bool, label_veclist): 42 | 43 | 44 | sub_matrix = ESA_sparse_matrix[idlist,:] 45 | result_veclist = [] 46 | for vec in label_veclist: 47 | cos = cosine_similarity(sub_matrix, vec) 48 | att_matrix = multiply_sparseMatrix_by_list_row_wise(sub_matrix, softmax(cos)) 49 | result_veclist.append(att_matrix.sum(axis=0)) 50 | 51 | return result_veclist 52 | 53 | # if text_bool: 54 | # sub_matrix = ESA_sparse_matrix[idlist,:] 55 | # myvalues = [all_word2DF.get(id) for id in idlist] 56 | # weighted_matrix = divide_sparseMatrix_by_list_row_wise(sub_matrix, myvalues) 57 | # return weighted_matrix.sum(axis=0) 58 | # else: #label names 59 | # sub_matrix = ESA_sparse_matrix[idlist,:] 60 | # return sub_matrix.sum(axis=0) 61 | 62 | def ESA_cosine(): 63 | label_veclist = [] 64 | for i in range(len(labelnames)): 65 | labelname_idlist = labelnames[i] 66 | '''label rep is sum up all word ESA vectors''' 67 | label_veclist.append(text_idlist_2_ESAVector(labelname_idlist, False)) 68 | # print(label_veclist) 69 | print('all labelnames are in vec succeed') 70 | labels = all_labels[0] 71 | sample_size = len(labels) 72 | print('total test size:', sample_size) 73 | seen_type_v0 = set([0,2,4,6,8]) 74 | seen_type_v1 = set([1,3,5,7,9]) 75 | hit_size_seen_v0 = 0 76 | all_size_seen_v0 = 0 77 | hit_size_unseen_v0 = 0 78 | all_size_unseen_v0 = 0 79 | 80 | hit_size_seen_v1 = 0 81 | all_size_seen_v1 = 0 82 | hit_size_unseen_v1 = 0 83 | all_size_unseen_v1= 0 84 | hit_size = 0 85 | co=0 86 | start_time = time.time() 87 | for i in range(sample_size): 88 | text_idlist = all_texts[0][i] 89 | '''text rep is weighted sum up of ESA vectors''' 90 | text_vec = text_idlist_2_ESAVector(text_idlist, True) 91 | cos_array=cosine_similarity(text_vec, np.vstack(label_veclist)) 92 | max_id = np.argmax(cos_array, axis=1) 93 | gold_label = labels[i] 94 | pred_label = max_id[0] 95 | '''v0''' 96 | if gold_label in seen_type_v0: 97 | all_size_seen_v0+=1 98 | if gold_label == pred_label: 99 | hit_size_seen_v0+=1 100 | else: 101 | all_size_unseen_v0+=1 102 | if gold_label == pred_label: 103 | hit_size_unseen_v0+=1 104 | '''v1''' 105 | if gold_label in seen_type_v1: 106 | all_size_seen_v1+=1 107 | if gold_label == pred_label: 108 | hit_size_seen_v1+=1 109 | else: 110 | all_size_unseen_v1+=1 111 | if gold_label == pred_label: 112 | hit_size_unseen_v1+=1 113 | 114 | 115 | 116 | if max_id[0] == labels[i]: 117 | hit_size+=1 118 | co+=1 119 | print(co, '...', hit_size/sample_size, hit_size/co, 'v0:', hit_size_seen_v0/(1e-8+all_size_seen_v0), hit_size_unseen_v0/(1e-8+all_size_unseen_v0), 'v1:', hit_size_seen_v1/(1e-8+all_size_seen_v1), hit_size_unseen_v1/(1e-8+all_size_unseen_v1), 'seen vs. unseen size:', all_size_seen_v0, all_size_unseen_v0) 120 | if co%10==0: 121 | spend_time = (time.time()-start_time)/60.0 122 | print('\t\t\t\t\t',spend_time, 'mins') 123 | acc = hit_size/sample_size 124 | print('acc:', acc) 125 | 126 | def ESA_cosine_attention(): 127 | '''not used finally''' 128 | label_veclist = [] 129 | for i in range(len(labelnames)): 130 | labelname_idlist = labelnames[i] 131 | label_veclist.append(text_idlist_2_ESAVector(labelname_idlist, False)) 132 | # print(label_veclist) 133 | print('all labelnames are in vec succeed') 134 | labels = all_labels[0] 135 | sample_size = len(labels) 136 | print('total test size:', sample_size) 137 | hit_size = 0 138 | co=0 139 | start_time = time.time() 140 | for i in range(sample_size): 141 | text_idlist = all_texts[0][i] 142 | text_veclist = text_idlist_2_ESAVector_attention(text_idlist, True, label_veclist) 143 | cos_array=cosine_similarity(np.vstack(text_veclist), np.vstack(label_veclist)) 144 | print('cos_array:',cos_array) 145 | print('diagonal:', cos_array.diagonal(), 'ground truth:', labels[i]) 146 | # exit(0) 147 | max_id = np.argmax(cos_array.diagonal()) 148 | if max_id == labels[i]: 149 | hit_size+=1 150 | co+=1 151 | print(co, '...', hit_size/sample_size, hit_size/co) 152 | if co%10==0: 153 | spend_time = (time.time()-start_time)/60.0 154 | print('\t\t\t\t\t',spend_time, 'mins') 155 | acc = hit_size/sample_size 156 | print('acc:', acc) 157 | 158 | if __name__ == '__main__': 159 | ESA_cosine() 160 | -------------------------------------------------------------------------------- /src/baseline_word2vec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cosine 3 | from sklearn.metrics import f1_score 4 | import codecs 5 | 6 | 7 | def sent_2_emb(wordlist): 8 | emb_list = [] 9 | for word in wordlist: 10 | emb = word2vec.get(word, None) 11 | if emb is not None: 12 | emb_list.append(emb) 13 | if len(emb_list) > 0: 14 | arr = np.array(emb_list) 15 | return np.sum(arr, axis=0) 16 | else: 17 | return np.array([0.0]*300) 18 | 19 | 20 | 21 | def baseline_w2v(): 22 | 23 | 24 | 25 | 26 | 27 | '''emotion''' 28 | type_list = ['sadness', 'joy', 'anger', 'disgust', 'fear', 'surprise', 'shame', 'guilt', 'love']#, 'noemo'] 29 | type_2_emb = [] 30 | for type in type_list: 31 | type_2_emb.append(sent_2_emb(type.split())) 32 | readfile = codecs.open('/export/home/Dataset/Stuttgart_Emotion/unify-emotion-datasets-master/zero-shot-split/test.txt', 'r', 'utf-8') 33 | gold_label_list = [] 34 | pred_label_list = [] 35 | co = 0 36 | for line in readfile: 37 | parts = line.strip().split('\t') 38 | gold_label_list.append(parts[0]) 39 | text = parts[2].strip() 40 | max_cos = 0.0 41 | max_type = -1 42 | text_emb = sent_2_emb(text.split()) 43 | for i, type in enumerate(type_list): 44 | 45 | type_emb = type_2_emb[i] 46 | cos = 1.0-cosine(text_emb, type_emb) 47 | if cos > max_cos: 48 | max_cos = cos 49 | max_type = type 50 | if max_cos > 0.0: 51 | pred_label_list.append(max_type) 52 | else: 53 | pred_label_list.append('noemo') 54 | co+=1 55 | if co % 1000 == 0: 56 | print('emotion co:', co) 57 | readfile.close() 58 | print('gold_label_list:', gold_label_list[:200]) 59 | print('pred_label_list:', pred_label_list[:200]) 60 | all_test_labels = list(set(gold_label_list)) 61 | f1_score_per_type = f1_score(gold_label_list, pred_label_list, labels = all_test_labels, average=None) 62 | seen_types_group = [['sadness', 'anger', 'fear', 'shame', 'love'],['joy', 'disgust', 'surprise', 'guilt']] 63 | for i in range(len(seen_types_group)): 64 | seen_types = seen_types_group[i] 65 | 66 | seen_f1_accu = 0.0 67 | seen_size = 0 68 | unseen_f1_accu = 0.0 69 | unseen_size = 0 70 | for i in range(len(all_test_labels)): 71 | f1=f1_score_per_type[i] 72 | co = gold_label_list.count(all_test_labels[i]) 73 | if all_test_labels[i] in seen_types: 74 | seen_f1_accu+=f1*co 75 | seen_size+=co 76 | else: 77 | unseen_f1_accu+=f1*co 78 | unseen_size+=co 79 | print('seen:', seen_f1_accu/seen_size, 'unseen:', unseen_f1_accu/unseen_size) 80 | print('overall:', f1_score(gold_label_list, pred_label_list, labels = all_test_labels, average='weighted')) 81 | 82 | '''situation''' 83 | origin_type_list = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 84 | type_list = ['search','evacuation','infrastructure','utilities utility','water','shelter','medical assistance','food', 'crime violence', 'terrorism', 'regime change']#, 'out-of-domain'] 85 | type_2_emb = [] 86 | for type in type_list: 87 | type_2_emb.append(sent_2_emb(type.split())) 88 | readfile = codecs.open('/export/home/Dataset/LORELEI/zero-shot-split/test.txt', 'r', 'utf-8') 89 | gold_label_list = [] 90 | pred_label_list = [] 91 | co=0 92 | for line in readfile: 93 | parts = line.strip().split('\t') 94 | gold_label_list.append(parts[0].split()) 95 | text = parts[1].strip() 96 | max_cos = 0.0 97 | pred_type_i = [] 98 | text_emb = sent_2_emb(text.split()) 99 | for i, type in enumerate(origin_type_list[:-1]): 100 | type_emb = type_2_emb[i] 101 | cos = 1.0-cosine(text_emb, type_emb) 102 | if cos > 0.5: 103 | pred_type_i.append(type) 104 | if len(pred_type_i) == 0: 105 | pred_type_i.append('out-of-domain') 106 | pred_label_list.append(pred_type_i) 107 | co+=1 108 | if co % 1000 == 0: 109 | print('situation co:', co) 110 | readfile.close() 111 | print('gold_label_list:', gold_label_list[:200]) 112 | print('pred_label_list:', pred_label_list[:200]) 113 | 114 | assert len(pred_label_list) == len(gold_label_list) 115 | total_premise_size = len(gold_label_list) 116 | type_in_test = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 117 | type2col = { type:i for i, type in enumerate(type_in_test)} 118 | gold_array = np.zeros((total_premise_size,12), dtype=int) 119 | pred_array = np.zeros((total_premise_size,12), dtype=int) 120 | for i in range(total_premise_size): 121 | for type in pred_label_list[i]: 122 | pred_array[i,type2col.get(type)]=1 123 | for type in gold_label_list[i]: 124 | gold_array[i,type2col.get(type)]=1 125 | 126 | f1_list = [] 127 | size_list = [] 128 | for i in range(len(type_in_test)): 129 | f1=f1_score(gold_array[:,i], pred_array[:,i], pos_label=1, average='binary') 130 | co = sum(gold_array[:,i]) 131 | f1_list.append(f1) 132 | size_list.append(co) 133 | 134 | print('f1_list:',f1_list) 135 | print('size_list:', size_list) 136 | seen_types_group = [['search','infra','water','med', 'crimeviolence', 'regimechange'], 137 | ['evac','utils','shelter','food', 'terrorism']] 138 | for i in range(len(seen_types_group)): 139 | seen_types = seen_types_group[i] 140 | 141 | seen_f1_accu = 0.0 142 | seen_size = 0 143 | unseen_f1_accu = 0.0 144 | unseen_size = 0 145 | for i in range(len(type_in_test)): 146 | if type_in_test[i] in seen_types: 147 | seen_f1_accu+=f1_list[i]*size_list[i] 148 | seen_size+=size_list[i] 149 | else: 150 | unseen_f1_accu+=f1_list[i]*size_list[i] 151 | unseen_size+=size_list[i] 152 | print('seen:', seen_f1_accu/seen_size, 'unseen:', unseen_f1_accu/unseen_size) 153 | 154 | overall = sum([f1_list[i]*size_list[i] for i in range(len(f1_list))])/sum(size_list) 155 | print('overall:', overall) 156 | 157 | 158 | 159 | '''yahoo''' 160 | type_list = ['society & culture', 'science & mathematics', 'health', 'education & reference','computer & internet','sports sport','business & finance','entertainment & music','Family & relationships relationship','politics & government'] 161 | type_2_emb = [] 162 | for type in type_list: 163 | type_2_emb.append(sent_2_emb(type.split())) 164 | readfile = codecs.open('/export/home/Dataset/YahooClassification/yahoo_answers_csv/zero-shot-split/test.txt', 'r', 'utf-8') 165 | gold_label_list = [] 166 | pred_label_list = [] 167 | co = 0 168 | for line in readfile: 169 | parts = line.strip().split('\t') 170 | gold_label_list.append(parts[0]) 171 | text = parts[1].strip() 172 | max_cos = 0.0 173 | max_type = '' 174 | text_emb = sent_2_emb(text.split()) 175 | for i, type in enumerate(type_list): 176 | 177 | type_emb = type_2_emb[i] 178 | cos = 1.0-cosine(text_emb, type_emb) 179 | if cos > max_cos: 180 | max_cos = cos 181 | max_type = str(i) 182 | pred_label_list.append(max_type) 183 | co+=1 184 | if co % 1000 == 0: 185 | print('yahoo co:', co) 186 | readfile.close() 187 | print('gold_label_list:', gold_label_list[:200]) 188 | print('pred_label_list:', pred_label_list[:200]) 189 | 190 | # all_test_labels = list(set(gold_label_list)) 191 | # f1_score_per_type = f1_score(gold_label_list, pred_label_list, labels = all_test_labels, average=None) 192 | seen_types_group = [['0','2','4','6','8'],['1','3','5','7','9']] 193 | 194 | for i in range(len(seen_types_group)): 195 | seen_types = set(seen_types_group[i]) 196 | 197 | seen_hit = 0.0 198 | seen_size = 0 199 | unseen_hit = 0.0 200 | unseen_size = 0 201 | for i in range(len(gold_label_list)): 202 | if gold_label_list[i] in seen_types: 203 | 204 | seen_size+=1 205 | if gold_label_list[i] == pred_label_list[i]: 206 | seen_hit+=1 207 | else: 208 | unseen_size+=1 209 | if gold_label_list[i] == pred_label_list[i]: 210 | unseen_hit+=1 211 | print('seen:', seen_hit/seen_size, 'unseen:', unseen_hit/unseen_size) 212 | 213 | all_hit = 0 214 | for i in range(len(gold_label_list)): 215 | if gold_label_list[i] == pred_label_list[i]: 216 | all_hit+=1 217 | 218 | 219 | print('overall:', all_hit/len(gold_label_list)) 220 | 221 | if __name__ == '__main__': 222 | 223 | '''first load word2vec embeddings''' 224 | word2vec = {} 225 | 226 | print("==> loading 300d word2vec") 227 | 228 | f=open('/export/home/Dataset/word2vec_words_300d.txt', 'r')#glove.6B.300d.txt, word2vec_words_300d.txt, glove.840B.300d.txt 229 | co = 0 230 | for line in f: 231 | l = line.split() 232 | word2vec[l[0]] = list(map(float, l[1:])) 233 | co+=1 234 | if co % 50000 == 0: 235 | print('loading w2v size:', co) 236 | # if co % 10000 == 0: 237 | # break 238 | print("==> word2vec is loaded") 239 | baseline_w2v() 240 | -------------------------------------------------------------------------------- /src/load_data_ESA.py: -------------------------------------------------------------------------------- 1 | #conding=utf8 2 | 3 | import json 4 | import codecs 5 | # import nltk 6 | from collections import defaultdict 7 | # import numpy as np 8 | # from scipy.sparse import coo_matrix, csr_matrix, lil_matrix 9 | # import time 10 | # from scipy import sparse 11 | # from sklearn.metrics.pairwise import cosine_similarity 12 | ESA_word2id={} 13 | 14 | def transfer_wordlist_2_idlist_with_existing_word2id(token_list): 15 | ''' 16 | From such as ['i', 'love', 'Munich'] to idlist [23, 129, 34], if maxlen is 5, then pad two zero in the left side, becoming [0, 0, 23, 129, 34] 17 | ''' 18 | idlist=[] 19 | for word in token_list: 20 | id=ESA_word2id.get(word) 21 | if id is not None: # if word was not in the vocabulary 22 | idlist.append(id) 23 | return idlist 24 | 25 | def load_ESA_word2id(): 26 | global ESA_word2id 27 | route = '/export/home/Dataset/wikipedia/parsed_output/statistics_from_json/' 28 | with open(route+'word2id.json', 'r') as fp2: 29 | ESA_word2id = json.load(fp2) 30 | print('load ESA word2id succeed') 31 | 32 | def load_yahoo(): 33 | yahoo_path = '/export/home/Dataset/YahooClassification/yahoo_answers_csv/' 34 | files = ['zero-shot-split/test.txt'] #'train_tokenized.txt','zero-shot-split/test.txt' 35 | # word2id={} 36 | all_texts=[] 37 | # all_masks=[] 38 | all_labels=[] 39 | all_word2DF=defaultdict(int) 40 | max_sen_len=0 41 | for i in range(len(files)): 42 | print('loading file:', yahoo_path+files[i], '...') 43 | 44 | texts=[] 45 | # text_masks=[] 46 | labels=[] 47 | readfile=codecs.open(yahoo_path+files[i], 'r', 'utf-8') 48 | line_co=0 49 | for line in readfile: 50 | parts = line.strip().split('\t') 51 | if len(parts)==2: 52 | label_id = int(parts[0]) 53 | '''truncate can speed up''' 54 | text_wordlist = parts[1].strip().lower().split()[:100]#[:30] 55 | text_len=len(text_wordlist) 56 | if text_len > max_sen_len: 57 | max_sen_len=text_len 58 | text_idlist=transfer_wordlist_2_idlist_with_existing_word2id(text_wordlist) 59 | if len(text_idlist) >0: 60 | texts.append(text_idlist) 61 | labels.append(label_id) 62 | idset = set(text_idlist) 63 | for iddd in idset: 64 | all_word2DF[iddd]+=1 65 | else: 66 | continue 67 | 68 | line_co+=1 69 | if line_co%10000==0: 70 | print('line_co:', line_co) 71 | # if i==0 and line_co==train_size_limit: 72 | # break 73 | 74 | 75 | all_texts.append(texts) 76 | all_labels.append(labels) 77 | print('\t\t\t size:', len(labels), 'samples') 78 | print('load yahoo text succeed, max sen len:', max_sen_len) 79 | return all_texts, all_labels, all_word2DF 80 | 81 | 82 | def load_situation(): 83 | yahoo_path = '/export/home/Dataset/LORELEI/' 84 | files = ['zero-shot-split/test.txt'] #'train_tokenized.txt','zero-shot-split/test.txt' 85 | # word2id={} 86 | all_texts=[] 87 | # all_masks=[] 88 | all_labels=[] 89 | all_word2DF=defaultdict(int) 90 | max_sen_len=0 91 | for i in range(len(files)): 92 | print('loading file:', yahoo_path+files[i], '...') 93 | 94 | texts=[] 95 | # text_masks=[] 96 | labels=[] 97 | readfile=codecs.open(yahoo_path+files[i], 'r', 'utf-8') 98 | line_co=0 99 | for line in readfile: 100 | parts = line.strip().split('\t') 101 | if len(parts)==2: 102 | label_id = parts[0].strip().split() 103 | '''truncate can speed up''' 104 | text_wordlist = parts[1].strip().lower().split()[:30]#[:30] 105 | text_len=len(text_wordlist) 106 | if text_len > max_sen_len: 107 | max_sen_len=text_len 108 | text_idlist=transfer_wordlist_2_idlist_with_existing_word2id(text_wordlist) 109 | if len(text_idlist) >0: 110 | texts.append(text_idlist) 111 | labels.append(label_id) 112 | idset = set(text_idlist) 113 | for iddd in idset: 114 | all_word2DF[iddd]+=1 115 | else: 116 | continue 117 | 118 | line_co+=1 119 | if line_co%100==0: 120 | print('line_co:', line_co) 121 | # if i==0 and line_co==train_size_limit: 122 | # break 123 | 124 | readfile.close() 125 | all_texts.append(texts) 126 | all_labels.append(labels) 127 | print('\t\t\t size:', len(labels), 'samples') 128 | print('load situation text succeed, max sen len:', max_sen_len) 129 | return all_texts, all_labels, all_word2DF 130 | 131 | def load_emotion(): 132 | yahoo_path = '/export/home/Dataset/Stuttgart_Emotion/unify-emotion-datasets-master/' 133 | files = ['zero-shot-split/test.txt'] #'train_tokenized.txt','zero-shot-split/test.txt' 134 | # word2id={} 135 | all_texts=[] 136 | # all_masks=[] 137 | all_labels=[] 138 | all_word2DF=defaultdict(int) 139 | max_sen_len=0 140 | for i in range(len(files)): 141 | print('loading file:', yahoo_path+files[i], '...') 142 | 143 | texts=[] 144 | # text_masks=[] 145 | labels=[] 146 | readfile=codecs.open(yahoo_path+files[i], 'r', 'utf-8') 147 | line_co=0 148 | for line in readfile: 149 | parts = line.strip().split('\t') 150 | if len(parts)==3: 151 | label_id = parts[0].strip() 152 | '''truncate can speed up''' 153 | text_wordlist = parts[2].strip().lower().split()[:30]#[:30] 154 | '''we found use the tokenzied text make performance always zero''' 155 | # text_wordlist = [word for word in nltk.word_tokenize(parts[2].strip()) if word.isalpha()] 156 | # text_wordlist = text_wordlist[:30]#[:30] 157 | text_len=len(text_wordlist) 158 | if text_len > max_sen_len: 159 | max_sen_len=text_len 160 | text_idlist=transfer_wordlist_2_idlist_with_existing_word2id(text_wordlist) 161 | if len(text_idlist) >0: 162 | texts.append(text_idlist) 163 | labels.append(label_id) 164 | idset = set(text_idlist) 165 | for iddd in idset: 166 | all_word2DF[iddd]+=1 167 | else: 168 | continue 169 | 170 | line_co+=1 171 | if line_co%100==0: 172 | print('line_co:', line_co) 173 | # if i==0 and line_co==train_size_limit: 174 | # break 175 | 176 | readfile.close() 177 | all_texts.append(texts) 178 | all_labels.append(labels) 179 | print('\t\t\t size:', len(labels), 'samples') 180 | print('load situation text succeed, max sen len:', max_sen_len) 181 | return all_texts, all_labels, all_word2DF 182 | 183 | def load_labels(): 184 | yahoo_path = '/export/home/Dataset/YahooClassification/yahoo_answers_csv/' 185 | texts=[] 186 | # text_masks=[] 187 | 188 | readfile=codecs.open(yahoo_path+'classes.txt', 'r', 'utf-8') 189 | for line in readfile: 190 | wordlist = line.strip().replace('&', ' ').lower().split() 191 | 192 | text_idlist=transfer_wordlist_2_idlist_with_existing_word2id(wordlist) 193 | if len(text_idlist) >0: 194 | texts.append(text_idlist) 195 | 196 | print('load yahoo labelnames succeed, totally :', len(texts), 'label names') 197 | 198 | return texts 199 | 200 | def load_labels_situation(): 201 | # yahoo_path = '/export/home/Dataset/YahooClassification/yahoo_answers_csv/' 202 | predefined_types_enriched = ['search','evacuation','infrastructure','utilities utility','water','shelter','medical assistance','food', 'crime violence', 'terrorism', 'regime change'] 203 | origin_type_list = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange'] 204 | texts=[] 205 | # text_masks=[] 206 | 207 | # readfile=codecs.open(yahoo_path+'classes.txt', 'r', 'utf-8') 208 | for type in predefined_types_enriched: 209 | wordlist = type.split() 210 | 211 | text_idlist=transfer_wordlist_2_idlist_with_existing_word2id(wordlist) 212 | if len(text_idlist) >0: 213 | texts.append(text_idlist) 214 | assert len(texts) == len(predefined_types_enriched) 215 | 216 | print('load yahoo labelnames succeed, totally :', len(texts), 'label names') 217 | 218 | return texts 219 | 220 | def load_labels_emotion(): 221 | # yahoo_path = '/export/home/Dataset/YahooClassification/yahoo_answers_csv/' 222 | type_list = ['sadness', 'joy', 'anger', 'disgust', 'fear', 'surprise', 'shame', 'guilt', 'love'] 223 | texts=[] 224 | # text_masks=[] 225 | 226 | # readfile=codecs.open(yahoo_path+'classes.txt', 'r', 'utf-8') 227 | for type in type_list: 228 | wordlist = type.split() 229 | 230 | text_idlist=transfer_wordlist_2_idlist_with_existing_word2id(wordlist) 231 | if len(text_idlist) >0: 232 | texts.append(text_idlist) 233 | assert len(texts) == len(type_list) 234 | 235 | print('load yahoo labelnames succeed, totally :', len(texts), 'label names') 236 | 237 | return texts 238 | 239 | def load_yahoo_and_labelnames(): 240 | load_ESA_word2id() 241 | all_texts, all_labels, all_word2DF = load_yahoo() 242 | labelnames = load_labels() 243 | return all_texts, all_labels, all_word2DF, labelnames 244 | 245 | def load_situation_and_labelnames(): 246 | load_ESA_word2id() 247 | all_texts, all_labels, all_word2DF = load_situation() 248 | # print('load all_labels:', all_labels[0][:10]) 249 | labelnames = load_labels_situation() 250 | return all_texts, all_labels, all_word2DF, labelnames 251 | 252 | def load_emotion_and_labelnames(): 253 | load_ESA_word2id() 254 | all_texts, all_labels, all_word2DF = load_emotion() 255 | # print('load all_labels:', all_labels[0][:10]) 256 | labelnames = load_labels_emotion() 257 | return all_texts, all_labels, all_word2DF, labelnames 258 | -------------------------------------------------------------------------------- /src/jsonlines.py: -------------------------------------------------------------------------------- 1 | """ 2 | jsonlines implementation 3 | """ 4 | 5 | import numbers 6 | import io 7 | import json 8 | 9 | import six 10 | 11 | 12 | TYPE_MAPPING = { 13 | dict: dict, 14 | list: list, 15 | str: six.text_type, 16 | int: six.integer_types, 17 | float: float, 18 | numbers.Number: numbers.Number, 19 | bool: bool, 20 | } 21 | 22 | 23 | class Error(Exception): 24 | """Base error class.""" 25 | pass 26 | 27 | 28 | class InvalidLineError(Error, ValueError): 29 | """ 30 | Error raised when an invalid line is encountered. 31 | 32 | This happens when the line does not contain valid JSON, or if a 33 | specific data type has been requested, and the line contained a 34 | different data type. 35 | 36 | The original line itself is stored on the exception instance as the 37 | ``.line`` attribute, and the line number as ``.lineno``. 38 | 39 | This class subclasses both ``jsonlines.Error`` and the built-in 40 | ``ValueError``. 41 | """ 42 | #: The invalid line 43 | line = None 44 | 45 | #: The line number 46 | lineno = None 47 | 48 | def __init__(self, msg, line, lineno): 49 | msg = "{} (line {})".format(msg, lineno) 50 | self.line = line.rstrip() 51 | self.lineno = lineno 52 | super(InvalidLineError, self).__init__(msg) 53 | 54 | 55 | class ReaderWriterBase(object): 56 | """ 57 | Base class with shared behaviour for both the reader and writer. 58 | """ 59 | def close(self): 60 | """ 61 | Close this reader/writer. 62 | 63 | This closes the underlying file if that file has been opened by 64 | this reader/writer. When an already opened file-like object was 65 | provided, the caller is responsible for closing it. 66 | """ 67 | if self._closed: 68 | return 69 | self._closed = True 70 | if self._should_close_fp: 71 | self._fp.close() 72 | 73 | def __repr__(self): 74 | name = getattr(self._fp, 'name', None) 75 | if name: 76 | wrapping = repr(name) 77 | else: 78 | wrapping = '<{} at 0x{:x}>'.format( 79 | type(self._fp).__name__, 80 | id(self._fp)) 81 | return ''.format( 82 | type(self).__name__, id(self), wrapping) 83 | 84 | def __enter__(self): 85 | return self 86 | 87 | def __exit__(self, *exc_info): 88 | self.close() 89 | return False 90 | 91 | 92 | class Reader(ReaderWriterBase): 93 | """ 94 | Reader for the jsonlines format. 95 | 96 | The first argument must be an iterable that yields JSON encoded 97 | strings. Usually this will be a readable file-like object, such as 98 | an open file or an ``io.TextIO`` instance, but it can also be 99 | something else as long as it yields strings when iterated over. 100 | 101 | The `loads` argument can be used to replace the standard json 102 | decoder. If specified, it must be a callable that accepts a 103 | (unicode) string and returns the decoded object. 104 | 105 | Instances are iterable and can be used as a context manager. 106 | 107 | :param file-like iterable: iterable yielding lines as strings 108 | :param callable loads: custom json decoder callable 109 | """ 110 | def __init__(self, iterable, loads=None): 111 | self._fp = iterable 112 | self._should_close_fp = False 113 | self._closed = False 114 | if loads is None: 115 | loads = json.loads 116 | self._loads = loads 117 | self._line_iter = enumerate(iterable, 1) 118 | 119 | def read(self, type=None, allow_none=False, skip_empty=False): 120 | """ 121 | Read and decode a line. 122 | 123 | The optional `type` argument specifies the expected data type. 124 | Supported types are ``dict``, ``list``, ``str``, ``int``, 125 | ``float``, ``numbers.Number`` (accepts both integers and 126 | floats), and ``bool``. When specified, non-conforming lines 127 | result in :py:exc:`InvalidLineError`. 128 | 129 | By default, input lines containing ``null`` (in JSON) are 130 | considered invalid, and will cause :py:exc:`InvalidLineError`. 131 | The `allow_none` argument can be used to change this behaviour, 132 | in which case ``None`` will be returned instead. 133 | 134 | If `skip_empty` is set to ``True``, empty lines and lines 135 | containing only whitespace are silently skipped. 136 | """ 137 | if self._closed: 138 | raise RuntimeError('reader is closed') 139 | if type is not None and type not in TYPE_MAPPING: 140 | raise ValueError("invalid type specified") 141 | 142 | try: 143 | lineno, line = next(self._line_iter) 144 | while skip_empty and not line.rstrip(): 145 | lineno, line = next(self._line_iter) 146 | except StopIteration: 147 | six.raise_from(EOFError, None) 148 | 149 | if isinstance(line, six.binary_type): 150 | try: 151 | line = line.decode('utf-8') 152 | except UnicodeDecodeError as orig_exc: 153 | exc = InvalidLineError( 154 | "line is not valid utf-8: {}".format(orig_exc), 155 | line, lineno) 156 | six.raise_from(exc, orig_exc) 157 | 158 | try: 159 | value = self._loads(line) 160 | except ValueError as orig_exc: 161 | exc = InvalidLineError( 162 | "line contains invalid json: {}".format(orig_exc), 163 | line, lineno) 164 | six.raise_from(exc, orig_exc) 165 | 166 | if value is None: 167 | if allow_none: 168 | return None 169 | raise InvalidLineError( 170 | "line contains null value", line, lineno) 171 | 172 | if type is not None: 173 | valid = isinstance(value, TYPE_MAPPING[type]) 174 | if type in (int, numbers.Number): 175 | valid = valid and not isinstance(value, bool) 176 | if not valid: 177 | raise InvalidLineError( 178 | "line does not match requested type", line, lineno) 179 | 180 | return value 181 | 182 | def iter(self, type=None, allow_none=False, skip_empty=False, 183 | skip_invalid=False): 184 | """ 185 | Iterate over all lines. 186 | 187 | This is the iterator equivalent to repeatedly calling 188 | :py:meth:`~Reader.read()`. If no arguments are specified, this 189 | is the same as directly iterating over this :py:class:`Reader` 190 | instance. 191 | 192 | When `skip_invalid` is set to ``True``, invalid lines will be 193 | silently ignored. 194 | 195 | See :py:meth:`~Reader.read()` for a description of the other 196 | arguments. 197 | """ 198 | try: 199 | while True: 200 | try: 201 | yield self.read( 202 | type=type, 203 | allow_none=allow_none, 204 | skip_empty=skip_empty) 205 | except InvalidLineError: 206 | if not skip_invalid: 207 | raise 208 | except EOFError: 209 | pass 210 | 211 | def __iter__(self): 212 | """ 213 | See :py:meth:`~Reader.iter()`. 214 | """ 215 | return self.iter() 216 | 217 | 218 | class Writer(ReaderWriterBase): 219 | """ 220 | Writer for the jsonlines format. 221 | 222 | The `fp` argument must be a file-like object with a ``.write()`` 223 | method accepting either text (unicode) or bytes. 224 | 225 | The `compact` argument can be used to to produce smaller output. 226 | 227 | The `sort_keys` argument can be used to sort keys in json objects, 228 | and will produce deterministic output. 229 | 230 | For more control, provide a a custom encoder callable using the 231 | `dumps` argument. The callable must produce (unicode) string output. 232 | If specified, the `compact` and `sort` arguments will be ignored. 233 | 234 | When the `flush` argument is set to ``True``, the writer will call 235 | ``fp.flush()`` after each written line. 236 | 237 | Instances can be used as a context manager. 238 | 239 | :param file-like fp: writable file-like object 240 | :param bool compact: whether to use a compact output format 241 | :param bool sort_keys: whether to sort object keys 242 | :param callable dumps: custom encoder callable 243 | :param bool flush: whether to flush the file-like object after 244 | writing each line 245 | """ 246 | def __init__( 247 | self, fp, compact=False, sort_keys=False, dumps=None, flush=False): 248 | self._closed = False 249 | try: 250 | fp.write(u'') 251 | self._fp_is_binary = False 252 | except TypeError: 253 | self._fp_is_binary = True 254 | if dumps is None: 255 | encoder_kwargs = dict(ensure_ascii=False, sort_keys=sort_keys) 256 | if compact: 257 | encoder_kwargs.update(separators=(',', ':')) 258 | dumps = json.JSONEncoder(**encoder_kwargs).encode 259 | self._fp = fp 260 | self._should_close_fp = False 261 | self._dumps = dumps 262 | self._flush = flush 263 | 264 | def write(self, obj): 265 | """ 266 | Encode and write a single object. 267 | 268 | :param obj: the object to encode and write 269 | """ 270 | if self._closed: 271 | raise RuntimeError('writer is closed') 272 | line = self._dumps(obj) 273 | # On Python 2, the JSON module has the nasty habit of returning 274 | # either a byte string or unicode string, depending on whether 275 | # the serialised structure can be encoded using ASCII only, so 276 | # this means this code needs to handle all combinations. 277 | if self._fp_is_binary: 278 | if not isinstance(line, six.binary_type): 279 | line = line.encode('utf-8') 280 | self._fp.write(line) 281 | self._fp.write(b'\n') 282 | else: 283 | if not isinstance(line, six.text_type): 284 | line = line.decode('ascii') # For Python 2. 285 | self._fp.write(line) 286 | self._fp.write(u'\n') 287 | if self._flush: 288 | self._fp.flush() 289 | 290 | def write_all(self, iterable): 291 | """ 292 | Encode and write multiple objects. 293 | 294 | :param iterable: an iterable of objects 295 | """ 296 | for obj in iterable: 297 | self.write(obj) 298 | 299 | 300 | def open(name, mode='r', **kwargs): 301 | """ 302 | Open a jsonlines file for reading or writing. 303 | 304 | This is a convenience function that opens a file, and wraps it in 305 | either a :py:class:`Reader` or :py:class:`Writer` instance, 306 | depending on the specified `mode`. 307 | 308 | Any additional keyword arguments will be passed on to the reader and 309 | writer: see their documentation for available options. 310 | 311 | The resulting reader or writer must be closed after use by the 312 | caller, which will also close the opened file. This can be done by 313 | calling ``.close()``, but the easiest way to ensure proper resource 314 | finalisation is to use a ``with`` block (context manager), e.g. 315 | 316 | :: 317 | 318 | with jsonlines.open('out.jsonl', mode='w') as writer: 319 | writer.write(...) 320 | 321 | :param file-like fp: name of the file to open 322 | :param str mode: whether to open the file for reading (``r``), 323 | writing (``w``) or appending (``a``). 324 | :param \*\*kwargs: additional arguments, forwarded to the reader or writer 325 | """ 326 | if mode not in {'r', 'w', 'a'}: 327 | raise ValueError("'mode' must be either 'r', 'w', or 'a'") 328 | fp = io.open(name, mode=mode + 't', encoding='utf-8') 329 | if mode == 'r': 330 | instance = Reader(fp, **kwargs) 331 | else: 332 | instance = Writer(fp, **kwargs) 333 | instance._should_close_fp = True 334 | return instance 335 | -------------------------------------------------------------------------------- /src/ESA.py: -------------------------------------------------------------------------------- 1 | #conding=utf8 2 | import os 3 | import json 4 | import codecs 5 | import nltk 6 | from collections import defaultdict, Counter 7 | import numpy as np 8 | from scipy.sparse import coo_matrix, csr_matrix, lil_matrix 9 | import time 10 | from scipy import sparse 11 | from sklearn.metrics.pairwise import cosine_similarity 12 | '''seven global variables''' 13 | title2id={} #5903490+1 14 | title_size = 0 15 | word2id={} #6161731+1 16 | word_size = 0 17 | # WordTitle2Count= lil_matrix((298099,40000))#(6113524, 5828563)) 18 | WordTitle2Count= lil_matrix((6161731, 5903486))#(6113524, 5828563)) 19 | Word2TileCount=defaultdict(int) 20 | fileset=set() 21 | 22 | def scan_all_json_files(rootDir): 23 | global fileset 24 | for lists in os.listdir(rootDir): 25 | path = os.path.join(rootDir, lists) 26 | if os.path.isdir(path): 27 | scan_all_json_files(path) 28 | else: # is a file 29 | fileset.add(path) 30 | 31 | def load_json(): 32 | global title2id 33 | global word2id 34 | global title_size 35 | global word_size 36 | global WordTitle2Count 37 | global Word2TileCount 38 | global fileset 39 | 40 | json_file_size = 0 41 | wiki_file_size = 0 42 | for json_input in fileset: 43 | json_file_size+=1 44 | print('\t\t\t', json_input) 45 | with codecs.open(json_input, 'r', 'utf-8') as f: 46 | for line in f: 47 | try: 48 | line_dic = json.loads(line) 49 | except ValueError: 50 | continue 51 | title = line_dic.get('title') 52 | title_id = title2id.get(title) 53 | if title_id is None: # if word was not in the vocabulary 54 | title_id=title_size # id of true words starts from 1, leaving 0 to "pad id" 55 | title2id[title]=title_id 56 | title_size+=1 57 | 58 | content = line_dic.get('text') 59 | '''this tokenizer step should be time-consuming''' 60 | tokenized_text = nltk.word_tokenize(content) 61 | word_id_set = set() 62 | for word in tokenized_text: 63 | if word.isalpha(): 64 | word_id = word2id.get(word) 65 | if word_id is None: 66 | word_id = word_size 67 | word2id[word]=word_id 68 | word_size+=1 69 | WordTitle2Count[str(word_id)+':'+str(title_id)]+=1 70 | word_id_set.add(word_id) 71 | for each_word_id in word_id_set: 72 | Word2TileCount[str(each_word_id)]+=1 #this word meets a new title 73 | wiki_file_size+=1 74 | print(json_file_size, '&',wiki_file_size, '...over') 75 | # if wiki_file_size ==4: 76 | # return 77 | 78 | def load_tokenized_json(): 79 | ''' 80 | we first tokenzie tool output json files into a single json file "tokenized_wiki.txt" 81 | now we do statistics on it 82 | ''' 83 | start_time = time.time() 84 | global title2id 85 | global word2id 86 | global title_size 87 | global word_size 88 | global WordTitle2Count 89 | global Word2TileCount 90 | # global fileset 91 | 92 | route = '/export/home/Dataset/wikipedia/parsed_output/tokenized_wiki/' 93 | wiki_file_size = 0 94 | with codecs.open(route+'tokenized_wiki.txt', 'r', 'utf-8') as f: 95 | for line in f: 96 | try: 97 | line_dic = json.loads(line) 98 | except ValueError: 99 | continue 100 | title = line_dic.get('title') 101 | title_id = title2id.get(title) 102 | if title_id is None: # if word was not in the vocabulary 103 | title_id=title_size # id of true words starts from 1, leaving 0 to "pad id" 104 | title2id[title]=title_id 105 | title_size+=1 106 | 107 | # content = line_dic.get('text') 108 | tokenized_text = line_dic.get('text').split() 109 | word2tf=Counter(tokenized_text) 110 | for word, tf in word2tf.items(): 111 | word_id = word2id.get(word) 112 | if word_id is None: 113 | word_id = word_size 114 | word2id[word]=word_id 115 | word_size+=1 116 | WordTitle2Count[word_id, title_id]=tf 117 | Word2TileCount[word_id]+=1 #this word meets a new title 118 | wiki_file_size+=1 119 | if wiki_file_size%10000==0: 120 | print(wiki_file_size, '...over') 121 | # if wiki_file_size ==4000: 122 | # break 123 | f.close() 124 | print('load_tokenized_json over.....words:', word_size, ' title size:', title_size) 125 | WordTitle2Count = divide_sparseMatrix_by_list_row_wise(WordTitle2Count, Word2TileCount.values()) 126 | print('divide_sparseMatrix_by_list_row_wise....over') 127 | spend_time = (time.time()-start_time)/60.0 128 | print(spend_time, 'mins') 129 | 130 | def store_ESA(): 131 | start_time = time.time() 132 | global title2id 133 | global word2id 134 | global WordTitle2Count 135 | global Word2TileCount 136 | route = '/export/home/Dataset/wikipedia/parsed_output/statistics_from_json/' 137 | with open(route+'title2id.json', 'w') as fp1: 138 | json.dump(title2id, fp1) 139 | with open(route+'word2id.json', 'w') as fp2: 140 | json.dump(word2id, fp2) 141 | # with open(route+'WordTitle2Count.json', 'w') as f3: 142 | # json.dump(WordTitle2Count, f3) 143 | '''note that WordTitle2Count is always a sparse matrix, not a dictionary''' 144 | sparse.save_npz(route+"ESA_Sparse_v1.npz", WordTitle2Count) 145 | print('ESA sparse matrix stored over, congrats!!!') 146 | with open(route+'Word2TileCount.json', 'w') as f4: 147 | json.dump(Word2TileCount, f4) 148 | print('store ESA over') 149 | spend_time = (time.time()-start_time)/60.0 150 | print(spend_time, 'mins') 151 | 152 | 153 | def tokenize_filter_tokens(): 154 | global fileset 155 | route = '/export/home/Dataset/wikipedia/parsed_output/tokenized_wiki/' 156 | writefile = codecs.open(route+'tokenized_wiki.txt' ,'a+', 'utf-8') 157 | json_file_size = 0 158 | wiki_file_size = 0 159 | for json_input in fileset: 160 | json_file_size+=1 161 | print('\t\t\t', json_input) 162 | with codecs.open(json_input, 'r', 'utf-8') as f: 163 | for line in f: 164 | try: 165 | line_dic = json.loads(line) 166 | except ValueError: 167 | continue 168 | # title = line_dic.get('title') 169 | content = line_dic.get('text') 170 | tokenized_text = nltk.word_tokenize(content) 171 | new_text = [] 172 | for word in tokenized_text: 173 | if word.isalpha(): 174 | new_text.append(word) 175 | line_dic['text']=' '.join(new_text) 176 | json.dump(line_dic, writefile) 177 | writefile.write('\n') 178 | wiki_file_size+=1 179 | print(json_file_size, '&',wiki_file_size, '...over') 180 | print('tokenize over') 181 | writefile.close() 182 | 183 | 184 | def reformat_into_expected_ESA(): 185 | ''' 186 | super super slow. do not use it 187 | ''' 188 | start_time = time.time() 189 | global Word2TileCount 190 | global WordTitle2Count 191 | route = '/home/wyin3/Datasets/Wikipedia20190320/parsed_output/statistics_from_json/' 192 | rows=[] 193 | cols=[] 194 | values =[] 195 | size = 0 196 | print('WordTitle2Count:', WordTitle2Count) 197 | for key, value in WordTitle2Count.items(): #"0:0": 8, "1:0": 24, 198 | key_parts = key.split(':') 199 | word_id_str = key_parts[0] 200 | concept_id_str = key_parts[1] 201 | word_df =Word2TileCount.get(word_id_str) 202 | rows.append(int(word_id_str)) 203 | cols.append(int(concept_id_str)) 204 | values.append(value/word_df) 205 | size+=1 206 | if size%10000000 ==0: 207 | print('reformat entry sizes:', size) 208 | WordTitle2Count=None # release the memory of big dictionary 209 | print('reformat entry over, building sparse matrix...') 210 | sparse_matrix = csr_matrix((values, (rows, cols))) 211 | non_zero=sparse_matrix.nonzero() 212 | row_array = list(non_zero[0]) 213 | col_array = non_zero[1] 214 | print('sparse matrix build succeed, start store...') 215 | writefile = codecs.open(route+'ESA.v1.json', 'w', 'utf-8') 216 | prior_row = -1 217 | finish_size=0 218 | for id, row_id in enumerate(row_array): 219 | if row_id !=prior_row: 220 | if row_id>0: 221 | # print(prior_row.dtype) 222 | json.dump({str(prior_row):new_list}, writefile) 223 | writefile.write('\n') 224 | new_list=None 225 | finish_size+=1 226 | if finish_size %1000: 227 | print('finish store rows ', finish_size) 228 | # else: 229 | new_list=[] 230 | new_list.append(str(col_array[id])+':'+str(sparse_matrix[row_id,col_array[id]])) 231 | prior_row=row_id 232 | else: 233 | new_list.append(str(col_array[id])+':'+str(sparse_matrix[row_id,col_array[id]])) 234 | 235 | json.dump({str(prior_row):new_list}, writefile) # the last row 236 | writefile.close() 237 | print('ESA format over') 238 | spend_time = (time.time()-start_time)/60.0 239 | print(spend_time, 'mins') 240 | 241 | def reformat_into_sparse_matrix_store(): 242 | start_time = time.time() 243 | global Word2TileCount 244 | global WordTitle2Count 245 | route = '/export/home/Dataset/wikipedia/parsed_output/statistics_from_json/' 246 | rows=[] 247 | cols=[] 248 | values =[] 249 | size = 0 250 | for key, value in WordTitle2Count.items(): #"0:0": 8, "1:0": 24, 251 | key_parts = key.split(':') 252 | word_id_str = key_parts[0] 253 | concept_id_str = key_parts[1] 254 | word_df =Word2TileCount.get(word_id_str) 255 | rows.append(int(word_id_str)) 256 | cols.append(int(concept_id_str)) 257 | values.append(value/word_df) 258 | size+=1 259 | if size%10000000 ==0: 260 | print('reformat entry sizes:', size) 261 | WordTitle2Count=None # release the memory of big dictionary 262 | print('reformat entry over, building sparse matrix...') 263 | sparse_matrix = csr_matrix((values, (rows, cols))) 264 | print('sparse matrix build succeed, start store...') 265 | sparse.save_npz(route+"ESA_Sparse_v1.npz", sparse_matrix) 266 | print('ESA sparse matrix stored over, congrats!!!') 267 | spend_time = (time.time()-start_time)/60.0 268 | print(spend_time, 'mins') 269 | 270 | def divide_sparseMatrix_by_list_row_wise(mat, lis): 271 | # C=lil_matrix([[2,4,6], [5,10,15]]) 272 | # print(C) 273 | D=np.asarray(list(lis)) 274 | r,c = mat.nonzero() 275 | val = np.repeat(1.0/D, mat.getnnz(axis=1)) 276 | rD_sp = csr_matrix((val, (r,c)), shape=(mat.shape)) 277 | out = mat.multiply(rD_sp) 278 | return out 279 | 280 | def multiply_sparseMatrix_by_list_row_wise(mat, lis): 281 | # C=lil_matrix([[2,4,6], [5,10,15]]) 282 | # print(C) 283 | D=np.asarray(list(lis)) 284 | r,c = mat.nonzero() 285 | val = np.repeat(D, mat.getnnz(axis=1)) 286 | rD_sp = csr_matrix((val, (r,c)), shape=(mat.shape)) 287 | out = mat.multiply(rD_sp) 288 | return out 289 | 290 | def load_sparse_matrix_4_cos(row1, row2): 291 | print('loading sparse matrix for cosine computation...') 292 | sparse_matrix = sparse.load_npz('/home/wyin3/Datasets/Wikipedia20190320/parsed_output/statistics_from_json/ESA_Sparse_v1.npz') 293 | print('cos: ', cosine_similarity(sparse_matrix.getrow(row1), sparse_matrix.getrow(row2))) 294 | 295 | def load_ESA_sparse_matrix(): 296 | # print('loading sparse matrix for cosine computation...') 297 | sparse_matrix = sparse.load_npz('/export/home/Dataset/wikipedia/parsed_output/statistics_from_json/ESA_Sparse_v1.npz') 298 | print('load ESA sparse matrix succeed') 299 | return sparse_matrix 300 | 301 | def crs_matrix_play(): 302 | # mat = lil_matrix((3, 5)) 303 | # mat[0,0]+=1 304 | # print(mat) 305 | # simi = cosine_similarity(mat.getrow(0), mat.getrow(0)) 306 | # print(simi) 307 | # C=lil_matrix([[2,4,6], [5,10,15]]) 308 | # print(C) 309 | # D=[2,5] 310 | # C=divide_sparseMatrix_by_list_row_wise(C,D) 311 | # print(C) 312 | 313 | C=lil_matrix([[2,4,6], [5,10,15], [1,10,9]]) 314 | sub=C[[0,2],:] 315 | print(C) 316 | print('haha',sub) 317 | print(sub.sum(axis=0)) 318 | 319 | 320 | def get_wordsize_pagesize(): 321 | ''' 322 | we first tokenzie tool output json files into a single json file "tokenized_wiki.txt" 323 | now we do statistics on it 324 | ''' 325 | start_time = time.time() 326 | global title2id 327 | global word2id 328 | global title_size 329 | global word_size 330 | # global WordTitle2Count 331 | # global Word2TileCount 332 | # global fileset 333 | 334 | route = '/export/home/Dataset/wikipedia/parsed_output/tokenized_wiki/' 335 | wiki_file_size = 0 336 | with codecs.open(route+'tokenized_wiki.txt', 'r', 'utf-8') as f: 337 | for line in f: 338 | try: 339 | line_dic = json.loads(line) 340 | except ValueError: 341 | continue 342 | title = line_dic.get('title') 343 | title_id = title2id.get(title) 344 | if title_id is None: # if word was not in the vocabulary 345 | title_id=title_size # id of true words starts from 1, leaving 0 to "pad id" 346 | title2id[title]=title_id 347 | title_size+=1 348 | 349 | # content = line_dic.get('text') 350 | tokenized_text = line_dic.get('text').split() 351 | word2tf=Counter(tokenized_text) 352 | for word, tf in word2tf.items(): 353 | word_id = word2id.get(word) 354 | if word_id is None: 355 | word_id = word_size 356 | word2id[word]=word_id 357 | word_size+=1 358 | # WordTitle2Count[word_id, title_id]=tf 359 | # Word2TileCount[word_id]+=1 #this word meets a new title 360 | wiki_file_size+=1 361 | if wiki_file_size%1000==0: 362 | print(wiki_file_size, '...over') 363 | if wiki_file_size ==4000: 364 | break 365 | f.close() 366 | print('word size:', word_size, ' title size:', title_size) 367 | 368 | if __name__ == '__main__': 369 | # scan_all_json_files('/export/home/Dataset/wikipedia/parsed_output/json/') 370 | '''note that file size 13354 does not mean wiki pages; each file contains multiple wiki pages''' 371 | # print('fileset size:', len(fileset)) #fileset size: 13354 372 | # load_json() #time-consuming, not useful 373 | # store_ESA() 374 | '''to save time, we tokenize wiki dump and save into files for future loading''' 375 | # tokenize_filter_tokens() 376 | '''word size 6161731; page size: 5903486''' 377 | # get_wordsize_pagesize() 378 | load_tokenized_json() 379 | '''store all the statistic dictionary into files for future loading''' 380 | store_ESA() 381 | # load_sparse_matrix_4_cos(1,2) 382 | 383 | # reformat_into_sparse_matrix_store() 384 | -------------------------------------------------------------------------------- /src/preprocess_situation.py: -------------------------------------------------------------------------------- 1 | 2 | import codecs 3 | from collections import defaultdict 4 | from sklearn.metrics import f1_score 5 | import random 6 | import numpy as np 7 | 8 | path = '/export/home/Dataset/LORELEI/' 9 | 10 | def combine_all_available_labeled_datasets(): 11 | 12 | files = [ 13 | 'full_BBN_multi.txt', 14 | 'il9_sf_gold.txt', #repeat 15 | 'il10_sf_gold.txt', # repeat 16 | 'il5_translated_seg_level_as_training_all_fields.txt', 17 | 'il3_sf_gold.txt', 18 | 'Mandarin_sf_gold.txt' #repeat 19 | ] 20 | writefile = codecs.open(path+'sf_all_labeled_data_multilabel.txt', 'w', 'utf-8') 21 | all_size = 0 22 | label2co = defaultdict(int) 23 | for fil in files: 24 | print('loading file:', path+fil, '...') 25 | size = 0 26 | readfile=codecs.open(path+fil, 'r', 'utf-8') 27 | stored_lines = set() 28 | for line in readfile: 29 | '''some labeled files have repeated lines''' 30 | if line.strip() not in stored_lines: 31 | parts=line.strip().split('\t') #lowercase all tokens, as we guess this is not important for sentiment task 32 | label_list = parts[1].strip().split() 33 | for label in set(label_list): 34 | label2co[label]+=1 35 | text=parts[2].strip() 36 | writefile.write(' '.join(label_list)+'\t'+text+'\n') 37 | size+=1 38 | all_size+=1 39 | stored_lines.add(line.strip()) 40 | readfile.close() 41 | print('size:', size) 42 | writefile.close() 43 | print('all_size:', all_size, label2co) 44 | 45 | 46 | # def split_all_labeleddata_into_subdata_per_label(): 47 | # readfile = codecs.open(path+'sf_all_labeled_data_multilabel.txt', 'r', 'utf-8') 48 | # label_list = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 49 | # writefile_list = [] 50 | # for label in label_list: 51 | # writefile = codecs.open(path+'data_per_label/'+label+'.txt', 'w', 'utf-8') 52 | # writefile_list.append(writefile) 53 | # for line in readfile: 54 | # parts=line.strip().split('\t') 55 | # label_list_instance = parts[0].strip().split() 56 | # for label in label_list_instance: 57 | # writefile_exit = writefile_list[label_list.index(label)] 58 | # writefile_exit.write(parts[1].strip()+'\n') 59 | # 60 | # for writefile in writefile_list: 61 | # writefile.close() 62 | # readfile.close() 63 | 64 | 65 | 66 | def build_zeroshot_test_dev_train_set(): 67 | 68 | # test_label_size_max = {'search':80, 'evac':70, 'infra':120, 'utils':100,'water':120,'shelter':175, 69 | # 'med':250,'food':190,'regimechange':30,'terrorism':70,'crimeviolence':250,'out-of-domain':400} 70 | # dev_label_size_max = {'search':50, 'evac':30, 'infra':50, 'utils':50,'water':50,'shelter':75, 71 | # 'med':100,'food':80,'regimechange':15,'terrorism':40,'crimeviolence':100,'out-of-domain':200} 72 | 73 | label_list = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 74 | 75 | test_store_size = defaultdict(int) 76 | dev_store_size = defaultdict(int) 77 | write_test = codecs.open(path+'zero-shot-split/test.txt', 'w', 'utf-8') 78 | write_dev = codecs.open(path+'zero-shot-split/dev.txt', 'w', 'utf-8') 79 | write_train_v0 = codecs.open(path+'zero-shot-split/train_pu_half_v0.txt', 'w', 'utf-8') 80 | seen_types_v0 = ['search','infra','water','med','crimeviolence', 'regimechange'] 81 | write_train_v1 = codecs.open(path+'zero-shot-split/train_pu_half_v1.txt', 'w', 'utf-8') 82 | seen_types_v1 = ['evac','utils', 'shelter','food', 'terrorism'] 83 | readfile = codecs.open(path+'sf_all_labeled_data_multilabel.txt', 'r', 'utf-8') 84 | for line in readfile: 85 | parts = line.strip().split('\t') 86 | type_set = set(parts[0].strip().split()) 87 | '''test and dev set build''' 88 | rand_value = random.uniform(0, 1) 89 | if rand_value > 2.0/5.0: 90 | write_test.write(line.strip()+'\n') 91 | else: 92 | write_dev.write(line.strip()+'\n') 93 | 94 | '''train set build''' 95 | remain_type_v0 = type_set & set(seen_types_v0) 96 | if len(remain_type_v0) > 0: 97 | write_train_v0.write(' '.join(list(remain_type_v0))+'\t'+parts[1].strip()+'\n') 98 | remain_type_v1 = type_set & set(seen_types_v1) 99 | if len(remain_type_v1) > 0: 100 | write_train_v1.write(' '.join(list(remain_type_v1))+'\t'+parts[1].strip()+'\n') 101 | write_test.close() 102 | write_dev.close() 103 | write_train_v0.close() 104 | write_train_v1.close() 105 | print('zero-shot data split over') 106 | 107 | def statistics(): 108 | filename=[path+'zero-shot-split/test.txt', path+'zero-shot-split/dev.txt', 109 | path+'zero-shot-split/train_pu_half_v0.txt',path+'zero-shot-split/train_pu_half_v1.txt'] 110 | for fil in filename: 111 | type2size= defaultdict(int) 112 | readfile=codecs.open(fil, 'r', 'utf-8') 113 | for line in readfile: 114 | type_list = line.strip().split('\t')[0].split() 115 | for type in type_list: 116 | type2size[type]+=1 117 | readfile.close() 118 | print('type2size:', type2size) 119 | 120 | 121 | # def build_zeroshot_train_set(): 122 | # readfile_remain = codecs.open(path+'unified-dataset-wo-devandtest.txt', 'r', 'utf-8') 123 | # '''we do not put None type in train''' 124 | # label_list = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange'] 125 | # writefile_PU_half_0 = codecs.open(path+'zero-shot-split/train_pu_half_v0.txt', 'w', 'utf-8') 126 | # writefile_PU_half_1 = codecs.open(path+'zero-shot-split/train_pu_half_v1.txt', 'w', 'utf-8') 127 | # 128 | # for line in readfile_remain: 129 | # parts = line.strip().split('\t') 130 | # type = parts[0] 131 | # if type in set(label_list): 132 | # if label_list.index(type) %2==0: 133 | # writefile_PU_half_0.write(line.strip()+'\n') 134 | # else: 135 | # writefile_PU_half_1.write(line.strip()+'\n') 136 | # writefile_PU_half_0.close() 137 | # writefile_PU_half_1.close() 138 | # print('PU half over') 139 | # '''PU_one''' 140 | # for i in range(len(label_list)): 141 | # readfile=codecs.open(path+'unified-dataset-wo-devandtest.txt', 'r', 'utf-8') 142 | # writefile_PU_one = codecs.open(path+'zero-shot-split/train_pu_one_'+'wo_'+str(i)+'.txt', 'w', 'utf-8') 143 | # line_co=0 144 | # for line in readfile: 145 | # parts = line.strip().split('\t') 146 | # type = parts[0] 147 | # if type in set(label_list): 148 | # label_id = label_list.index(type) 149 | # if label_id != i: 150 | # writefile_PU_one.write(line.strip()+'\n') 151 | # line_co+=1 152 | # writefile_PU_one.close() 153 | # readfile.close() 154 | # print('write size:', line_co) 155 | # print('build train over') 156 | 157 | 158 | 159 | def evaluate_situation_zeroshot_TwpPhasePred(pred_probs, pred_binary_labels_harsh, pred_binary_labels_loose, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index, seen_types): 160 | ''' 161 | pred_probs: a list, the prob for "entail" 162 | pred_binary_labels: a lit, each for 0 or 1 163 | eval_label_list: the gold type index; list length == lines in dev.txt 164 | eval_hypo_seen_str_indicator: totally hypo size, seen or unseen 165 | eval_hypo_2_type_index:: total hypo size, the type in [0,...n] 166 | seen_types: a set of type indices 167 | ''' 168 | 169 | pred_probs = list(pred_probs) 170 | # pred_binary_labels = list(pred_binary_labels) 171 | total_hypo_size = len(eval_hypo_seen_str_indicator) 172 | total_premise_size = len(eval_label_list) 173 | assert len(pred_probs) == total_premise_size*total_hypo_size 174 | assert len(eval_hypo_seen_str_indicator) == len(eval_hypo_2_type_index) 175 | 176 | pred_label_list = [] 177 | 178 | for i in range(total_premise_size): 179 | pred_probs_per_premise = pred_probs[i*total_hypo_size: (i+1)*total_hypo_size] 180 | pred_binary_labels_per_premise_harsh = pred_binary_labels_harsh[i*total_hypo_size: (i+1)*total_hypo_size] 181 | pred_binary_labels_per_premise_loose = pred_binary_labels_loose[i*total_hypo_size: (i+1)*total_hypo_size] 182 | 183 | pred_type = [] 184 | for j in range(total_hypo_size): 185 | if (eval_hypo_seen_str_indicator[j] == 'seen' and pred_probs_per_premise[j]>0.6) or \ 186 | (eval_hypo_seen_str_indicator[j] == 'unseen' and pred_probs_per_premise[j]>0.5): 187 | pred_type.append(eval_hypo_2_type_index[j]) 188 | 189 | if len(pred_type) ==0: 190 | pred_type.append('out-of-domain') 191 | pred_label_list.append(pred_type) 192 | 193 | assert len(pred_label_list) == len(eval_label_list) 194 | type_in_test = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 195 | type2col = { type:i for i, type in enumerate(type_in_test)} 196 | gold_array = np.zeros((total_premise_size,12), dtype=int) 197 | pred_array = np.zeros((total_premise_size,12), dtype=int) 198 | for i in range(total_premise_size): 199 | for type in pred_label_list[i]: 200 | pred_array[i,type2col.get(type)]=1 201 | for type in eval_label_list[i]: 202 | gold_array[i,type2col.get(type)]=1 203 | 204 | '''seen F1''' 205 | seen_f1_accu = 0.0 206 | seen_size = 0 207 | unseen_f1_accu = 0.0 208 | unseen_size = 0 209 | for i in range(len(type_in_test)): 210 | f1=f1_score(gold_array[:,i], pred_array[:,i], pos_label=1, average='binary') 211 | print(i, ':', f1) 212 | co = sum(gold_array[:,i]) 213 | if type_in_test[i] in seen_types: 214 | seen_f1_accu+=f1*co 215 | seen_size+=co 216 | else: 217 | unseen_f1_accu+=f1*co 218 | unseen_size+=co 219 | 220 | seen_f1 = seen_f1_accu/(1e-6+seen_size) 221 | unseen_f1 = unseen_f1_accu/(1e-6+unseen_size) 222 | 223 | return seen_f1, unseen_f1 224 | 225 | 226 | def situation_f1_given_goldlist_and_predlist(eval_label_list, pred_label_list, seen_types_v0, seen_types_v1): 227 | assert len(pred_label_list) == len(eval_label_list) 228 | total_premise_size = len(eval_label_list) 229 | type_in_test = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 230 | type2col = { type:i for i, type in enumerate(type_in_test)} 231 | gold_array = np.zeros((total_premise_size,12), dtype=int) 232 | pred_array = np.zeros((total_premise_size,12), dtype=int) 233 | for i in range(total_premise_size): 234 | for type in pred_label_list[i]: 235 | pred_array[i,type2col.get(type)]=1 236 | for type in eval_label_list[i]: 237 | gold_array[i,type2col.get(type)]=1 238 | # print('gold_array:', gold_array) 239 | # print('pred_array:', pred_array) 240 | # print('seen_types:', seen_types) 241 | '''seen F1''' 242 | 243 | # 244 | 245 | 246 | f1_list = [] 247 | co_list = [] 248 | for i in range(len(type_in_test)): 249 | if sum(pred_array[:,i]) < 1: 250 | f1=0.0 251 | else: 252 | f1=f1_score(gold_array[:,i], pred_array[:,i], pos_label=1, average='binary') 253 | co = sum(gold_array[:,i]) 254 | f1_list.append(f1) 255 | co_list.append(co) 256 | 257 | seen_f1_accu_v0 = 0.0 258 | seen_size_v0 = 0 259 | unseen_f1_accu_v0 = 0.0 260 | unseen_size_v0 = 0 261 | 262 | seen_f1_accu_v1 = 0.0 263 | seen_size_v1 = 0 264 | unseen_f1_accu_v1 = 0.0 265 | unseen_size_v1 = 0 266 | 267 | f1_accu = 0.0 268 | size_accu = 0 269 | for i in range(len(type_in_test)): 270 | f1 = f1_list[i] 271 | co =co_list[i] 272 | 273 | f1_accu+=f1*co 274 | size_accu+=co 275 | 276 | if type_in_test[i] in seen_types_v0: 277 | seen_f1_accu_v0+=f1*co 278 | seen_size_v0+=co 279 | else: 280 | unseen_f1_accu_v0+=f1*co 281 | unseen_size_v0+=co 282 | if type_in_test[i] in seen_types_v1: 283 | seen_f1_accu_v1+=f1*co 284 | seen_size_v1+=co 285 | else: 286 | unseen_f1_accu_v1+=f1*co 287 | unseen_size_v1+=co 288 | 289 | 290 | all_f1 = f1_accu/(1e-6+size_accu) 291 | 292 | v0 = (seen_f1_accu_v0/(1e-6+seen_size_v0), unseen_f1_accu_v0/(1e-6+unseen_size_v0)) 293 | v1 = (seen_f1_accu_v1/(1e-6+seen_size_v1), unseen_f1_accu_v1/(1e-6+unseen_size_v1)) 294 | 295 | return v0, v1, all_f1 296 | 297 | 298 | def evaluate_situation_zeroshot_SinglePhasePred(pred_probs, pred_binary_labels_harsh, pred_binary_labels_loose, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index, seen_types): 299 | ''' 300 | pred_probs: a list, the prob for "entail" 301 | pred_binary_labels: a lit, each for 0 or 1 302 | eval_label_list: the gold type index; list length == lines in dev.txt 303 | eval_hypo_seen_str_indicator: totally hypo size, seen or unseen 304 | eval_hypo_2_type_index:: total hypo size, the type in [0,...n] 305 | seen_types: a set of type indices 306 | ''' 307 | 308 | pred_probs = list(pred_probs) 309 | # pred_binary_labels = list(pred_binary_labels) 310 | total_hypo_size = len(eval_hypo_seen_str_indicator) 311 | total_premise_size = len(eval_label_list) 312 | assert len(pred_probs) == total_premise_size*total_hypo_size 313 | assert len(eval_hypo_seen_str_indicator) == len(eval_hypo_2_type_index) 314 | 315 | # print('seen_types:', seen_types) 316 | # print('eval_hypo_seen_str_indicator:', eval_hypo_seen_str_indicator) 317 | # print('eval_hypo_2_type_index:', eval_hypo_2_type_index) 318 | 319 | 320 | pred_label_list = [] 321 | 322 | for i in range(total_premise_size): 323 | pred_probs_per_premise = pred_probs[i*total_hypo_size: (i+1)*total_hypo_size] 324 | pred_binary_labels_per_premise_harsh = pred_binary_labels_harsh[i*total_hypo_size: (i+1)*total_hypo_size] 325 | pred_binary_labels_per_premise_loose = pred_binary_labels_loose[i*total_hypo_size: (i+1)*total_hypo_size] 326 | 327 | pred_type = [] 328 | for j in range(total_hypo_size): 329 | if pred_binary_labels_per_premise_loose[j]==0: # is entailment 330 | pred_type.append(eval_hypo_2_type_index[j]) 331 | 332 | if len(pred_type) ==0: 333 | pred_type.append('out-of-domain') 334 | pred_label_list.append(pred_type) 335 | 336 | assert len(pred_label_list) == len(eval_label_list) 337 | type_in_test = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 338 | type2col = { type:i for i, type in enumerate(type_in_test)} 339 | gold_array = np.zeros((total_premise_size,12), dtype=int) 340 | pred_array = np.zeros((total_premise_size,12), dtype=int) 341 | for i in range(total_premise_size): 342 | for type in pred_label_list[i]: 343 | pred_array[i,type2col.get(type)]=1 344 | for type in eval_label_list[i]: 345 | gold_array[i,type2col.get(type)]=1 346 | 347 | '''seen F1''' 348 | seen_f1_accu = 0.0 349 | seen_size = 0 350 | unseen_f1_accu = 0.0 351 | unseen_size = 0 352 | for i in range(len(type_in_test)): 353 | f1=f1_score(gold_array[:,i], pred_array[:,i], pos_label=1, average='binary') 354 | co = sum(gold_array[:,i]) 355 | if type_in_test[i] in seen_types: 356 | seen_f1_accu+=f1*co 357 | seen_size+=co 358 | else: 359 | unseen_f1_accu+=f1*co 360 | unseen_size+=co 361 | 362 | seen_f1 = seen_f1_accu/(1e-6+seen_size) 363 | unseen_f1 = unseen_f1_accu/(1e-6+unseen_size) 364 | 365 | return seen_f1, unseen_f1 366 | 367 | 368 | def majority_baseline(): 369 | readfile = codecs.open(path+'zero-shot-split/test.txt', 'r', 'utf-8') 370 | gold_label_list = [] 371 | for line in readfile: 372 | gold_label_list.append(line.strip().split('\t')[0].split()) 373 | '''out-of-domain is the main type''' 374 | pred_label_list = [['out-of-domain']] *len(gold_label_list) 375 | # seen_labels = set(['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange']) 376 | # seen_types = set(['evac','utils','shelter','food', 'terrorism']) 377 | seen_types = set(['search','infra','water','med', 'crimeviolence', 'regimechange']) 378 | # f1_score_per_type = f1_score(gold_label_list, pred_label_list, labels = list(set(gold_label_list)), average='weighted') 379 | 380 | assert len(pred_label_list) == len(gold_label_list) 381 | total_premise_size = len(gold_label_list) 382 | type_in_test = ['search','evac','infra','utils','water','shelter','med','food', 'crimeviolence', 'terrorism', 'regimechange', 'out-of-domain'] 383 | type2col = { type:i for i, type in enumerate(type_in_test)} 384 | gold_array = np.zeros((total_premise_size,12), dtype=int) 385 | pred_array = np.zeros((total_premise_size,12), dtype=int) 386 | for i in range(total_premise_size): 387 | for type in pred_label_list[i]: 388 | pred_array[i,type2col.get(type)]=1 389 | for type in gold_label_list[i]: 390 | gold_array[i,type2col.get(type)]=1 391 | 392 | '''seen F1''' 393 | seen_f1_accu = 0.0 394 | seen_size = 0 395 | unseen_f1_accu = 0.0 396 | unseen_size = 0 397 | 398 | f1_accu = 0.0 399 | size_accu = 0 400 | for i in range(len(type_in_test)): 401 | f1=f1_score(gold_array[:,i], pred_array[:,i], pos_label=1, average='binary') 402 | co = sum(gold_array[:,i]) 403 | 404 | f1_accu+=f1*co 405 | size_accu+=co 406 | if type_in_test[i] in seen_types: 407 | seen_f1_accu+=f1*co 408 | seen_size+=co 409 | else: 410 | unseen_f1_accu+=f1*co 411 | unseen_size+=co 412 | 413 | seen_f1 = seen_f1_accu/(1e-6+seen_size) 414 | unseen_f1 = unseen_f1_accu/(1e-6+unseen_size) 415 | 416 | all_f1 = f1_accu/(1e-6+size_accu) 417 | print('seen_f1:', seen_f1, 'unseen_f1:', unseen_f1, 'all:', all_f1) 418 | 419 | 420 | if __name__ == '__main__': 421 | # combine_all_available_labeled_datasets() 422 | '''not useful''' 423 | # split_all_labeleddata_into_subdata_per_label() 424 | # build_zeroshot_test_dev_set() 425 | # build_zeroshot_train_set() 426 | 427 | # build_zeroshot_test_dev_train_set() 428 | # statistics() 429 | 430 | majority_baseline() 431 | -------------------------------------------------------------------------------- /src/preprocess_emotion.py: -------------------------------------------------------------------------------- 1 | 2 | import jsonlines 3 | from collections import defaultdict 4 | import codecs 5 | from sklearn.metrics import f1_score 6 | 7 | path = '/export/home/Dataset/Stuttgart_Emotion/unify-emotion-datasets-master/' 8 | 9 | def statistics(): 10 | readfile = jsonlines.open(path+'unified-dataset.jsonl' ,'r') 11 | domain2size = defaultdict(int) 12 | source2size = defaultdict(int) 13 | emotion2size = defaultdict(int) 14 | '''single-label or multi-label''' 15 | single2size = defaultdict(int) 16 | emo_dom_size = defaultdict(int) 17 | line_co = 0 18 | valid_line_co = 0 19 | for line2dict in readfile: 20 | valid_line = False 21 | text = line2dict.get('text') 22 | domain = line2dict.get('domain') #tweets etc 23 | 24 | 25 | source_dataset = line2dict.get('source') 26 | 27 | single = line2dict.get('labeled') 28 | 29 | emotions =line2dict.get('emotions') 30 | if domain == 'headlines' or domain == 'facebook-messages': 31 | print(emotions) 32 | for emotion, label in emotions.items(): 33 | if label == 1: 34 | emotion2size[emotion]+=1 35 | emo_dom_size[(emotion, domain)]+=1 36 | valid_line = True 37 | if valid_line: 38 | valid_line_co+=1 39 | domain2size[domain]+=1 40 | source2size[source_dataset]+=1 41 | single2size[single]+=1 42 | line_co+=1 43 | if line_co%100==0: 44 | print(line_co) 45 | readfile.close() 46 | print('domain2size:',domain2size) 47 | print('source2size:',source2size) 48 | print('emotion2size:',emotion2size) 49 | print('single2size:',single2size) 50 | print('emo_dom_size:',emo_dom_size) 51 | print('line_co:', line_co) 52 | print('valid_line_co:', valid_line_co) 53 | 54 | ''' 55 | domain2size: defaultdict(, {'tweets': 54203, 'emotional_events': 7666, 'fairytale_sentences': 14771, 'artificial_sentences': 2268}) 56 | source2size: defaultdict(, {'grounded_emotions': 2585, 'ssec': 4776, 'isear': 7666, 'crowdflower': 39740, 'tales-emotion': 14771, 'emotion-cause': 2268, 'emoint': 7102}) 57 | emotion2size: defaultdict(, {'sadness': 12947, 'joy': 17833, 'anger': 8335, 'disgust': 3931, 'trust': 2700, 'fear': 14752, 'surprise': 4304, 'shame': 1096, 'guilt': 1093, 'noemo': 18765, 'love': 3820}) 58 | single2size: defaultdict(, {'single': 74132, 'multi': 4776}) 59 | 60 | 61 | emo_dom_size: defaultdict(, {('sadness', 'tweets'): 10355, ('joy', 'tweets'): 14433, ('anger', 'tweets'): 6024, ('disgust', 'tweets'): 2362, ('trust', 'tweets'): 2700, ('fear', 'tweets'): 12522, ('surprise', 'tweets'): 3285, ('joy', 'emotional_events'): 1094, ('fear', 'emotional_events'): 1095, ('anger', 'emotional_events'): 1096, ('sadness', 'emotional_events'): 1096, ('disgust', 'emotional_events'): 1096, ('shame', 'emotional_events'): 1096, ('guilt', 'emotional_events'): 1093, ('noemo', 'tweets'): 9370, ('love', 'tweets'): 3820, ('noemo', 'fairytale_sentences'): 9395, ('disgust', 'fairytale_sentences'): 378, ('joy', 'fairytale_sentences'): 1827, ('surprise', 'fairytale_sentences'): 806, ('fear', 'fairytale_sentences'): 712, ('anger', 'fairytale_sentences'): 732, ('sadness', 'fairytale_sentences'): 921, ('joy', 'artificial_sentences'): 479, ('sadness', 'artificial_sentences'): 575, ('surprise', 'artificial_sentences'): 213, ('disgust', 'artificial_sentences'): 95, ('anger', 'artificial_sentences'): 483, ('fear', 'artificial_sentences'): 423}) 62 | ''' 63 | 64 | 65 | def build_zeroshot_test_dev_set(): 66 | readfile = jsonlines.open(path+'unified-dataset.jsonl' ,'r') 67 | writefile_test = codecs.open(path+'zero-shot-split/test.txt', 'w', 'utf-8') 68 | writefile_dev = codecs.open(path+'zero-shot-split/dev.txt', 'w', 'utf-8') 69 | writefile_remain = codecs.open(path+'unified-dataset-wo-devandtest.txt', 'w', 'utf-8') 70 | 71 | emotion_type_list = ['sadness', 'joy', 'anger', 'disgust', 'trust', 'fear', 'surprise', 'shame', 'guilt', 'love', 'noemo'] 72 | domain_list = ['tweets', 'emotional_events', 'fairytale_sentences', 'artificial_sentences'] 73 | test_size_matrix = [[1500,2150,1650,50,800,2150,880,0,0,1100,1000], 74 | [300,200,400,400,0,200,0,300,300,0,0], 75 | [300,500,250,120,0,250,220,0,0,0,1000], 76 | [200,150,200,30,0,100,100,0,0,0,0]] 77 | 78 | dev_size_matrix = [[900,1050,400,40,250,1200,370,0,0,400,500], 79 | [150,150,150,150,0,150,0,100,100,0,0], 80 | [150,300,150,90,0,150,80,0,0,0,500], 81 | [100,100,100,20,0,100,50,0,0,0,0]] 82 | 83 | test_write_size = defaultdict(int) 84 | dev_write_size = defaultdict(int) 85 | 86 | line_co = 0 87 | spec_co = 0 88 | for line2dict in readfile: 89 | valid_line = False 90 | text = line2dict.get('text').strip() 91 | domain = line2dict.get('domain') #tweets etc 92 | 93 | 94 | source_dataset = line2dict.get('source') 95 | 96 | single = line2dict.get('labeled') 97 | '''we only consider single-label instances''' 98 | if single == 'single': 99 | target_emotion = '' 100 | emotions =line2dict.get('emotions') 101 | for emotion, label in emotions.items(): 102 | # print(emotion, label, label == 1) 103 | if label == 1: 104 | target_emotion = emotion 105 | break 106 | '''there is weird case that no positive label in the instances''' 107 | if len(target_emotion) > 0: 108 | if target_emotion == 'disgust' and domain =='tweets': 109 | spec_co+=1 110 | 111 | emotion_index = emotion_type_list.index(target_emotion) 112 | domain_index = domain_list.index(domain) 113 | if test_write_size.get((domain, target_emotion),0) < test_size_matrix[domain_index][emotion_index]: 114 | writefile_test.write(target_emotion+'\t'+domain+'\t'+text+'\n') 115 | test_write_size[(domain, target_emotion)]+=1 116 | elif dev_write_size.get((domain, target_emotion),0) < dev_size_matrix[domain_index][emotion_index]: 117 | writefile_dev.write(target_emotion+'\t'+domain+'\t'+text+'\n') 118 | dev_write_size[(domain, target_emotion)]+=1 119 | else: 120 | writefile_remain.write(target_emotion+'\t'+domain+'\t'+text+'\n') 121 | 122 | line_co+=1 123 | if line_co%100==0: 124 | print(line_co) 125 | writefile_test.close() 126 | writefile_dev.close() 127 | writefile_remain.close() 128 | print('test, dev, train build over') 129 | print(spec_co) 130 | 131 | writefile_test = codecs.open(path+'zero-shot-split/test.txt', 'r', 'utf-8') 132 | co=defaultdict(int) 133 | for line in writefile_test: 134 | parts = line.strip().split('\t') 135 | co[(parts[0], parts[1])]+=1 136 | writefile_test.close() 137 | print(co, '\n') 138 | 139 | writefile_dev = codecs.open(path+'zero-shot-split/dev.txt', 'r', 'utf-8') 140 | co=defaultdict(int) 141 | for line in writefile_dev: 142 | parts = line.strip().split('\t') 143 | co[(parts[0], parts[1])]+=1 144 | writefile_dev.close() 145 | print(co, '\n') 146 | writefile_remain = codecs.open(path+'unified-dataset-wo-devandtest.txt', 'r', 'utf-8') 147 | co=defaultdict(int) 148 | for line in writefile_remain: 149 | parts = line.strip().split('\t') 150 | co[(parts[0], parts[1])]+=1 151 | writefile_remain.close() 152 | print(co) 153 | 154 | def build_zeroshot_train_set(): 155 | readfile_remain = codecs.open(path+'unified-dataset-wo-devandtest.txt', 'r', 'utf-8') 156 | emotion_type_list = ['sadness', 'joy', 'anger', 'disgust', 'fear', 'surprise', 'shame', 'guilt', 'love'] 157 | writefile_PU_half_0 = codecs.open(path+'zero-shot-split/train_pu_half_v0.txt', 'w', 'utf-8') 158 | writefile_PU_half_1 = codecs.open(path+'zero-shot-split/train_pu_half_v1.txt', 'w', 'utf-8') 159 | 160 | for line in readfile_remain: 161 | parts = line.strip().split('\t') 162 | emotion = parts[0] 163 | if emotion in set(emotion_type_list): 164 | if emotion_type_list.index(emotion) %2==0: 165 | writefile_PU_half_0.write(line.strip()+'\n') 166 | else: 167 | writefile_PU_half_1.write(line.strip()+'\n') 168 | writefile_PU_half_0.close() 169 | writefile_PU_half_1.close() 170 | print('PU half over') 171 | '''PU_one''' 172 | for i in range(len(emotion_type_list)): 173 | readfile=codecs.open(path+'unified-dataset-wo-devandtest.txt', 'r', 'utf-8') 174 | writefile_PU_one = codecs.open(path+'zero-shot-split/train_pu_one_'+'wo_'+str(i)+'.txt', 'w', 'utf-8') 175 | line_co=0 176 | for line in readfile: 177 | parts = line.strip().split('\t') 178 | if len(parts)==3: 179 | emotion = parts[0] 180 | if emotion in set(emotion_type_list): 181 | label_id = emotion_type_list.index(emotion) 182 | if label_id != i: 183 | writefile_PU_one.write(line.strip()+'\n') 184 | line_co+=1 185 | writefile_PU_one.close() 186 | readfile.close() 187 | print('write size:', line_co) 188 | print('build train over') 189 | 190 | 191 | 192 | 193 | def evaluate_emotion_zeroshot_TwpPhasePred(pred_probs, pred_binary_labels_harsh, pred_binary_labels_loose, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index, seen_types): 194 | ''' 195 | pred_probs: a list, the prob for "entail" 196 | pred_binary_labels: a lit, each for 0 or 1 197 | eval_label_list: the gold type index; list length == lines in dev.txt 198 | eval_hypo_seen_str_indicator: totally hypo size, seen or unseen 199 | eval_hypo_2_type_index:: total hypo size, the type in [0,...n] 200 | seen_types: a set of type indices 201 | ''' 202 | 203 | pred_probs = list(pred_probs) 204 | # pred_binary_labels = list(pred_binary_labels) 205 | total_hypo_size = len(eval_hypo_seen_str_indicator) 206 | total_premise_size = len(eval_label_list) 207 | assert len(pred_probs) == total_premise_size*total_hypo_size 208 | assert len(eval_hypo_seen_str_indicator) == len(eval_hypo_2_type_index) 209 | 210 | # print('seen_types:', seen_types) 211 | # print('eval_hypo_seen_str_indicator:', eval_hypo_seen_str_indicator) 212 | # print('eval_hypo_2_type_index:', eval_hypo_2_type_index) 213 | 214 | 215 | pred_label_list = [] 216 | 217 | for i in range(total_premise_size): 218 | pred_probs_per_premise = pred_probs[i*total_hypo_size: (i+1)*total_hypo_size] 219 | pred_binary_labels_per_premise_harsh = pred_binary_labels_harsh[i*total_hypo_size: (i+1)*total_hypo_size] 220 | pred_binary_labels_per_premise_loose = pred_binary_labels_loose[i*total_hypo_size: (i+1)*total_hypo_size] 221 | 222 | 223 | # print('pred_probs_per_premise:',pred_probs_per_premise) 224 | # print('pred_binary_labels_per_premise:', pred_binary_labels_per_premise) 225 | 226 | 227 | '''first check if seen types get 'entailment''' 228 | seen_get_entail_flag=False 229 | for j in range(total_hypo_size): 230 | if eval_hypo_seen_str_indicator[j] == 'seen' and pred_binary_labels_per_premise_loose[j]==0: 231 | seen_get_entail_flag=True 232 | break 233 | '''first check if unseen types get 'entailment''' 234 | unseen_get_entail_flag=False 235 | for j in range(total_hypo_size): 236 | if eval_hypo_seen_str_indicator[j] == 'unseen' and pred_binary_labels_per_premise_loose[j]==0: 237 | unseen_get_entail_flag=True 238 | break 239 | 240 | if seen_get_entail_flag and unseen_get_entail_flag: 241 | 242 | '''compare their max prob''' 243 | max_prob_seen = -1.0 244 | max_seen_index = -1 245 | max_prob_unseen = -1.0 246 | max_unseen_index = -1 247 | for j in range(total_hypo_size): 248 | its_prob = pred_probs_per_premise[j] 249 | if eval_hypo_seen_str_indicator[j] == 'unseen': 250 | if its_prob > max_prob_unseen: 251 | max_prob_unseen = its_prob 252 | max_unseen_index = j 253 | else: 254 | if its_prob > max_prob_seen: 255 | max_prob_seen = its_prob 256 | max_seen_index = j 257 | if max_prob_seen - max_prob_unseen > 0.05: 258 | pred_type = eval_hypo_2_type_index[max_seen_index] 259 | else: 260 | pred_type = eval_hypo_2_type_index[max_unseen_index] 261 | 262 | elif unseen_get_entail_flag: 263 | '''find the unseen type with highest prob''' 264 | max_j = -1 265 | max_prob = -1.0 266 | for j in range(total_hypo_size): 267 | if eval_hypo_seen_str_indicator[j] == 'unseen': 268 | its_prob = pred_probs_per_premise[j] 269 | if its_prob > max_prob: 270 | max_prob = its_prob 271 | max_j = j 272 | pred_type = eval_hypo_2_type_index[max_j] 273 | 274 | elif seen_get_entail_flag: 275 | '''find the seen type with highest prob''' 276 | max_j = -1 277 | max_prob = -1.0 278 | for j in range(total_hypo_size): 279 | if eval_hypo_seen_str_indicator[j] == 'seen' and pred_binary_labels_per_premise_loose[j]==0: 280 | its_prob = pred_probs_per_premise[j] 281 | if its_prob > max_prob: 282 | max_prob = its_prob 283 | max_j = j 284 | assert max_prob > 0.5 285 | pred_type = eval_hypo_2_type_index[max_j] 286 | elif (not seen_get_entail_flag) and (not unseen_get_entail_flag): 287 | '''it means noemo''' 288 | pred_type = 'noemo' 289 | pred_label_list.append(pred_type) 290 | 291 | assert len(pred_label_list) == len(eval_label_list) 292 | 293 | all_test_labels = list(set(eval_label_list)) 294 | f1_score_per_type = f1_score(eval_label_list, pred_label_list, labels = all_test_labels, average=None) 295 | print('all_test_labels:', all_test_labels) 296 | print('f1_score_per_type:', f1_score_per_type) 297 | print('type size:', [eval_label_list.count(type) for type in all_test_labels]) 298 | 299 | '''seen F1''' 300 | seen_f1_accu = 0.0 301 | seen_size = 0 302 | unseen_f1_accu = 0.0 303 | unseen_size = 0 304 | for i in range(len(all_test_labels)): 305 | f1=f1_score_per_type[i] 306 | co = eval_label_list.count(all_test_labels[i]) 307 | if all_test_labels[i] in seen_types: 308 | seen_f1_accu+=f1*co 309 | seen_size+=co 310 | else: 311 | unseen_f1_accu+=f1*co 312 | unseen_size+=co 313 | 314 | 315 | 316 | 317 | seen_f1 = seen_f1_accu/(1e-6+seen_size) 318 | unseen_f1 = unseen_f1_accu/(1e-6+unseen_size) 319 | 320 | return seen_f1, unseen_f1 321 | 322 | def evaluate_emotion_zeroshot_SinglePhasePred(pred_probs, pred_binary_labels_harsh, pred_binary_labels_loose, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index, seen_types): 323 | ''' 324 | pred_probs: a list, the prob for "entail" 325 | pred_binary_labels: a lit, each for 0 or 1 326 | eval_label_list: the gold type index; list length == lines in dev.txt 327 | eval_hypo_seen_str_indicator: totally hypo size, seen or unseen 328 | eval_hypo_2_type_index:: total hypo size, the type in [0,...n] 329 | seen_types: a set of type indices 330 | ''' 331 | 332 | pred_probs = list(pred_probs) 333 | # pred_binary_labels = list(pred_binary_labels) 334 | total_hypo_size = len(eval_hypo_seen_str_indicator) 335 | total_premise_size = len(eval_label_list) 336 | assert len(pred_probs) == total_premise_size*total_hypo_size 337 | assert len(eval_hypo_seen_str_indicator) == len(eval_hypo_2_type_index) 338 | 339 | # print('seen_types:', seen_types) 340 | # print('eval_hypo_seen_str_indicator:', eval_hypo_seen_str_indicator) 341 | # print('eval_hypo_2_type_index:', eval_hypo_2_type_index) 342 | 343 | 344 | pred_label_list = [] 345 | 346 | for i in range(total_premise_size): 347 | pred_probs_per_premise = pred_probs[i*total_hypo_size: (i+1)*total_hypo_size] 348 | pred_binary_labels_per_premise_harsh = pred_binary_labels_harsh[i*total_hypo_size: (i+1)*total_hypo_size] 349 | pred_binary_labels_per_premise_loose = pred_binary_labels_loose[i*total_hypo_size: (i+1)*total_hypo_size] 350 | 351 | 352 | max_prob = -100.0 353 | max_index = -1 354 | for j in range(total_hypo_size): 355 | if pred_binary_labels_per_premise_loose[j]==0: # is entailment 356 | if pred_probs_per_premise[j] > max_prob: 357 | max_prob = pred_probs_per_premise[j] 358 | max_index = j 359 | 360 | if max_index == -1: 361 | pred_label_list.append('out-of-domain') 362 | else: 363 | pred_label_list.append(eval_hypo_2_type_index[max_index]) 364 | 365 | assert len(pred_label_list) == len(eval_label_list) 366 | 367 | all_test_labels = list(set(eval_label_list)) 368 | f1_score_per_type = f1_score(eval_label_list, pred_label_list, labels = all_test_labels, average=None) 369 | print('all_test_labels:', all_test_labels) 370 | print('f1_score_per_type:', f1_score_per_type) 371 | print('type size:', [eval_label_list.count(type) for type in all_test_labels]) 372 | 373 | '''seen F1''' 374 | seen_f1_accu = 0.0 375 | seen_size = 0 376 | unseen_f1_accu = 0.0 377 | unseen_size = 0 378 | for i in range(len(all_test_labels)): 379 | f1=f1_score_per_type[i] 380 | co = eval_label_list.count(all_test_labels[i]) 381 | if all_test_labels[i] in seen_types: 382 | seen_f1_accu+=f1*co 383 | seen_size+=co 384 | else: 385 | unseen_f1_accu+=f1*co 386 | unseen_size+=co 387 | 388 | 389 | 390 | 391 | seen_f1 = seen_f1_accu/(1e-6+seen_size) 392 | unseen_f1 = unseen_f1_accu/(1e-6+unseen_size) 393 | 394 | return seen_f1, unseen_f1 395 | 396 | def majority_baseline(): 397 | readfile = codecs.open(path+'zero-shot-split/test.txt', 'r', 'utf-8') 398 | gold_label_list = [] 399 | for line in readfile: 400 | gold_label_list.append(line.strip().split('\t')[0]) 401 | '''joy is the main emoion''' 402 | pred_label_list = ['joy'] *len(gold_label_list) 403 | # seen_labels = set(['sadness', 'joy', 'anger', 'disgust', 'fear', 'surprise', 'shame', 'guilt', 'love']) 404 | seen_types = set(['joy', 'disgust', 'surprise', 'guilt']) 405 | # f1_score_per_type = f1_score(gold_label_list, pred_label_list, labels = list(set(gold_label_list)), average='weighted') 406 | 407 | all_test_labels = list(set(gold_label_list)) 408 | f1_score_per_type = f1_score(gold_label_list, pred_label_list, labels = all_test_labels, average=None) 409 | 410 | seen_f1_accu = 0.0 411 | seen_size = 0 412 | unseen_f1_accu = 0.0 413 | unseen_size = 0 414 | for i in range(len(all_test_labels)): 415 | f1=f1_score_per_type[i] 416 | co = gold_label_list.count(all_test_labels[i]) 417 | if all_test_labels[i] in seen_types: 418 | seen_f1_accu+=f1*co 419 | seen_size+=co 420 | else: 421 | unseen_f1_accu+=f1*co 422 | unseen_size+=co 423 | 424 | 425 | 426 | 427 | seen_f1 = seen_f1_accu/(1e-6+seen_size) 428 | unseen_f1 = unseen_f1_accu/(1e-6+unseen_size) 429 | 430 | print('seen_f1:', seen_f1, 'unseen_f1:', unseen_f1) 431 | 432 | def emotion_f1_given_goldlist_and_predlist(gold_label_list, pred_label_list, seen_types_v0, seen_types_v1): 433 | 434 | # print('gold_label_list:', gold_label_list) 435 | # print('pred_label_list:', pred_label_list) 436 | all_test_labels = list(set(gold_label_list)) 437 | f1_score_per_type = f1_score(gold_label_list, pred_label_list, labels = all_test_labels, average=None) 438 | # print('f1_score_per_type:', f1_score_per_type) 439 | seen_f1_accu_v0 = 0.0 440 | seen_size_v0 = 0 441 | unseen_f1_accu_v0 = 0.0 442 | unseen_size_v0 = 0 443 | 444 | seen_f1_accu_v1 = 0.0 445 | seen_size_v1 = 0 446 | unseen_f1_accu_v1 = 0.0 447 | unseen_size_v1 = 0 448 | 449 | f1_accu = 0.0 450 | size_accu = 0 451 | for i in range(len(all_test_labels)): 452 | f1=f1_score_per_type[i] 453 | co = gold_label_list.count(all_test_labels[i]) 454 | # print('f1:', f1) 455 | # print('co:', co) 456 | 457 | f1_accu+=f1*co 458 | size_accu+=co 459 | 460 | if all_test_labels[i] in seen_types_v0: 461 | seen_f1_accu_v0+=f1*co 462 | seen_size_v0+=co 463 | else: 464 | unseen_f1_accu_v0+=f1*co 465 | unseen_size_v0+=co 466 | 467 | if all_test_labels[i] in seen_types_v1: 468 | seen_f1_accu_v1+=f1*co 469 | seen_size_v1+=co 470 | else: 471 | unseen_f1_accu_v1+=f1*co 472 | unseen_size_v1+=co 473 | 474 | 475 | v0 = (seen_f1_accu_v0/(1e-6+seen_size_v0), unseen_f1_accu_v0/(1e-6+unseen_size_v0)) 476 | v1 = (seen_f1_accu_v1/(1e-6+seen_size_v1), unseen_f1_accu_v1/(1e-6+unseen_size_v1)) 477 | all_f1 = f1_accu/(1e-6+size_accu) 478 | 479 | 480 | return v0, v1, all_f1 481 | 482 | 483 | def forfun(): 484 | readfile = codecs.open('/export/home/Dataset/Stuttgart_Emotion/unify-emotion-datasets-master/zero-shot-split/test.txt', 'r', 'utf-8') 485 | co=0 486 | for line in readfile: 487 | if line.strip().split('\t')[0] != 'noemo': 488 | co+=1 489 | else: 490 | print(co) #4685 491 | break 492 | readfile.close() 493 | if __name__ == '__main__': 494 | # statistics() 495 | # build_zeroshot_test_dev_set() 496 | # build_zeroshot_train_set() 497 | 498 | # majority_baseline() 499 | 500 | forfun() 501 | -------------------------------------------------------------------------------- /src/wikipedia/wikipedia-standard.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import requests 4 | import time 5 | from bs4 import BeautifulSoup 6 | from datetime import datetime, timedelta 7 | from decimal import Decimal 8 | 9 | from .exceptions import ( 10 | PageError, DisambiguationError, RedirectError, HTTPTimeoutError, 11 | WikipediaException, ODD_ERROR_MESSAGE) 12 | from .util import cache, stdout_encode, debug 13 | import re 14 | 15 | API_URL = 'http://en.wikipedia.org/w/api.php' 16 | RATE_LIMIT = False 17 | RATE_LIMIT_MIN_WAIT = None 18 | RATE_LIMIT_LAST_CALL = None 19 | USER_AGENT = 'wikipedia (https://github.com/goldsmith/Wikipedia/)' 20 | 21 | 22 | def set_lang(prefix): 23 | ''' 24 | Change the language of the API being requested. 25 | Set `prefix` to one of the two letter prefixes found on the `list of all Wikipedias `_. 26 | 27 | After setting the language, the cache for ``search``, ``suggest``, and ``summary`` will be cleared. 28 | 29 | .. note:: Make sure you search for page titles in the language that you have set. 30 | ''' 31 | global API_URL 32 | API_URL = 'http://' + prefix.lower() + '.wikipedia.org/w/api.php' 33 | 34 | for cached_func in (search, suggest, summary): 35 | cached_func.clear_cache() 36 | 37 | 38 | def set_user_agent(user_agent_string): 39 | ''' 40 | Set the User-Agent string to be used for all requests. 41 | 42 | Arguments: 43 | 44 | * user_agent_string - (string) a string specifying the User-Agent header 45 | ''' 46 | global USER_AGENT 47 | USER_AGENT = user_agent_string 48 | 49 | 50 | def set_rate_limiting(rate_limit, min_wait=timedelta(milliseconds=50)): 51 | ''' 52 | Enable or disable rate limiting on requests to the Mediawiki servers. 53 | If rate limiting is not enabled, under some circumstances (depending on 54 | load on Wikipedia, the number of requests you and other `wikipedia` users 55 | are making, and other factors), Wikipedia may return an HTTP timeout error. 56 | 57 | Enabling rate limiting generally prevents that issue, but please note that 58 | HTTPTimeoutError still might be raised. 59 | 60 | Arguments: 61 | 62 | * rate_limit - (Boolean) whether to enable rate limiting or not 63 | 64 | Keyword arguments: 65 | 66 | * min_wait - if rate limiting is enabled, `min_wait` is a timedelta describing the minimum time to wait before requests. 67 | Defaults to timedelta(milliseconds=50) 68 | ''' 69 | global RATE_LIMIT 70 | global RATE_LIMIT_MIN_WAIT 71 | global RATE_LIMIT_LAST_CALL 72 | 73 | RATE_LIMIT = rate_limit 74 | if not rate_limit: 75 | RATE_LIMIT_MIN_WAIT = None 76 | else: 77 | RATE_LIMIT_MIN_WAIT = min_wait 78 | 79 | RATE_LIMIT_LAST_CALL = None 80 | 81 | 82 | @cache 83 | def search(query, results=10, suggestion=False): 84 | ''' 85 | Do a Wikipedia search for `query`. 86 | 87 | Keyword arguments: 88 | 89 | * results - the maxmimum number of results returned 90 | * suggestion - if True, return results and suggestion (if any) in a tuple 91 | ''' 92 | 93 | search_params = { 94 | 'list': 'search', 95 | 'srprop': '', 96 | 'srlimit': results, 97 | 'limit': results, 98 | 'srsearch': query 99 | } 100 | if suggestion: 101 | search_params['srinfo'] = 'suggestion' 102 | 103 | raw_results = _wiki_request(search_params) 104 | 105 | if 'error' in raw_results: 106 | if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): 107 | raise HTTPTimeoutError(query) 108 | else: 109 | raise WikipediaException(raw_results['error']['info']) 110 | 111 | search_results = (d['title'] for d in raw_results['query']['search']) 112 | 113 | if suggestion: 114 | if raw_results['query'].get('searchinfo'): 115 | return list(search_results), raw_results['query']['searchinfo']['suggestion'] 116 | else: 117 | return list(search_results), None 118 | 119 | return list(search_results) 120 | 121 | 122 | @cache 123 | def geosearch(latitude, longitude, title=None, results=10, radius=1000): 124 | ''' 125 | Do a wikipedia geo search for `latitude` and `longitude` 126 | using HTTP API described in http://www.mediawiki.org/wiki/Extension:GeoData 127 | 128 | Arguments: 129 | 130 | * latitude (float or decimal.Decimal) 131 | * longitude (float or decimal.Decimal) 132 | 133 | Keyword arguments: 134 | 135 | * title - The title of an article to search for 136 | * results - the maximum number of results returned 137 | * radius - Search radius in meters. The value must be between 10 and 10000 138 | ''' 139 | 140 | search_params = { 141 | 'list': 'geosearch', 142 | 'gsradius': radius, 143 | 'gscoord': '{0}|{1}'.format(latitude, longitude), 144 | 'gslimit': results 145 | } 146 | if title: 147 | search_params['titles'] = title 148 | 149 | raw_results = _wiki_request(search_params) 150 | 151 | if 'error' in raw_results: 152 | if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): 153 | raise HTTPTimeoutError('{0}|{1}'.format(latitude, longitude)) 154 | else: 155 | raise WikipediaException(raw_results['error']['info']) 156 | 157 | search_pages = raw_results['query'].get('pages', None) 158 | if search_pages: 159 | search_results = (v['title'] for k, v in search_pages.items() if k != '-1') 160 | else: 161 | search_results = (d['title'] for d in raw_results['query']['geosearch']) 162 | 163 | return list(search_results) 164 | 165 | 166 | @cache 167 | def suggest(query): 168 | ''' 169 | Get a Wikipedia search suggestion for `query`. 170 | Returns a string or None if no suggestion was found. 171 | ''' 172 | 173 | search_params = { 174 | 'list': 'search', 175 | 'srinfo': 'suggestion', 176 | 'srprop': '', 177 | } 178 | search_params['srsearch'] = query 179 | 180 | raw_result = _wiki_request(search_params) 181 | 182 | if raw_result['query'].get('searchinfo'): 183 | return raw_result['query']['searchinfo']['suggestion'] 184 | 185 | return None 186 | 187 | 188 | def random(pages=1): 189 | ''' 190 | Get a list of random Wikipedia article titles. 191 | 192 | .. note:: Random only gets articles from namespace 0, meaning no Category, User talk, or other meta-Wikipedia pages. 193 | 194 | Keyword arguments: 195 | 196 | * pages - the number of random pages returned (max of 10) 197 | ''' 198 | #http://en.wikipedia.org/w/api.php?action=query&list=random&rnlimit=5000&format=jsonfm 199 | query_params = { 200 | 'list': 'random', 201 | 'rnnamespace': 0, 202 | 'rnlimit': pages, 203 | } 204 | 205 | request = _wiki_request(query_params) 206 | titles = [page['title'] for page in request['query']['random']] 207 | 208 | if len(titles) == 1: 209 | return titles[0] 210 | 211 | return titles 212 | 213 | 214 | @cache 215 | def summary(title, sentences=0, chars=0, auto_suggest=True, redirect=True): 216 | ''' 217 | Plain text summary of the page. 218 | 219 | .. note:: This is a convenience wrapper - auto_suggest and redirect are enabled by default 220 | 221 | Keyword arguments: 222 | 223 | * sentences - if set, return the first `sentences` sentences (can be no greater than 10). 224 | * chars - if set, return only the first `chars` characters (actual text returned may be slightly longer). 225 | * auto_suggest - let Wikipedia find a valid page title for the query 226 | * redirect - allow redirection without raising RedirectError 227 | ''' 228 | 229 | # use auto_suggest and redirect to get the correct article 230 | # also, use page's error checking to raise DisambiguationError if necessary 231 | page_info = page(title, auto_suggest=auto_suggest, redirect=redirect) 232 | title = page_info.title 233 | pageid = page_info.pageid 234 | 235 | query_params = { 236 | 'prop': 'extracts', 237 | 'explaintext': '', 238 | 'titles': title 239 | } 240 | 241 | if sentences: 242 | query_params['exsentences'] = sentences 243 | elif chars: 244 | query_params['exchars'] = chars 245 | else: 246 | query_params['exintro'] = '' 247 | 248 | request = _wiki_request(query_params) #the core step to retrieve from wikipedia given query parameters 249 | summary = request['query']['pages'][pageid]['extract'] 250 | 251 | return summary 252 | 253 | 254 | def page(title=None, pageid=None, auto_suggest=True, redirect=True, preload=False): 255 | ''' 256 | Get a WikipediaPage object for the page with title `title` or the pageid 257 | `pageid` (mutually exclusive). 258 | 259 | Keyword arguments: 260 | 261 | * title - the title of the page to load 262 | * pageid - the numeric pageid of the page to load 263 | * auto_suggest - let Wikipedia find a valid page title for the query 264 | * redirect - allow redirection without raising RedirectError 265 | * preload - load content, summary, images, references, and links during initialization 266 | ''' 267 | 268 | if title is not None: 269 | if auto_suggest: 270 | results, suggestion = search(title, results=1, suggestion=True) 271 | try: 272 | title = suggestion or results[0] 273 | except IndexError: 274 | # if there is no suggestion or search results, the page doesn't exist 275 | raise PageError(title) 276 | return WikipediaPage(title, redirect=redirect, preload=preload) 277 | elif pageid is not None: 278 | return WikipediaPage(pageid=pageid, preload=preload) 279 | else: 280 | raise ValueError("Either a title or a pageid must be specified") 281 | 282 | 283 | 284 | class WikipediaPage(object): 285 | ''' 286 | Contains data from a Wikipedia page. 287 | Uses property methods to filter data from the raw HTML. 288 | ''' 289 | 290 | def __init__(self, title=None, pageid=None, redirect=True, preload=False, original_title=''): 291 | if title is not None: 292 | self.title = title 293 | self.original_title = original_title or title 294 | elif pageid is not None: 295 | self.pageid = pageid 296 | else: 297 | raise ValueError("Either a title or a pageid must be specified") 298 | 299 | self.__load(redirect=redirect, preload=preload) 300 | 301 | if preload: 302 | for prop in ('content', 'summary', 'images', 'references', 'links', 'sections'): 303 | getattr(self, prop) 304 | 305 | def __repr__(self): 306 | return stdout_encode(u''.format(self.title)) 307 | 308 | def __eq__(self, other): 309 | try: 310 | return ( 311 | self.pageid == other.pageid 312 | and self.title == other.title 313 | and self.url == other.url 314 | ) 315 | except: 316 | return False 317 | 318 | def __load(self, redirect=True, preload=False): 319 | ''' 320 | Load basic information from Wikipedia. 321 | Confirm that page exists and is not a disambiguation/redirect. 322 | 323 | Does not need to be called manually, should be called automatically during __init__. 324 | ''' 325 | query_params = { 326 | 'prop': 'info|pageprops', 327 | 'inprop': 'url', 328 | 'ppprop': 'disambiguation', 329 | 'redirects': '', 330 | } 331 | if not getattr(self, 'pageid', None): 332 | query_params['titles'] = self.title 333 | else: 334 | query_params['pageids'] = self.pageid 335 | 336 | request = _wiki_request(query_params) 337 | 338 | query = request['query'] 339 | pageid = list(query['pages'].keys())[0] 340 | page = query['pages'][pageid] 341 | 342 | # missing is present if the page is missing 343 | if 'missing' in page: 344 | if hasattr(self, 'title'): 345 | raise PageError(self.title) 346 | else: 347 | raise PageError(pageid=self.pageid) 348 | 349 | # same thing for redirect, except it shows up in query instead of page for 350 | # whatever silly reason 351 | elif 'redirects' in query: 352 | if redirect: 353 | redirects = query['redirects'][0] 354 | 355 | if 'normalized' in query: 356 | normalized = query['normalized'][0] 357 | assert normalized['from'] == self.title, ODD_ERROR_MESSAGE 358 | 359 | from_title = normalized['to'] 360 | 361 | else: 362 | from_title = self.title 363 | 364 | assert redirects['from'] == from_title, ODD_ERROR_MESSAGE 365 | 366 | # change the title and reload the whole object 367 | self.__init__(redirects['to'], redirect=redirect, preload=preload) 368 | 369 | else: 370 | raise RedirectError(getattr(self, 'title', page['title'])) 371 | 372 | # since we only asked for disambiguation in ppprop, 373 | # if a pageprop is returned, 374 | # then the page must be a disambiguation page 375 | elif 'pageprops' in page: 376 | query_params = { 377 | 'prop': 'revisions', 378 | 'rvprop': 'content', 379 | 'rvparse': '', 380 | 'rvlimit': 1 381 | } 382 | if hasattr(self, 'pageid'): 383 | query_params['pageids'] = self.pageid 384 | else: 385 | query_params['titles'] = self.title 386 | request = _wiki_request(query_params) 387 | html = request['query']['pages'][pageid]['revisions'][0]['*'] 388 | 389 | lis = BeautifulSoup(html, 'html.parser').find_all('li') 390 | filtered_lis = [li for li in lis if not 'tocsection' in ''.join(li.get('class', []))] 391 | may_refer_to = [li.a.get_text() for li in filtered_lis if li.a] 392 | 393 | raise DisambiguationError(getattr(self, 'title', page['title']), may_refer_to) 394 | 395 | else: 396 | self.pageid = pageid 397 | self.title = page['title'] 398 | self.url = page['fullurl'] 399 | 400 | def __continued_query(self, query_params): 401 | ''' 402 | Based on https://www.mediawiki.org/wiki/API:Query#Continuing_queries 403 | ''' 404 | query_params.update(self.__title_query_param) 405 | 406 | last_continue = {} 407 | prop = query_params.get('prop', None) 408 | 409 | while True: 410 | params = query_params.copy() 411 | params.update(last_continue) 412 | 413 | request = _wiki_request(params) 414 | 415 | if 'query' not in request: 416 | break 417 | 418 | pages = request['query']['pages'] 419 | if 'generator' in query_params: 420 | for datum in pages.values(): # in python 3.3+: "yield from pages.values()" 421 | yield datum 422 | else: 423 | for datum in pages[self.pageid][prop]: 424 | yield datum 425 | 426 | if 'continue' not in request: 427 | break 428 | 429 | last_continue = request['continue'] 430 | 431 | @property 432 | def __title_query_param(self): 433 | if getattr(self, 'title', None) is not None: 434 | return {'titles': self.title} 435 | else: 436 | return {'pageids': self.pageid} 437 | 438 | def html(self): 439 | ''' 440 | Get full page HTML. 441 | 442 | .. warning:: This can get pretty slow on long pages. 443 | ''' 444 | 445 | if not getattr(self, '_html', False): 446 | query_params = { 447 | 'prop': 'revisions', 448 | 'rvprop': 'content', 449 | 'rvlimit': 1, 450 | 'rvparse': '', 451 | 'titles': self.title 452 | } 453 | 454 | request = _wiki_request(query_params) 455 | self._html = request['query']['pages'][self.pageid]['revisions'][0]['*'] 456 | 457 | return self._html 458 | 459 | @property 460 | def content(self): 461 | ''' 462 | Plain text content of the page, excluding images, tables, and other data. 463 | ''' 464 | 465 | if not getattr(self, '_content', False): 466 | query_params = { 467 | 'prop': 'extracts|revisions', 468 | 'explaintext': '', 469 | 'rvprop': 'ids' 470 | } 471 | if not getattr(self, 'title', None) is None: 472 | query_params['titles'] = self.title 473 | else: 474 | query_params['pageids'] = self.pageid 475 | request = _wiki_request(query_params) 476 | self._content = request['query']['pages'][self.pageid]['extract'] 477 | self._revision_id = request['query']['pages'][self.pageid]['revisions'][0]['revid'] 478 | self._parent_id = request['query']['pages'][self.pageid]['revisions'][0]['parentid'] 479 | 480 | return self._content 481 | 482 | @property 483 | def revision_id(self): 484 | ''' 485 | Revision ID of the page. 486 | 487 | The revision ID is a number that uniquely identifies the current 488 | version of the page. It can be used to create the permalink or for 489 | other direct API calls. See `Help:Page history 490 | `_ for more 491 | information. 492 | ''' 493 | 494 | if not getattr(self, '_revid', False): 495 | # fetch the content (side effect is loading the revid) 496 | self.content 497 | 498 | return self._revision_id 499 | 500 | @property 501 | def parent_id(self): 502 | ''' 503 | Revision ID of the parent version of the current revision of this 504 | page. See ``revision_id`` for more information. 505 | ''' 506 | 507 | if not getattr(self, '_parentid', False): 508 | # fetch the content (side effect is loading the revid) 509 | self.content 510 | 511 | return self._parent_id 512 | 513 | @property 514 | def summary(self): 515 | ''' 516 | Plain text summary of the page. 517 | ''' 518 | 519 | if not getattr(self, '_summary', False): 520 | query_params = { 521 | 'prop': 'extracts', 522 | 'explaintext': '', 523 | 'exintro': '', 524 | } 525 | if not getattr(self, 'title', None) is None: 526 | query_params['titles'] = self.title 527 | else: 528 | query_params['pageids'] = self.pageid 529 | 530 | request = _wiki_request(query_params) 531 | self._summary = request['query']['pages'][self.pageid]['extract'] 532 | 533 | return self._summary 534 | 535 | @property 536 | def images(self): 537 | ''' 538 | List of URLs of images on the page. 539 | ''' 540 | 541 | if not getattr(self, '_images', False): 542 | self._images = [ 543 | page['imageinfo'][0]['url'] 544 | for page in self.__continued_query({ 545 | 'generator': 'images', 546 | 'gimlimit': 'max', 547 | 'prop': 'imageinfo', 548 | 'iiprop': 'url', 549 | }) 550 | if 'imageinfo' in page 551 | ] 552 | 553 | return self._images 554 | 555 | @property 556 | def coordinates(self): 557 | ''' 558 | Tuple of Decimals in the form of (lat, lon) or None 559 | ''' 560 | if not getattr(self, '_coordinates', False): 561 | query_params = { 562 | 'prop': 'coordinates', 563 | 'colimit': 'max', 564 | 'titles': self.title, 565 | } 566 | 567 | request = _wiki_request(query_params) 568 | 569 | if 'query' in request: 570 | coordinates = request['query']['pages'][self.pageid]['coordinates'] 571 | self._coordinates = (Decimal(coordinates[0]['lat']), Decimal(coordinates[0]['lon'])) 572 | else: 573 | self._coordinates = None 574 | 575 | return self._coordinates 576 | 577 | @property 578 | def references(self): 579 | ''' 580 | List of URLs of external links on a page. 581 | May include external links within page that aren't technically cited anywhere. 582 | ''' 583 | 584 | if not getattr(self, '_references', False): 585 | def add_protocol(url): 586 | return url if url.startswith('http') else 'http:' + url 587 | 588 | self._references = [ 589 | add_protocol(link['*']) 590 | for link in self.__continued_query({ 591 | 'prop': 'extlinks', 592 | 'ellimit': 'max' 593 | }) 594 | ] 595 | 596 | return self._references 597 | 598 | @property 599 | def links(self): 600 | ''' 601 | List of titles of Wikipedia page links on a page. 602 | 603 | .. note:: Only includes articles from namespace 0, meaning no Category, User talk, or other meta-Wikipedia pages. 604 | ''' 605 | 606 | if not getattr(self, '_links', False): 607 | self._links = [ 608 | link['title'] 609 | for link in self.__continued_query({ 610 | 'prop': 'links', 611 | 'plnamespace': 0, 612 | 'pllimit': 'max' 613 | }) 614 | ] 615 | 616 | return self._links 617 | 618 | @property 619 | def categories(self): 620 | ''' 621 | List of categories of a page. 622 | ''' 623 | 624 | if not getattr(self, '_categories', False): 625 | self._categories = [re.sub(r'^Category:', '', x) for x in 626 | [link['title'] 627 | for link in self.__continued_query({ 628 | 'prop': 'categories', 629 | 'cllimit': 'max' 630 | }) 631 | ]] 632 | 633 | return self._categories 634 | 635 | @property 636 | def sections(self): 637 | ''' 638 | List of section titles from the table of contents on the page. 639 | ''' 640 | 641 | if not getattr(self, '_sections', False): 642 | query_params = { 643 | 'action': 'parse', 644 | 'prop': 'sections', 645 | } 646 | query_params.update(self.__title_query_param) 647 | 648 | request = _wiki_request(query_params) 649 | self._sections = [section['line'] for section in request['parse']['sections']] 650 | 651 | return self._sections 652 | 653 | def section(self, section_title): 654 | ''' 655 | Get the plain text content of a section from `self.sections`. 656 | Returns None if `section_title` isn't found, otherwise returns a whitespace stripped string. 657 | 658 | This is a convenience method that wraps self.content. 659 | 660 | .. warning:: Calling `section` on a section that has subheadings will NOT return 661 | the full text of all of the subsections. It only gets the text between 662 | `section_title` and the next subheading, which is often empty. 663 | ''' 664 | 665 | section = u"== {} ==".format(section_title) 666 | try: 667 | index = self.content.index(section) + len(section) 668 | except ValueError: 669 | return None 670 | 671 | try: 672 | next_index = self.content.index("==", index) 673 | except ValueError: 674 | next_index = len(self.content) 675 | 676 | return self.content[index:next_index].lstrip("=").strip() 677 | 678 | 679 | @cache 680 | def languages(): 681 | ''' 682 | List all the currently supported language prefixes (usually ISO language code). 683 | 684 | Can be inputted to `set_lang` to change the Mediawiki that `wikipedia` requests 685 | results from. 686 | 687 | Returns: dict of : pairs. To get just a list of prefixes, 688 | use `wikipedia.languages().keys()`. 689 | ''' 690 | response = _wiki_request({ 691 | 'meta': 'siteinfo', 692 | 'siprop': 'languages' 693 | }) 694 | 695 | languages = response['query']['languages'] 696 | 697 | return { 698 | lang['code']: lang['*'] 699 | for lang in languages 700 | } 701 | 702 | 703 | def donate(): 704 | ''' 705 | Open up the Wikimedia donate page in your favorite browser. 706 | ''' 707 | import webbrowser 708 | 709 | webbrowser.open('https://donate.wikimedia.org/w/index.php?title=Special:FundraiserLandingPage', new=2) 710 | 711 | 712 | def _wiki_request(params): 713 | ''' 714 | Make a request to the Wikipedia API using the given search parameters. 715 | Returns a parsed dict of the JSON response. 716 | ''' 717 | global RATE_LIMIT_LAST_CALL 718 | global USER_AGENT 719 | 720 | params['format'] = 'json' 721 | if not 'action' in params: 722 | params['action'] = 'query' 723 | 724 | headers = { 725 | 'User-Agent': USER_AGENT 726 | } 727 | 728 | if RATE_LIMIT and RATE_LIMIT_LAST_CALL and \ 729 | RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT > datetime.now(): 730 | 731 | # it hasn't been long enough since the last API call 732 | # so wait until we're in the clear to make the request 733 | 734 | wait_time = (RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT) - datetime.now() 735 | time.sleep(int(wait_time.total_seconds())) 736 | 737 | r = requests.get(API_URL, params=params, headers=headers) 738 | 739 | if RATE_LIMIT: 740 | RATE_LIMIT_LAST_CALL = datetime.now() 741 | 742 | return r.json() 743 | -------------------------------------------------------------------------------- /src/wikipedia/wikipedia.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import requests 4 | import time 5 | from bs4 import BeautifulSoup 6 | from datetime import datetime, timedelta 7 | from decimal import Decimal 8 | 9 | from .exceptions import ( 10 | PageError, DisambiguationError, RedirectError, HTTPTimeoutError, 11 | WikipediaException, ODD_ERROR_MESSAGE) 12 | from .util import cache, stdout_encode, debug 13 | import re 14 | 15 | API_URL = 'http://en.wikipedia.org/w/api.php' 16 | RATE_LIMIT = False 17 | RATE_LIMIT_MIN_WAIT = None 18 | RATE_LIMIT_LAST_CALL = None 19 | USER_AGENT = 'wikipedia (https://github.com/goldsmith/Wikipedia/)' 20 | 21 | 22 | def set_lang(prefix): 23 | ''' 24 | Change the language of the API being requested. 25 | Set `prefix` to one of the two letter prefixes found on the `list of all Wikipedias `_. 26 | 27 | After setting the language, the cache for ``search``, ``suggest``, and ``summary`` will be cleared. 28 | 29 | .. note:: Make sure you search for page titles in the language that you have set. 30 | ''' 31 | global API_URL 32 | API_URL = 'http://' + prefix.lower() + '.wikipedia.org/w/api.php' 33 | 34 | for cached_func in (search, suggest, summary): 35 | cached_func.clear_cache() 36 | 37 | 38 | def set_user_agent(user_agent_string): 39 | ''' 40 | Set the User-Agent string to be used for all requests. 41 | 42 | Arguments: 43 | 44 | * user_agent_string - (string) a string specifying the User-Agent header 45 | ''' 46 | global USER_AGENT 47 | USER_AGENT = user_agent_string 48 | 49 | 50 | def set_rate_limiting(rate_limit, min_wait=timedelta(milliseconds=50)): 51 | ''' 52 | Enable or disable rate limiting on requests to the Mediawiki servers. 53 | If rate limiting is not enabled, under some circumstances (depending on 54 | load on Wikipedia, the number of requests you and other `wikipedia` users 55 | are making, and other factors), Wikipedia may return an HTTP timeout error. 56 | 57 | Enabling rate limiting generally prevents that issue, but please note that 58 | HTTPTimeoutError still might be raised. 59 | 60 | Arguments: 61 | 62 | * rate_limit - (Boolean) whether to enable rate limiting or not 63 | 64 | Keyword arguments: 65 | 66 | * min_wait - if rate limiting is enabled, `min_wait` is a timedelta describing the minimum time to wait before requests. 67 | Defaults to timedelta(milliseconds=50) 68 | ''' 69 | global RATE_LIMIT 70 | global RATE_LIMIT_MIN_WAIT 71 | global RATE_LIMIT_LAST_CALL 72 | 73 | RATE_LIMIT = rate_limit 74 | if not rate_limit: 75 | RATE_LIMIT_MIN_WAIT = None 76 | else: 77 | RATE_LIMIT_MIN_WAIT = min_wait 78 | 79 | RATE_LIMIT_LAST_CALL = None 80 | 81 | 82 | @cache 83 | def search(query, results=10, suggestion=False): 84 | ''' 85 | Do a Wikipedia search for `query`. 86 | 87 | Keyword arguments: 88 | 89 | * results - the maxmimum number of results returned 90 | * suggestion - if True, return results and suggestion (if any) in a tuple 91 | ''' 92 | 93 | search_params = { 94 | 'list': 'search', 95 | 'srprop': '', 96 | 'srlimit': results, 97 | 'limit': results, 98 | 'srsearch': query 99 | } 100 | if suggestion: 101 | search_params['srinfo'] = 'suggestion' 102 | 103 | raw_results = _wiki_request(search_params) 104 | 105 | if 'error' in raw_results: 106 | if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): 107 | raise HTTPTimeoutError(query) 108 | else: 109 | raise WikipediaException(raw_results['error']['info']) 110 | 111 | search_results = (d['title'] for d in raw_results['query']['search']) 112 | 113 | if suggestion: 114 | if raw_results['query'].get('searchinfo'): 115 | return list(search_results), raw_results['query']['searchinfo']['suggestion'] 116 | else: 117 | return list(search_results), None 118 | 119 | return list(search_results) 120 | 121 | 122 | @cache 123 | def geosearch(latitude, longitude, title=None, results=10, radius=1000): 124 | ''' 125 | Do a wikipedia geo search for `latitude` and `longitude` 126 | using HTTP API described in http://www.mediawiki.org/wiki/Extension:GeoData 127 | 128 | Arguments: 129 | 130 | * latitude (float or decimal.Decimal) 131 | * longitude (float or decimal.Decimal) 132 | 133 | Keyword arguments: 134 | 135 | * title - The title of an article to search for 136 | * results - the maximum number of results returned 137 | * radius - Search radius in meters. The value must be between 10 and 10000 138 | ''' 139 | 140 | search_params = { 141 | 'list': 'geosearch', 142 | 'gsradius': radius, 143 | 'gscoord': '{0}|{1}'.format(latitude, longitude), 144 | 'gslimit': results 145 | } 146 | if title: 147 | search_params['titles'] = title 148 | 149 | raw_results = _wiki_request(search_params) 150 | 151 | if 'error' in raw_results: 152 | if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): 153 | raise HTTPTimeoutError('{0}|{1}'.format(latitude, longitude)) 154 | else: 155 | raise WikipediaException(raw_results['error']['info']) 156 | 157 | search_pages = raw_results['query'].get('pages', None) 158 | if search_pages: 159 | search_results = (v['title'] for k, v in search_pages.items() if k != '-1') 160 | else: 161 | search_results = (d['title'] for d in raw_results['query']['geosearch']) 162 | 163 | return list(search_results) 164 | 165 | 166 | @cache 167 | def suggest(query): 168 | ''' 169 | Get a Wikipedia search suggestion for `query`. 170 | Returns a string or None if no suggestion was found. 171 | ''' 172 | 173 | search_params = { 174 | 'list': 'search', 175 | 'srinfo': 'suggestion', 176 | 'srprop': '', 177 | } 178 | search_params['srsearch'] = query 179 | 180 | raw_result = _wiki_request(search_params) 181 | 182 | if raw_result['query'].get('searchinfo'): 183 | return raw_result['query']['searchinfo']['suggestion'] 184 | 185 | return None 186 | 187 | 188 | def random(pages=1): 189 | ''' 190 | Get a list of random Wikipedia article titles. 191 | 192 | .. note:: Random only gets articles from namespace 0, meaning no Category, User talk, or other meta-Wikipedia pages. 193 | 194 | Keyword arguments: 195 | 196 | * pages - the number of random pages returned (max of 10) 197 | ''' 198 | #http://en.wikipedia.org/w/api.php?action=query&list=random&rnlimit=5000&format=jsonfm 199 | query_params = { 200 | 'list': 'random', 201 | 'rnnamespace': 0, 202 | 'rnlimit': pages, 203 | } 204 | 205 | request = _wiki_request(query_params) 206 | titles = [page['title'] for page in request['query']['random']] 207 | 208 | if len(titles) == 1: 209 | return titles[0] 210 | 211 | return titles 212 | 213 | 214 | @cache 215 | def summary(title, sentences=0, chars=0, auto_suggest=True, redirect=True): 216 | ''' 217 | Plain text summary of the page. 218 | 219 | .. note:: This is a convenience wrapper - auto_suggest and redirect are enabled by default 220 | 221 | Keyword arguments: 222 | 223 | * sentences - if set, return the first `sentences` sentences (can be no greater than 10). 224 | * chars - if set, return only the first `chars` characters (actual text returned may be slightly longer). 225 | * auto_suggest - let Wikipedia find a valid page title for the query 226 | * redirect - allow redirection without raising RedirectError 227 | ''' 228 | 229 | # use auto_suggest and redirect to get the correct article 230 | # also, use page's error checking to raise DisambiguationError if necessary 231 | page_info = page(title, auto_suggest=auto_suggest, redirect=redirect) 232 | if page_info == 'invalidreturn': 233 | return '............' 234 | else: 235 | if page_info.valid is False: 236 | return '............' 237 | else: 238 | title = page_info.title 239 | pageid = page_info.pageid 240 | 241 | query_params = { 242 | 'prop': 'extracts', 243 | 'explaintext': '', 244 | 'titles': title 245 | } 246 | 247 | if sentences: 248 | query_params['exsentences'] = sentences 249 | elif chars: 250 | query_params['exchars'] = chars 251 | else: 252 | query_params['exintro'] = '' 253 | 254 | request = _wiki_request(query_params) #the core step to retrieve from wikipedia given query parameters 255 | summary = request['query']['pages'][pageid]['extract'] 256 | 257 | return summary 258 | 259 | 260 | def page(title=None, pageid=None, auto_suggest=True, redirect=True, preload=False): 261 | ''' 262 | Get a WikipediaPage object for the page with title `title` or the pageid 263 | `pageid` (mutually exclusive). 264 | 265 | Keyword arguments: 266 | 267 | * title - the title of the page to load 268 | * pageid - the numeric pageid of the page to load 269 | * auto_suggest - let Wikipedia find a valid page title for the query 270 | * redirect - allow redirection without raising RedirectError 271 | * preload - load content, summary, images, references, and links during initialization 272 | ''' 273 | 274 | if title is not None: 275 | if auto_suggest: 276 | results, suggestion = search(title, results=1, suggestion=True) 277 | try: 278 | title = suggestion or results[0] 279 | except IndexError: 280 | # if there is no suggestion or search results, the page doesn't exist 281 | return 'invalidreturn' 282 | return WikipediaPage(title, redirect=redirect, preload=preload) 283 | elif pageid is not None: 284 | return WikipediaPage(pageid=pageid, preload=preload) 285 | else: 286 | return 'invalidreturn' 287 | 288 | 289 | 290 | class WikipediaPage(object): 291 | ''' 292 | Contains data from a Wikipedia page. 293 | Uses property methods to filter data from the raw HTML. 294 | ''' 295 | 296 | def __init__(self, title=None, pageid=None, redirect=True, preload=False, original_title=''): 297 | self.valid=True 298 | if title is not None: 299 | self.title = title 300 | self.original_title = original_title or title 301 | elif pageid is not None: 302 | self.pageid = pageid 303 | else: 304 | self.valid=False 305 | 306 | self.__load(redirect=redirect, preload=preload) 307 | 308 | if preload: 309 | for prop in ('content', 'summary', 'images', 'references', 'links', 'sections'): 310 | getattr(self, prop) 311 | 312 | def __repr__(self): 313 | return stdout_encode(u''.format(self.title)) 314 | 315 | def __eq__(self, other): 316 | try: 317 | return ( 318 | self.pageid == other.pageid 319 | and self.title == other.title 320 | and self.url == other.url 321 | ) 322 | except: 323 | return False 324 | 325 | def __load(self, redirect=True, preload=False): 326 | ''' 327 | Load basic information from Wikipedia. 328 | Confirm that page exists and is not a disambiguation/redirect. 329 | 330 | Does not need to be called manually, should be called automatically during __init__. 331 | ''' 332 | query_params = { 333 | 'prop': 'info|pageprops', 334 | 'inprop': 'url', 335 | 'ppprop': 'disambiguation', 336 | 'redirects': '', 337 | } 338 | if not getattr(self, 'pageid', None): 339 | query_params['titles'] = self.title 340 | else: 341 | query_params['pageids'] = self.pageid 342 | 343 | request = _wiki_request(query_params) 344 | 345 | query = request['query'] 346 | pageid = list(query['pages'].keys())[0] 347 | page = query['pages'][pageid] 348 | 349 | # missing is present if the page is missing 350 | if 'missing' in page: 351 | if hasattr(self, 'title'): 352 | self.valid=False 353 | else: 354 | self.valid=False 355 | 356 | # same thing for redirect, except it shows up in query instead of page for 357 | # whatever silly reason 358 | elif 'redirects' in query: 359 | if redirect: 360 | redirects = query['redirects'][0] 361 | 362 | if 'normalized' in query: 363 | normalized = query['normalized'][0] 364 | assert normalized['from'] == self.title, ODD_ERROR_MESSAGE 365 | 366 | from_title = normalized['to'] 367 | 368 | else: 369 | from_title = self.title 370 | 371 | assert redirects['from'] == from_title, ODD_ERROR_MESSAGE 372 | 373 | # change the title and reload the whole object 374 | self.__init__(redirects['to'], redirect=redirect, preload=preload) 375 | 376 | else: 377 | self.valid=False 378 | 379 | # since we only asked for disambiguation in ppprop, 380 | # if a pageprop is returned, 381 | # then the page must be a disambiguation page 382 | elif 'pageprops' in page: 383 | query_params = { 384 | 'prop': 'revisions', 385 | 'rvprop': 'content', 386 | 'rvparse': '', 387 | 'rvlimit': 1 388 | } 389 | if hasattr(self, 'pageid'): 390 | query_params['pageids'] = self.pageid 391 | else: 392 | query_params['titles'] = self.title 393 | request = _wiki_request(query_params) 394 | html = request['query']['pages'][pageid]['revisions'][0]['*'] 395 | 396 | lis = BeautifulSoup(html, 'html.parser').find_all('li') 397 | filtered_lis = [li for li in lis if not 'tocsection' in ''.join(li.get('class', []))] 398 | may_refer_to = [li.a.get_text() for li in filtered_lis if li.a] 399 | 400 | self.valid=False 401 | 402 | else: 403 | self.pageid = pageid 404 | self.title = page['title'] 405 | self.url = page['fullurl'] 406 | 407 | def __continued_query(self, query_params): 408 | ''' 409 | Based on https://www.mediawiki.org/wiki/API:Query#Continuing_queries 410 | ''' 411 | query_params.update(self.__title_query_param) 412 | 413 | last_continue = {} 414 | prop = query_params.get('prop', None) 415 | 416 | while True: 417 | params = query_params.copy() 418 | params.update(last_continue) 419 | 420 | request = _wiki_request(params) 421 | 422 | if 'query' not in request: 423 | break 424 | 425 | pages = request['query']['pages'] 426 | if 'generator' in query_params: 427 | for datum in pages.values(): # in python 3.3+: "yield from pages.values()" 428 | yield datum 429 | else: 430 | for datum in pages[self.pageid][prop]: 431 | yield datum 432 | 433 | if 'continue' not in request: 434 | break 435 | 436 | last_continue = request['continue'] 437 | 438 | @property 439 | def __title_query_param(self): 440 | if getattr(self, 'title', None) is not None: 441 | return {'titles': self.title} 442 | else: 443 | return {'pageids': self.pageid} 444 | 445 | def html(self): 446 | ''' 447 | Get full page HTML. 448 | 449 | .. warning:: This can get pretty slow on long pages. 450 | ''' 451 | 452 | if not getattr(self, '_html', False): 453 | query_params = { 454 | 'prop': 'revisions', 455 | 'rvprop': 'content', 456 | 'rvlimit': 1, 457 | 'rvparse': '', 458 | 'titles': self.title 459 | } 460 | 461 | request = _wiki_request(query_params) 462 | self._html = request['query']['pages'][self.pageid]['revisions'][0]['*'] 463 | 464 | return self._html 465 | 466 | @property 467 | def content(self): 468 | ''' 469 | Plain text content of the page, excluding images, tables, and other data. 470 | ''' 471 | 472 | if not getattr(self, '_content', False): 473 | query_params = { 474 | 'prop': 'extracts|revisions', 475 | 'explaintext': '', 476 | 'rvprop': 'ids' 477 | } 478 | if not getattr(self, 'title', None) is None: 479 | query_params['titles'] = self.title 480 | else: 481 | query_params['pageids'] = self.pageid 482 | request = _wiki_request(query_params) 483 | self._content = request['query']['pages'][self.pageid]['extract'] 484 | self._revision_id = request['query']['pages'][self.pageid]['revisions'][0]['revid'] 485 | self._parent_id = request['query']['pages'][self.pageid]['revisions'][0]['parentid'] 486 | 487 | return self._content 488 | 489 | @property 490 | def revision_id(self): 491 | ''' 492 | Revision ID of the page. 493 | 494 | The revision ID is a number that uniquely identifies the current 495 | version of the page. It can be used to create the permalink or for 496 | other direct API calls. See `Help:Page history 497 | `_ for more 498 | information. 499 | ''' 500 | 501 | if not getattr(self, '_revid', False): 502 | # fetch the content (side effect is loading the revid) 503 | self.content 504 | 505 | return self._revision_id 506 | 507 | @property 508 | def parent_id(self): 509 | ''' 510 | Revision ID of the parent version of the current revision of this 511 | page. See ``revision_id`` for more information. 512 | ''' 513 | 514 | if not getattr(self, '_parentid', False): 515 | # fetch the content (side effect is loading the revid) 516 | self.content 517 | 518 | return self._parent_id 519 | 520 | @property 521 | def summary(self): 522 | ''' 523 | Plain text summary of the page. 524 | ''' 525 | 526 | if not getattr(self, '_summary', False): 527 | query_params = { 528 | 'prop': 'extracts', 529 | 'explaintext': '', 530 | 'exintro': '', 531 | } 532 | if not getattr(self, 'title', None) is None: 533 | query_params['titles'] = self.title 534 | else: 535 | query_params['pageids'] = self.pageid 536 | 537 | request = _wiki_request(query_params) 538 | self._summary = request['query']['pages'][self.pageid]['extract'] 539 | 540 | return self._summary 541 | 542 | @property 543 | def images(self): 544 | ''' 545 | List of URLs of images on the page. 546 | ''' 547 | 548 | if not getattr(self, '_images', False): 549 | self._images = [ 550 | page['imageinfo'][0]['url'] 551 | for page in self.__continued_query({ 552 | 'generator': 'images', 553 | 'gimlimit': 'max', 554 | 'prop': 'imageinfo', 555 | 'iiprop': 'url', 556 | }) 557 | if 'imageinfo' in page 558 | ] 559 | 560 | return self._images 561 | 562 | @property 563 | def coordinates(self): 564 | ''' 565 | Tuple of Decimals in the form of (lat, lon) or None 566 | ''' 567 | if not getattr(self, '_coordinates', False): 568 | query_params = { 569 | 'prop': 'coordinates', 570 | 'colimit': 'max', 571 | 'titles': self.title, 572 | } 573 | 574 | request = _wiki_request(query_params) 575 | 576 | if 'query' in request: 577 | coordinates = request['query']['pages'][self.pageid]['coordinates'] 578 | self._coordinates = (Decimal(coordinates[0]['lat']), Decimal(coordinates[0]['lon'])) 579 | else: 580 | self._coordinates = None 581 | 582 | return self._coordinates 583 | 584 | @property 585 | def references(self): 586 | ''' 587 | List of URLs of external links on a page. 588 | May include external links within page that aren't technically cited anywhere. 589 | ''' 590 | 591 | if not getattr(self, '_references', False): 592 | def add_protocol(url): 593 | return url if url.startswith('http') else 'http:' + url 594 | 595 | self._references = [ 596 | add_protocol(link['*']) 597 | for link in self.__continued_query({ 598 | 'prop': 'extlinks', 599 | 'ellimit': 'max' 600 | }) 601 | ] 602 | 603 | return self._references 604 | 605 | @property 606 | def links(self): 607 | ''' 608 | List of titles of Wikipedia page links on a page. 609 | 610 | .. note:: Only includes articles from namespace 0, meaning no Category, User talk, or other meta-Wikipedia pages. 611 | ''' 612 | 613 | if not getattr(self, '_links', False): 614 | self._links = [ 615 | link['title'] 616 | for link in self.__continued_query({ 617 | 'prop': 'links', 618 | 'plnamespace': 0, 619 | 'pllimit': 'max' 620 | }) 621 | ] 622 | 623 | return self._links 624 | 625 | @property 626 | def categories(self): 627 | ''' 628 | List of categories of a page. 629 | ''' 630 | 631 | if not getattr(self, '_categories', False): 632 | self._categories = [re.sub(r'^Category:', '', x) for x in 633 | [link['title'] 634 | for link in self.__continued_query({ 635 | 'prop': 'categories', 636 | 'cllimit': 'max' 637 | }) 638 | ]] 639 | 640 | return self._categories 641 | 642 | @property 643 | def sections(self): 644 | ''' 645 | List of section titles from the table of contents on the page. 646 | ''' 647 | 648 | if not getattr(self, '_sections', False): 649 | query_params = { 650 | 'action': 'parse', 651 | 'prop': 'sections', 652 | } 653 | query_params.update(self.__title_query_param) 654 | 655 | request = _wiki_request(query_params) 656 | self._sections = [section['line'] for section in request['parse']['sections']] 657 | 658 | return self._sections 659 | 660 | def section(self, section_title): 661 | ''' 662 | Get the plain text content of a section from `self.sections`. 663 | Returns None if `section_title` isn't found, otherwise returns a whitespace stripped string. 664 | 665 | This is a convenience method that wraps self.content. 666 | 667 | .. warning:: Calling `section` on a section that has subheadings will NOT return 668 | the full text of all of the subsections. It only gets the text between 669 | `section_title` and the next subheading, which is often empty. 670 | ''' 671 | 672 | section = u"== {} ==".format(section_title) 673 | try: 674 | index = self.content.index(section) + len(section) 675 | except ValueError: 676 | return None 677 | 678 | try: 679 | next_index = self.content.index("==", index) 680 | except ValueError: 681 | next_index = len(self.content) 682 | 683 | return self.content[index:next_index].lstrip("=").strip() 684 | 685 | 686 | @cache 687 | def languages(): 688 | ''' 689 | List all the currently supported language prefixes (usually ISO language code). 690 | 691 | Can be inputted to `set_lang` to change the Mediawiki that `wikipedia` requests 692 | results from. 693 | 694 | Returns: dict of : pairs. To get just a list of prefixes, 695 | use `wikipedia.languages().keys()`. 696 | ''' 697 | response = _wiki_request({ 698 | 'meta': 'siteinfo', 699 | 'siprop': 'languages' 700 | }) 701 | 702 | languages = response['query']['languages'] 703 | 704 | return { 705 | lang['code']: lang['*'] 706 | for lang in languages 707 | } 708 | 709 | 710 | def donate(): 711 | ''' 712 | Open up the Wikimedia donate page in your favorite browser. 713 | ''' 714 | import webbrowser 715 | 716 | webbrowser.open('https://donate.wikimedia.org/w/index.php?title=Special:FundraiserLandingPage', new=2) 717 | 718 | 719 | def _wiki_request(params): 720 | ''' 721 | Make a request to the Wikipedia API using the given search parameters. 722 | Returns a parsed dict of the JSON response. 723 | ''' 724 | global RATE_LIMIT_LAST_CALL 725 | global USER_AGENT 726 | 727 | params['format'] = 'json' 728 | if not 'action' in params: 729 | params['action'] = 'query' 730 | 731 | headers = { 732 | 'User-Agent': USER_AGENT 733 | } 734 | 735 | if RATE_LIMIT and RATE_LIMIT_LAST_CALL and \ 736 | RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT > datetime.now(): 737 | 738 | # it hasn't been long enough since the last API call 739 | # so wait until we're in the clear to make the request 740 | 741 | wait_time = (RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT) - datetime.now() 742 | time.sleep(int(wait_time.total_seconds())) 743 | 744 | r = requests.get(API_URL, params=params, headers=headers) 745 | 746 | if RATE_LIMIT: 747 | RATE_LIMIT_LAST_CALL = datetime.now() 748 | 749 | return r.json() 750 | -------------------------------------------------------------------------------- /src/preprocess_yahoo.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from collections import defaultdict 3 | import numpy as np 4 | import statistics 5 | yahoo_path = '/export/home/Dataset/YahooClassification/yahoo_answers_csv/' 6 | 7 | # ''' 8 | # Society & Culture 9 | # Science & Mathematics 10 | # Health 11 | # Education & Reference 12 | # Computers & Internet 13 | # Sports 14 | # Business & Finance 15 | # Entertainment & Music 16 | # Family & Relationships 17 | # Politics & Government 18 | # ''' 19 | # type2hypothesis = { 20 | # 0: ['it is related with society or culture', 'this text describes something about an extended social group having a distinctive cultural and economic organization or a particular society at a particular time and place'], 21 | # 1:['it is related with science or mathematics', 'this text describes something about a particular branch of scientific knowledge or a science (or group of related sciences) dealing with the logic of quantity and shape and arrangement'], 22 | # 2: ['it is related with health', 'this text describes something about a healthy state of wellbeing free from disease'], 23 | # 3: ['it is related with education or reference', 'this text describes something about the activities of educating or instructing or activities that impart knowledge or skill or an indicator that orients you generally'], 24 | # 4: ['it is related with computers or Internet', 'this text describes something about a machine for performing calculations automatically or a computer network consisting of a worldwide network of computer networks that use the TCP/IP network protocols to facilitate data transmission and exchange'], 25 | # 5: ['it is related with sports', 'this text describes something about an active diversion requiring physical exertion and competition'], 26 | # 6: ['it is related with business or finance', 'this text describes something about a commercial or industrial enterprise and the people who constitute it or the commercial activity of providing funds and capital'], 27 | # 7: ['it is related with entertainment or music', 'this text describes something about an activity that is diverting and that holds the attention or an artistic form of auditory communication incorporating instrumental or vocal tones in a structured and continuous manner'], 28 | # 8: ['it is related with family or relationships', 'this text describes something about a social unit living together, primary social group; parents and children or a relation between people'], 29 | # 9: ['it is related with politics or government', 'this text describes something about social relations involving intrigue to gain authority or power or the organization that is the governing authority of a political unit']} 30 | # # society, culture, science, mathematics, health, education, reference, computers, Internet, sports, business, finance, entertainment, music, family, relationships, politics, government 31 | # 32 | # # type2hypothesis = { 33 | # # 0: ['this text describes something about society or culture, not about science, mathematics, health, education, reference, computers, Internet, sports, business, finance, entertainment, music, family, relationships, politics, government'], 34 | # # 1:['this text describes something about science or mathematics, not about society or culture, not about health, education, reference, computers, Internet, sports, business, finance, entertainment, music, family, relationships, politics, government'], 35 | # # 2: ['this text describes something about health, not about society, culture, science, mathematics, education, reference, computers, Internet, sports, business, finance, entertainment, music, family, relationships, politics, government'], 36 | # # 3: ['this text describes something about education or reference, not about society, culture, science, mathematics, education, reference, computers, Internet, sports, business, finance, entertainment, music, family, relationships, politics, government'], 37 | # # 4: ['this text describes something about computers or Internet, not about society, culture, science, mathematics, health, education, reference, sports, business, finance, entertainment, music, family, relationships, politics, government'], 38 | # # 5: ['this text describes something about sports, not about society, culture, science, mathematics, health, education, reference, computers, Internet, business, finance, entertainment, music, family, relationships, politics, government'], 39 | # # 6: ['this text describes something about business or finance, not about society, culture, science, mathematics, health, education, reference, computers, Internet, sports, entertainment, music, family, relationships, politics, government'], 40 | # # 7: ['this text describes something about entertainment or music, not about society, culture, science, mathematics, health, education, reference, computers, Internet, sports, business, finance, family, relationships, politics, government'], 41 | # # 8: ['this text describes something about family or relationships, not about society, culture, science, mathematics, health, education, reference, computers, Internet, sports, business, finance, entertainment, music, politics, government'], 42 | # # 9: ['this text describes something about politics or government, not about society, culture, science, mathematics, health, education, reference, computers, Internet, sports, business, finance, entertainment, music, family, relationships']} 43 | # 44 | # def load_labels(word2id, maxlen): 45 | # 46 | # texts=[] 47 | # text_masks=[] 48 | # 49 | # readfile=codecs.open(yahoo_path+'classes.txt', 'r', 'utf-8') 50 | # for line in readfile: 51 | # wordlist = line.strip().replace('&', ' ').lower().split() 52 | # 53 | # text_idlist, text_masklist=transfer_wordlist_2_idlist_with_maxlen(wordlist, word2id, maxlen) 54 | # texts.append(text_idlist) 55 | # text_masks.append(text_masklist) 56 | # 57 | # print('\t\t\t totally :', len(texts), 'label names') 58 | # 59 | # return texts, text_masks, word2id 60 | # 61 | # 62 | # 63 | # 64 | # 65 | # def convert_yahoo_train_zeroshot(): 66 | # train_type_set = set([0,2,4,6,8]) 67 | # # train_type_set = set([1,3,5,7,9]) 68 | # id2size = defaultdict(int) 69 | # 70 | # readfile=codecs.open(yahoo_path+'train_tokenized.txt', 'r', 'utf-8') 71 | # writefile = codecs.open(yahoo_path+'zero-shot-split/train.two.phases.txt', 'w', 'utf-8') 72 | # line_co=0 73 | # for line in readfile: 74 | # parts = line.strip().split('\t') 75 | # if len(parts)==2: 76 | # label_id = int(parts[0]) 77 | # if label_id in train_type_set: 78 | # id2size[label_id]+=1 79 | # sent1 = parts[1].strip() 80 | # # start write hypo 81 | # idd=0 82 | # while idd < 10: 83 | # '''only consider pos and neg in seen labels''' 84 | # if idd in train_type_set: 85 | # hypo_list = type2hypothesis.get(idd) 86 | # for hypo in hypo_list: 87 | # if idd == label_id: 88 | # writefile.write('1\t'+sent1+'\t'+hypo+'\t'+str(line_co)+':'+str(idd)+'\n') 89 | # else: 90 | # writefile.write('0\t'+sent1+'\t'+hypo+'\t'+str(line_co)+':'+str(idd)+'\n') 91 | # idd +=1 92 | # line_co+=1 93 | # if line_co%10000==0: 94 | # print('line_co:', line_co) 95 | # print('dataset loaded over, id2size:', id2size, 'total read lines:',line_co ) 96 | # 97 | # def convert_yahoo_test_zeroshot(): 98 | # id2size = defaultdict(int) 99 | # 100 | # readfile=codecs.open(yahoo_path+'test_tokenized.txt', 'r', 'utf-8') 101 | # writefile = codecs.open(yahoo_path+'zero-shot-split/test.two.phases.txt', 'w', 'utf-8') 102 | # line_co=0 103 | # for line in readfile: 104 | # parts = line.strip().split('\t') 105 | # if len(parts)==2: 106 | # label_id = int(parts[0]) 107 | # id2size[label_id]+=1 108 | # sent1 = parts[1].strip() 109 | # # start write hypo 110 | # idd=0 111 | # while idd < 10: 112 | # hypo_list = type2hypothesis.get(idd) 113 | # for hypo in hypo_list: 114 | # if idd == label_id: 115 | # writefile.write('1\t'+sent1+'\t'+hypo+'\t'+str(line_co)+':'+str(idd)+'\n') 116 | # else: 117 | # writefile.write('0\t'+sent1+'\t'+hypo+'\t'+str(line_co)+':'+str(idd)+'\n') 118 | # idd +=1 119 | # line_co+=1 120 | # if line_co%10000==0: 121 | # print('line_co:', line_co) 122 | # print('dataset loaded over, id2size:', id2size, 'total read lines:',line_co ) 123 | # 124 | # def evaluate_Yahoo_zeroshot(preds, gold_label_list, coord_list, seen_col_set): 125 | # ''' 126 | # preds: probability vector 127 | # ''' 128 | # pred_list = list(preds) 129 | # assert len(pred_list) == len(gold_label_list) 130 | # seen_hit=0 131 | # unseen_hit = 0 132 | # seen_size = 0 133 | # unseen_size = 0 134 | # 135 | # 136 | # start = 0 137 | # end = 0 138 | # total_sizes = [0.0]*10 139 | # hit_sizes = [0.0]*10 140 | # while end< len(coord_list): 141 | # # print('end:', end) 142 | # # print('start:', start) 143 | # # print('len(coord_list):', len(coord_list)) 144 | # while end< len(coord_list) and int(coord_list[end].split(':')[0]) == int(coord_list[start].split(':')[0]): 145 | # end+=1 146 | # preds_row = pred_list[start:end] 147 | # # print('preds_row:',preds_row) 148 | # gold_label_row = gold_label_list[start:end] 149 | # # print('gold_label_row:',gold_label_row) 150 | # # print(start,end) 151 | # # assert sum(gold_label_row) >= 1 152 | # coord_list_row = [int(x.split(':')[1]) for x in coord_list[start:end]] 153 | # # print('coord_list_row:',coord_list_row) 154 | # # print(start,end) 155 | # # assert coord_list_row == [0,0,1,2,3,4,5,6,7,8,9] 156 | # '''max_pred_id = np.argmax(np.asarray(preds_row)) is wrong, since argmax can be >=10''' 157 | # max_pred_id = np.argmax(np.asarray(preds_row)) 158 | # pred_label_id = coord_list_row[max_pred_id] 159 | # gold_label = -1 160 | # for idd, gold in enumerate(gold_label_row): 161 | # if gold == 1: 162 | # gold_label = coord_list_row[idd] 163 | # break 164 | # # assert gold_label!=-1 165 | # if gold_label == -1: 166 | # if end == len(coord_list): 167 | # break 168 | # else: 169 | # print('gold_label_row:',gold_label_row) 170 | # exit(0) 171 | # 172 | # total_sizes[gold_label]+=1 173 | # if gold_label == pred_label_id: 174 | # hit_sizes[gold_label]+=1 175 | # 176 | # start = end 177 | # 178 | # # seen_acc = statistics.mean([hit_sizes[i]/total_sizes[i] for i in range(10) if i in seen_col_set]) 179 | # seen_hit = sum([hit_sizes[i] for i in range(10) if i in seen_col_set]) 180 | # seen_total = sum([total_sizes[i] for i in range(10) if i in seen_col_set]) 181 | # unseen_hit = sum([hit_sizes[i] for i in range(10) if i not in seen_col_set]) 182 | # unseen_total = sum([total_sizes[i] for i in range(10) if i not in seen_col_set]) 183 | # # unseen_acc = statistics.mean([hit_sizes[i]/total_sizes[i] for i in range(10) if i not in seen_col_set]) 184 | # print('acc for each label:', [hit_sizes[i]/total_sizes[i] for i in range(10)]) 185 | # print('total_sizes:',total_sizes) 186 | # print('hit_sizes:',hit_sizes) 187 | # 188 | # return seen_hit/(1e-6+seen_total), unseen_hit/(1e-6+unseen_total) 189 | # 190 | # 191 | # def evaluate_Yahoo_zeroshot_2phases(pred_probs, pred_labels, gold_label_list, coord_list, seen_col_set): 192 | # ''' 193 | # pred_probs: probability vector 194 | # pred_labels: a list of 0/1 195 | # ''' 196 | # pred_list = list(pred_probs) 197 | # pred_labels = list(pred_labels) 198 | # assert len(pred_list) == len(gold_label_list) 199 | # seen_hit=0 200 | # unseen_hit = 0 201 | # seen_size = 0 202 | # unseen_size = 0 203 | # 204 | # 205 | # start = 0 206 | # end = 0 207 | # total_sizes = [0.0]*10 208 | # hit_sizes = [0.0]*10 209 | # while end< len(coord_list): 210 | # # print('end:', end) 211 | # # print('start:', start) 212 | # # print('len(coord_list):', len(coord_list)) 213 | # while end< len(coord_list) and int(coord_list[end].split(':')[0]) == int(coord_list[start].split(':')[0]): 214 | # end+=1 215 | # pred_probs_row = pred_list[start:end] 216 | # pred_label_row = pred_labels[start:end] 217 | # 218 | # gold_label_row = gold_label_list[start:end] 219 | # # print('gold_label_row:',gold_label_row) 220 | # # print(start,end) 221 | # # assert sum(gold_label_row) >= 1 222 | # coord_list_row = [int(x.split(':')[1]) for x in coord_list[start:end]] 223 | # # print('coord_list_row:',coord_list_row) 224 | # # print(start,end) 225 | # # assert coord_list_row == [0,0,1,2,3,4,5,6,7,8,9] 226 | # '''max_pred_id = np.argmax(np.asarray(pred_probs_row)) is wrong, since argmax can be >=10''' 227 | # '''pred label -- finalize''' 228 | # pred_label = -1 229 | # unseen_col_with_max_prob = -1 230 | # max_prob = -10.0 231 | # for idd, col in enumerate(coord_list_row): 232 | # if col in seen_col_set and pred_label_row[idd] == 0: # 0 is entailment 233 | # pred_label = col 234 | # elif col not in seen_col_set: # unseen class 235 | # if pred_probs_row[idd] > max_prob: 236 | # max_prob = pred_probs_row[idd] 237 | # unseen_col_with_max_prob = col 238 | # pred_label = unseen_col_with_max_prob if pred_label==-1 else pred_label 239 | # 240 | # 241 | # # max_pred_id = np.argmax(np.asarray(pred_probs_row)) 242 | # # pred_label_id = coord_list_row[max_pred_id] 243 | # '''gold label''' 244 | # gold_label = -1 245 | # for idd, gold in enumerate(gold_label_row): 246 | # if gold == 1: 247 | # gold_label = coord_list_row[idd] 248 | # break 249 | # # assert gold_label!=-1 250 | # if gold_label == -1: 251 | # if end == len(coord_list): 252 | # break 253 | # else: 254 | # print('gold_label_row:',gold_label_row) 255 | # exit(0) 256 | # print('pred_probs_row:',pred_probs_row) 257 | # print('pred_label_row:',pred_label_row) 258 | # print('gold_label_row:',gold_label_row) 259 | # print('coord_list_row:',coord_list_row) 260 | # print('gold_label:',gold_label) 261 | # print('pred_label:',pred_label) 262 | # 263 | # total_sizes[gold_label]+=1 264 | # if gold_label == pred_label: 265 | # hit_sizes[gold_label]+=1 266 | # 267 | # start = end 268 | # 269 | # # seen_acc = statistics.mean([hit_sizes[i]/total_sizes[i] for i in range(10) if i in seen_col_set]) 270 | # seen_hit = sum([hit_sizes[i] for i in range(10) if i in seen_col_set]) 271 | # seen_total = sum([total_sizes[i] for i in range(10) if i in seen_col_set]) 272 | # unseen_hit = sum([hit_sizes[i] for i in range(10) if i not in seen_col_set]) 273 | # unseen_total = sum([total_sizes[i] for i in range(10) if i not in seen_col_set]) 274 | # # unseen_acc = statistics.mean([hit_sizes[i]/total_sizes[i] for i in range(10) if i not in seen_col_set]) 275 | # print('acc for each label:', [hit_sizes[i]/total_sizes[i] for i in range(10)]) 276 | # print('total_sizes:',total_sizes) 277 | # print('hit_sizes:',hit_sizes) 278 | # 279 | # return seen_hit/(1e-6+seen_total), unseen_hit/(1e-6+unseen_total) 280 | 281 | 282 | def build_zeroshot_devset(): 283 | '''directly copy from current testset''' 284 | id2size = defaultdict(int) 285 | 286 | readfile=codecs.open(yahoo_path+'test_tokenized.txt', 'r', 'utf-8') 287 | writefile = codecs.open(yahoo_path+'zero-shot-split/dev.txt', 'w', 'utf-8') 288 | line_co=0 289 | for line in readfile: 290 | writefile.write(line.strip()+'\n') 291 | line_co+=1 292 | writefile.close() 293 | readfile.close() 294 | print('create dev set size:',line_co ) 295 | print('build dev over') 296 | 297 | def build_zeroshot_testset(): 298 | '''extract 100K from current train as test set''' 299 | # train_type_set = set([0,2,4,6,8]) 300 | # # train_type_set = set([1,3,5,7,9]) 301 | id2size = defaultdict(int) 302 | 303 | readfile=codecs.open(yahoo_path+'train_tokenized.txt', 'r', 'utf-8') 304 | writefile_test = codecs.open(yahoo_path+'zero-shot-split/test.txt', 'w', 'utf-8') 305 | writefile_remain = codecs.open(yahoo_path+'train_tokenized_wo_test.txt', 'w', 'utf-8') 306 | line_co=0 307 | for line in readfile: 308 | parts = line.strip().split('\t') 309 | if len(parts)==2: 310 | label_id = int(parts[0]) 311 | copy_size = id2size.get(label_id, 0) 312 | if copy_size < 10000: 313 | '''copy to test''' 314 | writefile_test.write(line.strip()+'\n') 315 | id2size[label_id]+=1 316 | line_co+=1 317 | else: 318 | '''keep in train''' 319 | writefile_remain.write(line.strip()+'\n') 320 | writefile_test.close() 321 | writefile_remain.close() 322 | print('dataset loaded over, id2size:', id2size, 'total read lines:',line_co ) 323 | print('build test over') 324 | 325 | def build_zeroshot_trainset(): 326 | '''extract 100K from current train as test set''' 327 | train_type_set = set([0,2,4,6,8]) 328 | # # train_type_set = set([1,3,5,7,9]) 329 | id2size = defaultdict(int) 330 | 331 | readfile=codecs.open(yahoo_path+'train_tokenized_wo_test.txt', 'r', 'utf-8') 332 | '''store classes 0,2,4,6,8''' 333 | writefile_PU_half_0 = codecs.open(yahoo_path+'zero-shot-split/train_pu_half_v0.txt', 'w', 'utf-8') 334 | '''store classes 1,3,5,7,9''' 335 | writefile_PU_half_1 = codecs.open(yahoo_path+'zero-shot-split/train_pu_half_v1.txt', 'w', 'utf-8') 336 | 337 | line_co=0 338 | for line in readfile: 339 | parts = line.strip().split('\t') 340 | if len(parts)==2: 341 | label_id = int(parts[0]) 342 | if label_id in train_type_set: 343 | writefile_PU_half_0.write(line.strip()+'\n') 344 | else: 345 | writefile_PU_half_1.write(line.strip()+'\n') 346 | writefile_PU_half_0.close() 347 | writefile_PU_half_1.close() 348 | readfile.close() 349 | print('PU half over') 350 | '''PU_one''' 351 | 352 | for i in range(10): 353 | readfile=codecs.open(yahoo_path+'train_tokenized_wo_test.txt', 'r', 'utf-8') 354 | writefile_PU_one = codecs.open(yahoo_path+'zero-shot-split/train_pu_one_'+'wo_'+str(i)+'.txt', 'w', 'utf-8') 355 | line_co=0 356 | for line in readfile: 357 | parts = line.strip().split('\t') 358 | if len(parts)==2: 359 | label_id = int(parts[0]) 360 | if label_id != i: 361 | writefile_PU_one.write(line.strip()+'\n') 362 | line_co+=1 363 | writefile_PU_one.close() 364 | readfile.close() 365 | print('write size:', line_co) 366 | print('build train over') 367 | 368 | 369 | def evaluate_Yahoo_zeroshot_TwpPhasePred(pred_probs, pred_binary_labels_harsh, pred_binary_labels_loose, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index, seen_types): 370 | ''' 371 | pred_probs: a list, the prob for "entail" 372 | pred_binary_labels: a lit, each for 0 or 1 373 | eval_label_list: the gold type index; list length == lines in dev.txt 374 | eval_hypo_seen_str_indicator: totally hypo size, seen or unseen 375 | eval_hypo_2_type_index:: total hypo size, the type in [0,...n] 376 | seen_types: a set of type indices 377 | ''' 378 | 379 | pred_probs = list(pred_probs) 380 | # pred_binary_labels = list(pred_binary_labels) 381 | total_hypo_size = len(eval_hypo_seen_str_indicator) 382 | total_premise_size = len(eval_label_list) 383 | assert len(pred_probs) == total_premise_size*total_hypo_size 384 | assert len(eval_hypo_seen_str_indicator) == len(eval_hypo_2_type_index) 385 | 386 | # print('seen_types:', seen_types) 387 | # print('eval_hypo_seen_str_indicator:', eval_hypo_seen_str_indicator) 388 | # print('eval_hypo_2_type_index:', eval_hypo_2_type_index) 389 | 390 | 391 | seen_hit=0 392 | unseen_hit = 0 393 | seen_size = 0 394 | unseen_size = 0 395 | 396 | for i in range(total_premise_size): 397 | pred_probs_per_premise = pred_probs[i*total_hypo_size: (i+1)*total_hypo_size] 398 | pred_binary_labels_per_premise_harsh = pred_binary_labels_harsh[i*total_hypo_size: (i+1)*total_hypo_size] 399 | pred_binary_labels_per_premise_loose = pred_binary_labels_loose[i*total_hypo_size: (i+1)*total_hypo_size] 400 | 401 | 402 | # print('pred_probs_per_premise:',pred_probs_per_premise) 403 | # print('pred_binary_labels_per_premise:', pred_binary_labels_per_premise) 404 | 405 | 406 | '''first check if seen types get 'entailment''' 407 | seen_get_entail_flag=False 408 | for j in range(total_hypo_size): 409 | if eval_hypo_seen_str_indicator[j] == 'seen' and pred_binary_labels_per_premise_loose[j]==0: 410 | seen_get_entail_flag=True 411 | break 412 | '''first check if unseen types get 'entailment''' 413 | unseen_get_entail_flag=False 414 | for j in range(total_hypo_size): 415 | if eval_hypo_seen_str_indicator[j] == 'unseen' and pred_binary_labels_per_premise_loose[j]==0: 416 | unseen_get_entail_flag=True 417 | break 418 | 419 | if seen_get_entail_flag and unseen_get_entail_flag or \ 420 | (not seen_get_entail_flag and not unseen_get_entail_flag): 421 | '''compare their max prob''' 422 | max_prob_seen = -1.0 423 | max_seen_index = -1 424 | max_prob_unseen = -1.0 425 | max_unseen_index = -1 426 | for j in range(total_hypo_size): 427 | its_prob = pred_probs_per_premise[j] 428 | if eval_hypo_seen_str_indicator[j] == 'unseen': 429 | if its_prob > max_prob_unseen: 430 | max_prob_unseen = its_prob 431 | max_unseen_index = j 432 | else: 433 | if its_prob > max_prob_seen: 434 | max_prob_seen = its_prob 435 | max_seen_index = j 436 | if max_prob_seen - max_prob_unseen > 0.1: 437 | pred_type = eval_hypo_2_type_index[max_seen_index] 438 | else: 439 | pred_type = eval_hypo_2_type_index[max_unseen_index] 440 | 441 | elif unseen_get_entail_flag: 442 | '''find the unseen type with highest prob''' 443 | max_j = -1 444 | max_prob = -1.0 445 | for j in range(total_hypo_size): 446 | if eval_hypo_seen_str_indicator[j] == 'unseen': 447 | its_prob = pred_probs_per_premise[j] 448 | if its_prob > max_prob: 449 | max_prob = its_prob 450 | max_j = j 451 | pred_type = eval_hypo_2_type_index[max_j] 452 | 453 | elif seen_get_entail_flag: 454 | '''find the seen type with highest prob''' 455 | max_j = -1 456 | max_prob = -1.0 457 | for j in range(total_hypo_size): 458 | if eval_hypo_seen_str_indicator[j] == 'seen' and pred_binary_labels_per_premise_loose[j]==0: 459 | its_prob = pred_probs_per_premise[j] 460 | if its_prob > max_prob: 461 | max_prob = its_prob 462 | max_j = j 463 | assert max_prob > 0.5 464 | pred_type = eval_hypo_2_type_index[max_j] 465 | gold_type = eval_label_list[i] 466 | 467 | # print('pred_type:', pred_type, 'gold_type:', gold_type) 468 | if gold_type in seen_types: 469 | seen_size+=1 470 | if gold_type == pred_type: 471 | seen_hit+=1 472 | else: 473 | unseen_size+=1 474 | if gold_type == pred_type: 475 | unseen_hit+=1 476 | 477 | seen_acc = seen_hit/(1e-6+seen_size) 478 | unseen_acc = unseen_hit/(1e-6+unseen_size) 479 | 480 | return seen_acc, unseen_acc 481 | 482 | def evaluate_Yahoo_zeroshot_SinglePhasePred(pred_probs, pred_binary_labels_harsh, pred_binary_labels_loose, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index, seen_types): 483 | ''' 484 | pred_probs: a list, the prob for "entail" 485 | pred_binary_labels: a lit, each for 0 or 1 486 | eval_label_list: the gold type index; list length == lines in dev.txt 487 | eval_hypo_seen_str_indicator: totally hypo size, seen or unseen 488 | eval_hypo_2_type_index:: total hypo size, the type in [0,...n] 489 | seen_types: a set of type indices 490 | ''' 491 | 492 | pred_probs = list(pred_probs) 493 | # pred_binary_labels = list(pred_binary_labels) 494 | total_hypo_size = len(eval_hypo_seen_str_indicator) 495 | total_premise_size = len(eval_label_list) 496 | assert len(pred_probs) == total_premise_size*total_hypo_size 497 | assert len(eval_hypo_seen_str_indicator) == len(eval_hypo_2_type_index) 498 | 499 | seen_hit=0 500 | unseen_hit = 0 501 | seen_size = 0 502 | unseen_size = 0 503 | 504 | for i in range(total_premise_size): 505 | pred_probs_per_premise = pred_probs[i*total_hypo_size: (i+1)*total_hypo_size] 506 | pred_binary_labels_per_premise_harsh = pred_binary_labels_harsh[i*total_hypo_size: (i+1)*total_hypo_size] 507 | pred_binary_labels_per_premise_loose = pred_binary_labels_loose[i*total_hypo_size: (i+1)*total_hypo_size] 508 | 509 | max_prob = -100.0 510 | max_index = -1 511 | for j in range(total_hypo_size): 512 | if pred_probs_per_premise[j] > max_prob: 513 | max_prob = pred_probs_per_premise[j] 514 | max_index = j 515 | 516 | pred_type = eval_hypo_2_type_index[max_index] 517 | gold_type = eval_label_list[i] 518 | 519 | # print('pred_type:', pred_type, 'gold_type:', gold_type) 520 | if gold_type in seen_types: 521 | seen_size+=1 522 | if gold_type == pred_type: 523 | seen_hit+=1 524 | else: 525 | unseen_size+=1 526 | if gold_type == pred_type: 527 | unseen_hit+=1 528 | 529 | seen_acc = seen_hit/(1e-6+seen_size) 530 | unseen_acc = unseen_hit/(1e-6+unseen_size) 531 | 532 | return seen_acc, unseen_acc 533 | 534 | 535 | 536 | 537 | if __name__ == '__main__': 538 | build_zeroshot_devset() 539 | build_zeroshot_testset() 540 | build_zeroshot_trainset() 541 | -------------------------------------------------------------------------------- /src/demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import csv 22 | import logging 23 | import os 24 | import random 25 | import sys 26 | import codecs 27 | import numpy as np 28 | import torch 29 | from collections import defaultdict 30 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 31 | TensorDataset) 32 | from torch.utils.data.distributed import DistributedSampler 33 | from tqdm import tqdm, trange 34 | 35 | from torch.nn import CrossEntropyLoss, MSELoss 36 | from scipy.special import softmax 37 | from scipy.stats import pearsonr, spearmanr 38 | from sklearn.metrics import matthews_corrcoef, f1_score 39 | 40 | # from pytorch_transformers.file_utils import PYTORCH_TRANSFORMERS_CACHE 41 | # from pytorch_transformers.modeling_bert import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME 42 | # from pytorch_transformers.tokenization_bert import BertTokenizer 43 | # from pytorch_transformers.optimization import AdamW 44 | 45 | from transformers.file_utils import PYTORCH_TRANSFORMERS_CACHE 46 | from transformers.modeling_bert import BertForSequenceClassification 47 | from transformers.tokenization_bert import BertTokenizer 48 | from transformers.optimization import AdamW 49 | 50 | # from pytorch_transformers import * 51 | 52 | from preprocess_yahoo import evaluate_Yahoo_zeroshot_TwpPhasePred 53 | # import torch.optim as optimizer_wenpeng 54 | 55 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 56 | datefmt = '%m/%d/%Y %H:%M:%S', 57 | level = logging.INFO) 58 | logger = logging.getLogger(__name__) 59 | 60 | 61 | type2hypothesis = { 62 | 0: ['it is related with society or culture', 'this text describes something about an extended social group having a distinctive cultural and economic organization or a particular society at a particular time and place'], 63 | 1:['it is related with science or mathematics', 'this text describes something about a particular branch of scientific knowledge or a science (or group of related sciences) dealing with the logic of quantity and shape and arrangement'], 64 | 2: ['it is related with health', 'this text describes something about a healthy state of wellbeing free from disease'], 65 | 3: ['it is related with education or reference', 'this text describes something about the activities of educating or instructing or activities that impart knowledge or skill or an indicator that orients you generally'], 66 | 4: ['it is related with computers or Internet', 'this text describes something about a machine for performing calculations automatically or a computer network consisting of a worldwide network of computer networks that use the TCP/IP network protocols to facilitate data transmission and exchange'], 67 | 5: ['it is related with sports', 'this text describes something about an active diversion requiring physical exertion and competition'], 68 | 6: ['it is related with business or finance', 'this text describes something about a commercial or industrial enterprise and the people who constitute it or the commercial activity of providing funds and capital'], 69 | 7: ['it is related with entertainment or music', 'this text describes something about an activity that is diverting and that holds the attention or an artistic form of auditory communication incorporating instrumental or vocal tones in a structured and continuous manner'], 70 | 8: ['it is related with family or relationships', 'this text describes something about a social unit living together, primary social group; parents and children or a relation between people'], 71 | 9: ['it is related with politics or government', 'this text describes something about social relations involving intrigue to gain authority or power or the organization that is the governing authority of a political unit']} 72 | 73 | class InputExample(object): 74 | """A single training/test example for simple sequence classification.""" 75 | 76 | def __init__(self, guid, text_a, text_b=None, label=None): 77 | """Constructs a InputExample. 78 | 79 | Args: 80 | guid: Unique id for the example. 81 | text_a: string. The untokenized text of the first sequence. For single 82 | sequence tasks, only this sequence must be specified. 83 | text_b: (Optional) string. The untokenized text of the second sequence. 84 | Only must be specified for sequence pair tasks. 85 | label: (Optional) string. The label of the example. This should be 86 | specified for train and dev examples, but not for test examples. 87 | """ 88 | self.guid = guid 89 | self.text_a = text_a 90 | self.text_b = text_b 91 | self.label = label 92 | 93 | 94 | class InputFeatures(object): 95 | """A single set of features of data.""" 96 | 97 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 98 | self.input_ids = input_ids 99 | self.input_mask = input_mask 100 | self.segment_ids = segment_ids 101 | self.label_id = label_id 102 | 103 | 104 | class DataProcessor(object): 105 | """Base class for data converters for sequence classification data sets.""" 106 | 107 | def get_train_examples(self, data_dir): 108 | """Gets a collection of `InputExample`s for the train set.""" 109 | raise NotImplementedError() 110 | 111 | def get_dev_examples(self, data_dir): 112 | """Gets a collection of `InputExample`s for the dev set.""" 113 | raise NotImplementedError() 114 | 115 | def get_labels(self): 116 | """Gets the list of labels for this data set.""" 117 | raise NotImplementedError() 118 | 119 | @classmethod 120 | def _read_tsv(cls, input_file, quotechar=None): 121 | """Reads a tab separated value file.""" 122 | with open(input_file, "r") as f: 123 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 124 | lines = [] 125 | for line in reader: 126 | if sys.version_info[0] == 2: 127 | line = list(unicode(cell, 'utf-8') for cell in line) 128 | lines.append(line) 129 | return lines 130 | 131 | class RteProcessor(DataProcessor): 132 | """Processor for the RTE data set (GLUE version).""" 133 | def get_train_examples_wenpeng(self, filename): 134 | readfile = codecs.open(filename, 'r', 'utf-8') 135 | line_co=0 136 | examples=[] 137 | for row in readfile: 138 | if line_co>0: 139 | line=row.strip().split('\t') 140 | guid = "train-"+line[0] 141 | text_a = line[1] 142 | text_b = line[2] 143 | label = line[-1] 144 | examples.append( 145 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 146 | line_co+=1 147 | else: 148 | line_co+=1 149 | continue 150 | readfile.close() 151 | print('loaded training size:', line_co) 152 | return examples 153 | 154 | 155 | def get_examples_Yahoo_train(self, filename, size_limit_per_type): 156 | readfile = codecs.open(filename, 'r', 'utf-8') 157 | line_co=0 158 | exam_co = 0 159 | examples=[] 160 | label_list = [] 161 | 162 | '''first get all the seen types, since we will only create pos and neg hypo in seen types''' 163 | seen_types = set() 164 | for row in readfile: 165 | line=row.strip().split('\t') 166 | if len(line)==2: # label_id, text 167 | type_index = int(line[0]) 168 | seen_types.add(type_index) 169 | readfile.close() 170 | 171 | readfile = codecs.open(filename, 'r', 'utf-8') 172 | type_load_size = defaultdict(int) 173 | for row in readfile: 174 | line=row.strip().split('\t') 175 | if len(line)==2: # label_id, text 176 | 177 | type_index = int(line[0]) 178 | if type_load_size.get(type_index,0)< size_limit_per_type: 179 | for i in range(10): 180 | hypo_list = type2hypothesis.get(i) 181 | if i == type_index: 182 | '''pos pair''' 183 | for hypo in hypo_list: 184 | guid = "train-"+str(exam_co) 185 | text_a = line[1] 186 | text_b = hypo 187 | label = 'entailment' #if line[0] == '1' else 'not_entailment' 188 | examples.append( 189 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 190 | exam_co+=1 191 | elif i in seen_types: 192 | '''neg pair''' 193 | for hypo in hypo_list: 194 | guid = "train-"+str(exam_co) 195 | text_a = line[1] 196 | text_b = hypo 197 | label = 'not_entailment' #if line[0] == '1' else 'not_entailment' 198 | examples.append( 199 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 200 | exam_co+=1 201 | line_co+=1 202 | if line_co % 10000 == 0: 203 | print('loading training size:', line_co) 204 | 205 | type_load_size[type_index]+=1 206 | else: 207 | continue 208 | readfile.close() 209 | print('loaded size:', line_co) 210 | print('seen_types:', seen_types) 211 | return examples, seen_types 212 | 213 | 214 | 215 | 216 | def get_examples_Yahoo_test(self, filename, seen_types): 217 | readfile = codecs.open(filename, 'r', 'utf-8') 218 | line_co=0 219 | exam_co = 0 220 | examples=[] 221 | 222 | hypo_seen_str_indicator=[] 223 | hypo_2_type_index=[] 224 | for i in range(10): 225 | hypo_list = type2hypothesis.get(i) 226 | for hypo in hypo_list: 227 | hypo_2_type_index.append(i) # this hypo is for type i 228 | if i in seen_types: 229 | hypo_seen_str_indicator.append('seen')# this hypo is for a seen type 230 | else: 231 | hypo_seen_str_indicator.append('unseen') 232 | 233 | gold_label_list = [] 234 | for row in readfile: 235 | line=row.strip().split('\t') 236 | if len(line)==2: # label_id, text 237 | 238 | type_index = int(line[0]) 239 | gold_label_list.append(type_index) 240 | for i in range(10): 241 | hypo_list = type2hypothesis.get(i) 242 | if i == type_index: 243 | '''pos pair''' 244 | for hypo in hypo_list: 245 | guid = "test-"+str(exam_co) 246 | text_a = line[1] 247 | text_b = hypo 248 | label = 'entailment' #if line[0] == '1' else 'not_entailment' 249 | examples.append( 250 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 251 | exam_co+=1 252 | else: 253 | '''neg pair''' 254 | for hypo in hypo_list: 255 | guid = "test-"+str(exam_co) 256 | text_a = line[1] 257 | text_b = hypo 258 | label = 'not_entailment' #if line[0] == '1' else 'not_entailment' 259 | examples.append( 260 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 261 | exam_co+=1 262 | line_co+=1 263 | if line_co % 1000 == 0: 264 | print('loading test size:', line_co) 265 | # if line_co == 1000: 266 | # break 267 | 268 | 269 | readfile.close() 270 | print('loaded size:', line_co) 271 | return examples, gold_label_list, hypo_seen_str_indicator, hypo_2_type_index 272 | 273 | def get_train_examples(self, data_dir): 274 | """See base class.""" 275 | return self._create_examples( 276 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 277 | 278 | def get_dev_examples(self, data_dir): 279 | """See base class.""" 280 | return self._create_examples( 281 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 282 | 283 | def get_labels(self): 284 | """See base class.""" 285 | return ["entailment", "not_entailment"] 286 | 287 | def _create_examples(self, lines, set_type): 288 | """Creates examples for the training and dev sets.""" 289 | examples = [] 290 | for (i, line) in enumerate(lines): 291 | if i == 0: 292 | continue 293 | guid = "%s-%s" % (set_type, line[0]) 294 | text_a = line[1] 295 | text_b = line[2] 296 | label = line[-1] 297 | examples.append( 298 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 299 | return examples 300 | 301 | def load_demo_input(premise_str, hypo_list): 302 | 303 | examples=[] 304 | exam_co = 0 305 | for hypo in hypo_list: 306 | guid = "test-"+str(exam_co) 307 | text_a = premise_str 308 | text_b = hypo 309 | label = 'entailment' #if line[0] == '1' else 'not_entailment' 310 | examples.append( 311 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 312 | return examples#, gold_label_list, hypo_seen_str_indicator, hypo_2_type_index 313 | 314 | def convert_examples_to_features(examples, label_list, max_seq_length, 315 | tokenizer, output_mode): 316 | """Loads a data file into a list of `InputBatch`s.""" 317 | 318 | label_map = {label : i for i, label in enumerate(label_list)} 319 | 320 | premise_2_tokenzed={} 321 | hypothesis_2_tokenzed={} 322 | list_2_tokenizedID = {} 323 | 324 | features = [] 325 | for (ex_index, example) in enumerate(examples): 326 | if ex_index % 10000 == 0: 327 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 328 | 329 | tokens_a = premise_2_tokenzed.get(example.text_a) 330 | if tokens_a is None: 331 | tokens_a = tokenizer.tokenize(example.text_a) 332 | premise_2_tokenzed[example.text_a] = tokens_a 333 | 334 | tokens_b = premise_2_tokenzed.get(example.text_b) 335 | if tokens_b is None: 336 | tokens_b = tokenizer.tokenize(example.text_b) 337 | hypothesis_2_tokenzed[example.text_b] = tokens_b 338 | 339 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 340 | 341 | tokens_A = ["[CLS]"] + tokens_a + ["[SEP]"] 342 | segment_ids_A = [0] * len(tokens_A) 343 | tokens_B = tokens_b + ["[SEP]"] 344 | segment_ids_B = [1] * (len(tokens_b) + 1) 345 | tokens = tokens_A+tokens_B 346 | segment_ids = segment_ids_A+segment_ids_B 347 | 348 | 349 | input_ids_A = list_2_tokenizedID.get(' '.join(tokens_A)) 350 | if input_ids_A is None: 351 | input_ids_A = tokenizer.convert_tokens_to_ids(tokens_A) 352 | list_2_tokenizedID[' '.join(tokens_A)] = input_ids_A 353 | input_ids_B = list_2_tokenizedID.get(' '.join(tokens_B)) 354 | if input_ids_B is None: 355 | input_ids_B = tokenizer.convert_tokens_to_ids(tokens_B) 356 | list_2_tokenizedID[' '.join(tokens_B)] = input_ids_B 357 | input_ids = input_ids_A + input_ids_B 358 | 359 | 360 | # tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 361 | # segment_ids = [0] * len(tokens) 362 | # 363 | # tokens += tokens_b + ["[SEP]"] 364 | # segment_ids += [1] * (len(tokens_b) + 1) 365 | # input_ids = tokenizer.convert_tokens_to_ids(tokens) 366 | 367 | 368 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 369 | # tokens are attended to. 370 | input_mask = [1] * len(input_ids) 371 | 372 | # Zero-pad up to the sequence length. 373 | padding = [0] * (max_seq_length - len(input_ids)) 374 | input_ids += padding 375 | input_mask += padding 376 | segment_ids += padding 377 | 378 | assert len(input_ids) == max_seq_length 379 | assert len(input_mask) == max_seq_length 380 | assert len(segment_ids) == max_seq_length 381 | 382 | if output_mode == "classification": 383 | label_id = label_map[example.label] 384 | elif output_mode == "regression": 385 | label_id = float(example.label) 386 | else: 387 | raise KeyError(output_mode) 388 | 389 | if ex_index < 5: 390 | logger.info("*** Example ***") 391 | logger.info("guid: %s" % (example.guid)) 392 | logger.info("tokens: %s" % " ".join( 393 | [str(x) for x in tokens])) 394 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 395 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 396 | logger.info( 397 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 398 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 399 | 400 | features.append( 401 | InputFeatures(input_ids=input_ids, 402 | input_mask=input_mask, 403 | segment_ids=segment_ids, 404 | label_id=label_id)) 405 | return features 406 | 407 | 408 | 409 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 410 | """Truncates a sequence pair in place to the maximum length.""" 411 | 412 | # This is a simple heuristic which will always truncate the longer sequence 413 | # one token at a time. This makes more sense than truncating an equal percent 414 | # of tokens from each, since if one sequence is very short then each token 415 | # that's truncated likely contains more information than a longer sequence. 416 | while True: 417 | total_length = len(tokens_a) + len(tokens_b) 418 | if total_length <= max_length: 419 | break 420 | if len(tokens_a) > len(tokens_b): 421 | tokens_a.pop() 422 | else: 423 | tokens_b.pop() 424 | 425 | 426 | def simple_accuracy(preds, labels): 427 | return (preds == labels).mean() 428 | 429 | 430 | def acc_and_f1(preds, labels): 431 | acc = simple_accuracy(preds, labels) 432 | f1 = f1_score(y_true=labels, y_pred=preds) 433 | return { 434 | "acc": acc, 435 | "f1": f1, 436 | "acc_and_f1": (acc + f1) / 2, 437 | } 438 | 439 | 440 | def pearson_and_spearman(preds, labels): 441 | pearson_corr = pearsonr(preds, labels)[0] 442 | spearman_corr = spearmanr(preds, labels)[0] 443 | return { 444 | "pearson": pearson_corr, 445 | "spearmanr": spearman_corr, 446 | "corr": (pearson_corr + spearman_corr) / 2, 447 | } 448 | 449 | 450 | def compute_metrics(task_name, preds, labels): 451 | assert len(preds) == len(labels) 452 | if task_name == "cola": 453 | return {"mcc": matthews_corrcoef(labels, preds)} 454 | elif task_name == "sst-2": 455 | return {"acc": simple_accuracy(preds, labels)} 456 | elif task_name == "mrpc": 457 | return acc_and_f1(preds, labels) 458 | elif task_name == "sts-b": 459 | return pearson_and_spearman(preds, labels) 460 | elif task_name == "qqp": 461 | return acc_and_f1(preds, labels) 462 | elif task_name == "mnli": 463 | return {"acc": simple_accuracy(preds, labels)} 464 | elif task_name == "mnli-mm": 465 | return {"acc": simple_accuracy(preds, labels)} 466 | elif task_name == "qnli": 467 | return {"acc": simple_accuracy(preds, labels)} 468 | elif task_name == "rte": 469 | return {"acc": simple_accuracy(preds, labels)} 470 | elif task_name == "wnli": 471 | return {"acc": simple_accuracy(preds, labels)} 472 | elif task_name == 'F1': 473 | return {"f1": f1_score(y_true=labels, y_pred=preds)} 474 | else: 475 | raise KeyError(task_name) 476 | 477 | 478 | def main(): 479 | parser = argparse.ArgumentParser() 480 | 481 | ## Required parameters 482 | ''' 483 | python -u demo.py 484 | ''' 485 | parser.add_argument("--premise_str", 486 | default=None, 487 | type=str, 488 | required=True, 489 | help="text to classify") 490 | parser.add_argument("--hypo_list", 491 | default=None, 492 | type=str, 493 | required=True, 494 | help="sentences separated by |") 495 | parser.add_argument("--task_name", 496 | default='rte', 497 | type=str, 498 | help="The name of the task to train.") 499 | parser.add_argument("--max_seq_length", 500 | default=128, 501 | type=int, 502 | help="The maximum total input sequence length after WordPiece tokenization. \n" 503 | "Sequences longer than this will be truncated, and sequences shorter \n" 504 | "than this will be padded.") 505 | parser.add_argument("--do_eval", 506 | action='store_true', 507 | help="Whether to run eval on the dev set.") 508 | # parser.add_argument("--do_lower_case", 509 | # action='store_true', 510 | # help="Set this flag if you are using an uncased model.") 511 | parser.add_argument("--train_batch_size", 512 | default=32, 513 | type=int, 514 | help="Total batch size for training.") 515 | parser.add_argument("--eval_batch_size", 516 | default=256, 517 | type=int, 518 | help="Total batch size for eval.") 519 | parser.add_argument("--learning_rate", 520 | default=5e-5, 521 | type=float, 522 | help="The initial learning rate for Adam.") 523 | parser.add_argument("--warmup_proportion", 524 | default=0.1, 525 | type=float, 526 | help="Proportion of training to perform linear learning rate warmup for. " 527 | "E.g., 0.1 = 10%% of training.") 528 | parser.add_argument("--no_cuda", 529 | action='store_true', 530 | help="Whether not to use CUDA when available") 531 | parser.add_argument("--local_rank", 532 | type=int, 533 | default=-1, 534 | help="local_rank for distributed training on gpus") 535 | parser.add_argument('--seed', 536 | type=int, 537 | default=42, 538 | help="random seed for initialization") 539 | parser.add_argument('--gradient_accumulation_steps', 540 | type=int, 541 | default=1, 542 | help="Number of updates steps to accumulate before performing a backward/update pass.") 543 | parser.add_argument('--fp16', 544 | action='store_true', 545 | help="Whether to use 16-bit float precision instead of 32-bit") 546 | parser.add_argument('--loss_scale', 547 | type=float, default=0, 548 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 549 | "0 (default value): dynamic loss scaling.\n" 550 | "Positive power of 2: static loss scaling value.\n") 551 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 552 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 553 | args = parser.parse_args() 554 | 555 | processors = { 556 | "rte": RteProcessor 557 | } 558 | 559 | output_modes = { 560 | "rte": "classification" 561 | } 562 | 563 | if args.local_rank == -1 or args.no_cuda: 564 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 565 | n_gpu = torch.cuda.device_count() 566 | else: 567 | torch.cuda.set_device(args.local_rank) 568 | device = torch.device("cuda", args.local_rank) 569 | n_gpu = 1 570 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 571 | torch.distributed.init_process_group(backend='nccl') 572 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 573 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 574 | 575 | if args.gradient_accumulation_steps < 1: 576 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 577 | args.gradient_accumulation_steps)) 578 | 579 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 580 | 581 | random.seed(args.seed) 582 | np.random.seed(args.seed) 583 | torch.manual_seed(args.seed) 584 | if n_gpu > 0: 585 | torch.cuda.manual_seed_all(args.seed) 586 | 587 | 588 | 589 | task_name = args.task_name.lower() 590 | 591 | if task_name not in processors: 592 | raise ValueError("Task not found: %s" % (task_name)) 593 | 594 | processor = processors[task_name]() 595 | output_mode = output_modes[task_name] 596 | 597 | label_list = processor.get_labels() #[0,1] 598 | num_labels = len(label_list) 599 | 600 | 601 | 602 | train_examples = None 603 | 604 | # Prepare model 605 | # cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_TRANSFORMERS_CACHE), 'distributed_{}'.format(args.local_rank)) 606 | # model = BertForSequenceClassification.from_pretrained(args.bert_model, 607 | # cache_dir=cache_dir, 608 | # num_labels=num_labels) 609 | # tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 610 | 611 | pretrain_model_dir = '/export/home/Dataset/fine_tune_Bert_stored/FineTuneOnRTE' #FineTuneOnCombined'# FineTuneOnMNLI 612 | model = BertForSequenceClassification.from_pretrained(pretrain_model_dir, num_labels=num_labels) 613 | tokenizer = BertTokenizer.from_pretrained(pretrain_model_dir) 614 | 615 | model.to(device) 616 | 617 | if n_gpu > 1: 618 | model = torch.nn.DataParallel(model) 619 | 620 | # Prepare optimizer 621 | # param_optimizer = list(model.named_parameters()) 622 | # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 623 | # optimizer_grouped_parameters = [ 624 | # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 625 | # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 626 | # ] 627 | # optimizer = AdamW(optimizer_grouped_parameters, 628 | # lr=args.learning_rate) 629 | global_step = 0 630 | nb_tr_steps = 0 631 | tr_loss = 0 632 | max_test_unseen_acc = 0.0 633 | max_dev_unseen_acc = 0.0 634 | max_dev_seen_acc = 0.0 635 | max_overall_acc = 0.0 636 | '''load test set''' 637 | 638 | 639 | seen_types = set() 640 | # test_examples, test_label_list, test_hypo_seen_str_indicator, test_hypo_2_type_index = processor.get_examples_Yahoo_test('/export/home/Dataset/YahooClassification/yahoo_answers_csv/zero-shot-split/test.txt', seen_types) 641 | # test_examples = load_demo_input(premise_str, hypo_list) 642 | # test_examples = load_demo_input('fuck why my email not come yet', ['anger', 'this text expresses anger', 'the guy is very unhappy']) 643 | test_examples = load_demo_input(args.premise_str, args.hypo_list.split(' | ')) 644 | test_features = convert_examples_to_features( 645 | test_examples, label_list, args.max_seq_length, tokenizer, output_mode) 646 | 647 | test_all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 648 | test_all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 649 | test_all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 650 | test_all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long) 651 | 652 | test_data = TensorDataset(test_all_input_ids, test_all_input_mask, test_all_segment_ids, test_all_label_ids) 653 | test_sampler = SequentialSampler(test_data) 654 | test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) 655 | 656 | ''' 657 | start evaluate on test set after this epoch 658 | ''' 659 | model.eval() 660 | 661 | logger.info("***** Running testing *****") 662 | logger.info(" Num examples = %d", len(test_examples)) 663 | logger.info(" Batch size = %d", args.eval_batch_size) 664 | 665 | test_loss = 0 666 | nb_test_steps = 0 667 | preds = [] 668 | # print('Testing...') 669 | for input_ids, input_mask, segment_ids, label_ids in test_dataloader: 670 | input_ids = input_ids.to(device) 671 | input_mask = input_mask.to(device) 672 | segment_ids = segment_ids.to(device) 673 | label_ids = label_ids.to(device) 674 | 675 | with torch.no_grad(): 676 | logits = model(input_ids, segment_ids, input_mask, labels=None) 677 | logits = logits[0] 678 | if len(preds) == 0: 679 | preds.append(logits.detach().cpu().numpy()) 680 | else: 681 | preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0) 682 | 683 | # eval_loss = eval_loss / nb_eval_steps 684 | preds = preds[0] 685 | pred_probs = softmax(preds,axis=1)[:,0] 686 | return max(pred_probs) 687 | if __name__ == "__main__": 688 | prob = main() 689 | print('prob:', prob) 690 | 691 | ''' 692 | CUDA_VISIBLE_DEVICES=7 python -u demo.py --premise_str 'fuck why my email not come yet' --hypo_list 'anger | this text expresses anger | the guy is very unhappy' 693 | ''' 694 | --------------------------------------------------------------------------------