├── requirements.txt ├── .gitignore ├── README.md └── code ├── run.py ├── utils.py ├── data_iterator.py ├── extractor.py ├── tokenizer.py ├── gensim_preprocess.py └── bert_model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim 2 | h5py 3 | numpy 4 | Pillow 5 | pyrouge 6 | scipy 7 | six 8 | smart-open 9 | torch>=1.0 10 | torchaudio 11 | torchvision 12 | typing-extensions -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # the environments directory 2 | env/ 3 | venv/ 4 | 5 | # cache directories 6 | **/__pycache__/** 7 | 8 | # when running the code a temp directory is created containing the extracted data 9 | temp/ 10 | 11 | # data 12 | data/ 13 | 14 | # models 15 | pacssum_models/ 16 | 17 | # user's created file 18 | scripts.sh 19 | 20 | # user's config file 21 | setup.ini -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PacSum 2 | 3 | This code is for paper [Sentence Centrality Revisited for Unsupervised Summarization](https://arxiv.org/pdf/1906.03508.pdf) ACL 2019 4 | 5 | Some codes are borrowed from [pytorch_pretrained_bert](https://github.com/huggingface/pytorch-transformers) and [gensim](https://github.com/RaRe-Technologies/gensim) 6 | 7 | 8 | ------- 9 | ### Dependencies 10 | Python3.6, pytorch >= 1.0, numpy, gensim, pyrouge 11 | 12 | 13 | ------- 14 | ### Data used in the paper: 15 | 16 | Download https://drive.google.com/open?id=1gNKWkZG4dVr5XrOeQBVicy1fdnpH2d5l 17 | 18 | ### Bert models fine-tuned using the approach in the paper: 19 | 20 | Download https://drive.google.com/file/d/1wbMlLmnbD_0j7Qs8YY8cSCh935WKKdsP/view?usp=sharing 21 | 22 | 23 | ### Tuning the hyperparamters and test the performance using TfIdf or BERT representation 24 | ``` 25 | python run.py --rep tfidf --mode tune --tune_data_file path/to/validation/data --test_data_file path/to/test/data 26 | ``` 27 | ``` 28 | python run.py --rep bert --mode tune --tune_data_file path/to/validation/data --test_data_file path/to/test/data --bert_model_file path/to/model --bert_config_file path/to/config --bert_vocab_file path/to/vocab 29 | ``` 30 | -------------------------------------------------------------------------------- /code/run.py: -------------------------------------------------------------------------------- 1 | from extractor import PacSumExtractorWithBert, PacSumExtractorWithTfIdf 2 | from data_iterator import Dataset 3 | 4 | import argparse 5 | 6 | if __name__ == '__main__': 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--mode', type=str, choices = ['tune', 'test'], help='tune or test') 10 | parser.add_argument('--rep', type=str, choices = ['tfidf', 'bert'], help='tfidf or bert') 11 | parser.add_argument('--extract_num', type=int, default=3, help='number of extracted sentences') 12 | parser.add_argument('--bert_config_file', type=str, default='/disk/scratch1/s1797858/bert_model_path/uncased_L-12_H-768_A-12/bert_config.json', help='bert configuration file') 13 | parser.add_argument('--bert_model_file', type=str, help='bert model file') 14 | parser.add_argument('--bert_vocab_file', type=str, default='/disk/scratch1/s1797858/bert_model_path/uncased_L-12_H-768_A-12/vocab.txt',help='bert vocabulary file') 15 | 16 | parser.add_argument('--beta', type=float, default=0., help='beta') 17 | parser.add_argument('--lambda1', type=float, default=0., help='lambda1') 18 | parser.add_argument('--lambda2', type=float, default=1., help='lambda2') 19 | 20 | parser.add_argument('--tune_data_file', type=str, help='data for tunining hyperparameters') 21 | parser.add_argument('--test_data_file', type=str, help='data for testing') 22 | 23 | 24 | 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | if args.rep == 'tfidf': 29 | extractor = PacSumExtractorWithTfIdf(beta = args.beta, 30 | lambda1=args.lambda1, 31 | lambda2=args.lambda2) 32 | #tune 33 | if args.mode == 'tune': 34 | tune_dataset = Dataset(args.tune_data_file) 35 | tune_dataset_iterator = tune_dataset.iterate_once_doc_tfidf() 36 | extractor.tune_hparams(tune_dataset_iterator) 37 | 38 | #test 39 | test_dataset = Dataset(args.test_data_file) 40 | test_dataset_iterator = test_dataset.iterate_once_doc_tfidf() 41 | extractor.extract_summary(test_dataset_iterator) 42 | 43 | 44 | 45 | elif args.rep == 'bert': 46 | extractor = PacSumExtractorWithBert(bert_model_file = args.bert_model_file, 47 | bert_config_file = args.bert_config_file, 48 | beta = args.beta, 49 | lambda1=args.lambda1, 50 | lambda2=args.lambda2) 51 | #tune 52 | if args.mode == 'tune': 53 | tune_dataset = Dataset(args.tune_data_file, vocab_file = args.bert_vocab_file) 54 | tune_dataset_iterator = tune_dataset.iterate_once_doc_bert() 55 | extractor.tune_hparams(tune_dataset_iterator) 56 | 57 | #test 58 | test_dataset = Dataset(args.test_data_file, vocab_file = args.bert_vocab_file) 59 | test_dataset_iterator = test_dataset.iterate_once_doc_bert() 60 | extractor.extract_summary(test_dataset_iterator) 61 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | from pyrouge import Rouge155 2 | import os, shutil, random, string 3 | 4 | from gensim_preprocess import preprocess_documents 5 | 6 | 7 | def evaluate_rouge(summaries, references, remove_temp=False, rouge_args=[]): 8 | ''' 9 | Args: 10 | summaries: [[sentence]]. Each summary is a list of strings (sentences) 11 | references: [[[sentence]]]. Each reference is a list of candidate summaries. 12 | remove_temp: bool. Whether to remove the temporary files created during evaluation. 13 | rouge_args: [string]. A list of arguments to pass to the ROUGE CLI. 14 | ''' 15 | temp_dir = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) 16 | temp_dir = os.path.join("temp",temp_dir) 17 | print(temp_dir) 18 | system_dir = os.path.join(temp_dir, 'system') 19 | model_dir = os.path.join(temp_dir, 'model') 20 | # directory for generated summaries 21 | os.makedirs(system_dir) 22 | # directory for reference summaries 23 | os.makedirs(model_dir) 24 | print(temp_dir, system_dir, model_dir) 25 | 26 | assert len(summaries) == len(references) 27 | for i, (summary, candidates) in enumerate(zip(summaries, references)): 28 | summary_fn = '%i.txt' % i 29 | for j, candidate in enumerate(candidates): 30 | candidate_fn = '%i.%i.txt' % (i, j) 31 | with open(os.path.join(model_dir, candidate_fn), 'w') as f: 32 | #print(candidate) 33 | f.write('\n'.join(candidate)) 34 | 35 | with open(os.path.join(system_dir, summary_fn), 'w') as f: 36 | f.write('\n'.join(summary)) 37 | 38 | args_str = ' '.join(map(str, rouge_args)) 39 | rouge = Rouge155(rouge_args=args_str) 40 | rouge.system_dir = system_dir 41 | rouge.model_dir = model_dir 42 | rouge.system_filename_pattern = '(\d+).txt' 43 | rouge.model_filename_pattern = '#ID#.\d+.txt' 44 | 45 | #rouge_args = '-c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -a' 46 | #output = rouge.convert_and_evaluate(rouge_args=rouge_args) 47 | output = rouge.convert_and_evaluate() 48 | 49 | r = rouge.output_to_dict(output) 50 | print(output) 51 | #print(r) 52 | 53 | # remove the created temporary files 54 | #if remove_temp: 55 | # shutil.rmtree(temp_dir) 56 | return r 57 | 58 | def clean_text_by_sentences(text): 59 | """Tokenize a given text into sentences, applying filters and lemmatize them. 60 | 61 | Parameters 62 | ---------- 63 | text : str 64 | Given text. 65 | 66 | Returns 67 | ------- 68 | list of :class:`~gensim.summarization.syntactic_unit.SyntacticUnit` 69 | Sentences of the given text. 70 | 71 | """ 72 | original_sentences = text 73 | filtered_sentences = [join_words(sentence) for sentence in preprocess_documents(original_sentences)] 74 | 75 | return filtered_sentences 76 | 77 | 78 | def join_words(words, separator=" "): 79 | """Concatenates `words` with `separator` between elements. 80 | 81 | Parameters 82 | ---------- 83 | words : list of str 84 | Given words. 85 | separator : str, optional 86 | The separator between elements. 87 | 88 | Returns 89 | ------- 90 | str 91 | String of merged words with separator between elements. 92 | 93 | """ 94 | return separator.join(words) 95 | -------------------------------------------------------------------------------- /code/data_iterator.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import glob 3 | import json 4 | import random 5 | import math 6 | 7 | import numpy as np 8 | import torch 9 | import h5py 10 | 11 | from tokenizer import FullTokenizer 12 | from utils import clean_text_by_sentences 13 | 14 | 15 | class Dataset(object): 16 | 17 | def __init__(self, file_pattern = None, vocab_file = None): 18 | 19 | self._file_pattern = file_pattern 20 | self._max_len = 60 21 | print("input max len : "+str(self._max_len)) 22 | if vocab_file is not None: 23 | self._tokenizer = FullTokenizer(vocab_file, True) 24 | 25 | def iterate_once_doc_tfidf(self): 26 | def file_stream(): 27 | for file_name in glob.glob(self._file_pattern): 28 | yield file_name 29 | for value in self._doc_stream_tfidf(file_stream()): 30 | yield value 31 | 32 | def _doc_stream_tfidf(self, file_stream): 33 | for file_name in file_stream: 34 | for doc in self._parse_file2doc_tfidf(file_name): 35 | yield doc 36 | 37 | def _parse_file2doc_tfidf(self, file_name): 38 | print("Processing file: %s" % file_name) 39 | with h5py.File(file_name,'r') as f: 40 | for j_str in f['dataset']: 41 | obj = json.loads(j_str) 42 | article, abstract = obj['article'], obj['abstract'] 43 | #article, abstract = obj['article'], obj['abstracts'] 44 | clean_article = clean_text_by_sentences(article) 45 | segmented_artile = [sentence.split() for sentence in clean_article] 46 | #print(tokenized_article[0]) 47 | 48 | yield article, abstract, [segmented_artile] 49 | 50 | 51 | def iterate_once_doc_bert(self): 52 | def file_stream(): 53 | for file_name in glob.glob(self._file_pattern): 54 | yield file_name 55 | for value in self._doc_iterate_bert(self._doc_stream_bert(file_stream())): 56 | yield value 57 | 58 | def _doc_stream_bert(self, file_stream): 59 | for file_name in file_stream: 60 | for doc in self._parse_file2doc_bert(file_name): 61 | yield doc 62 | 63 | def _parse_file2doc_bert(self, file_name): 64 | print("Processing file: %s" % file_name) 65 | with h5py.File(file_name,'r') as f: 66 | for j_str in f['dataset']: 67 | obj = json.loads(j_str) 68 | article, abstract = obj['article'], obj['abstract'] 69 | #article, abstract = obj['article'], obj['abstracts'] 70 | tokenized_article = [self._tokenizer.tokenize(sen) for sen in article] 71 | #print(tokenized_article[0]) 72 | 73 | article_token_ids = [] 74 | article_seg_ids = [] 75 | article_token_ids_c = [] 76 | article_seg_ids_c = [] 77 | pair_indice = [] 78 | k = 0 79 | for i in range(len(article)): 80 | for j in range(i+1, len(article)): 81 | 82 | tokens_a = tokenized_article[i] 83 | tokens_b = tokenized_article[j] 84 | 85 | input_ids, segment_ids = self._2bert_rep(tokens_a) 86 | input_ids_c, segment_ids_c = self._2bert_rep(tokens_b) 87 | assert len(input_ids) == len(segment_ids) 88 | assert len(input_ids_c) == len(segment_ids_c) 89 | article_token_ids.append(input_ids) 90 | article_seg_ids.append(segment_ids) 91 | article_token_ids_c.append(input_ids_c) 92 | article_seg_ids_c.append(segment_ids_c) 93 | 94 | pair_indice.append(((i,j), k)) 95 | k+=1 96 | yield article_token_ids, article_seg_ids, article_token_ids_c, article_seg_ids_c, pair_indice, article, abstract 97 | 98 | def _doc_iterate_bert(self, docs): 99 | 100 | for article_token_ids, article_seg_ids, article_token_ids_c, article_seg_ids_c, pair_indice, article, abstract in docs: 101 | 102 | if len(article_token_ids) == 0: 103 | yield None, None, None, None, None, None, pair_indice, article, abstract 104 | continue 105 | num_steps = max(len(item) for item in article_token_ids) 106 | #num_steps = max(len(item) for item in iarticle) 107 | batch_size = len(article_token_ids) 108 | x = np.zeros([batch_size, num_steps], np.int32) 109 | t = np.zeros([batch_size, num_steps], np.int32) 110 | w = np.zeros([batch_size, num_steps], np.uint8) 111 | 112 | num_steps_c = max(len(item) for item in article_token_ids_c) 113 | #num_steps = max(len(item) for item in iarticle) 114 | x_c = np.zeros([batch_size, num_steps_c], np.int32) 115 | t_c = np.zeros([batch_size, num_steps_c], np.int32) 116 | w_c = np.zeros([batch_size, num_steps_c], np.uint8) 117 | for i in range(batch_size): 118 | num_tokens = len(article_token_ids[i]) 119 | x[i,:num_tokens] = article_token_ids[i] 120 | t[i,:num_tokens] = article_seg_ids[i] 121 | w[i,:num_tokens] = 1 122 | 123 | num_tokens_c = len(article_token_ids_c[i]) 124 | x_c[i,:num_tokens_c] = article_token_ids_c[i] 125 | t_c[i,:num_tokens_c] = article_seg_ids_c[i] 126 | w_c[i,:num_tokens_c] = 1 127 | 128 | if not np.any(w): 129 | return 130 | out_x = torch.LongTensor(x) 131 | out_t = torch.LongTensor(t) 132 | out_w = torch.LongTensor(w) 133 | 134 | out_x_c = torch.LongTensor(x_c) 135 | out_t_c = torch.LongTensor(t_c) 136 | out_w_c = torch.LongTensor(w_c) 137 | 138 | yield article, abstract, (out_x, out_t, out_w, out_x_c, out_t_c, out_w_c, pair_indice) 139 | 140 | def _2bert_rep(self, tokens_a, tokens_b=None): 141 | 142 | if tokens_b is None: 143 | tokens_a = tokens_a[: self._max_len - 2] 144 | else: 145 | self._truncate_seq_pair(tokens_a, tokens_b, self._max_len - 3) 146 | 147 | tokens = [] 148 | segment_ids = [] 149 | tokens.append("[CLS]") 150 | segment_ids.append(0) 151 | for token in tokens_a: 152 | tokens.append(token) 153 | segment_ids.append(0) 154 | tokens.append("[SEP]") 155 | segment_ids.append(0) 156 | 157 | if tokens_b is not None: 158 | 159 | for token in tokens_b: 160 | tokens.append(token) 161 | segment_ids.append(1) 162 | 163 | tokens.append("[SEP]") 164 | segment_ids.append(1) 165 | #print(tokens) 166 | input_ids = self._tokenizer.convert_tokens_to_ids(tokens) 167 | #print(input_ids) 168 | 169 | return input_ids, segment_ids 170 | 171 | def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): 172 | """Truncates a sequence pair in place to the maximum length.""" 173 | 174 | # This is a simple heuristic which will always truncate the longer sequence 175 | # one token at a time. This makes more sense than truncating an equal percent 176 | # of tokens from each, since if one sequence is very short then each token 177 | # that's truncated likely contains more information than a longer sequence. 178 | while True: 179 | total_length = len(tokens_a) + len(tokens_b) 180 | if total_length <= max_length: 181 | break 182 | if len(tokens_a) > len(tokens_b): 183 | tokens_a.pop() 184 | else: 185 | tokens_b.pop() 186 | -------------------------------------------------------------------------------- /code/extractor.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import random 7 | import time 8 | import io 9 | import codecs 10 | 11 | 12 | 13 | from utils import evaluate_rouge 14 | from bert_model import BertEdgeScorer, BertConfig 15 | 16 | class PacSumExtractor: 17 | 18 | def __init__(self, extract_num = 3, beta = 3, lambda1 = -0.2, lambda2 = -0.2): 19 | 20 | self.extract_num = extract_num 21 | self.beta = beta 22 | self.lambda1 = lambda1 23 | self.lambda2 = lambda2 24 | 25 | def extract_summary(self, data_iterator): 26 | 27 | summaries = [] 28 | references = [] 29 | 30 | for item in data_iterator: 31 | article, abstract, inputs = item 32 | if len(article) <= self.extract_num: 33 | summaries.append(article) 34 | references.append([abstract]) 35 | continue 36 | 37 | edge_scores = self._calculate_similarity_matrix(*inputs) 38 | ids = self._select_tops(edge_scores, beta=self.beta, lambda1=self.lambda1, lambda2=self.lambda2) 39 | summary = list(map(lambda x: article[x], ids)) 40 | summaries.append(summary) 41 | references.append([abstract]) 42 | 43 | result = evaluate_rouge(summaries, references, remove_temp=True, rouge_args=[]) 44 | 45 | def tune_hparams(self, data_iterator, example_num=1000): 46 | 47 | 48 | summaries, references = [], [] 49 | k = 0 50 | for item in data_iterator: 51 | article, abstract, inputs = item 52 | edge_scores = self._calculate_similarity_matrix(*inputs) 53 | tops_list, hparam_list = self._tune_extractor(edge_scores) 54 | 55 | summary_list = [list(map(lambda x: article[x], ids)) for ids in tops_list] 56 | summaries.append(summary_list) 57 | references.append([abstract]) 58 | k += 1 59 | print(k) 60 | if k % example_num == 0: 61 | break 62 | 63 | best_rouge = 0 64 | best_hparam = None 65 | for i in range(len(summaries[0])): 66 | print("threshold : "+str(hparam_list[i])+'\n') 67 | #print("non-lead ratio : "+str(ratios[i])+'\n') 68 | result = evaluate_rouge([summaries[k][i] for k in range(len(summaries))], references, remove_temp=True, rouge_args=[]) 69 | 70 | if result['rouge_1_f_score'] > best_rouge: 71 | best_rouge = result['rouge_1_f_score'] 72 | best_hparam = hparam_list[i] 73 | 74 | print("The best hyper-parameter : beta %.4f , lambda1 %.4f, lambda2 %.4f " % (best_hparam[0], best_hparam[1], best_hparam[2])) 75 | print("The best rouge_1_f_score : %.4f " % best_rouge) 76 | 77 | self.beta = best_hparam[0] 78 | self.lambda1 = best_hparam[1] 79 | self.lambda2 = best_hparam[2] 80 | 81 | def _calculate_similarity_matrix(self, *inputs): 82 | 83 | raise NotImplementedError 84 | 85 | 86 | def _select_tops(self, edge_scores, beta, lambda1, lambda2): 87 | 88 | min_score = edge_scores.min() 89 | max_score = edge_scores.max() 90 | edge_threshold = min_score + beta * (max_score - min_score) 91 | new_edge_scores = edge_scores - edge_threshold 92 | forward_scores, backward_scores, _ = self._compute_scores(new_edge_scores, 0) 93 | forward_scores = 0 - forward_scores 94 | 95 | paired_scores = [] 96 | for node in range(len(forward_scores)): 97 | paired_scores.append([node, lambda1 * forward_scores[node] + lambda2 * backward_scores[node]]) 98 | 99 | #shuffle to avoid any possible bias 100 | random.shuffle(paired_scores) 101 | paired_scores.sort(key = lambda x: x[1], reverse = True) 102 | extracted = [item[0] for item in paired_scores[:self.extract_num]] 103 | 104 | 105 | return extracted 106 | 107 | def _compute_scores(self, similarity_matrix, edge_threshold): 108 | 109 | forward_scores = [0 for i in range(len(similarity_matrix))] 110 | backward_scores = [0 for i in range(len(similarity_matrix))] 111 | edges = [] 112 | for i in range(len(similarity_matrix)): 113 | for j in range(i+1, len(similarity_matrix[i])): 114 | edge_score = similarity_matrix[i][j] 115 | if edge_score > edge_threshold: 116 | forward_scores[j] += edge_score 117 | backward_scores[i] += edge_score 118 | edges.append((i,j,edge_score)) 119 | 120 | return np.asarray(forward_scores), np.asarray(backward_scores), edges 121 | 122 | 123 | def _tune_extractor(self, edge_scores): 124 | 125 | tops_list = [] 126 | hparam_list = [] 127 | num = 10 128 | for k in range(num + 1): 129 | beta = k / num 130 | for i in range(11): 131 | lambda1 = i/10 132 | lambda2 = 1 - lambda1 133 | extracted = self._select_tops(edge_scores, beta=beta, lambda1=lambda1, lambda2=lambda2) 134 | 135 | tops_list.append(extracted) 136 | hparam_list.append((beta, lambda1, lambda2)) 137 | 138 | return tops_list, hparam_list 139 | 140 | 141 | class PacSumExtractorWithBert(PacSumExtractor): 142 | 143 | def __init__(self, bert_model_file, bert_config_file, extract_num = 3, beta = 3, lambda1 = -0.2, lambda2 = -0.2): 144 | 145 | super(PacSumExtractorWithBert, self).__init__(extract_num, beta, lambda1, lambda2) 146 | self.model = self._load_edge_model(bert_model_file, bert_config_file) 147 | 148 | def _calculate_similarity_matrix(self, x, t, w, x_c, t_c, w_c, pair_indice): 149 | #doc: a list of sequences, each sequence is a list of words 150 | 151 | def pairdown(scores, pair_indice, length): 152 | #1 for self score 153 | out_matrix = np.ones((length, length)) 154 | for pair in pair_indice: 155 | out_matrix[pair[0][0]][pair[0][1]] = scores[pair[1]] 156 | out_matrix[pair[0][1]][pair[0][0]] = scores[pair[1]] 157 | 158 | return out_matrix 159 | 160 | scores = self._generate_score(x, t, w, x_c, t_c, w_c) 161 | doc_len = int(math.sqrt(len(x)*2)) + 1 162 | similarity_matrix = pairdown(scores, pair_indice, doc_len) 163 | 164 | return similarity_matrix 165 | 166 | def _generate_score(self, x, t, w, x_c, t_c, w_c): 167 | 168 | #score = log PMI -log k 169 | scores = torch.zeros(len(x)).cuda() 170 | step = 20 171 | for i in range(0,len(x),step): 172 | 173 | batch_x = x[i:i+step] 174 | batch_t = t[i:i+step] 175 | batch_w = w[i:i+step] 176 | batch_x_c = x_c[i:i+step] 177 | batch_t_c = t_c[i:i+step] 178 | batch_w_c = w_c[i:i+step] 179 | 180 | inputs = tuple(t.to('cuda') for t in (batch_x, batch_t, batch_w, batch_x_c, batch_t_c, batch_w_c)) 181 | batch_scores, batch_pros = self.model(*inputs) 182 | scores[i:i+step] = batch_scores.detach() 183 | 184 | 185 | return scores 186 | 187 | def _load_edge_model(self, bert_model_file, bert_config_file): 188 | 189 | bert_config = BertConfig.from_json_file(bert_config_file) 190 | model = BertEdgeScorer(bert_config) 191 | model_states = torch.load(bert_model_file) 192 | print(model_states.keys()) 193 | model.bert.load_state_dict(model_states) 194 | 195 | model.cuda() 196 | model.eval() 197 | return model 198 | 199 | 200 | class PacSumExtractorWithTfIdf(PacSumExtractor): 201 | 202 | def __init__(self, extract_num = 3, beta = 3, lambda1 = -0.2, lambda2 = -0.2): 203 | 204 | super(PacSumExtractorWithTfIdf, self).__init__(extract_num, beta, lambda1, lambda2) 205 | 206 | def _calculate_similarity_matrix(self, doc): 207 | 208 | idf_score = self._calculate_idf_scores(doc) 209 | 210 | tf_scores = [ 211 | Counter(sentence) for sentence in doc 212 | ] 213 | length = len(doc) 214 | 215 | similarity_matrix = np.zeros([length] * 2) 216 | 217 | for i in range(length): 218 | for j in range(i, length): 219 | similarity = self._idf_modified_dot(tf_scores, i, j, idf_score) 220 | 221 | if similarity: 222 | similarity_matrix[i, j] = similarity 223 | similarity_matrix[j, i] = similarity 224 | 225 | return similarity_matrix 226 | 227 | 228 | def _idf_modified_dot(self, tf_scores, i, j, idf_score): 229 | 230 | if i == j: 231 | return 1 232 | 233 | tf_i, tf_j = tf_scores[i], tf_scores[j] 234 | words_i, words_j = set(tf_i.keys()), set(tf_j.keys()) 235 | 236 | score = 0 237 | 238 | for word in words_i & words_j: 239 | idf = idf_score[word] 240 | score += tf_i[word] * tf_j[word] * idf ** 2 241 | 242 | 243 | return score 244 | 245 | 246 | def _calculate_idf_scores(self, doc): 247 | 248 | doc_number_total = 0. 249 | df = {} 250 | for i, sen in enumerate(doc): 251 | tf = Counter(sen) 252 | for word in tf.keys(): 253 | if word not in df: 254 | df[word] = 0 255 | df[word] += 1 256 | doc_number_total += 1 257 | 258 | idf_score = {} 259 | for word, freq in df.items(): 260 | idf_score[word] = math.log(doc_number_total - freq + 0.5) - math.log(freq + 0.5) 261 | 262 | return idf_score 263 | -------------------------------------------------------------------------------- /code/tokenizer.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import collections 3 | import unicodedata 4 | import six 5 | 6 | 7 | def convert_to_unicode(text): 8 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 9 | if six.PY3: 10 | if isinstance(text, str): 11 | return text 12 | elif isinstance(text, bytes): 13 | return text.decode("utf-8", "ignore") 14 | else: 15 | raise ValueError("Unsupported string type: %s" % (type(text))) 16 | elif six.PY2: 17 | if isinstance(text, str): 18 | return text.decode("utf-8", "ignore") 19 | elif isinstance(text, unicode): 20 | return text 21 | else: 22 | raise ValueError("Unsupported string type: %s" % (type(text))) 23 | else: 24 | raise ValueError("Not running on Python2 or Python 3?") 25 | 26 | 27 | def printable_text(text): 28 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 29 | 30 | # These functions want `str` for both Python2 and Python3, but in one case 31 | # it's a Unicode string and in the other it's a byte string. 32 | if six.PY3: 33 | if isinstance(text, str): 34 | return text 35 | elif isinstance(text, bytes): 36 | return text.decode("utf-8", "ignore") 37 | else: 38 | raise ValueError("Unsupported string type: %s" % (type(text))) 39 | elif six.PY2: 40 | if isinstance(text, str): 41 | return text 42 | elif isinstance(text, unicode): 43 | return text.encode("utf-8") 44 | else: 45 | raise ValueError("Unsupported string type: %s" % (type(text))) 46 | else: 47 | raise ValueError("Not running on Python2 or Python 3?") 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r") as reader: 55 | #with codecs.open(vocab_file, "r", "utf-8") as reader: 56 | while True: 57 | line = reader.readline() 58 | #print(line) 59 | token = convert_to_unicode(line) 60 | if not token: 61 | break 62 | token = token.strip() 63 | vocab[token] = index 64 | index += 1 65 | #print(index) 66 | #print(str(vocab)) 67 | return vocab 68 | 69 | 70 | def convert_tokens_to_ids(vocab, tokens): 71 | """Converts a sequence of tokens into ids using the vocab.""" 72 | ids = [] 73 | for token in tokens: 74 | ids.append(vocab[token]) 75 | return ids 76 | 77 | 78 | def whitespace_tokenize(text): 79 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 80 | text = text.strip() 81 | if not text: 82 | return [] 83 | tokens = text.split() 84 | return tokens 85 | 86 | 87 | class FullTokenizer(object): 88 | """Runs end-to-end tokenziation.""" 89 | 90 | def __init__(self, vocab_file, do_lower_case=True): 91 | self.vocab = load_vocab(vocab_file) 92 | #print(self.vocab) 93 | #print(len(self.vocab)) 94 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 95 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 96 | 97 | def tokenize(self, text): 98 | split_tokens = [] 99 | for token in self.basic_tokenizer.tokenize(text): 100 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 101 | split_tokens.append(sub_token) 102 | 103 | return split_tokens 104 | 105 | def convert_tokens_to_ids(self, tokens): 106 | return convert_tokens_to_ids(self.vocab, tokens) 107 | 108 | 109 | class BasicTokenizer(object): 110 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 111 | 112 | def __init__(self, do_lower_case=True): 113 | """Constructs a BasicTokenizer. 114 | Args: 115 | do_lower_case: Whether to lower case the input. 116 | """ 117 | self.do_lower_case = do_lower_case 118 | 119 | def tokenize(self, text): 120 | """Tokenizes a piece of text.""" 121 | text = convert_to_unicode(text) 122 | text = self._clean_text(text) 123 | 124 | text = self._tokenize_chinese_chars(text) 125 | 126 | orig_tokens = whitespace_tokenize(text) 127 | split_tokens = [] 128 | for token in orig_tokens: 129 | if self.do_lower_case: 130 | token = token.lower() 131 | token = self._run_strip_accents(token) 132 | split_tokens.extend(self._run_split_on_punc(token)) 133 | 134 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 135 | return output_tokens 136 | 137 | def _run_strip_accents(self, text): 138 | """Strips accents from a piece of text.""" 139 | text = unicodedata.normalize("NFD", text) 140 | output = [] 141 | for char in text: 142 | cat = unicodedata.category(char) 143 | if cat == "Mn": 144 | continue 145 | output.append(char) 146 | return "".join(output) 147 | 148 | def _run_split_on_punc(self, text): 149 | """Splits punctuation on a piece of text.""" 150 | chars = list(text) 151 | i = 0 152 | start_new_word = True 153 | output = [] 154 | while i < len(chars): 155 | char = chars[i] 156 | if _is_punctuation(char): 157 | output.append([char]) 158 | start_new_word = True 159 | else: 160 | if start_new_word: 161 | output.append([]) 162 | start_new_word = False 163 | output[-1].append(char) 164 | i += 1 165 | 166 | return ["".join(x) for x in output] 167 | 168 | def _tokenize_chinese_chars(self, text): 169 | """Adds whitespace around any CJK character.""" 170 | output = [] 171 | for char in text: 172 | cp = ord(char) 173 | if self._is_chinese_char(cp): 174 | output.append(" ") 175 | output.append(char) 176 | output.append(" ") 177 | else: 178 | output.append(char) 179 | return "".join(output) 180 | 181 | def _is_chinese_char(self, cp): 182 | """Checks whether CP is the codepoint of a CJK character.""" 183 | # This defines a "chinese character" as anything in the CJK Unicode block: 184 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 185 | # 186 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 187 | # despite its name. The modern Korean Hangul alphabet is a different block, 188 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 189 | # space-separated words, so they are not treated specially and handled 190 | # like the all of the other languages. 191 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 192 | (cp >= 0x3400 and cp <= 0x4DBF) or # 193 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 194 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 195 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 196 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 197 | (cp >= 0xF900 and cp <= 0xFAFF) or # 198 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 199 | return True 200 | 201 | return False 202 | 203 | 204 | def _clean_text(self, text): 205 | """Performs invalid character removal and whitespace cleanup on text.""" 206 | output = [] 207 | for char in text: 208 | cp = ord(char) 209 | if cp == 0 or cp == 0xfffd or _is_control(char): 210 | continue 211 | if _is_whitespace(char): 212 | output.append(" ") 213 | else: 214 | output.append(char) 215 | return "".join(output) 216 | 217 | 218 | class WordpieceTokenizer(object): 219 | """Runs WordPiece tokenization.""" 220 | 221 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 222 | self.vocab = vocab 223 | self.unk_token = unk_token 224 | self.max_input_chars_per_word = max_input_chars_per_word 225 | 226 | def tokenize(self, text): 227 | """Tokenizes a piece of text into its word pieces. 228 | This uses a greedy longest-match-first algorithm to perform tokenization 229 | using the given vocabulary. 230 | For example: 231 | input = "unaffable" 232 | output = ["un", "##aff", "##able"] 233 | Args: 234 | text: A single token or whitespace separated tokens. This should have 235 | already been passed through `BasicTokenizer. 236 | Returns: 237 | A list of wordpiece tokens. 238 | """ 239 | 240 | text = convert_to_unicode(text) 241 | 242 | output_tokens = [] 243 | for token in whitespace_tokenize(text): 244 | chars = list(token) 245 | if len(chars) > self.max_input_chars_per_word: 246 | output_tokens.append(self.unk_token) 247 | continue 248 | 249 | is_bad = False 250 | start = 0 251 | sub_tokens = [] 252 | while start < len(chars): 253 | end = len(chars) 254 | cur_substr = None 255 | while start < end: 256 | substr = "".join(chars[start:end]) 257 | if start > 0: 258 | substr = "##" + substr 259 | if substr in self.vocab: 260 | cur_substr = substr 261 | break 262 | end -= 1 263 | if cur_substr is None: 264 | is_bad = True 265 | break 266 | sub_tokens.append(cur_substr) 267 | start = end 268 | 269 | if is_bad: 270 | output_tokens.append(self.unk_token) 271 | else: 272 | output_tokens.extend(sub_tokens) 273 | return output_tokens 274 | 275 | 276 | def _is_whitespace(char): 277 | """Checks whether `chars` is a whitespace character.""" 278 | # \t, \n, and \r are technically contorl characters but we treat them 279 | # as whitespace since they are generally considered as such. 280 | if char == " " or char == "\t" or char == "\n" or char == "\r": 281 | return True 282 | cat = unicodedata.category(char) 283 | if cat == "Zs": 284 | return True 285 | return False 286 | 287 | 288 | def _is_control(char): 289 | """Checks whether `chars` is a control character.""" 290 | # These are technically control characters but we count them as whitespace 291 | # characters. 292 | if char == "\t" or char == "\n" or char == "\r": 293 | return False 294 | cat = unicodedata.category(char) 295 | if cat.startswith("C"): 296 | return True 297 | return False 298 | 299 | 300 | def _is_punctuation(char): 301 | """Checks whether `chars` is a punctuation character.""" 302 | cp = ord(char) 303 | # We treat all non-letter/number ASCII as punctuation. 304 | # Characters such as "^", "$", and "`" are not in the Unicode 305 | # Punctuation class but we treat them as punctuation anyways, for 306 | # consistency. 307 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 308 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 309 | return True 310 | cat = unicodedata.category(char) 311 | if cat.startswith("P"): 312 | return True 313 | 314 | return False 315 | -------------------------------------------------------------------------------- /code/gensim_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html 5 | 6 | """This module contains methods for parsing and preprocessing strings. Let's consider the most noticeable: 7 | 8 | * :func:`~gensim.parsing.preprocessing.remove_stopwords` - remove all stopwords from string 9 | * :func:`~gensim.parsing.preprocessing.preprocess_string` - preprocess string (in default NLP meaning) 10 | 11 | Examples: 12 | --------- 13 | .. sourcecode:: pycon 14 | 15 | >>> from gensim.parsing.preprocessing import remove_stopwords 16 | >>> remove_stopwords("Better late than never, but better never late.") 17 | u'Better late never, better late.' 18 | >>> 19 | >>> preprocess_string("Hel 9lo Wo9 rld! Th3 weather_is really g00d today, isn't it?") 20 | [u'hel', u'rld', u'weather', u'todai', u'isn'] 21 | 22 | 23 | Data: 24 | ----- 25 | 26 | .. data:: STOPWORDS - Set of stopwords from Stone, Denis, Kwantes (2010). 27 | .. data:: RE_PUNCT - Regexp for search an punctuation. 28 | .. data:: RE_TAGS - Regexp for search an tags. 29 | .. data:: RE_NUMERIC - Regexp for search an numbers. 30 | .. data:: RE_NONALPHA - Regexp for search an non-alphabetic character. 31 | .. data:: RE_AL_NUM - Regexp for search a position between letters and digits. 32 | .. data:: RE_NUM_AL - Regexp for search a position between digits and letters . 33 | .. data:: RE_WHITESPACE - Regexp for search space characters. 34 | .. data:: DEFAULT_FILTERS - List of function for string preprocessing. 35 | 36 | """ 37 | 38 | import re 39 | import string 40 | import glob 41 | 42 | from gensim import utils 43 | from gensim.parsing.porter import PorterStemmer 44 | 45 | 46 | STOPWORDS = frozenset([ 47 | 'all', 'six', 'just', 'less', 'being', 'indeed', 'over', 'move', 'anyway', 'four', 'not', 'own', 'through', 48 | 'using', 'fifty', 'where', 'mill', 'only', 'find', 'before', 'one', 'whose', 'system', 'how', 'somewhere', 49 | 'much', 'thick', 'show', 'had', 'enough', 'should', 'to', 'must', 'whom', 'seeming', 'yourselves', 'under', 50 | 'ours', 'two', 'has', 'might', 'thereafter', 'latterly', 'do', 'them', 'his', 'around', 'than', 'get', 'very', 51 | 'de', 'none', 'cannot', 'every', 'un', 'they', 'front', 'during', 'thus', 'now', 'him', 'nor', 'name', 'regarding', 52 | 'several', 'hereafter', 'did', 'always', 'who', 'didn', 'whither', 'this', 'someone', 'either', 'each', 'become', 53 | 'thereupon', 'sometime', 'side', 'towards', 'therein', 'twelve', 'because', 'often', 'ten', 'our', 'doing', 'km', 54 | 'eg', 'some', 'back', 'used', 'up', 'go', 'namely', 'computer', 'are', 'further', 'beyond', 'ourselves', 'yet', 55 | 'out', 'even', 'will', 'what', 'still', 'for', 'bottom', 'mine', 'since', 'please', 'forty', 'per', 'its', 56 | 'everything', 'behind', 'does', 'various', 'above', 'between', 'it', 'neither', 'seemed', 'ever', 'across', 'she', 57 | 'somehow', 'be', 'we', 'full', 'never', 'sixty', 'however', 'here', 'otherwise', 'were', 'whereupon', 'nowhere', 58 | 'although', 'found', 'alone', 're', 'along', 'quite', 'fifteen', 'by', 'both', 'about', 'last', 'would', 59 | 'anything', 'via', 'many', 'could', 'thence', 'put', 'against', 'keep', 'etc', 'amount', 'became', 'ltd', 'hence', 60 | 'onto', 'or', 'con', 'among', 'already', 'co', 'afterwards', 'formerly', 'within', 'seems', 'into', 'others', 61 | 'while', 'whatever', 'except', 'down', 'hers', 'everyone', 'done', 'least', 'another', 'whoever', 'moreover', 62 | 'couldnt', 'throughout', 'anyhow', 'yourself', 'three', 'from', 'her', 'few', 'together', 'top', 'there', 'due', 63 | 'been', 'next', 'anyone', 'eleven', 'cry', 'call', 'therefore', 'interest', 'then', 'thru', 'themselves', 64 | 'hundred', 'really', 'sincere', 'empty', 'more', 'himself', 'elsewhere', 'mostly', 'on', 'fire', 'am', 'becoming', 65 | 'hereby', 'amongst', 'else', 'part', 'everywhere', 'too', 'kg', 'herself', 'former', 'those', 'he', 'me', 'myself', 66 | 'made', 'twenty', 'these', 'was', 'bill', 'cant', 'us', 'until', 'besides', 'nevertheless', 'below', 'anywhere', 67 | 'nine', 'can', 'whether', 'of', 'your', 'toward', 'my', 'say', 'something', 'and', 'whereafter', 'whenever', 68 | 'give', 'almost', 'wherever', 'is', 'describe', 'beforehand', 'herein', 'doesn', 'an', 'as', 'itself', 'at', 69 | 'have', 'in', 'seem', 'whence', 'ie', 'any', 'fill', 'again', 'hasnt', 'inc', 'thereby', 'thin', 'no', 'perhaps', 70 | 'latter', 'meanwhile', 'when', 'detail', 'same', 'wherein', 'beside', 'also', 'that', 'other', 'take', 'which', 71 | 'becomes', 'you', 'if', 'nobody', 'unless', 'whereas', 'see', 'though', 'may', 'after', 'upon', 'most', 'hereupon', 72 | 'eight', 'but', 'serious', 'nothing', 'such', 'why', 'off', 'a', 'don', 'whereby', 'third', 'i', 'whole', 'noone', 73 | 'sometimes', 'well', 'amoungst', 'yours', 'their', 'rather', 'without', 'so', 'five', 'the', 'first', 'with', 74 | 'make', 'once' 75 | ]) 76 | 77 | 78 | RE_PUNCT = re.compile(r'([%s])+' % re.escape(string.punctuation), re.UNICODE) 79 | RE_TAGS = re.compile(r"<([^>]+)>", re.UNICODE) 80 | RE_NUMERIC = re.compile(r"[0-9]+", re.UNICODE) 81 | RE_NONALPHA = re.compile(r"\W", re.UNICODE) 82 | RE_AL_NUM = re.compile(r"([a-z]+)([0-9]+)", flags=re.UNICODE) 83 | RE_NUM_AL = re.compile(r"([0-9]+)([a-z]+)", flags=re.UNICODE) 84 | RE_WHITESPACE = re.compile(r"(\s)+", re.UNICODE) 85 | 86 | 87 | def remove_stopwords(s): 88 | """Remove :const:`~gensim.parsing.preprocessing.STOPWORDS` from `s`. 89 | 90 | Parameters 91 | ---------- 92 | s : str 93 | 94 | Returns 95 | ------- 96 | str 97 | Unicode string without :const:`~gensim.parsing.preprocessing.STOPWORDS`. 98 | 99 | Examples 100 | -------- 101 | .. sourcecode:: pycon 102 | 103 | >>> from gensim.parsing.preprocessing import remove_stopwords 104 | >>> remove_stopwords("Better late than never, but better never late.") 105 | u'Better late never, better late.' 106 | 107 | """ 108 | s = utils.to_unicode(s) 109 | return " ".join(w for w in s.split() if w not in STOPWORDS) 110 | 111 | 112 | def strip_punctuation(s): 113 | """Replace punctuation characters with spaces in `s` using :const:`~gensim.parsing.preprocessing.RE_PUNCT`. 114 | 115 | Parameters 116 | ---------- 117 | s : str 118 | 119 | Returns 120 | ------- 121 | str 122 | Unicode string without punctuation characters. 123 | 124 | Examples 125 | -------- 126 | .. sourcecode:: pycon 127 | 128 | >>> from gensim.parsing.preprocessing import strip_punctuation 129 | >>> strip_punctuation("A semicolon is a stronger break than a comma, but not as much as a full stop!") 130 | u'A semicolon is a stronger break than a comma but not as much as a full stop ' 131 | 132 | """ 133 | s = utils.to_unicode(s) 134 | return RE_PUNCT.sub(" ", s) 135 | 136 | 137 | strip_punctuation2 = strip_punctuation 138 | 139 | 140 | def strip_tags(s): 141 | """Remove tags from `s` using :const:`~gensim.parsing.preprocessing.RE_TAGS`. 142 | 143 | Parameters 144 | ---------- 145 | s : str 146 | 147 | Returns 148 | ------- 149 | str 150 | Unicode string without tags. 151 | 152 | Examples 153 | -------- 154 | .. sourcecode:: pycon 155 | 156 | >>> from gensim.parsing.preprocessing import strip_tags 157 | >>> strip_tags("Hello World!") 158 | u'Hello World!' 159 | 160 | """ 161 | s = utils.to_unicode(s) 162 | return RE_TAGS.sub("", s) 163 | 164 | 165 | def strip_short(s, minsize=3): 166 | """Remove words with length lesser than `minsize` from `s`. 167 | 168 | Parameters 169 | ---------- 170 | s : str 171 | minsize : int, optional 172 | 173 | Returns 174 | ------- 175 | str 176 | Unicode string without short words. 177 | 178 | Examples 179 | -------- 180 | .. sourcecode:: pycon 181 | 182 | >>> from gensim.parsing.preprocessing import strip_short 183 | >>> strip_short("salut les amis du 59") 184 | u'salut les amis' 185 | >>> 186 | >>> strip_short("one two three four five six seven eight nine ten", minsize=5) 187 | u'three seven eight' 188 | 189 | """ 190 | s = utils.to_unicode(s) 191 | return " ".join(e for e in s.split() if len(e) >= minsize) 192 | 193 | 194 | def strip_numeric(s): 195 | """Remove digits from `s` using :const:`~gensim.parsing.preprocessing.RE_NUMERIC`. 196 | 197 | Parameters 198 | ---------- 199 | s : str 200 | 201 | Returns 202 | ------- 203 | str 204 | Unicode string without digits. 205 | 206 | Examples 207 | -------- 208 | .. sourcecode:: pycon 209 | 210 | >>> from gensim.parsing.preprocessing import strip_numeric 211 | >>> strip_numeric("0text24gensim365test") 212 | u'textgensimtest' 213 | 214 | """ 215 | s = utils.to_unicode(s) 216 | return RE_NUMERIC.sub("", s) 217 | 218 | 219 | def strip_non_alphanum(s): 220 | """Remove non-alphabetic characters from `s` using :const:`~gensim.parsing.preprocessing.RE_NONALPHA`. 221 | 222 | Parameters 223 | ---------- 224 | s : str 225 | 226 | Returns 227 | ------- 228 | str 229 | Unicode string with alphabetic characters only. 230 | 231 | Notes 232 | ----- 233 | Word characters - alphanumeric & underscore. 234 | 235 | Examples 236 | -------- 237 | .. sourcecode:: pycon 238 | 239 | >>> from gensim.parsing.preprocessing import strip_non_alphanum 240 | >>> strip_non_alphanum("if-you#can%read$this&then@this#method^works") 241 | u'if you can read this then this method works' 242 | 243 | """ 244 | s = utils.to_unicode(s) 245 | return RE_NONALPHA.sub(" ", s) 246 | 247 | 248 | def strip_multiple_whitespaces(s): 249 | r"""Remove repeating whitespace characters (spaces, tabs, line breaks) from `s` 250 | and turns tabs & line breaks into spaces using :const:`~gensim.parsing.preprocessing.RE_WHITESPACE`. 251 | 252 | Parameters 253 | ---------- 254 | s : str 255 | 256 | Returns 257 | ------- 258 | str 259 | Unicode string without repeating in a row whitespace characters. 260 | 261 | Examples 262 | -------- 263 | .. sourcecode:: pycon 264 | 265 | >>> from gensim.parsing.preprocessing import strip_multiple_whitespaces 266 | >>> strip_multiple_whitespaces("salut" + '\r' + " les" + '\n' + " loulous!") 267 | u'salut les loulous!' 268 | 269 | """ 270 | s = utils.to_unicode(s) 271 | return RE_WHITESPACE.sub(" ", s) 272 | 273 | 274 | def split_alphanum(s): 275 | """Add spaces between digits & letters in `s` using :const:`~gensim.parsing.preprocessing.RE_AL_NUM`. 276 | 277 | Parameters 278 | ---------- 279 | s : str 280 | 281 | Returns 282 | ------- 283 | str 284 | Unicode string with spaces between digits & letters. 285 | 286 | Examples 287 | -------- 288 | .. sourcecode:: pycon 289 | 290 | >>> from gensim.parsing.preprocessing import split_alphanum 291 | >>> split_alphanum("24.0hours7 days365 a1b2c3") 292 | u'24.0 hours 7 days 365 a 1 b 2 c 3' 293 | 294 | """ 295 | s = utils.to_unicode(s) 296 | s = RE_AL_NUM.sub(r"\1 \2", s) 297 | return RE_NUM_AL.sub(r"\1 \2", s) 298 | 299 | 300 | def stem_text(text): 301 | """Transform `s` into lowercase and stem it. 302 | 303 | Parameters 304 | ---------- 305 | text : str 306 | 307 | Returns 308 | ------- 309 | str 310 | Unicode lowercased and porter-stemmed version of string `text`. 311 | 312 | Examples 313 | -------- 314 | .. sourcecode:: pycon 315 | 316 | >>> from gensim.parsing.preprocessing import stem_text 317 | >>> stem_text("While it is quite useful to be able to search a large collection of documents almost instantly.") 318 | u'while it is quit us to be abl to search a larg collect of document almost instantly.' 319 | 320 | """ 321 | text = utils.to_unicode(text) 322 | p = PorterStemmer() 323 | return ' '.join(p.stem(word) for word in text.split()) 324 | 325 | 326 | stem = stem_text 327 | 328 | 329 | DEFAULT_FILTERS = [ 330 | lambda x: x.lower(), strip_tags, strip_punctuation, 331 | strip_multiple_whitespaces, strip_numeric, 332 | remove_stopwords, strip_short, stem_text 333 | ] 334 | 335 | 336 | def preprocess_string(s, filters=DEFAULT_FILTERS): 337 | """Apply list of chosen filters to `s`. 338 | 339 | Default list of filters: 340 | 341 | * :func:`~gensim.parsing.preprocessing.strip_tags`, 342 | * :func:`~gensim.parsing.preprocessing.strip_punctuation`, 343 | * :func:`~gensim.parsing.preprocessing.strip_multiple_whitespaces`, 344 | * :func:`~gensim.parsing.preprocessing.strip_numeric`, 345 | * :func:`~gensim.parsing.preprocessing.remove_stopwords`, 346 | * :func:`~gensim.parsing.preprocessing.strip_short`, 347 | * :func:`~gensim.parsing.preprocessing.stem_text`. 348 | 349 | Parameters 350 | ---------- 351 | s : str 352 | filters: list of functions, optional 353 | 354 | Returns 355 | ------- 356 | list of str 357 | Processed strings (cleaned). 358 | 359 | Examples 360 | -------- 361 | .. sourcecode:: pycon 362 | 363 | >>> from gensim.parsing.preprocessing import preprocess_string 364 | >>> preprocess_string("Hel 9lo Wo9 rld! Th3 weather_is really g00d today, isn't it?") 365 | [u'hel', u'rld', u'weather', u'todai', u'isn'] 366 | >>> 367 | >>> s = "Hel 9lo Wo9 rld! Th3 weather_is really g00d today, isn't it?" 368 | >>> CUSTOM_FILTERS = [lambda x: x.lower(), strip_tags, strip_punctuation] 369 | >>> preprocess_string(s, CUSTOM_FILTERS) 370 | [u'hel', u'9lo', u'wo9', u'rld', u'th3', u'weather', u'is', u'really', u'g00d', u'today', u'isn', u't', u'it'] 371 | 372 | """ 373 | s = utils.to_unicode(s) 374 | for f in filters: 375 | s = f(s) 376 | return s.split() 377 | 378 | 379 | def preprocess_documents(docs): 380 | """Apply :const:`~gensim.parsing.preprocessing.DEFAULT_FILTERS` to the documents strings. 381 | 382 | Parameters 383 | ---------- 384 | docs : list of str 385 | 386 | Returns 387 | ------- 388 | list of list of str 389 | Processed documents split by whitespace. 390 | 391 | Examples 392 | -------- 393 | 394 | .. sourcecode:: pycon 395 | 396 | >>> from gensim.parsing.preprocessing import preprocess_documents 397 | >>> preprocess_documents(["Hel 9lo Wo9 rld!", "Th3 weather_is really g00d today, isn't it?"]) 398 | [[u'hel', u'rld'], [u'weather', u'todai', u'isn']] 399 | 400 | """ 401 | return [preprocess_string(d) for d in docs] 402 | 403 | 404 | def read_file(path): 405 | with utils.smart_open(path) as fin: 406 | return fin.read() 407 | 408 | 409 | def read_files(pattern): 410 | return [read_file(fname) for fname in glob.glob(pattern)] 411 | -------------------------------------------------------------------------------- /code/bert_model.py: -------------------------------------------------------------------------------- 1 | """PyTorch BERT model.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import json 9 | import math 10 | import six 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import CrossEntropyLoss 14 | 15 | def gelu(x): 16 | """Implementation of the gelu activation function. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | """ 20 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 21 | 22 | 23 | class BertConfig(object): 24 | """Configuration class to store the configuration of a `BertModel`. 25 | """ 26 | def __init__(self, 27 | vocab_size, 28 | hidden_size=768, 29 | num_hidden_layers=12, 30 | num_attention_heads=12, 31 | intermediate_size=3072, 32 | hidden_act="gelu", 33 | hidden_dropout_prob=0.1, 34 | attention_probs_dropout_prob=0.1, 35 | max_position_embeddings=512, 36 | type_vocab_size=16, 37 | initializer_range=0.02): 38 | """Constructs BertConfig. 39 | Args: 40 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 41 | hidden_size: Size of the encoder layers and the pooler layer. 42 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 43 | num_attention_heads: Number of attention heads for each attention layer in 44 | the Transformer encoder. 45 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 46 | layer in the Transformer encoder. 47 | hidden_act: The non-linear activation function (function or string) in the 48 | encoder and pooler. 49 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 50 | layers in the embeddings, encoder, and pooler. 51 | attention_probs_dropout_prob: The dropout ratio for the attention 52 | probabilities. 53 | max_position_embeddings: The maximum sequence length that this model might 54 | ever be used with. Typically set this to something large just in case 55 | (e.g., 512 or 1024 or 2048). 56 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 57 | `BertModel`. 58 | initializer_range: The sttdev of the truncated_normal_initializer for 59 | initializing all weight matrices. 60 | """ 61 | self.vocab_size = vocab_size 62 | self.hidden_size = hidden_size 63 | self.num_hidden_layers = num_hidden_layers 64 | self.num_attention_heads = num_attention_heads 65 | self.hidden_act = hidden_act 66 | self.intermediate_size = intermediate_size 67 | self.hidden_dropout_prob = hidden_dropout_prob 68 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 69 | self.max_position_embeddings = max_position_embeddings 70 | self.type_vocab_size = type_vocab_size 71 | self.initializer_range = initializer_range 72 | 73 | @classmethod 74 | def from_dict(cls, json_object): 75 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 76 | config = BertConfig(vocab_size=None) 77 | for (key, value) in six.iteritems(json_object): 78 | config.__dict__[key] = value 79 | return config 80 | 81 | @classmethod 82 | def from_json_file(cls, json_file): 83 | """Constructs a `BertConfig` from a json file of parameters.""" 84 | with open(json_file, "r") as reader: 85 | text = reader.read() 86 | return cls.from_dict(json.loads(text)) 87 | 88 | def to_dict(self): 89 | """Serializes this instance to a Python dictionary.""" 90 | output = copy.deepcopy(self.__dict__) 91 | return output 92 | 93 | def to_json_string(self): 94 | """Serializes this instance to a JSON string.""" 95 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 96 | 97 | 98 | class BERTLayerNorm(nn.Module): 99 | def __init__(self, config, variance_epsilon=1e-12): 100 | """Construct a layernorm module in the TF style (epsilon inside the square root). 101 | """ 102 | super(BERTLayerNorm, self).__init__() 103 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 104 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 105 | self.variance_epsilon = variance_epsilon 106 | 107 | def forward(self, x): 108 | u = x.mean(-1, keepdim=True) 109 | s = (x - u).pow(2).mean(-1, keepdim=True) 110 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 111 | return self.gamma * x + self.beta 112 | 113 | class BERTEmbeddings(nn.Module): 114 | def __init__(self, config): 115 | super(BERTEmbeddings, self).__init__() 116 | """Construct the embedding module from word, position and token_type embeddings. 117 | """ 118 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 119 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 120 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 121 | 122 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 123 | # any TensorFlow checkpoint file 124 | self.LayerNorm = BERTLayerNorm(config) 125 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 126 | 127 | def forward(self, input_ids, token_type_ids=None): 128 | seq_length = input_ids.size(1) 129 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 130 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 131 | if token_type_ids is None: 132 | token_type_ids = torch.zeros_like(input_ids) 133 | 134 | words_embeddings = self.word_embeddings(input_ids) 135 | position_embeddings = self.position_embeddings(position_ids) 136 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 137 | 138 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 139 | embeddings = self.LayerNorm(embeddings) 140 | embeddings = self.dropout(embeddings) 141 | return embeddings 142 | 143 | 144 | class BERTSelfAttention(nn.Module): 145 | def __init__(self, config): 146 | super(BERTSelfAttention, self).__init__() 147 | if config.hidden_size % config.num_attention_heads != 0: 148 | raise ValueError( 149 | "The hidden size (%d) is not a multiple of the number of attention " 150 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 151 | self.num_attention_heads = config.num_attention_heads 152 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 153 | self.all_head_size = self.num_attention_heads * self.attention_head_size 154 | 155 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 156 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 157 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 158 | 159 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 160 | 161 | def transpose_for_scores(self, x): 162 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 163 | x = x.view(*new_x_shape) 164 | return x.permute(0, 2, 1, 3) 165 | 166 | def forward(self, hidden_states, attention_mask): 167 | mixed_query_layer = self.query(hidden_states) 168 | mixed_key_layer = self.key(hidden_states) 169 | mixed_value_layer = self.value(hidden_states) 170 | 171 | query_layer = self.transpose_for_scores(mixed_query_layer) 172 | key_layer = self.transpose_for_scores(mixed_key_layer) 173 | value_layer = self.transpose_for_scores(mixed_value_layer) 174 | 175 | # Take the dot product between "query" and "key" to get the raw attention scores. 176 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 177 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 178 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 179 | attention_scores = attention_scores + attention_mask 180 | 181 | # Normalize the attention scores to probabilities. 182 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 183 | 184 | # This is actually dropping out entire tokens to attend to, which might 185 | # seem a bit unusual, but is taken from the original Transformer paper. 186 | attention_probs = self.dropout(attention_probs) 187 | 188 | context_layer = torch.matmul(attention_probs, value_layer) 189 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 190 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 191 | context_layer = context_layer.view(*new_context_layer_shape) 192 | return context_layer 193 | 194 | 195 | class BERTSelfOutput(nn.Module): 196 | def __init__(self, config): 197 | super(BERTSelfOutput, self).__init__() 198 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 199 | self.LayerNorm = BERTLayerNorm(config) 200 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 201 | 202 | def forward(self, hidden_states, input_tensor): 203 | hidden_states = self.dense(hidden_states) 204 | hidden_states = self.dropout(hidden_states) 205 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 206 | return hidden_states 207 | 208 | 209 | class BERTAttention(nn.Module): 210 | def __init__(self, config): 211 | super(BERTAttention, self).__init__() 212 | self.self = BERTSelfAttention(config) 213 | self.output = BERTSelfOutput(config) 214 | 215 | def forward(self, input_tensor, attention_mask): 216 | self_output = self.self(input_tensor, attention_mask) 217 | attention_output = self.output(self_output, input_tensor) 218 | return attention_output 219 | 220 | 221 | class BERTIntermediate(nn.Module): 222 | def __init__(self, config): 223 | super(BERTIntermediate, self).__init__() 224 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 225 | self.intermediate_act_fn = gelu 226 | 227 | def forward(self, hidden_states): 228 | return self.intermediate_act_fn(self.dense(hidden_states)) 229 | 230 | class BERTOutput(nn.Module): 231 | def __init__(self, config): 232 | super(BERTOutput, self).__init__() 233 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 234 | self.LayerNorm = BERTLayerNorm(config) 235 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 236 | 237 | def forward(self, hidden_states, input_tensor): 238 | hidden_states = self.dense(hidden_states) 239 | hidden_states = self.dropout(hidden_states) 240 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 241 | return hidden_states 242 | 243 | 244 | class BERTLayer(nn.Module): 245 | def __init__(self, config): 246 | super(BERTLayer, self).__init__() 247 | self.attention = BERTAttention(config) 248 | self.intermediate = BERTIntermediate(config) 249 | self.output = BERTOutput(config) 250 | 251 | def forward(self, hidden_states, attention_mask): 252 | attention_output = self.attention(hidden_states, attention_mask) 253 | intermediate_output = self.intermediate(attention_output) 254 | layer_output = self.output(intermediate_output, attention_output) 255 | return layer_output 256 | 257 | 258 | class BERTEncoder(nn.Module): 259 | def __init__(self, config): 260 | super(BERTEncoder, self).__init__() 261 | layer = BERTLayer(config) 262 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 263 | 264 | def forward(self, hidden_states, attention_mask): 265 | for layer_module in self.layer: 266 | hidden_states = layer_module(hidden_states, attention_mask) 267 | return hidden_states 268 | 269 | 270 | class BERTPooler(nn.Module): 271 | def __init__(self, config): 272 | super(BERTPooler, self).__init__() 273 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 274 | self.activation = nn.Tanh() 275 | 276 | def forward(self, hidden_states): 277 | # We "pool" the model by simply taking the hidden state corresponding 278 | # to the first token. 279 | first_token_tensor = hidden_states[:, 0] 280 | pooled_output = self.dense(first_token_tensor) 281 | pooled_output = self.activation(pooled_output) 282 | return pooled_output 283 | 284 | 285 | class BertModel(nn.Module): 286 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 287 | Example usage: 288 | ```python 289 | # Already been converted into WordPiece token ids 290 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 291 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 292 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 293 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 294 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 295 | model = modeling.BertModel(config=config) 296 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 297 | ``` 298 | """ 299 | def __init__(self, config: BertConfig): 300 | """Constructor for BertModel. 301 | Args: 302 | config: `BertConfig` instance. 303 | """ 304 | super(BertModel, self).__init__() 305 | self.embeddings = BERTEmbeddings(config) 306 | self.encoder = BERTEncoder(config) 307 | self.pooler = BERTPooler(config) 308 | 309 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 310 | if attention_mask is None: 311 | attention_mask = torch.ones_like(input_ids) 312 | if token_type_ids is None: 313 | token_type_ids = torch.zeros_like(input_ids) 314 | 315 | # We create a 3D attention mask from a 2D tensor mask. 316 | # Sizes are [batch_size, 1, 1, from_seq_length] 317 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 318 | # this attention mask is more simple than the triangular masking of causal attention 319 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 320 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 321 | 322 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 323 | # masked positions, this operation will create a tensor which is 0.0 for 324 | # positions we want to attend and -10000.0 for masked positions. 325 | # Since we are adding it to the raw scores before the softmax, this is 326 | # effectively the same as removing these entirely. 327 | extended_attention_mask = extended_attention_mask.float() 328 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 329 | 330 | embedding_output = self.embeddings(input_ids, token_type_ids) 331 | hidden_states = self.encoder(embedding_output, extended_attention_mask) 332 | pooled_output = self.pooler(hidden_states) 333 | return pooled_output 334 | 335 | 336 | class BertEdgeScorer(nn.Module): 337 | """BERT model for computing sentence similarity and scoring edges. 338 | 339 | """ 340 | def __init__(self, config): 341 | super(BertEdgeScorer, self).__init__() 342 | self.bert = BertModel(config) 343 | 344 | def forward(self, input_ids, token_type_ids, attention_mask, input_ids_c, token_type_ids_c, attention_mask_c, labels=None): 345 | #nputs_id: B * T 346 | 347 | pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 348 | pooled_output_c = self.bert(input_ids_c, token_type_ids_c, attention_mask_c) 349 | logits = torch.bmm(pooled_output.unsqueeze(1), pooled_output_c.unsqueeze(2)).view(-1) 350 | pros = torch.sigmoid(logits) 351 | 352 | return logits, pros 353 | --------------------------------------------------------------------------------