├── requirements.txt
├── .gitignore
├── README.md
└── code
├── run.py
├── utils.py
├── data_iterator.py
├── extractor.py
├── tokenizer.py
├── gensim_preprocess.py
└── bert_model.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | gensim
2 | h5py
3 | numpy
4 | Pillow
5 | pyrouge
6 | scipy
7 | six
8 | smart-open
9 | torch>=1.0
10 | torchaudio
11 | torchvision
12 | typing-extensions
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # the environments directory
2 | env/
3 | venv/
4 |
5 | # cache directories
6 | **/__pycache__/**
7 |
8 | # when running the code a temp directory is created containing the extracted data
9 | temp/
10 |
11 | # data
12 | data/
13 |
14 | # models
15 | pacssum_models/
16 |
17 | # user's created file
18 | scripts.sh
19 |
20 | # user's config file
21 | setup.ini
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PacSum
2 |
3 | This code is for paper [Sentence Centrality Revisited for Unsupervised Summarization](https://arxiv.org/pdf/1906.03508.pdf) ACL 2019
4 |
5 | Some codes are borrowed from [pytorch_pretrained_bert](https://github.com/huggingface/pytorch-transformers) and [gensim](https://github.com/RaRe-Technologies/gensim)
6 |
7 |
8 | -------
9 | ### Dependencies
10 | Python3.6, pytorch >= 1.0, numpy, gensim, pyrouge
11 |
12 |
13 | -------
14 | ### Data used in the paper:
15 |
16 | Download https://drive.google.com/open?id=1gNKWkZG4dVr5XrOeQBVicy1fdnpH2d5l
17 |
18 | ### Bert models fine-tuned using the approach in the paper:
19 |
20 | Download https://drive.google.com/file/d/1wbMlLmnbD_0j7Qs8YY8cSCh935WKKdsP/view?usp=sharing
21 |
22 |
23 | ### Tuning the hyperparamters and test the performance using TfIdf or BERT representation
24 | ```
25 | python run.py --rep tfidf --mode tune --tune_data_file path/to/validation/data --test_data_file path/to/test/data
26 | ```
27 | ```
28 | python run.py --rep bert --mode tune --tune_data_file path/to/validation/data --test_data_file path/to/test/data --bert_model_file path/to/model --bert_config_file path/to/config --bert_vocab_file path/to/vocab
29 | ```
30 |
--------------------------------------------------------------------------------
/code/run.py:
--------------------------------------------------------------------------------
1 | from extractor import PacSumExtractorWithBert, PacSumExtractorWithTfIdf
2 | from data_iterator import Dataset
3 |
4 | import argparse
5 |
6 | if __name__ == '__main__':
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--mode', type=str, choices = ['tune', 'test'], help='tune or test')
10 | parser.add_argument('--rep', type=str, choices = ['tfidf', 'bert'], help='tfidf or bert')
11 | parser.add_argument('--extract_num', type=int, default=3, help='number of extracted sentences')
12 | parser.add_argument('--bert_config_file', type=str, default='/disk/scratch1/s1797858/bert_model_path/uncased_L-12_H-768_A-12/bert_config.json', help='bert configuration file')
13 | parser.add_argument('--bert_model_file', type=str, help='bert model file')
14 | parser.add_argument('--bert_vocab_file', type=str, default='/disk/scratch1/s1797858/bert_model_path/uncased_L-12_H-768_A-12/vocab.txt',help='bert vocabulary file')
15 |
16 | parser.add_argument('--beta', type=float, default=0., help='beta')
17 | parser.add_argument('--lambda1', type=float, default=0., help='lambda1')
18 | parser.add_argument('--lambda2', type=float, default=1., help='lambda2')
19 |
20 | parser.add_argument('--tune_data_file', type=str, help='data for tunining hyperparameters')
21 | parser.add_argument('--test_data_file', type=str, help='data for testing')
22 |
23 |
24 |
25 | args = parser.parse_args()
26 | print(args)
27 |
28 | if args.rep == 'tfidf':
29 | extractor = PacSumExtractorWithTfIdf(beta = args.beta,
30 | lambda1=args.lambda1,
31 | lambda2=args.lambda2)
32 | #tune
33 | if args.mode == 'tune':
34 | tune_dataset = Dataset(args.tune_data_file)
35 | tune_dataset_iterator = tune_dataset.iterate_once_doc_tfidf()
36 | extractor.tune_hparams(tune_dataset_iterator)
37 |
38 | #test
39 | test_dataset = Dataset(args.test_data_file)
40 | test_dataset_iterator = test_dataset.iterate_once_doc_tfidf()
41 | extractor.extract_summary(test_dataset_iterator)
42 |
43 |
44 |
45 | elif args.rep == 'bert':
46 | extractor = PacSumExtractorWithBert(bert_model_file = args.bert_model_file,
47 | bert_config_file = args.bert_config_file,
48 | beta = args.beta,
49 | lambda1=args.lambda1,
50 | lambda2=args.lambda2)
51 | #tune
52 | if args.mode == 'tune':
53 | tune_dataset = Dataset(args.tune_data_file, vocab_file = args.bert_vocab_file)
54 | tune_dataset_iterator = tune_dataset.iterate_once_doc_bert()
55 | extractor.tune_hparams(tune_dataset_iterator)
56 |
57 | #test
58 | test_dataset = Dataset(args.test_data_file, vocab_file = args.bert_vocab_file)
59 | test_dataset_iterator = test_dataset.iterate_once_doc_bert()
60 | extractor.extract_summary(test_dataset_iterator)
61 |
--------------------------------------------------------------------------------
/code/utils.py:
--------------------------------------------------------------------------------
1 | from pyrouge import Rouge155
2 | import os, shutil, random, string
3 |
4 | from gensim_preprocess import preprocess_documents
5 |
6 |
7 | def evaluate_rouge(summaries, references, remove_temp=False, rouge_args=[]):
8 | '''
9 | Args:
10 | summaries: [[sentence]]. Each summary is a list of strings (sentences)
11 | references: [[[sentence]]]. Each reference is a list of candidate summaries.
12 | remove_temp: bool. Whether to remove the temporary files created during evaluation.
13 | rouge_args: [string]. A list of arguments to pass to the ROUGE CLI.
14 | '''
15 | temp_dir = ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
16 | temp_dir = os.path.join("temp",temp_dir)
17 | print(temp_dir)
18 | system_dir = os.path.join(temp_dir, 'system')
19 | model_dir = os.path.join(temp_dir, 'model')
20 | # directory for generated summaries
21 | os.makedirs(system_dir)
22 | # directory for reference summaries
23 | os.makedirs(model_dir)
24 | print(temp_dir, system_dir, model_dir)
25 |
26 | assert len(summaries) == len(references)
27 | for i, (summary, candidates) in enumerate(zip(summaries, references)):
28 | summary_fn = '%i.txt' % i
29 | for j, candidate in enumerate(candidates):
30 | candidate_fn = '%i.%i.txt' % (i, j)
31 | with open(os.path.join(model_dir, candidate_fn), 'w') as f:
32 | #print(candidate)
33 | f.write('\n'.join(candidate))
34 |
35 | with open(os.path.join(system_dir, summary_fn), 'w') as f:
36 | f.write('\n'.join(summary))
37 |
38 | args_str = ' '.join(map(str, rouge_args))
39 | rouge = Rouge155(rouge_args=args_str)
40 | rouge.system_dir = system_dir
41 | rouge.model_dir = model_dir
42 | rouge.system_filename_pattern = '(\d+).txt'
43 | rouge.model_filename_pattern = '#ID#.\d+.txt'
44 |
45 | #rouge_args = '-c 95 -2 -1 -U -r 1000 -n 4 -w 1.2 -a'
46 | #output = rouge.convert_and_evaluate(rouge_args=rouge_args)
47 | output = rouge.convert_and_evaluate()
48 |
49 | r = rouge.output_to_dict(output)
50 | print(output)
51 | #print(r)
52 |
53 | # remove the created temporary files
54 | #if remove_temp:
55 | # shutil.rmtree(temp_dir)
56 | return r
57 |
58 | def clean_text_by_sentences(text):
59 | """Tokenize a given text into sentences, applying filters and lemmatize them.
60 |
61 | Parameters
62 | ----------
63 | text : str
64 | Given text.
65 |
66 | Returns
67 | -------
68 | list of :class:`~gensim.summarization.syntactic_unit.SyntacticUnit`
69 | Sentences of the given text.
70 |
71 | """
72 | original_sentences = text
73 | filtered_sentences = [join_words(sentence) for sentence in preprocess_documents(original_sentences)]
74 |
75 | return filtered_sentences
76 |
77 |
78 | def join_words(words, separator=" "):
79 | """Concatenates `words` with `separator` between elements.
80 |
81 | Parameters
82 | ----------
83 | words : list of str
84 | Given words.
85 | separator : str, optional
86 | The separator between elements.
87 |
88 | Returns
89 | -------
90 | str
91 | String of merged words with separator between elements.
92 |
93 | """
94 | return separator.join(words)
95 |
--------------------------------------------------------------------------------
/code/data_iterator.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import glob
3 | import json
4 | import random
5 | import math
6 |
7 | import numpy as np
8 | import torch
9 | import h5py
10 |
11 | from tokenizer import FullTokenizer
12 | from utils import clean_text_by_sentences
13 |
14 |
15 | class Dataset(object):
16 |
17 | def __init__(self, file_pattern = None, vocab_file = None):
18 |
19 | self._file_pattern = file_pattern
20 | self._max_len = 60
21 | print("input max len : "+str(self._max_len))
22 | if vocab_file is not None:
23 | self._tokenizer = FullTokenizer(vocab_file, True)
24 |
25 | def iterate_once_doc_tfidf(self):
26 | def file_stream():
27 | for file_name in glob.glob(self._file_pattern):
28 | yield file_name
29 | for value in self._doc_stream_tfidf(file_stream()):
30 | yield value
31 |
32 | def _doc_stream_tfidf(self, file_stream):
33 | for file_name in file_stream:
34 | for doc in self._parse_file2doc_tfidf(file_name):
35 | yield doc
36 |
37 | def _parse_file2doc_tfidf(self, file_name):
38 | print("Processing file: %s" % file_name)
39 | with h5py.File(file_name,'r') as f:
40 | for j_str in f['dataset']:
41 | obj = json.loads(j_str)
42 | article, abstract = obj['article'], obj['abstract']
43 | #article, abstract = obj['article'], obj['abstracts']
44 | clean_article = clean_text_by_sentences(article)
45 | segmented_artile = [sentence.split() for sentence in clean_article]
46 | #print(tokenized_article[0])
47 |
48 | yield article, abstract, [segmented_artile]
49 |
50 |
51 | def iterate_once_doc_bert(self):
52 | def file_stream():
53 | for file_name in glob.glob(self._file_pattern):
54 | yield file_name
55 | for value in self._doc_iterate_bert(self._doc_stream_bert(file_stream())):
56 | yield value
57 |
58 | def _doc_stream_bert(self, file_stream):
59 | for file_name in file_stream:
60 | for doc in self._parse_file2doc_bert(file_name):
61 | yield doc
62 |
63 | def _parse_file2doc_bert(self, file_name):
64 | print("Processing file: %s" % file_name)
65 | with h5py.File(file_name,'r') as f:
66 | for j_str in f['dataset']:
67 | obj = json.loads(j_str)
68 | article, abstract = obj['article'], obj['abstract']
69 | #article, abstract = obj['article'], obj['abstracts']
70 | tokenized_article = [self._tokenizer.tokenize(sen) for sen in article]
71 | #print(tokenized_article[0])
72 |
73 | article_token_ids = []
74 | article_seg_ids = []
75 | article_token_ids_c = []
76 | article_seg_ids_c = []
77 | pair_indice = []
78 | k = 0
79 | for i in range(len(article)):
80 | for j in range(i+1, len(article)):
81 |
82 | tokens_a = tokenized_article[i]
83 | tokens_b = tokenized_article[j]
84 |
85 | input_ids, segment_ids = self._2bert_rep(tokens_a)
86 | input_ids_c, segment_ids_c = self._2bert_rep(tokens_b)
87 | assert len(input_ids) == len(segment_ids)
88 | assert len(input_ids_c) == len(segment_ids_c)
89 | article_token_ids.append(input_ids)
90 | article_seg_ids.append(segment_ids)
91 | article_token_ids_c.append(input_ids_c)
92 | article_seg_ids_c.append(segment_ids_c)
93 |
94 | pair_indice.append(((i,j), k))
95 | k+=1
96 | yield article_token_ids, article_seg_ids, article_token_ids_c, article_seg_ids_c, pair_indice, article, abstract
97 |
98 | def _doc_iterate_bert(self, docs):
99 |
100 | for article_token_ids, article_seg_ids, article_token_ids_c, article_seg_ids_c, pair_indice, article, abstract in docs:
101 |
102 | if len(article_token_ids) == 0:
103 | yield None, None, None, None, None, None, pair_indice, article, abstract
104 | continue
105 | num_steps = max(len(item) for item in article_token_ids)
106 | #num_steps = max(len(item) for item in iarticle)
107 | batch_size = len(article_token_ids)
108 | x = np.zeros([batch_size, num_steps], np.int32)
109 | t = np.zeros([batch_size, num_steps], np.int32)
110 | w = np.zeros([batch_size, num_steps], np.uint8)
111 |
112 | num_steps_c = max(len(item) for item in article_token_ids_c)
113 | #num_steps = max(len(item) for item in iarticle)
114 | x_c = np.zeros([batch_size, num_steps_c], np.int32)
115 | t_c = np.zeros([batch_size, num_steps_c], np.int32)
116 | w_c = np.zeros([batch_size, num_steps_c], np.uint8)
117 | for i in range(batch_size):
118 | num_tokens = len(article_token_ids[i])
119 | x[i,:num_tokens] = article_token_ids[i]
120 | t[i,:num_tokens] = article_seg_ids[i]
121 | w[i,:num_tokens] = 1
122 |
123 | num_tokens_c = len(article_token_ids_c[i])
124 | x_c[i,:num_tokens_c] = article_token_ids_c[i]
125 | t_c[i,:num_tokens_c] = article_seg_ids_c[i]
126 | w_c[i,:num_tokens_c] = 1
127 |
128 | if not np.any(w):
129 | return
130 | out_x = torch.LongTensor(x)
131 | out_t = torch.LongTensor(t)
132 | out_w = torch.LongTensor(w)
133 |
134 | out_x_c = torch.LongTensor(x_c)
135 | out_t_c = torch.LongTensor(t_c)
136 | out_w_c = torch.LongTensor(w_c)
137 |
138 | yield article, abstract, (out_x, out_t, out_w, out_x_c, out_t_c, out_w_c, pair_indice)
139 |
140 | def _2bert_rep(self, tokens_a, tokens_b=None):
141 |
142 | if tokens_b is None:
143 | tokens_a = tokens_a[: self._max_len - 2]
144 | else:
145 | self._truncate_seq_pair(tokens_a, tokens_b, self._max_len - 3)
146 |
147 | tokens = []
148 | segment_ids = []
149 | tokens.append("[CLS]")
150 | segment_ids.append(0)
151 | for token in tokens_a:
152 | tokens.append(token)
153 | segment_ids.append(0)
154 | tokens.append("[SEP]")
155 | segment_ids.append(0)
156 |
157 | if tokens_b is not None:
158 |
159 | for token in tokens_b:
160 | tokens.append(token)
161 | segment_ids.append(1)
162 |
163 | tokens.append("[SEP]")
164 | segment_ids.append(1)
165 | #print(tokens)
166 | input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
167 | #print(input_ids)
168 |
169 | return input_ids, segment_ids
170 |
171 | def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
172 | """Truncates a sequence pair in place to the maximum length."""
173 |
174 | # This is a simple heuristic which will always truncate the longer sequence
175 | # one token at a time. This makes more sense than truncating an equal percent
176 | # of tokens from each, since if one sequence is very short then each token
177 | # that's truncated likely contains more information than a longer sequence.
178 | while True:
179 | total_length = len(tokens_a) + len(tokens_b)
180 | if total_length <= max_length:
181 | break
182 | if len(tokens_a) > len(tokens_b):
183 | tokens_a.pop()
184 | else:
185 | tokens_b.pop()
186 |
--------------------------------------------------------------------------------
/code/extractor.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import numpy as np
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import random
7 | import time
8 | import io
9 | import codecs
10 |
11 |
12 |
13 | from utils import evaluate_rouge
14 | from bert_model import BertEdgeScorer, BertConfig
15 |
16 | class PacSumExtractor:
17 |
18 | def __init__(self, extract_num = 3, beta = 3, lambda1 = -0.2, lambda2 = -0.2):
19 |
20 | self.extract_num = extract_num
21 | self.beta = beta
22 | self.lambda1 = lambda1
23 | self.lambda2 = lambda2
24 |
25 | def extract_summary(self, data_iterator):
26 |
27 | summaries = []
28 | references = []
29 |
30 | for item in data_iterator:
31 | article, abstract, inputs = item
32 | if len(article) <= self.extract_num:
33 | summaries.append(article)
34 | references.append([abstract])
35 | continue
36 |
37 | edge_scores = self._calculate_similarity_matrix(*inputs)
38 | ids = self._select_tops(edge_scores, beta=self.beta, lambda1=self.lambda1, lambda2=self.lambda2)
39 | summary = list(map(lambda x: article[x], ids))
40 | summaries.append(summary)
41 | references.append([abstract])
42 |
43 | result = evaluate_rouge(summaries, references, remove_temp=True, rouge_args=[])
44 |
45 | def tune_hparams(self, data_iterator, example_num=1000):
46 |
47 |
48 | summaries, references = [], []
49 | k = 0
50 | for item in data_iterator:
51 | article, abstract, inputs = item
52 | edge_scores = self._calculate_similarity_matrix(*inputs)
53 | tops_list, hparam_list = self._tune_extractor(edge_scores)
54 |
55 | summary_list = [list(map(lambda x: article[x], ids)) for ids in tops_list]
56 | summaries.append(summary_list)
57 | references.append([abstract])
58 | k += 1
59 | print(k)
60 | if k % example_num == 0:
61 | break
62 |
63 | best_rouge = 0
64 | best_hparam = None
65 | for i in range(len(summaries[0])):
66 | print("threshold : "+str(hparam_list[i])+'\n')
67 | #print("non-lead ratio : "+str(ratios[i])+'\n')
68 | result = evaluate_rouge([summaries[k][i] for k in range(len(summaries))], references, remove_temp=True, rouge_args=[])
69 |
70 | if result['rouge_1_f_score'] > best_rouge:
71 | best_rouge = result['rouge_1_f_score']
72 | best_hparam = hparam_list[i]
73 |
74 | print("The best hyper-parameter : beta %.4f , lambda1 %.4f, lambda2 %.4f " % (best_hparam[0], best_hparam[1], best_hparam[2]))
75 | print("The best rouge_1_f_score : %.4f " % best_rouge)
76 |
77 | self.beta = best_hparam[0]
78 | self.lambda1 = best_hparam[1]
79 | self.lambda2 = best_hparam[2]
80 |
81 | def _calculate_similarity_matrix(self, *inputs):
82 |
83 | raise NotImplementedError
84 |
85 |
86 | def _select_tops(self, edge_scores, beta, lambda1, lambda2):
87 |
88 | min_score = edge_scores.min()
89 | max_score = edge_scores.max()
90 | edge_threshold = min_score + beta * (max_score - min_score)
91 | new_edge_scores = edge_scores - edge_threshold
92 | forward_scores, backward_scores, _ = self._compute_scores(new_edge_scores, 0)
93 | forward_scores = 0 - forward_scores
94 |
95 | paired_scores = []
96 | for node in range(len(forward_scores)):
97 | paired_scores.append([node, lambda1 * forward_scores[node] + lambda2 * backward_scores[node]])
98 |
99 | #shuffle to avoid any possible bias
100 | random.shuffle(paired_scores)
101 | paired_scores.sort(key = lambda x: x[1], reverse = True)
102 | extracted = [item[0] for item in paired_scores[:self.extract_num]]
103 |
104 |
105 | return extracted
106 |
107 | def _compute_scores(self, similarity_matrix, edge_threshold):
108 |
109 | forward_scores = [0 for i in range(len(similarity_matrix))]
110 | backward_scores = [0 for i in range(len(similarity_matrix))]
111 | edges = []
112 | for i in range(len(similarity_matrix)):
113 | for j in range(i+1, len(similarity_matrix[i])):
114 | edge_score = similarity_matrix[i][j]
115 | if edge_score > edge_threshold:
116 | forward_scores[j] += edge_score
117 | backward_scores[i] += edge_score
118 | edges.append((i,j,edge_score))
119 |
120 | return np.asarray(forward_scores), np.asarray(backward_scores), edges
121 |
122 |
123 | def _tune_extractor(self, edge_scores):
124 |
125 | tops_list = []
126 | hparam_list = []
127 | num = 10
128 | for k in range(num + 1):
129 | beta = k / num
130 | for i in range(11):
131 | lambda1 = i/10
132 | lambda2 = 1 - lambda1
133 | extracted = self._select_tops(edge_scores, beta=beta, lambda1=lambda1, lambda2=lambda2)
134 |
135 | tops_list.append(extracted)
136 | hparam_list.append((beta, lambda1, lambda2))
137 |
138 | return tops_list, hparam_list
139 |
140 |
141 | class PacSumExtractorWithBert(PacSumExtractor):
142 |
143 | def __init__(self, bert_model_file, bert_config_file, extract_num = 3, beta = 3, lambda1 = -0.2, lambda2 = -0.2):
144 |
145 | super(PacSumExtractorWithBert, self).__init__(extract_num, beta, lambda1, lambda2)
146 | self.model = self._load_edge_model(bert_model_file, bert_config_file)
147 |
148 | def _calculate_similarity_matrix(self, x, t, w, x_c, t_c, w_c, pair_indice):
149 | #doc: a list of sequences, each sequence is a list of words
150 |
151 | def pairdown(scores, pair_indice, length):
152 | #1 for self score
153 | out_matrix = np.ones((length, length))
154 | for pair in pair_indice:
155 | out_matrix[pair[0][0]][pair[0][1]] = scores[pair[1]]
156 | out_matrix[pair[0][1]][pair[0][0]] = scores[pair[1]]
157 |
158 | return out_matrix
159 |
160 | scores = self._generate_score(x, t, w, x_c, t_c, w_c)
161 | doc_len = int(math.sqrt(len(x)*2)) + 1
162 | similarity_matrix = pairdown(scores, pair_indice, doc_len)
163 |
164 | return similarity_matrix
165 |
166 | def _generate_score(self, x, t, w, x_c, t_c, w_c):
167 |
168 | #score = log PMI -log k
169 | scores = torch.zeros(len(x)).cuda()
170 | step = 20
171 | for i in range(0,len(x),step):
172 |
173 | batch_x = x[i:i+step]
174 | batch_t = t[i:i+step]
175 | batch_w = w[i:i+step]
176 | batch_x_c = x_c[i:i+step]
177 | batch_t_c = t_c[i:i+step]
178 | batch_w_c = w_c[i:i+step]
179 |
180 | inputs = tuple(t.to('cuda') for t in (batch_x, batch_t, batch_w, batch_x_c, batch_t_c, batch_w_c))
181 | batch_scores, batch_pros = self.model(*inputs)
182 | scores[i:i+step] = batch_scores.detach()
183 |
184 |
185 | return scores
186 |
187 | def _load_edge_model(self, bert_model_file, bert_config_file):
188 |
189 | bert_config = BertConfig.from_json_file(bert_config_file)
190 | model = BertEdgeScorer(bert_config)
191 | model_states = torch.load(bert_model_file)
192 | print(model_states.keys())
193 | model.bert.load_state_dict(model_states)
194 |
195 | model.cuda()
196 | model.eval()
197 | return model
198 |
199 |
200 | class PacSumExtractorWithTfIdf(PacSumExtractor):
201 |
202 | def __init__(self, extract_num = 3, beta = 3, lambda1 = -0.2, lambda2 = -0.2):
203 |
204 | super(PacSumExtractorWithTfIdf, self).__init__(extract_num, beta, lambda1, lambda2)
205 |
206 | def _calculate_similarity_matrix(self, doc):
207 |
208 | idf_score = self._calculate_idf_scores(doc)
209 |
210 | tf_scores = [
211 | Counter(sentence) for sentence in doc
212 | ]
213 | length = len(doc)
214 |
215 | similarity_matrix = np.zeros([length] * 2)
216 |
217 | for i in range(length):
218 | for j in range(i, length):
219 | similarity = self._idf_modified_dot(tf_scores, i, j, idf_score)
220 |
221 | if similarity:
222 | similarity_matrix[i, j] = similarity
223 | similarity_matrix[j, i] = similarity
224 |
225 | return similarity_matrix
226 |
227 |
228 | def _idf_modified_dot(self, tf_scores, i, j, idf_score):
229 |
230 | if i == j:
231 | return 1
232 |
233 | tf_i, tf_j = tf_scores[i], tf_scores[j]
234 | words_i, words_j = set(tf_i.keys()), set(tf_j.keys())
235 |
236 | score = 0
237 |
238 | for word in words_i & words_j:
239 | idf = idf_score[word]
240 | score += tf_i[word] * tf_j[word] * idf ** 2
241 |
242 |
243 | return score
244 |
245 |
246 | def _calculate_idf_scores(self, doc):
247 |
248 | doc_number_total = 0.
249 | df = {}
250 | for i, sen in enumerate(doc):
251 | tf = Counter(sen)
252 | for word in tf.keys():
253 | if word not in df:
254 | df[word] = 0
255 | df[word] += 1
256 | doc_number_total += 1
257 |
258 | idf_score = {}
259 | for word, freq in df.items():
260 | idf_score[word] = math.log(doc_number_total - freq + 0.5) - math.log(freq + 0.5)
261 |
262 | return idf_score
263 |
--------------------------------------------------------------------------------
/code/tokenizer.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import collections
3 | import unicodedata
4 | import six
5 |
6 |
7 | def convert_to_unicode(text):
8 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
9 | if six.PY3:
10 | if isinstance(text, str):
11 | return text
12 | elif isinstance(text, bytes):
13 | return text.decode("utf-8", "ignore")
14 | else:
15 | raise ValueError("Unsupported string type: %s" % (type(text)))
16 | elif six.PY2:
17 | if isinstance(text, str):
18 | return text.decode("utf-8", "ignore")
19 | elif isinstance(text, unicode):
20 | return text
21 | else:
22 | raise ValueError("Unsupported string type: %s" % (type(text)))
23 | else:
24 | raise ValueError("Not running on Python2 or Python 3?")
25 |
26 |
27 | def printable_text(text):
28 | """Returns text encoded in a way suitable for print or `tf.logging`."""
29 |
30 | # These functions want `str` for both Python2 and Python3, but in one case
31 | # it's a Unicode string and in the other it's a byte string.
32 | if six.PY3:
33 | if isinstance(text, str):
34 | return text
35 | elif isinstance(text, bytes):
36 | return text.decode("utf-8", "ignore")
37 | else:
38 | raise ValueError("Unsupported string type: %s" % (type(text)))
39 | elif six.PY2:
40 | if isinstance(text, str):
41 | return text
42 | elif isinstance(text, unicode):
43 | return text.encode("utf-8")
44 | else:
45 | raise ValueError("Unsupported string type: %s" % (type(text)))
46 | else:
47 | raise ValueError("Not running on Python2 or Python 3?")
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | index = 0
54 | with open(vocab_file, "r") as reader:
55 | #with codecs.open(vocab_file, "r", "utf-8") as reader:
56 | while True:
57 | line = reader.readline()
58 | #print(line)
59 | token = convert_to_unicode(line)
60 | if not token:
61 | break
62 | token = token.strip()
63 | vocab[token] = index
64 | index += 1
65 | #print(index)
66 | #print(str(vocab))
67 | return vocab
68 |
69 |
70 | def convert_tokens_to_ids(vocab, tokens):
71 | """Converts a sequence of tokens into ids using the vocab."""
72 | ids = []
73 | for token in tokens:
74 | ids.append(vocab[token])
75 | return ids
76 |
77 |
78 | def whitespace_tokenize(text):
79 | """Runs basic whitespace cleaning and splitting on a peice of text."""
80 | text = text.strip()
81 | if not text:
82 | return []
83 | tokens = text.split()
84 | return tokens
85 |
86 |
87 | class FullTokenizer(object):
88 | """Runs end-to-end tokenziation."""
89 |
90 | def __init__(self, vocab_file, do_lower_case=True):
91 | self.vocab = load_vocab(vocab_file)
92 | #print(self.vocab)
93 | #print(len(self.vocab))
94 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
95 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
96 |
97 | def tokenize(self, text):
98 | split_tokens = []
99 | for token in self.basic_tokenizer.tokenize(text):
100 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
101 | split_tokens.append(sub_token)
102 |
103 | return split_tokens
104 |
105 | def convert_tokens_to_ids(self, tokens):
106 | return convert_tokens_to_ids(self.vocab, tokens)
107 |
108 |
109 | class BasicTokenizer(object):
110 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
111 |
112 | def __init__(self, do_lower_case=True):
113 | """Constructs a BasicTokenizer.
114 | Args:
115 | do_lower_case: Whether to lower case the input.
116 | """
117 | self.do_lower_case = do_lower_case
118 |
119 | def tokenize(self, text):
120 | """Tokenizes a piece of text."""
121 | text = convert_to_unicode(text)
122 | text = self._clean_text(text)
123 |
124 | text = self._tokenize_chinese_chars(text)
125 |
126 | orig_tokens = whitespace_tokenize(text)
127 | split_tokens = []
128 | for token in orig_tokens:
129 | if self.do_lower_case:
130 | token = token.lower()
131 | token = self._run_strip_accents(token)
132 | split_tokens.extend(self._run_split_on_punc(token))
133 |
134 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
135 | return output_tokens
136 |
137 | def _run_strip_accents(self, text):
138 | """Strips accents from a piece of text."""
139 | text = unicodedata.normalize("NFD", text)
140 | output = []
141 | for char in text:
142 | cat = unicodedata.category(char)
143 | if cat == "Mn":
144 | continue
145 | output.append(char)
146 | return "".join(output)
147 |
148 | def _run_split_on_punc(self, text):
149 | """Splits punctuation on a piece of text."""
150 | chars = list(text)
151 | i = 0
152 | start_new_word = True
153 | output = []
154 | while i < len(chars):
155 | char = chars[i]
156 | if _is_punctuation(char):
157 | output.append([char])
158 | start_new_word = True
159 | else:
160 | if start_new_word:
161 | output.append([])
162 | start_new_word = False
163 | output[-1].append(char)
164 | i += 1
165 |
166 | return ["".join(x) for x in output]
167 |
168 | def _tokenize_chinese_chars(self, text):
169 | """Adds whitespace around any CJK character."""
170 | output = []
171 | for char in text:
172 | cp = ord(char)
173 | if self._is_chinese_char(cp):
174 | output.append(" ")
175 | output.append(char)
176 | output.append(" ")
177 | else:
178 | output.append(char)
179 | return "".join(output)
180 |
181 | def _is_chinese_char(self, cp):
182 | """Checks whether CP is the codepoint of a CJK character."""
183 | # This defines a "chinese character" as anything in the CJK Unicode block:
184 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
185 | #
186 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
187 | # despite its name. The modern Korean Hangul alphabet is a different block,
188 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
189 | # space-separated words, so they are not treated specially and handled
190 | # like the all of the other languages.
191 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
192 | (cp >= 0x3400 and cp <= 0x4DBF) or #
193 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
194 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
195 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
196 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
197 | (cp >= 0xF900 and cp <= 0xFAFF) or #
198 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
199 | return True
200 |
201 | return False
202 |
203 |
204 | def _clean_text(self, text):
205 | """Performs invalid character removal and whitespace cleanup on text."""
206 | output = []
207 | for char in text:
208 | cp = ord(char)
209 | if cp == 0 or cp == 0xfffd or _is_control(char):
210 | continue
211 | if _is_whitespace(char):
212 | output.append(" ")
213 | else:
214 | output.append(char)
215 | return "".join(output)
216 |
217 |
218 | class WordpieceTokenizer(object):
219 | """Runs WordPiece tokenization."""
220 |
221 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
222 | self.vocab = vocab
223 | self.unk_token = unk_token
224 | self.max_input_chars_per_word = max_input_chars_per_word
225 |
226 | def tokenize(self, text):
227 | """Tokenizes a piece of text into its word pieces.
228 | This uses a greedy longest-match-first algorithm to perform tokenization
229 | using the given vocabulary.
230 | For example:
231 | input = "unaffable"
232 | output = ["un", "##aff", "##able"]
233 | Args:
234 | text: A single token or whitespace separated tokens. This should have
235 | already been passed through `BasicTokenizer.
236 | Returns:
237 | A list of wordpiece tokens.
238 | """
239 |
240 | text = convert_to_unicode(text)
241 |
242 | output_tokens = []
243 | for token in whitespace_tokenize(text):
244 | chars = list(token)
245 | if len(chars) > self.max_input_chars_per_word:
246 | output_tokens.append(self.unk_token)
247 | continue
248 |
249 | is_bad = False
250 | start = 0
251 | sub_tokens = []
252 | while start < len(chars):
253 | end = len(chars)
254 | cur_substr = None
255 | while start < end:
256 | substr = "".join(chars[start:end])
257 | if start > 0:
258 | substr = "##" + substr
259 | if substr in self.vocab:
260 | cur_substr = substr
261 | break
262 | end -= 1
263 | if cur_substr is None:
264 | is_bad = True
265 | break
266 | sub_tokens.append(cur_substr)
267 | start = end
268 |
269 | if is_bad:
270 | output_tokens.append(self.unk_token)
271 | else:
272 | output_tokens.extend(sub_tokens)
273 | return output_tokens
274 |
275 |
276 | def _is_whitespace(char):
277 | """Checks whether `chars` is a whitespace character."""
278 | # \t, \n, and \r are technically contorl characters but we treat them
279 | # as whitespace since they are generally considered as such.
280 | if char == " " or char == "\t" or char == "\n" or char == "\r":
281 | return True
282 | cat = unicodedata.category(char)
283 | if cat == "Zs":
284 | return True
285 | return False
286 |
287 |
288 | def _is_control(char):
289 | """Checks whether `chars` is a control character."""
290 | # These are technically control characters but we count them as whitespace
291 | # characters.
292 | if char == "\t" or char == "\n" or char == "\r":
293 | return False
294 | cat = unicodedata.category(char)
295 | if cat.startswith("C"):
296 | return True
297 | return False
298 |
299 |
300 | def _is_punctuation(char):
301 | """Checks whether `chars` is a punctuation character."""
302 | cp = ord(char)
303 | # We treat all non-letter/number ASCII as punctuation.
304 | # Characters such as "^", "$", and "`" are not in the Unicode
305 | # Punctuation class but we treat them as punctuation anyways, for
306 | # consistency.
307 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
308 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
309 | return True
310 | cat = unicodedata.category(char)
311 | if cat.startswith("P"):
312 | return True
313 |
314 | return False
315 |
--------------------------------------------------------------------------------
/code/gensim_preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
5 |
6 | """This module contains methods for parsing and preprocessing strings. Let's consider the most noticeable:
7 |
8 | * :func:`~gensim.parsing.preprocessing.remove_stopwords` - remove all stopwords from string
9 | * :func:`~gensim.parsing.preprocessing.preprocess_string` - preprocess string (in default NLP meaning)
10 |
11 | Examples:
12 | ---------
13 | .. sourcecode:: pycon
14 |
15 | >>> from gensim.parsing.preprocessing import remove_stopwords
16 | >>> remove_stopwords("Better late than never, but better never late.")
17 | u'Better late never, better late.'
18 | >>>
19 | >>> preprocess_string("Hel 9lo Wo9 rld! Th3 weather_is really g00d today, isn't it?")
20 | [u'hel', u'rld', u'weather', u'todai', u'isn']
21 |
22 |
23 | Data:
24 | -----
25 |
26 | .. data:: STOPWORDS - Set of stopwords from Stone, Denis, Kwantes (2010).
27 | .. data:: RE_PUNCT - Regexp for search an punctuation.
28 | .. data:: RE_TAGS - Regexp for search an tags.
29 | .. data:: RE_NUMERIC - Regexp for search an numbers.
30 | .. data:: RE_NONALPHA - Regexp for search an non-alphabetic character.
31 | .. data:: RE_AL_NUM - Regexp for search a position between letters and digits.
32 | .. data:: RE_NUM_AL - Regexp for search a position between digits and letters .
33 | .. data:: RE_WHITESPACE - Regexp for search space characters.
34 | .. data:: DEFAULT_FILTERS - List of function for string preprocessing.
35 |
36 | """
37 |
38 | import re
39 | import string
40 | import glob
41 |
42 | from gensim import utils
43 | from gensim.parsing.porter import PorterStemmer
44 |
45 |
46 | STOPWORDS = frozenset([
47 | 'all', 'six', 'just', 'less', 'being', 'indeed', 'over', 'move', 'anyway', 'four', 'not', 'own', 'through',
48 | 'using', 'fifty', 'where', 'mill', 'only', 'find', 'before', 'one', 'whose', 'system', 'how', 'somewhere',
49 | 'much', 'thick', 'show', 'had', 'enough', 'should', 'to', 'must', 'whom', 'seeming', 'yourselves', 'under',
50 | 'ours', 'two', 'has', 'might', 'thereafter', 'latterly', 'do', 'them', 'his', 'around', 'than', 'get', 'very',
51 | 'de', 'none', 'cannot', 'every', 'un', 'they', 'front', 'during', 'thus', 'now', 'him', 'nor', 'name', 'regarding',
52 | 'several', 'hereafter', 'did', 'always', 'who', 'didn', 'whither', 'this', 'someone', 'either', 'each', 'become',
53 | 'thereupon', 'sometime', 'side', 'towards', 'therein', 'twelve', 'because', 'often', 'ten', 'our', 'doing', 'km',
54 | 'eg', 'some', 'back', 'used', 'up', 'go', 'namely', 'computer', 'are', 'further', 'beyond', 'ourselves', 'yet',
55 | 'out', 'even', 'will', 'what', 'still', 'for', 'bottom', 'mine', 'since', 'please', 'forty', 'per', 'its',
56 | 'everything', 'behind', 'does', 'various', 'above', 'between', 'it', 'neither', 'seemed', 'ever', 'across', 'she',
57 | 'somehow', 'be', 'we', 'full', 'never', 'sixty', 'however', 'here', 'otherwise', 'were', 'whereupon', 'nowhere',
58 | 'although', 'found', 'alone', 're', 'along', 'quite', 'fifteen', 'by', 'both', 'about', 'last', 'would',
59 | 'anything', 'via', 'many', 'could', 'thence', 'put', 'against', 'keep', 'etc', 'amount', 'became', 'ltd', 'hence',
60 | 'onto', 'or', 'con', 'among', 'already', 'co', 'afterwards', 'formerly', 'within', 'seems', 'into', 'others',
61 | 'while', 'whatever', 'except', 'down', 'hers', 'everyone', 'done', 'least', 'another', 'whoever', 'moreover',
62 | 'couldnt', 'throughout', 'anyhow', 'yourself', 'three', 'from', 'her', 'few', 'together', 'top', 'there', 'due',
63 | 'been', 'next', 'anyone', 'eleven', 'cry', 'call', 'therefore', 'interest', 'then', 'thru', 'themselves',
64 | 'hundred', 'really', 'sincere', 'empty', 'more', 'himself', 'elsewhere', 'mostly', 'on', 'fire', 'am', 'becoming',
65 | 'hereby', 'amongst', 'else', 'part', 'everywhere', 'too', 'kg', 'herself', 'former', 'those', 'he', 'me', 'myself',
66 | 'made', 'twenty', 'these', 'was', 'bill', 'cant', 'us', 'until', 'besides', 'nevertheless', 'below', 'anywhere',
67 | 'nine', 'can', 'whether', 'of', 'your', 'toward', 'my', 'say', 'something', 'and', 'whereafter', 'whenever',
68 | 'give', 'almost', 'wherever', 'is', 'describe', 'beforehand', 'herein', 'doesn', 'an', 'as', 'itself', 'at',
69 | 'have', 'in', 'seem', 'whence', 'ie', 'any', 'fill', 'again', 'hasnt', 'inc', 'thereby', 'thin', 'no', 'perhaps',
70 | 'latter', 'meanwhile', 'when', 'detail', 'same', 'wherein', 'beside', 'also', 'that', 'other', 'take', 'which',
71 | 'becomes', 'you', 'if', 'nobody', 'unless', 'whereas', 'see', 'though', 'may', 'after', 'upon', 'most', 'hereupon',
72 | 'eight', 'but', 'serious', 'nothing', 'such', 'why', 'off', 'a', 'don', 'whereby', 'third', 'i', 'whole', 'noone',
73 | 'sometimes', 'well', 'amoungst', 'yours', 'their', 'rather', 'without', 'so', 'five', 'the', 'first', 'with',
74 | 'make', 'once'
75 | ])
76 |
77 |
78 | RE_PUNCT = re.compile(r'([%s])+' % re.escape(string.punctuation), re.UNICODE)
79 | RE_TAGS = re.compile(r"<([^>]+)>", re.UNICODE)
80 | RE_NUMERIC = re.compile(r"[0-9]+", re.UNICODE)
81 | RE_NONALPHA = re.compile(r"\W", re.UNICODE)
82 | RE_AL_NUM = re.compile(r"([a-z]+)([0-9]+)", flags=re.UNICODE)
83 | RE_NUM_AL = re.compile(r"([0-9]+)([a-z]+)", flags=re.UNICODE)
84 | RE_WHITESPACE = re.compile(r"(\s)+", re.UNICODE)
85 |
86 |
87 | def remove_stopwords(s):
88 | """Remove :const:`~gensim.parsing.preprocessing.STOPWORDS` from `s`.
89 |
90 | Parameters
91 | ----------
92 | s : str
93 |
94 | Returns
95 | -------
96 | str
97 | Unicode string without :const:`~gensim.parsing.preprocessing.STOPWORDS`.
98 |
99 | Examples
100 | --------
101 | .. sourcecode:: pycon
102 |
103 | >>> from gensim.parsing.preprocessing import remove_stopwords
104 | >>> remove_stopwords("Better late than never, but better never late.")
105 | u'Better late never, better late.'
106 |
107 | """
108 | s = utils.to_unicode(s)
109 | return " ".join(w for w in s.split() if w not in STOPWORDS)
110 |
111 |
112 | def strip_punctuation(s):
113 | """Replace punctuation characters with spaces in `s` using :const:`~gensim.parsing.preprocessing.RE_PUNCT`.
114 |
115 | Parameters
116 | ----------
117 | s : str
118 |
119 | Returns
120 | -------
121 | str
122 | Unicode string without punctuation characters.
123 |
124 | Examples
125 | --------
126 | .. sourcecode:: pycon
127 |
128 | >>> from gensim.parsing.preprocessing import strip_punctuation
129 | >>> strip_punctuation("A semicolon is a stronger break than a comma, but not as much as a full stop!")
130 | u'A semicolon is a stronger break than a comma but not as much as a full stop '
131 |
132 | """
133 | s = utils.to_unicode(s)
134 | return RE_PUNCT.sub(" ", s)
135 |
136 |
137 | strip_punctuation2 = strip_punctuation
138 |
139 |
140 | def strip_tags(s):
141 | """Remove tags from `s` using :const:`~gensim.parsing.preprocessing.RE_TAGS`.
142 |
143 | Parameters
144 | ----------
145 | s : str
146 |
147 | Returns
148 | -------
149 | str
150 | Unicode string without tags.
151 |
152 | Examples
153 | --------
154 | .. sourcecode:: pycon
155 |
156 | >>> from gensim.parsing.preprocessing import strip_tags
157 | >>> strip_tags("Hello World!")
158 | u'Hello World!'
159 |
160 | """
161 | s = utils.to_unicode(s)
162 | return RE_TAGS.sub("", s)
163 |
164 |
165 | def strip_short(s, minsize=3):
166 | """Remove words with length lesser than `minsize` from `s`.
167 |
168 | Parameters
169 | ----------
170 | s : str
171 | minsize : int, optional
172 |
173 | Returns
174 | -------
175 | str
176 | Unicode string without short words.
177 |
178 | Examples
179 | --------
180 | .. sourcecode:: pycon
181 |
182 | >>> from gensim.parsing.preprocessing import strip_short
183 | >>> strip_short("salut les amis du 59")
184 | u'salut les amis'
185 | >>>
186 | >>> strip_short("one two three four five six seven eight nine ten", minsize=5)
187 | u'three seven eight'
188 |
189 | """
190 | s = utils.to_unicode(s)
191 | return " ".join(e for e in s.split() if len(e) >= minsize)
192 |
193 |
194 | def strip_numeric(s):
195 | """Remove digits from `s` using :const:`~gensim.parsing.preprocessing.RE_NUMERIC`.
196 |
197 | Parameters
198 | ----------
199 | s : str
200 |
201 | Returns
202 | -------
203 | str
204 | Unicode string without digits.
205 |
206 | Examples
207 | --------
208 | .. sourcecode:: pycon
209 |
210 | >>> from gensim.parsing.preprocessing import strip_numeric
211 | >>> strip_numeric("0text24gensim365test")
212 | u'textgensimtest'
213 |
214 | """
215 | s = utils.to_unicode(s)
216 | return RE_NUMERIC.sub("", s)
217 |
218 |
219 | def strip_non_alphanum(s):
220 | """Remove non-alphabetic characters from `s` using :const:`~gensim.parsing.preprocessing.RE_NONALPHA`.
221 |
222 | Parameters
223 | ----------
224 | s : str
225 |
226 | Returns
227 | -------
228 | str
229 | Unicode string with alphabetic characters only.
230 |
231 | Notes
232 | -----
233 | Word characters - alphanumeric & underscore.
234 |
235 | Examples
236 | --------
237 | .. sourcecode:: pycon
238 |
239 | >>> from gensim.parsing.preprocessing import strip_non_alphanum
240 | >>> strip_non_alphanum("if-you#can%read$this&then@this#method^works")
241 | u'if you can read this then this method works'
242 |
243 | """
244 | s = utils.to_unicode(s)
245 | return RE_NONALPHA.sub(" ", s)
246 |
247 |
248 | def strip_multiple_whitespaces(s):
249 | r"""Remove repeating whitespace characters (spaces, tabs, line breaks) from `s`
250 | and turns tabs & line breaks into spaces using :const:`~gensim.parsing.preprocessing.RE_WHITESPACE`.
251 |
252 | Parameters
253 | ----------
254 | s : str
255 |
256 | Returns
257 | -------
258 | str
259 | Unicode string without repeating in a row whitespace characters.
260 |
261 | Examples
262 | --------
263 | .. sourcecode:: pycon
264 |
265 | >>> from gensim.parsing.preprocessing import strip_multiple_whitespaces
266 | >>> strip_multiple_whitespaces("salut" + '\r' + " les" + '\n' + " loulous!")
267 | u'salut les loulous!'
268 |
269 | """
270 | s = utils.to_unicode(s)
271 | return RE_WHITESPACE.sub(" ", s)
272 |
273 |
274 | def split_alphanum(s):
275 | """Add spaces between digits & letters in `s` using :const:`~gensim.parsing.preprocessing.RE_AL_NUM`.
276 |
277 | Parameters
278 | ----------
279 | s : str
280 |
281 | Returns
282 | -------
283 | str
284 | Unicode string with spaces between digits & letters.
285 |
286 | Examples
287 | --------
288 | .. sourcecode:: pycon
289 |
290 | >>> from gensim.parsing.preprocessing import split_alphanum
291 | >>> split_alphanum("24.0hours7 days365 a1b2c3")
292 | u'24.0 hours 7 days 365 a 1 b 2 c 3'
293 |
294 | """
295 | s = utils.to_unicode(s)
296 | s = RE_AL_NUM.sub(r"\1 \2", s)
297 | return RE_NUM_AL.sub(r"\1 \2", s)
298 |
299 |
300 | def stem_text(text):
301 | """Transform `s` into lowercase and stem it.
302 |
303 | Parameters
304 | ----------
305 | text : str
306 |
307 | Returns
308 | -------
309 | str
310 | Unicode lowercased and porter-stemmed version of string `text`.
311 |
312 | Examples
313 | --------
314 | .. sourcecode:: pycon
315 |
316 | >>> from gensim.parsing.preprocessing import stem_text
317 | >>> stem_text("While it is quite useful to be able to search a large collection of documents almost instantly.")
318 | u'while it is quit us to be abl to search a larg collect of document almost instantly.'
319 |
320 | """
321 | text = utils.to_unicode(text)
322 | p = PorterStemmer()
323 | return ' '.join(p.stem(word) for word in text.split())
324 |
325 |
326 | stem = stem_text
327 |
328 |
329 | DEFAULT_FILTERS = [
330 | lambda x: x.lower(), strip_tags, strip_punctuation,
331 | strip_multiple_whitespaces, strip_numeric,
332 | remove_stopwords, strip_short, stem_text
333 | ]
334 |
335 |
336 | def preprocess_string(s, filters=DEFAULT_FILTERS):
337 | """Apply list of chosen filters to `s`.
338 |
339 | Default list of filters:
340 |
341 | * :func:`~gensim.parsing.preprocessing.strip_tags`,
342 | * :func:`~gensim.parsing.preprocessing.strip_punctuation`,
343 | * :func:`~gensim.parsing.preprocessing.strip_multiple_whitespaces`,
344 | * :func:`~gensim.parsing.preprocessing.strip_numeric`,
345 | * :func:`~gensim.parsing.preprocessing.remove_stopwords`,
346 | * :func:`~gensim.parsing.preprocessing.strip_short`,
347 | * :func:`~gensim.parsing.preprocessing.stem_text`.
348 |
349 | Parameters
350 | ----------
351 | s : str
352 | filters: list of functions, optional
353 |
354 | Returns
355 | -------
356 | list of str
357 | Processed strings (cleaned).
358 |
359 | Examples
360 | --------
361 | .. sourcecode:: pycon
362 |
363 | >>> from gensim.parsing.preprocessing import preprocess_string
364 | >>> preprocess_string("Hel 9lo Wo9 rld! Th3 weather_is really g00d today, isn't it?")
365 | [u'hel', u'rld', u'weather', u'todai', u'isn']
366 | >>>
367 | >>> s = "Hel 9lo Wo9 rld! Th3 weather_is really g00d today, isn't it?"
368 | >>> CUSTOM_FILTERS = [lambda x: x.lower(), strip_tags, strip_punctuation]
369 | >>> preprocess_string(s, CUSTOM_FILTERS)
370 | [u'hel', u'9lo', u'wo9', u'rld', u'th3', u'weather', u'is', u'really', u'g00d', u'today', u'isn', u't', u'it']
371 |
372 | """
373 | s = utils.to_unicode(s)
374 | for f in filters:
375 | s = f(s)
376 | return s.split()
377 |
378 |
379 | def preprocess_documents(docs):
380 | """Apply :const:`~gensim.parsing.preprocessing.DEFAULT_FILTERS` to the documents strings.
381 |
382 | Parameters
383 | ----------
384 | docs : list of str
385 |
386 | Returns
387 | -------
388 | list of list of str
389 | Processed documents split by whitespace.
390 |
391 | Examples
392 | --------
393 |
394 | .. sourcecode:: pycon
395 |
396 | >>> from gensim.parsing.preprocessing import preprocess_documents
397 | >>> preprocess_documents(["Hel 9lo Wo9 rld!", "Th3 weather_is really g00d today, isn't it?"])
398 | [[u'hel', u'rld'], [u'weather', u'todai', u'isn']]
399 |
400 | """
401 | return [preprocess_string(d) for d in docs]
402 |
403 |
404 | def read_file(path):
405 | with utils.smart_open(path) as fin:
406 | return fin.read()
407 |
408 |
409 | def read_files(pattern):
410 | return [read_file(fname) for fname in glob.glob(pattern)]
411 |
--------------------------------------------------------------------------------
/code/bert_model.py:
--------------------------------------------------------------------------------
1 | """PyTorch BERT model."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import copy
8 | import json
9 | import math
10 | import six
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn import CrossEntropyLoss
14 |
15 | def gelu(x):
16 | """Implementation of the gelu activation function.
17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19 | """
20 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
21 |
22 |
23 | class BertConfig(object):
24 | """Configuration class to store the configuration of a `BertModel`.
25 | """
26 | def __init__(self,
27 | vocab_size,
28 | hidden_size=768,
29 | num_hidden_layers=12,
30 | num_attention_heads=12,
31 | intermediate_size=3072,
32 | hidden_act="gelu",
33 | hidden_dropout_prob=0.1,
34 | attention_probs_dropout_prob=0.1,
35 | max_position_embeddings=512,
36 | type_vocab_size=16,
37 | initializer_range=0.02):
38 | """Constructs BertConfig.
39 | Args:
40 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
41 | hidden_size: Size of the encoder layers and the pooler layer.
42 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
43 | num_attention_heads: Number of attention heads for each attention layer in
44 | the Transformer encoder.
45 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
46 | layer in the Transformer encoder.
47 | hidden_act: The non-linear activation function (function or string) in the
48 | encoder and pooler.
49 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
50 | layers in the embeddings, encoder, and pooler.
51 | attention_probs_dropout_prob: The dropout ratio for the attention
52 | probabilities.
53 | max_position_embeddings: The maximum sequence length that this model might
54 | ever be used with. Typically set this to something large just in case
55 | (e.g., 512 or 1024 or 2048).
56 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
57 | `BertModel`.
58 | initializer_range: The sttdev of the truncated_normal_initializer for
59 | initializing all weight matrices.
60 | """
61 | self.vocab_size = vocab_size
62 | self.hidden_size = hidden_size
63 | self.num_hidden_layers = num_hidden_layers
64 | self.num_attention_heads = num_attention_heads
65 | self.hidden_act = hidden_act
66 | self.intermediate_size = intermediate_size
67 | self.hidden_dropout_prob = hidden_dropout_prob
68 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
69 | self.max_position_embeddings = max_position_embeddings
70 | self.type_vocab_size = type_vocab_size
71 | self.initializer_range = initializer_range
72 |
73 | @classmethod
74 | def from_dict(cls, json_object):
75 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
76 | config = BertConfig(vocab_size=None)
77 | for (key, value) in six.iteritems(json_object):
78 | config.__dict__[key] = value
79 | return config
80 |
81 | @classmethod
82 | def from_json_file(cls, json_file):
83 | """Constructs a `BertConfig` from a json file of parameters."""
84 | with open(json_file, "r") as reader:
85 | text = reader.read()
86 | return cls.from_dict(json.loads(text))
87 |
88 | def to_dict(self):
89 | """Serializes this instance to a Python dictionary."""
90 | output = copy.deepcopy(self.__dict__)
91 | return output
92 |
93 | def to_json_string(self):
94 | """Serializes this instance to a JSON string."""
95 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
96 |
97 |
98 | class BERTLayerNorm(nn.Module):
99 | def __init__(self, config, variance_epsilon=1e-12):
100 | """Construct a layernorm module in the TF style (epsilon inside the square root).
101 | """
102 | super(BERTLayerNorm, self).__init__()
103 | self.gamma = nn.Parameter(torch.ones(config.hidden_size))
104 | self.beta = nn.Parameter(torch.zeros(config.hidden_size))
105 | self.variance_epsilon = variance_epsilon
106 |
107 | def forward(self, x):
108 | u = x.mean(-1, keepdim=True)
109 | s = (x - u).pow(2).mean(-1, keepdim=True)
110 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
111 | return self.gamma * x + self.beta
112 |
113 | class BERTEmbeddings(nn.Module):
114 | def __init__(self, config):
115 | super(BERTEmbeddings, self).__init__()
116 | """Construct the embedding module from word, position and token_type embeddings.
117 | """
118 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
119 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
120 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
121 |
122 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
123 | # any TensorFlow checkpoint file
124 | self.LayerNorm = BERTLayerNorm(config)
125 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
126 |
127 | def forward(self, input_ids, token_type_ids=None):
128 | seq_length = input_ids.size(1)
129 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
130 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
131 | if token_type_ids is None:
132 | token_type_ids = torch.zeros_like(input_ids)
133 |
134 | words_embeddings = self.word_embeddings(input_ids)
135 | position_embeddings = self.position_embeddings(position_ids)
136 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
137 |
138 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
139 | embeddings = self.LayerNorm(embeddings)
140 | embeddings = self.dropout(embeddings)
141 | return embeddings
142 |
143 |
144 | class BERTSelfAttention(nn.Module):
145 | def __init__(self, config):
146 | super(BERTSelfAttention, self).__init__()
147 | if config.hidden_size % config.num_attention_heads != 0:
148 | raise ValueError(
149 | "The hidden size (%d) is not a multiple of the number of attention "
150 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
151 | self.num_attention_heads = config.num_attention_heads
152 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
153 | self.all_head_size = self.num_attention_heads * self.attention_head_size
154 |
155 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
156 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
157 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
158 |
159 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
160 |
161 | def transpose_for_scores(self, x):
162 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
163 | x = x.view(*new_x_shape)
164 | return x.permute(0, 2, 1, 3)
165 |
166 | def forward(self, hidden_states, attention_mask):
167 | mixed_query_layer = self.query(hidden_states)
168 | mixed_key_layer = self.key(hidden_states)
169 | mixed_value_layer = self.value(hidden_states)
170 |
171 | query_layer = self.transpose_for_scores(mixed_query_layer)
172 | key_layer = self.transpose_for_scores(mixed_key_layer)
173 | value_layer = self.transpose_for_scores(mixed_value_layer)
174 |
175 | # Take the dot product between "query" and "key" to get the raw attention scores.
176 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
177 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
178 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
179 | attention_scores = attention_scores + attention_mask
180 |
181 | # Normalize the attention scores to probabilities.
182 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
183 |
184 | # This is actually dropping out entire tokens to attend to, which might
185 | # seem a bit unusual, but is taken from the original Transformer paper.
186 | attention_probs = self.dropout(attention_probs)
187 |
188 | context_layer = torch.matmul(attention_probs, value_layer)
189 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
190 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
191 | context_layer = context_layer.view(*new_context_layer_shape)
192 | return context_layer
193 |
194 |
195 | class BERTSelfOutput(nn.Module):
196 | def __init__(self, config):
197 | super(BERTSelfOutput, self).__init__()
198 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
199 | self.LayerNorm = BERTLayerNorm(config)
200 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
201 |
202 | def forward(self, hidden_states, input_tensor):
203 | hidden_states = self.dense(hidden_states)
204 | hidden_states = self.dropout(hidden_states)
205 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
206 | return hidden_states
207 |
208 |
209 | class BERTAttention(nn.Module):
210 | def __init__(self, config):
211 | super(BERTAttention, self).__init__()
212 | self.self = BERTSelfAttention(config)
213 | self.output = BERTSelfOutput(config)
214 |
215 | def forward(self, input_tensor, attention_mask):
216 | self_output = self.self(input_tensor, attention_mask)
217 | attention_output = self.output(self_output, input_tensor)
218 | return attention_output
219 |
220 |
221 | class BERTIntermediate(nn.Module):
222 | def __init__(self, config):
223 | super(BERTIntermediate, self).__init__()
224 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
225 | self.intermediate_act_fn = gelu
226 |
227 | def forward(self, hidden_states):
228 | return self.intermediate_act_fn(self.dense(hidden_states))
229 |
230 | class BERTOutput(nn.Module):
231 | def __init__(self, config):
232 | super(BERTOutput, self).__init__()
233 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
234 | self.LayerNorm = BERTLayerNorm(config)
235 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
236 |
237 | def forward(self, hidden_states, input_tensor):
238 | hidden_states = self.dense(hidden_states)
239 | hidden_states = self.dropout(hidden_states)
240 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
241 | return hidden_states
242 |
243 |
244 | class BERTLayer(nn.Module):
245 | def __init__(self, config):
246 | super(BERTLayer, self).__init__()
247 | self.attention = BERTAttention(config)
248 | self.intermediate = BERTIntermediate(config)
249 | self.output = BERTOutput(config)
250 |
251 | def forward(self, hidden_states, attention_mask):
252 | attention_output = self.attention(hidden_states, attention_mask)
253 | intermediate_output = self.intermediate(attention_output)
254 | layer_output = self.output(intermediate_output, attention_output)
255 | return layer_output
256 |
257 |
258 | class BERTEncoder(nn.Module):
259 | def __init__(self, config):
260 | super(BERTEncoder, self).__init__()
261 | layer = BERTLayer(config)
262 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
263 |
264 | def forward(self, hidden_states, attention_mask):
265 | for layer_module in self.layer:
266 | hidden_states = layer_module(hidden_states, attention_mask)
267 | return hidden_states
268 |
269 |
270 | class BERTPooler(nn.Module):
271 | def __init__(self, config):
272 | super(BERTPooler, self).__init__()
273 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
274 | self.activation = nn.Tanh()
275 |
276 | def forward(self, hidden_states):
277 | # We "pool" the model by simply taking the hidden state corresponding
278 | # to the first token.
279 | first_token_tensor = hidden_states[:, 0]
280 | pooled_output = self.dense(first_token_tensor)
281 | pooled_output = self.activation(pooled_output)
282 | return pooled_output
283 |
284 |
285 | class BertModel(nn.Module):
286 | """BERT model ("Bidirectional Embedding Representations from a Transformer").
287 | Example usage:
288 | ```python
289 | # Already been converted into WordPiece token ids
290 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
291 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
292 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
293 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
294 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
295 | model = modeling.BertModel(config=config)
296 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
297 | ```
298 | """
299 | def __init__(self, config: BertConfig):
300 | """Constructor for BertModel.
301 | Args:
302 | config: `BertConfig` instance.
303 | """
304 | super(BertModel, self).__init__()
305 | self.embeddings = BERTEmbeddings(config)
306 | self.encoder = BERTEncoder(config)
307 | self.pooler = BERTPooler(config)
308 |
309 | def forward(self, input_ids, token_type_ids=None, attention_mask=None):
310 | if attention_mask is None:
311 | attention_mask = torch.ones_like(input_ids)
312 | if token_type_ids is None:
313 | token_type_ids = torch.zeros_like(input_ids)
314 |
315 | # We create a 3D attention mask from a 2D tensor mask.
316 | # Sizes are [batch_size, 1, 1, from_seq_length]
317 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
318 | # this attention mask is more simple than the triangular masking of causal attention
319 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
320 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
321 |
322 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
323 | # masked positions, this operation will create a tensor which is 0.0 for
324 | # positions we want to attend and -10000.0 for masked positions.
325 | # Since we are adding it to the raw scores before the softmax, this is
326 | # effectively the same as removing these entirely.
327 | extended_attention_mask = extended_attention_mask.float()
328 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
329 |
330 | embedding_output = self.embeddings(input_ids, token_type_ids)
331 | hidden_states = self.encoder(embedding_output, extended_attention_mask)
332 | pooled_output = self.pooler(hidden_states)
333 | return pooled_output
334 |
335 |
336 | class BertEdgeScorer(nn.Module):
337 | """BERT model for computing sentence similarity and scoring edges.
338 |
339 | """
340 | def __init__(self, config):
341 | super(BertEdgeScorer, self).__init__()
342 | self.bert = BertModel(config)
343 |
344 | def forward(self, input_ids, token_type_ids, attention_mask, input_ids_c, token_type_ids_c, attention_mask_c, labels=None):
345 | #nputs_id: B * T
346 |
347 | pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
348 | pooled_output_c = self.bert(input_ids_c, token_type_ids_c, attention_mask_c)
349 | logits = torch.bmm(pooled_output.unsqueeze(1), pooled_output_c.unsqueeze(2)).view(-1)
350 | pros = torch.sigmoid(logits)
351 |
352 | return logits, pros
353 |
--------------------------------------------------------------------------------