├── .gitignore ├── LICENSE ├── README.md ├── bojone_snippets.py ├── bojone_tokenizers.py ├── configuration ├── __init__.py └── config.py ├── opt.py ├── train_and_eval.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | *.iml 4 | *.egg-info 5 | *.pyc 6 | *.csv 7 | *.json 8 | *.pt 9 | data 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 wakafengfan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simcse-pytorch 2 | 3 | 最近出圈的无监督语义表示模型simcse,依然是基于苏神的keras版本改造的pytorch版本, 4 | 先占坑待后续补充更多实验,并补充Danqi女神的pytorch版本在中文上效果 5 | 6 | 目前仅实验了roberta-wwm在LCQMC上无监督训练效果,评测指标是Spearman correlation 7 | 8 | 9 | | Model | correlation score | 10 | | -------------------------- | ----------------- | 11 | | `roberta-wwm` | 0.67029 | 12 | | dropout_rate=0.1 | | 13 | | learning_rate=1e-5 | | 14 | | pooling: first-last-avg | | 15 | | `roberta-wwm(no training)` | 0.60377 | 16 | | pooling: first-last-avg | | 17 | 18 | 19 | ### 环境 20 | 21 | - python==3.6.* 22 | - pytorch==1.8 23 | - transformers==4.4.2 24 | 25 | ### 参考 26 | SimCSE: Simple Contrastive Learning of Sentence Embeddings https://arxiv.org/pdf/2104.08821.pdf 27 | https://kexue.fm/archives/8348 -------------------------------------------------------------------------------- /bojone_snippets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def sequence_padding(inputs, length=None, padding=0, mode='post', with_mask=False): 5 | """Numpy函数,将序列padding到同一长度 6 | """ 7 | if length is None: 8 | length = max([len(x) for x in inputs]) 9 | 10 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 11 | outputs = [] 12 | output_masks = [] 13 | for x in inputs: 14 | x = x[:length] 15 | if mode == 'post': 16 | pad_width[0] = (0, length - len(x)) 17 | elif mode == 'pre': 18 | pad_width[0] = (length - len(x), 0) 19 | else: 20 | raise ValueError('"mode" argument must be "post" or "pre".') 21 | m = np.pad([1]*len(x), pad_width, 'constant', constant_values=padding) 22 | output_masks.append(m) 23 | x = np.pad(x, pad_width, 'constant', constant_values=padding) 24 | outputs.append(x) 25 | 26 | if with_mask: 27 | return np.array(outputs), np.array(output_masks) 28 | 29 | return np.array(outputs) 30 | 31 | 32 | def text_segmentate(text, maxlen, seps='\n', strips=None): 33 | """将文本按照标点符号划分为若干个短句 34 | """ 35 | text = text.strip().strip(strips) 36 | if seps and len(text) > maxlen: 37 | pieces = text.split(seps[0]) 38 | text, texts = '', [] 39 | for i, p in enumerate(pieces): 40 | if text and p and len(text) + len(p) > maxlen - 1: 41 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 42 | text = '' 43 | if i + 1 == len(pieces): 44 | text = text + p 45 | else: 46 | text = text + p + seps[0] 47 | if text: 48 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 49 | return texts 50 | else: 51 | return [text] 52 | 53 | 54 | class DataGenerator(): 55 | def __init__(self, data, batch_size=32, buffer_size=None, **kwargs): 56 | self.data = data 57 | self.batch_size = batch_size 58 | if hasattr(self.data, '__len__'): 59 | self.steps = len(self.data) // self.batch_size 60 | if len(self.data) % self.batch_size != 0: 61 | self.steps += 1 62 | else: 63 | self.steps = None 64 | self.buffer_size = buffer_size or batch_size * 1000 65 | self.kwargs = kwargs 66 | 67 | def __len__(self): 68 | return self.steps 69 | 70 | def sample(self, random=False): 71 | """采样函数,每个样本同时返回一个is_end标记 72 | """ 73 | if random: 74 | if self.steps is None: 75 | 76 | def generator(): 77 | caches, isfull = [], False 78 | for d in self.data: 79 | caches.append(d) 80 | if isfull: 81 | i = np.random.randint(len(caches)) 82 | yield caches.pop(i) 83 | elif len(caches) == self.buffer_size: 84 | isfull = True 85 | while caches: 86 | i = np.random.randint(len(caches)) 87 | yield caches.pop(i) 88 | 89 | else: 90 | 91 | def generator(): 92 | indices = list(range(len(self.data))) 93 | np.random.shuffle(indices) 94 | for i in indices: 95 | yield self.data[i] 96 | 97 | data = generator() 98 | else: 99 | data = iter(self.data) 100 | 101 | d_current = next(data) 102 | for d_next in data: 103 | yield False, d_current 104 | d_current = d_next 105 | 106 | yield True, d_current 107 | 108 | def __iter__(self, random=False): 109 | raise NotImplementedError 110 | 111 | 112 | def softmax(x, axis=-1): 113 | """numpy版softmax 114 | """ 115 | x = x - x.max(axis=axis, keepdims=True) 116 | x = np.exp(x) 117 | return x / x.sum(axis=axis, keepdims=True) 118 | 119 | 120 | class AutoRegressiveDecoder(object): 121 | """通用自回归生成模型解码基类 122 | 包含beam search和random sample两种策略 123 | """ 124 | def __init__(self, start_id, end_id, maxlen, minlen=1): 125 | self.start_id = start_id 126 | self.end_id = end_id 127 | self.maxlen = maxlen 128 | self.minlen = minlen 129 | self.models = {} 130 | if start_id is None: 131 | self.first_output_ids = np.empty((1, 0), dtype=int) 132 | else: 133 | self.first_output_ids = np.array([[self.start_id]]) 134 | 135 | @staticmethod 136 | def wraps(default_rtype='probas', use_states=False): 137 | """用来进一步完善predict函数 138 | 目前包含:1. 设置rtype参数,并做相应处理; 139 | 2. 确定states的使用,并做相应处理; 140 | 3. 设置温度参数,并做相应处理。 141 | """ 142 | def actual_decorator(predict): 143 | def new_predict( 144 | self, 145 | inputs, 146 | output_ids, 147 | states, 148 | temperature=1, 149 | rtype=default_rtype 150 | ): 151 | assert rtype in ['probas', 'logits'] 152 | prediction = predict(self, inputs, output_ids, states) 153 | 154 | if not use_states: 155 | prediction = (prediction, None) 156 | 157 | if default_rtype == 'logits': 158 | prediction = ( 159 | softmax(prediction[0] / temperature), prediction[1] 160 | ) 161 | elif temperature != 1: 162 | probas = np.power(prediction[0], 1.0 / temperature) 163 | probas = probas / probas.sum(axis=-1, keepdims=True) 164 | prediction = (probas, prediction[1]) 165 | 166 | if rtype == 'probas': 167 | return prediction 168 | else: 169 | return np.log(prediction[0] + 1e-12), prediction[1] 170 | 171 | return new_predict 172 | 173 | return actual_decorator 174 | 175 | def last_token(self, model): 176 | """创建一个只返回最后一个token输出的新Model 177 | """ 178 | if model not in self.models: 179 | outputs = [ 180 | keras.layers.Lambda(lambda x: x[:, -1])(output) 181 | for output in model.outputs 182 | ] 183 | self.models[model] = keras.models.Model(model.inputs, outputs) 184 | 185 | return self.models[model] 186 | 187 | def predict(self, inputs, output_ids, states=None): 188 | """用户需自定义递归预测函数 189 | 说明:定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states, 190 | 其中default_rtype为字符串logits或probas,probas时返回归一化的概率, 191 | rtype=logits时则返回softmax前的结果或者概率对数。 192 | 返回:二元组 (得分或概率, states) 193 | """ 194 | raise NotImplementedError 195 | 196 | def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1): 197 | """beam search解码 198 | 说明:这里的topk即beam size; 199 | 返回:最优解码序列。 200 | """ 201 | inputs = [np.array([i]) for i in inputs] 202 | output_ids, output_scores = self.first_output_ids, np.zeros(1) 203 | for step in range(self.maxlen): 204 | scores, states = self.predict( 205 | inputs, output_ids, states, temperature, 'logits' 206 | ) # 计算当前得分 207 | if step == 0: # 第1步预测后将输入重复topk次 208 | inputs = [np.repeat(i, topk, axis=0) for i in inputs] 209 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 210 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk 211 | indices_1 = indices // scores.shape[1] # 行索引 212 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 213 | output_ids = np.concatenate([output_ids[indices_1], indices_2], 214 | 1) # 更新输出 215 | output_scores = np.take_along_axis( 216 | scores, indices, axis=None 217 | ) # 更新得分 218 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 219 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 220 | best_one = output_scores.argmax() # 得分最大的那个 221 | if end_counts[best_one] == min_ends: # 如果已经终止 222 | return output_ids[best_one] # 直接输出 223 | else: # 否则,只保留未完成部分 224 | flag = (end_counts < min_ends) # 标记未完成序列 225 | if not flag.all(): # 如果有已完成的 226 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列 227 | output_ids = output_ids[flag] # 扔掉已完成序列 228 | output_scores = output_scores[flag] # 扔掉已完成序列 229 | end_counts = end_counts[flag] # 扔掉已完成end计数 230 | topk = flag.sum() # topk相应变化 231 | # 达到长度直接输出 232 | return output_ids[output_scores.argmax()] 233 | 234 | def random_sample( 235 | self, 236 | inputs, 237 | n, 238 | topk=None, 239 | topp=None, 240 | states=None, 241 | temperature=1, 242 | min_ends=1 243 | ): 244 | """随机采样n个结果 245 | 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp 246 | 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 247 | 返回:n个解码序列组成的list。 248 | """ 249 | inputs = [np.array([i]) for i in inputs] 250 | output_ids = self.first_output_ids 251 | results = [] 252 | for step in range(self.maxlen): 253 | probas, states = self.predict( 254 | inputs, output_ids, states, temperature, 'probas' 255 | ) # 计算当前概率 256 | probas /= probas.sum(axis=1, keepdims=True) # 确保归一化 257 | if step == 0: # 第1步预测后将结果重复n次 258 | probas = np.repeat(probas, n, axis=0) 259 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 260 | output_ids = np.repeat(output_ids, n, axis=0) 261 | if topk is not None: 262 | k_indices = probas.argpartition(-topk, 263 | axis=1)[:, -topk:] # 仅保留topk 264 | probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率 265 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 266 | if topp is not None: 267 | p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序 268 | probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率 269 | cumsum_probas = np.cumsum(probas, axis=1) # 累积概率 270 | flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分 271 | flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果 272 | probas[flag] = 0 # 后面的全部置零 273 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 274 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数 275 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 276 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 277 | if topp is not None: 278 | sample_ids = np.take_along_axis( 279 | p_indices, sample_ids, axis=1 280 | ) # 对齐原id 281 | if topk is not None: 282 | sample_ids = np.take_along_axis( 283 | k_indices, sample_ids, axis=1 284 | ) # 对齐原id 285 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 286 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 287 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 288 | flag = (end_counts == min_ends) # 标记已完成序列 289 | if flag.any(): # 如果有已完成的 290 | for ids in output_ids[flag]: # 存好已完成序列 291 | results.append(ids) 292 | flag = (flag == False) # 标记未完成序列 293 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 294 | output_ids = output_ids[flag] # 只保留未完成部分候选集 295 | end_counts = end_counts[flag] # 只保留未完成部分end计数 296 | if len(output_ids) == 0: 297 | break 298 | # 如果还有未完成序列,直接放入结果 299 | for ids in output_ids: 300 | results.append(ids) 301 | # 返回结果 302 | return results 303 | 304 | -------------------------------------------------------------------------------- /bojone_tokenizers.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | import numpy as np 3 | import re 4 | 5 | 6 | def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None): 7 | """从bert的词典文件中读取词典 8 | """ 9 | token_dict = {} 10 | with open(dict_path, encoding=encoding) as reader: 11 | for line in reader: 12 | token = line.split() 13 | token = token[0] if token else line.strip() 14 | token_dict[token] = len(token_dict) 15 | 16 | if simplified: # 过滤冗余部分token 17 | new_token_dict, keep_tokens = {}, [] 18 | startswith = startswith or [] 19 | for t in startswith: 20 | new_token_dict[t] = len(new_token_dict) 21 | keep_tokens.append(token_dict[t]) 22 | 23 | for t, _ in sorted(token_dict.items(), key=lambda s: s[1]): 24 | if t not in new_token_dict: 25 | keep = True 26 | if len(t) > 1: 27 | for c in Tokenizer.stem(t): 28 | if ( 29 | Tokenizer._is_cjk_character(c) or 30 | Tokenizer._is_punctuation(c) 31 | ): 32 | keep = False 33 | break 34 | if keep: 35 | new_token_dict[t] = len(new_token_dict) 36 | keep_tokens.append(token_dict[t]) 37 | 38 | return new_token_dict, keep_tokens 39 | else: 40 | return token_dict 41 | 42 | 43 | def save_vocab(dict_path, token_dict, encoding='utf-8'): 44 | """将词典(比如精简过的)保存为文件 45 | """ 46 | with open(dict_path, 'w', encoding=encoding) as writer: 47 | for k, v in sorted(token_dict.items(), key=lambda s: s[1]): 48 | writer.write(k + '\n') 49 | 50 | 51 | def truncate_sequences(maxlen, index, *sequences): 52 | """截断总长度至不超过maxlen 53 | """ 54 | sequences = [s for s in sequences if s] 55 | while True: 56 | lengths = [len(s) for s in sequences] 57 | if sum(lengths) > maxlen: 58 | i = np.argmax(lengths) 59 | sequences[i].pop(index) 60 | else: 61 | return sequences 62 | 63 | 64 | def convert_to_unicode(text, encoding='utf-8', errors='ignore'): 65 | """字符串转换为unicode格式(假设输入为utf-8格式) 66 | """ 67 | if isinstance(text, bytes): 68 | text = text.decode(encoding, errors=errors) 69 | return text 70 | 71 | 72 | def is_string(s): 73 | """判断是否是字符串 74 | """ 75 | return isinstance(s, str) 76 | 77 | 78 | class TokenizerBase(object): 79 | """分词器基类 80 | """ 81 | def __init__( 82 | self, 83 | token_start='[CLS]', 84 | token_end='[SEP]', 85 | pre_tokenize=None, 86 | token_translate=None 87 | ): 88 | """参数说明: 89 | pre_tokenize:外部传入的分词函数,用作对文本进行预分词。如果传入 90 | pre_tokenize,则先执行pre_tokenize(text),然后在它 91 | 的基础上执行原本的tokenize函数; 92 | token_translate:映射字典,主要用在tokenize之后,将某些特殊的token 93 | 替换为对应的token。 94 | """ 95 | self._token_pad = '[PAD]' 96 | self._token_unk = '[UNK]' 97 | self._token_mask = '[MASK]' 98 | self._token_start = token_start 99 | self._token_end = token_end 100 | self._pre_tokenize = pre_tokenize 101 | self._token_translate = token_translate or {} 102 | self._token_translate_inv = { 103 | v: k 104 | for k, v in self._token_translate.items() 105 | } 106 | 107 | def tokenize(self, text, maxlen=None): 108 | """分词函数 109 | """ 110 | tokens = [ 111 | self._token_translate.get(token) or token 112 | for token in self._tokenize(text) 113 | ] 114 | if self._token_start is not None: 115 | tokens.insert(0, self._token_start) 116 | if self._token_end is not None: 117 | tokens.append(self._token_end) 118 | 119 | if maxlen is not None: 120 | index = int(self._token_end is not None) + 1 121 | truncate_sequences(maxlen, -index, tokens) 122 | 123 | return tokens 124 | 125 | def token_to_id(self, token): 126 | """token转换为对应的id 127 | """ 128 | raise NotImplementedError 129 | 130 | def tokens_to_ids(self, tokens): 131 | """token序列转换为对应的id序列 132 | """ 133 | return [self.token_to_id(token) for token in tokens] 134 | 135 | def encode( 136 | self, first_text, second_text=None, maxlen=None, pattern='S*E*E' 137 | ): 138 | """输出文本对应token id和segment id 139 | """ 140 | if is_string(first_text): 141 | first_tokens = self.tokenize(first_text) 142 | else: 143 | first_tokens = first_text 144 | 145 | if second_text is None: 146 | second_tokens = None 147 | elif is_string(second_text): 148 | if pattern == 'S*E*E': 149 | idx = int(bool(self._token_start)) 150 | second_tokens = self.tokenize(second_text)[idx:] 151 | elif pattern == 'S*ES*E': 152 | second_tokens = self.tokenize(second_text) 153 | else: 154 | second_tokens = second_text 155 | 156 | if maxlen is not None: 157 | index = int(self._token_end is not None) + 1 158 | truncate_sequences(maxlen, -index, first_tokens, second_tokens) 159 | 160 | first_token_ids = self.tokens_to_ids(first_tokens) 161 | first_segment_ids = [0] * len(first_token_ids) 162 | 163 | if second_text is not None: 164 | second_token_ids = self.tokens_to_ids(second_tokens) 165 | second_segment_ids = [1] * len(second_token_ids) 166 | first_token_ids.extend(second_token_ids) 167 | first_segment_ids.extend(second_segment_ids) 168 | 169 | return first_token_ids, first_segment_ids 170 | 171 | def id_to_token(self, i): 172 | """id序列为对应的token 173 | """ 174 | raise NotImplementedError 175 | 176 | def ids_to_tokens(self, ids): 177 | """id序列转换为对应的token序列 178 | """ 179 | return [self.id_to_token(i) for i in ids] 180 | 181 | def decode(self, ids): 182 | """转为可读文本 183 | """ 184 | raise NotImplementedError 185 | 186 | def _tokenize(self, text): 187 | """基本分词函数 188 | """ 189 | raise NotImplementedError 190 | 191 | 192 | class Tokenizer(TokenizerBase): 193 | """Bert原生分词器 194 | 纯Python实现,代码修改自keras_bert的tokenizer实现 195 | """ 196 | def __init__( 197 | self, token_dict, do_lower_case=False, word_maxlen=200, **kwargs 198 | ): 199 | super(Tokenizer, self).__init__(**kwargs) 200 | if is_string(token_dict): 201 | token_dict = load_vocab(token_dict) 202 | 203 | self._do_lower_case = do_lower_case 204 | self._token_dict = token_dict 205 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 206 | self._vocab_size = len(token_dict) 207 | self._word_maxlen = word_maxlen 208 | 209 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 210 | try: 211 | _token_id = token_dict[getattr(self, '_token_%s' % token)] 212 | setattr(self, '_token_%s_id' % token, _token_id) 213 | except: 214 | pass 215 | 216 | def token_to_id(self, token): 217 | """token转换为对应的id 218 | """ 219 | return self._token_dict.get(token, self._token_unk_id) 220 | 221 | def id_to_token(self, i): 222 | """id转换为对应的token 223 | """ 224 | return self._token_dict_inv[i] 225 | 226 | def decode(self, ids, tokens=None): 227 | """转为可读文本 228 | """ 229 | tokens = tokens or self.ids_to_tokens(ids) 230 | tokens = [token for token in tokens if not self._is_special(token)] 231 | 232 | text, flag = '', False 233 | for i, token in enumerate(tokens): 234 | if token[:2] == '##': 235 | text += token[2:] 236 | elif len(token) == 1 and self._is_cjk_character(token): 237 | text += token 238 | elif len(token) == 1 and self._is_punctuation(token): 239 | text += token 240 | text += ' ' 241 | elif i > 0 and self._is_cjk_character(text[-1]): 242 | text += token 243 | else: 244 | text += ' ' 245 | text += token 246 | 247 | text = re.sub(' +', ' ', text) 248 | text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) 249 | punctuation = self._cjk_punctuation() + '+-/={(<[' 250 | punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) 251 | punctuation_regex = '(%s) ' % punctuation_regex 252 | text = re.sub(punctuation_regex, '\\1', text) 253 | text = re.sub('(\d\.) (\d)', '\\1\\2', text) 254 | 255 | return text.strip() 256 | 257 | def _tokenize(self, text, pre_tokenize=True): 258 | """基本分词函数 259 | """ 260 | if self._do_lower_case: 261 | text = text.lower() 262 | text = unicodedata.normalize('NFD', text) 263 | text = ''.join([ 264 | ch for ch in text if unicodedata.category(ch) != 'Mn' 265 | ]) 266 | 267 | if pre_tokenize and self._pre_tokenize is not None: 268 | tokens = [] 269 | for token in self._pre_tokenize(text): 270 | if token in self._token_dict: 271 | tokens.append(token) 272 | else: 273 | tokens.extend(self._tokenize(token, False)) 274 | return tokens 275 | 276 | spaced = '' 277 | for ch in text: 278 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 279 | spaced += ' ' + ch + ' ' 280 | elif self._is_space(ch): 281 | spaced += ' ' 282 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 283 | continue 284 | else: 285 | spaced += ch 286 | 287 | tokens = [] 288 | for word in spaced.strip().split(): 289 | tokens.extend(self._word_piece_tokenize(word)) 290 | 291 | return tokens 292 | 293 | def _word_piece_tokenize(self, word): 294 | """word内分成subword 295 | """ 296 | if len(word) > self._word_maxlen: 297 | return [word] 298 | 299 | tokens, start, end = [], 0, 0 300 | while start < len(word): 301 | end = len(word) 302 | while end > start: 303 | sub = word[start:end] 304 | if start > 0: 305 | sub = '##' + sub 306 | if sub in self._token_dict: 307 | break 308 | end -= 1 309 | if start == end: 310 | return [word] 311 | else: 312 | tokens.append(sub) 313 | start = end 314 | 315 | return tokens 316 | 317 | @staticmethod 318 | def stem(token): 319 | """获取token的“词干”(如果是##开头,则自动去掉##) 320 | """ 321 | if token[:2] == '##': 322 | return token[2:] 323 | else: 324 | return token 325 | 326 | @staticmethod 327 | def _is_space(ch): 328 | """空格类字符判断 329 | """ 330 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 331 | unicodedata.category(ch) == 'Zs' 332 | 333 | @staticmethod 334 | def _is_punctuation(ch): 335 | """标点符号类字符判断(全/半角均在此内) 336 | 提醒:unicodedata.category这个函数在py2和py3下的 337 | 表现可能不一样,比如u'§'字符,在py2下的结果为'So', 338 | 在py3下的结果是'Po'。 339 | """ 340 | code = ord(ch) 341 | return 33 <= code <= 47 or \ 342 | 58 <= code <= 64 or \ 343 | 91 <= code <= 96 or \ 344 | 123 <= code <= 126 or \ 345 | unicodedata.category(ch).startswith('P') 346 | 347 | @staticmethod 348 | def _cjk_punctuation(): 349 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' 350 | 351 | @staticmethod 352 | def _is_cjk_character(ch): 353 | """CJK类字符判断(包括中文字符也在此列) 354 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 355 | """ 356 | code = ord(ch) 357 | return 0x4E00 <= code <= 0x9FFF or \ 358 | 0x3400 <= code <= 0x4DBF or \ 359 | 0x20000 <= code <= 0x2A6DF or \ 360 | 0x2A700 <= code <= 0x2B73F or \ 361 | 0x2B740 <= code <= 0x2B81F or \ 362 | 0x2B820 <= code <= 0x2CEAF or \ 363 | 0xF900 <= code <= 0xFAFF or \ 364 | 0x2F800 <= code <= 0x2FA1F 365 | 366 | @staticmethod 367 | def _is_control(ch): 368 | """控制类字符判断 369 | """ 370 | return unicodedata.category(ch) in ('Cc', 'Cf') 371 | 372 | @staticmethod 373 | def _is_special(ch): 374 | """判断是不是有特殊含义的符号 375 | """ 376 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 377 | 378 | def rematch(self, text, tokens): 379 | """给出原始的text和tokenize后的tokens的映射关系 380 | """ 381 | if self._do_lower_case: 382 | text = text.lower() 383 | 384 | normalized_text, char_mapping = '', [] 385 | for i, ch in enumerate(text): 386 | if self._do_lower_case: 387 | ch = unicodedata.normalize('NFD', ch) 388 | ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn']) 389 | ch = ''.join([ 390 | c for c in ch 391 | if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) 392 | ]) 393 | normalized_text += ch 394 | char_mapping.extend([i] * len(ch)) 395 | 396 | text, token_mapping, offset = normalized_text, [], 0 397 | for token in tokens: 398 | if self._is_special(token): 399 | token_mapping.append([]) 400 | else: 401 | token = self.stem(token) 402 | start = text[offset:].index(token) + offset 403 | end = start + len(token) 404 | token_mapping.append(char_mapping[start:end]) 405 | offset = end 406 | 407 | return token_mapping 408 | 409 | 410 | class SpTokenizer(TokenizerBase): 411 | """基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。 412 | """ 413 | def __init__(self, sp_model_path, **kwargs): 414 | super(SpTokenizer, self).__init__(**kwargs) 415 | import sentencepiece as spm 416 | self.sp_model = spm.SentencePieceProcessor() 417 | self.sp_model.Load(sp_model_path) 418 | self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id()) 419 | self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id()) 420 | self._vocab_size = self.sp_model.get_piece_size() 421 | 422 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 423 | try: 424 | _token = getattr(self, '_token_%s' % token) 425 | _token_id = self.sp_model.piece_to_id(_token) 426 | setattr(self, '_token_%s_id' % token, _token_id) 427 | except: 428 | pass 429 | 430 | def token_to_id(self, token): 431 | """token转换为对应的id 432 | """ 433 | return self.sp_model.piece_to_id(token) 434 | 435 | def id_to_token(self, i): 436 | """id转换为对应的token 437 | """ 438 | if i < self._vocab_size: 439 | return self.sp_model.id_to_piece(i) 440 | else: 441 | return '' 442 | 443 | def decode(self, ids): 444 | """转为可读文本 445 | """ 446 | tokens = [ 447 | self._token_translate_inv.get(token) or token 448 | for token in self.ids_to_tokens(ids) 449 | ] 450 | text = self.sp_model.decode_pieces(tokens) 451 | return convert_to_unicode(text) 452 | 453 | def _tokenize(self, text): 454 | """基本分词函数 455 | """ 456 | if self._pre_tokenize is not None: 457 | text = ' '.join(self._pre_tokenize(text)) 458 | 459 | tokens = self.sp_model.encode_as_pieces(text) 460 | return tokens 461 | 462 | def _is_special(self, i): 463 | """判断是不是有特殊含义的符号 464 | """ 465 | return self.sp_model.is_control(i) or \ 466 | self.sp_model.is_unknown(i) or \ 467 | self.sp_model.is_unused(i) 468 | 469 | def _is_decodable(self, i): 470 | """判断是否应该被解码输出 471 | """ 472 | return (i < self._vocab_size) and not self._is_special(i) -------------------------------------------------------------------------------- /configuration/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /configuration/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | import json 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | 10 | 11 | ROOT_PATH = os.path.normpath(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..")) 12 | 13 | # data 14 | data_dir = Path(ROOT_PATH)/ "data" 15 | model_dir = Path(ROOT_PATH) / "model" 16 | 17 | bert_data_path = Path.home() / 'db__pytorch_pretrained_bert' 18 | bert_vocab_path = bert_data_path / 'bert-base-chinese' / 'vocab.txt' 19 | bert_model_path = bert_data_path / 'bert-base-chinese' 20 | uer_bert_base_model_path = bert_data_path / 'uer-bert-base' 21 | uer_bert_large_model_path = bert_data_path / 'uer-bert-large' 22 | 23 | tencent_w2v_path = Path.home() / 'db__word2vec' 24 | 25 | roberta_large_model_path = bert_data_path / 'chinese_Roberta_bert_wwm_large_ext_pytorch' 26 | 27 | bert_insurance_path = bert_data_path / 'bert_insurance_v2' 28 | 29 | bert_wwm_path = bert_data_path / "chinese_wwm_ext_L-12_H-768_A-12" 30 | bert_wwm_pt_path = bert_data_path / "chinese_wwm_ext_pytorch" 31 | robert_wwm_pt_path = bert_data_path / "chinese_roberta_wwm_ext_pytorch" 32 | 33 | mt5_pt_path = bert_data_path / "mt5_small_pt" 34 | nezha_pt_path = bert_data_path / "nezha-cn-base" 35 | 36 | simbert_path = bert_data_path / "chinese_simbert_L-12_H-768_A-12" 37 | simbert_pt_path = bert_data_path / "chinese_simbert_pt" 38 | 39 | common_data_path = Path.home() / 'db__common_dataset' 40 | open_dataset_path = common_data_path / "open_dataset" 41 | 42 | 43 | 44 | 45 | ############################################### 46 | # log 47 | ############################################### 48 | 49 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 50 | datefmt='%m/%d/%y %H:%M:%S', 51 | level=logging.INFO) 52 | logger = logging.getLogger(__name__) 53 | 54 | logger.info(f'begin progress ...') 55 | 56 | 57 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.adamw import AdamW 5 | from transformers import get_scheduler 6 | 7 | 8 | def create_optimizer_and_scheduler(model, lr, 9 | num_training_steps, 10 | weight_decay=0.0, 11 | warmup_ratio=0.0, 12 | warmup_steps=0, 13 | lr_scheduler_type="linear"): 14 | decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm]) 15 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 16 | optimizer_grouped_parameters = [ 17 | { 18 | "params": [p for n, p in model.named_parameters() if n in decay_parameters], 19 | "weight_decay": weight_decay, 20 | }, 21 | { 22 | "params": [p for n, p in model.named_parameters() if n not in decay_parameters], 23 | "weight_decay": 0.0, 24 | }, 25 | ] 26 | 27 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 28 | 29 | warmup_steps = warmup_steps if warmup_steps > 0 else math.ceil(num_training_steps * warmup_ratio) 30 | 31 | lr_scheduler = get_scheduler( 32 | lr_scheduler_type, 33 | optimizer, 34 | num_warmup_steps=warmup_steps, 35 | num_training_steps=num_training_steps, 36 | ) 37 | 38 | return optimizer, lr_scheduler 39 | 40 | 41 | 42 | 43 | 44 | def get_parameter_names(model, forbidden_layer_types): 45 | """ 46 | Returns the names of the model parameters that are not inside a forbidden layer. 47 | """ 48 | result = [] 49 | for name, child in model.named_children(): 50 | result += [ 51 | f"{name}.{n}" 52 | for n in get_parameter_names(child, forbidden_layer_types) 53 | if not isinstance(child, tuple(forbidden_layer_types)) 54 | ] 55 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 56 | result += list(model._parameters.keys()) 57 | return result 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /train_and_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import BertPreTrainedModel, BertModel, AutoConfig 5 | 6 | from bojone_snippets import DataGenerator, sequence_padding 7 | from bojone_tokenizers import Tokenizer 8 | from configuration.config import * 9 | from opt import create_optimizer_and_scheduler 10 | from utils import l2_normalize, compute_corrcoef 11 | 12 | batch_size = 64 13 | maxlen = 64 14 | task_name = "LCQMC" 15 | epochs = 1 16 | gradient_accumulation_steps = 1 17 | 18 | 19 | # 加载数据 20 | def load_data(data_path): 21 | D = [] 22 | for line in data_path.open(): 23 | text1, text2, label = line.strip().split("\t") 24 | D.append((text1, text2, float(label))) 25 | return D 26 | 27 | 28 | # 加载分词器 29 | dict_path = str(robert_wwm_pt_path / "vocab.txt") 30 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 31 | 32 | 33 | class data_generator(DataGenerator): 34 | """训练语料生成器 35 | """ 36 | def __iter__(self, random=False): 37 | batch_token_ids, batch_segment_ids = [], [] 38 | for is_end, text, in self.sample(random): 39 | token_ids, _ = tokenizer.encode(text, maxlen=maxlen) 40 | batch_token_ids.append(token_ids) 41 | if "mode" in self.kwargs and self.kwargs["mode"] == "train": 42 | batch_token_ids.append(token_ids) 43 | batch_segment_ids.append([1] * len(token_ids)) 44 | batch_segment_ids.append([1] * len(token_ids)) 45 | 46 | if len(batch_token_ids) == self.batch_size * 2 or is_end: 47 | batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long) 48 | batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long) 49 | yield batch_token_ids, batch_segment_ids 50 | batch_token_ids, batch_segment_ids = [], [] 51 | 52 | 53 | class EncodingModel(BertPreTrainedModel): 54 | def __init__(self, config): 55 | super(EncodingModel, self).__init__(config) 56 | self.bert = BertModel(config) 57 | 58 | def forward(self, input_ids, attention_mask, encoder_type="fist-last-avg"): 59 | """ 60 | 61 | :param input_ids: 62 | :param attention_mask: 63 | :param encoder_type: "first-last-avg", "last-avg", "cls", "pooler(cls + dense)" 64 | :return: 65 | """ 66 | 67 | output = self.bert(input_ids, attention_mask, output_hidden_states=True) 68 | 69 | if encoder_type == "fist-last-avg": 70 | first = output.hidden_states[1] # hidden_states列表有13个hidden_state,第一个其实是embeddings,第二个元素才是第一层的hidden_state 71 | last = output.hidden_states[-1] 72 | seq_length = first.size(1) 73 | first_avg = torch.avg_pool1d(first.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # [b,d] 74 | last_avg = torch.avg_pool1d(last.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # [b,d] 75 | final_encoding = torch.avg_pool1d(torch.cat([first_avg.unsqueeze(1), last_avg.unsqueeze(1)], dim=1).transpose(1,2), kernel_size=2).squeeze(-1) 76 | return final_encoding 77 | 78 | if encoder_type == "last-avg": 79 | sequence_output = output.last_hidden_state # [b,s,d] 80 | seq_length = sequence_output.size(1) 81 | final_encoding = torch.avg_pool1d(sequence_output.transpose(1,2), kernel_size=seq_length).squeeze(-1) # [b,d] 82 | return final_encoding 83 | 84 | if encoder_type == "cls": 85 | sequence_output = output.last_hidden_state 86 | cls = sequence_output[:, 0] # [b,d] 87 | return cls 88 | 89 | if encoder_type == "pooler": 90 | pooler_output = output.pooler_output # [b,d] 91 | return pooler_output 92 | 93 | 94 | def convert_to_ids(data): 95 | """转换文本数据为id形式 96 | """ 97 | a_token_ids, b_token_ids, labels = [], [], [] 98 | for d in tqdm(data): 99 | token_ids = tokenizer.encode(d[0], maxlen=maxlen)[0] 100 | a_token_ids.append(token_ids) 101 | token_ids = tokenizer.encode(d[1], maxlen=maxlen)[0] 102 | b_token_ids.append(token_ids) 103 | labels.append(d[2]) 104 | a_token_ids = sequence_padding(a_token_ids) 105 | b_token_ids = sequence_padding(b_token_ids) 106 | return a_token_ids, b_token_ids, labels 107 | 108 | 109 | def split_data(dat): 110 | a_texts, b_texts, labels = [],[],[], 111 | for d in tqdm(dat): 112 | a_texts.append(d[0]) 113 | b_texts.append(d[1]) 114 | labels.append(d[2]) 115 | return a_texts, b_texts, labels 116 | 117 | 118 | datasets = {fn: load_data(open_dataset_path / task_name / f"{fn}.tsv") for fn in ["train", "dev", "test"]} 119 | all_weights, all_texts, all_labels = [], [], [] 120 | train_texts = [] 121 | for name, data in datasets.items(): 122 | a_texts, b_texts, labels = split_data(data) 123 | all_weights.append(len(data)) 124 | all_texts.append((a_texts, b_texts)) 125 | all_labels.append(labels) 126 | 127 | train_texts.extend(a_texts) 128 | train_texts.extend(b_texts) 129 | 130 | np.random.shuffle(train_texts) 131 | train_texts = train_texts[:10000] 132 | train_generator = data_generator(train_texts, batch_size, mode="train") 133 | 134 | 135 | # 计算loss 136 | loss_func = nn.BCEWithLogitsLoss() 137 | def simcse_loss(y_pred): 138 | """用于SimCSE训练的loss 139 | """ 140 | # 构造标签 141 | idxs = torch.arange(0, y_pred.size(0)) # [b] 142 | 143 | idxs_1 = idxs[None, :] # [1,b] 144 | idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None] # [b,1] 145 | y_true = idxs_1 == idxs_2 146 | y_true = y_true.to(torch.float).to(device) 147 | # 计算相似度 148 | y_pred = F.normalize(y_pred, dim=1, p=2) 149 | similarities = torch.matmul(y_pred, y_pred.transpose(0,1)) # [b,d] * [b.d] -> [b,1] 150 | similarities = similarities - torch.eye(y_pred.size(0)).to(device) * 1e12 151 | similarities = similarities * 20 152 | loss = loss_func(similarities, y_true) 153 | return loss 154 | 155 | 156 | # 加载模型 157 | config_path = robert_wwm_pt_path / "bert_config.json" 158 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path=config_path, hidden_dropout_prob=0.1) 159 | model = EncodingModel.from_pretrained(robert_wwm_pt_path, config=config) 160 | 161 | optimizer, scheduler = create_optimizer_and_scheduler(model=model, lr=1e-5, num_training_steps=train_generator.steps * epochs // gradient_accumulation_steps) 162 | 163 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 164 | model.to(device) 165 | 166 | # train 167 | model.zero_grad() 168 | for e in range(epochs): 169 | model.train() 170 | for step, batch in enumerate(train_generator): 171 | # if step > 1: break 172 | batch = [_.to(device) for _ in batch] 173 | input_ids, seg_ids = batch 174 | encoding_output = model(input_ids, seg_ids) 175 | 176 | loss = simcse_loss(encoding_output) 177 | loss.backward() 178 | 179 | if step % gradient_accumulation_steps == 0 and step != 0: 180 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) 181 | optimizer.step() 182 | optimizer.zero_grad() 183 | 184 | if step % 100 == 0 and step != 0: 185 | print(f"epoch: {e} - batch: {step}/{train_generator.steps} - loss: {loss}") 186 | 187 | model.eval() 188 | 189 | # 语料向量化 190 | all_vecs = [] 191 | for a_texts, b_texts in all_texts: 192 | a_text_generator = data_generator(a_texts, batch_size, mode="eval") 193 | b_text_generator = data_generator(b_texts, batch_size, mode="eval") 194 | 195 | all_a_vecs = [] 196 | for eval_batch in tqdm(a_text_generator): 197 | eval_batch = [_.to(device) for _ in eval_batch] 198 | with torch.no_grad(): 199 | eval_encodings = model(*eval_batch) 200 | eval_encodings = eval_encodings.cpu().detach().numpy() 201 | all_a_vecs.extend(eval_encodings) 202 | 203 | all_b_vecs = [] 204 | for eval_batch in tqdm(b_text_generator): 205 | eval_batch = [_.to(device) for _ in eval_batch] 206 | with torch.no_grad(): 207 | eval_encodings = model(*eval_batch) 208 | eval_encodings = eval_encodings.cpu().detach().numpy() 209 | all_b_vecs.extend(eval_encodings) 210 | 211 | all_vecs.append((np.array(all_a_vecs), np.array(all_b_vecs))) 212 | 213 | 214 | # 标准化,相似度,相关系数 215 | all_corrcoefs = [] 216 | for (a_vecs, b_vecs), labels in zip(all_vecs, all_labels): 217 | a_vecs = l2_normalize(a_vecs) 218 | b_vecs = l2_normalize(b_vecs) 219 | sims = (a_vecs * b_vecs).sum(axis=1) 220 | corrcoef = compute_corrcoef(labels, sims) 221 | all_corrcoefs.append(corrcoef) 222 | 223 | all_corrcoefs.extend([ 224 | np.average(all_corrcoefs), 225 | np.average(all_corrcoefs, weights=all_weights) 226 | ]) 227 | 228 | print(all_corrcoefs) 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | 4 | 5 | 6 | def l2_normalize(vecs): 7 | """标准化 8 | """ 9 | norms = (vecs**2).sum(axis=1, keepdims=True)**0.5 10 | return vecs / np.clip(norms, 1e-8, np.inf) 11 | 12 | 13 | def compute_corrcoef(x, y): 14 | """Spearman相关系数 15 | """ 16 | return scipy.stats.spearmanr(x, y).correlation --------------------------------------------------------------------------------