├── requirements.txt ├── README.md ├── run_chinese_ref.py ├── Test.ipynb ├── TrainExample.ipynb ├── MoELayer.py ├── LICENSE └── CustomBertModel.py /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter>=7.2.0 2 | pytorch>=2.4.0 3 | numpy>=2.1.1 4 | transformers>=4.45.0 5 | pypinyin>=0.53.0 6 | pkuseg==0.0.25 7 | ltp>=4.2.14 8 | datasets>=3.4.1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zh-CN-Multi-Mask-BERT (CNMBERT🍋) 2 | ~~吃柠檬Bert~~ 3 | Official repository of paper "CNMBERT: A Model for Converting Hanyu Pinyin Abbreviations to Chinese Characters" published on IJCNN 2025 4 | ![image](https://github.com/user-attachments/assets/a888fde7-6766-43f1-a753-810399418bda) 5 | 6 | --- 7 | 8 | 一个用来翻译拼音缩写/汉字谐音的模型 9 | 10 | 此模型基于[Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)训练而来,通过修改其预训练任务来使其适配拼音缩写/汉字谐音翻译任务,相较于微调过的GPT模型以及GPT-4o达到了sota 11 | 12 | --- 13 | 14 | ## 什么是拼音缩写/汉字谐音 15 | 16 | 形如: 17 | 18 | > "bhys" -> "不好意思" 19 | > 20 | > "kb" -> "看病" 21 | > 22 | > "将军是一 **支柱** " -> "将军是一 **只猪** " 23 | > 24 | > "想 **紫砂** 了" -> "想 **自杀** 了" 25 | 26 | 27 | 如果对拼音缩写感兴趣可以看看这个↓ 28 | 29 | [大家为什么会讨厌缩写? - 远方青木的回答 - 知乎](https://www.zhihu.com/question/269016377/answer/2654824753) 30 | 31 | ### CNMBERT 32 | 33 | | Model | 模型权重 | Memory Usage (FP16) | Model Size | QPS | MRR | Acc | 34 | | --------------- | ----------------------------------------------------------- | ------------------- | ---------- | ----- | ----- | ----- | 35 | | CNMBERT-Default* | [Huggingface](https://huggingface.co/Midsummra/CNMBert) | 0.4GB | 131M | 12.56 | 59.70 | 49.74 | 36 | | CNMBERT-MoE | [Huggingface](https://huggingface.co/Midsummra/CNMBert-MoE) | 0.8GB | 329M | 3.20 | 61.53 | 51.86 | 37 | 38 | * 所有模型均在相同的200万条wiki,知乎以及b站评论语料下训练 39 | * B站评论语料:[仓库](https://github.com/IgarashiAkatuki/BilibiliDatasets) 40 | * QPS 为 queries per second 41 | * MRR 为平均倒数排名(mean reciprocal rank) 42 | * Acc 为准确率(accuracy) 43 | * CNMBERT-Default 存在[量化版本](https://huggingface.co/mradermacher/CNMBert-GGUF) 44 | 45 | 模型架构&性能对比: 46 | ![overall (1)](https://github.com/user-attachments/assets/cf9575c4-c37d-484b-8a3b-f8f536ca78c9) 47 | ![output](https://github.com/user-attachments/assets/3de2b56d-f8cb-40f1-8ffa-68968bbd2ed5) 48 | 49 | 50 | ### Usage 51 | 52 | ```python 53 | from transformers import AutoTokenizer, BertConfig 54 | 55 | from CustomBertModel import predict, word_level_predict 56 | from MoELayer import BertWwmMoE 57 | ``` 58 | 59 | 加载模型 60 | 61 | ```python 62 | # use CNMBert with MoE 63 | # To use CNMBert without MoE, replace all "Midsummra/CNMBert-MoE" with "Midsummra/CNMBert" and use BertForMaskedLM instead of using BertWwmMoE 64 | tokenizer = AutoTokenizer.from_pretrained("Midsummra/CNMBert-MoE") 65 | config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE') 66 | model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda') 67 | 68 | # model = BertForMaskedLM.from_pretrained('Midsummra/CNMBert').to('cuda') 69 | ``` 70 | 71 | 预测词语 72 | 73 | ```python 74 | print(word_level_predict("将军是一支柱", "支柱", model, tokenizer)[:5]) 75 | print(predict("我有两千kq", "kq", model, tokenizer)[:5]) 76 | print(predict("快去给魔理沙看b吧", "b", model, tokenizer)[:5]) 77 | ``` 78 | > ['只猪', 0.013427094615127833, 1.0], ['支主', 0.012690062437477466, 1.0], ['支州', 0.012477088056586812, 0.9230769230769231], ['支战', 0.01260267308151233, 0.7692307692307692], ['侄子', 0.012531780478518316, 0.7272727272727273] 79 | 80 | > ['块钱', 1.2056937473156175], ['块前', 0.05837443749364857], ['开千', 0.0483869208528063], ['可千', 0.03996622172280445], ['口气', 0.037183335575008414] 81 | 82 | > ['病', 1.6893256306648254], ['吧', 0.1642467901110649], ['呗', 0.026976384222507477], ['包', 0.021441461518406868], ['报', 0.01396679226309061] 83 | 84 | --- 85 | 86 | ```python 87 | # 默认的predict函数使用束搜索 88 | def predict(sentence: str, 89 | predict_word: str, 90 | model, 91 | tokenizer, 92 | top_k=10, 93 | beam_size=24, # 束宽 94 | threshold=0.005, # 阈值 95 | fast_mode=True, # 是否使用快速模式 96 | strict_mode=True): # 是否对输出结果进行检查 97 | 98 | # 使用回溯的无剪枝暴力搜索 99 | def backtrack_predict(sentence: str, 100 | predict_word: str, 101 | model, 102 | tokenizer, 103 | top_k=10, 104 | fast_mode=True, 105 | strict_mode=True): 106 | 107 | # 如果要翻译汉字谐音,则使用word_level_predict 108 | def word_level_predict(sentence: str, 109 | predict_word: str, 110 | model, 111 | tokenizer, 112 | top_k=10, 113 | beam_size=24, # 束宽 114 | threshold=0.005, # 阈值 115 | fast_mode=True, # 是否使用快速模式 116 | strict_mode=True): # 是否对输出结果进行检查并使用Levenshtein Distance进行排序 117 | ``` 118 | 119 | > 由于BERT的自编码特性,导致其在预测MASK时,顺序不同会导致预测结果不同,如果启用`fast_mode`,则会正向和反向分别对输入进行预测,可以提升一点准确率(2%左右),但是会带来更大的性能开销。 120 | 121 | > `strict_mode`会对输入进行检查,以判断其是否为一个真实存在的汉语词汇。 122 | 123 | ### 如何微调模型 124 | 125 | 请参考[TrainExample.ipynb](https://github.com/IgarashiAkatuki/CNMBert/blob/main/TrainExample.ipynb),在数据集的格式上,只要保证csv的第一列为要训练的语料即可。 126 | > 感觉冻结其他层只训练embedding也可以(? 有空会试一下的 127 | 128 | ### Q&A 129 | 130 | Q: 感觉这个东西准确度有点低啊 131 | 132 | A: 可以尝试设置`fast_mode`和`strict_mode`为`False`。 模型是在很小的数据集(200w)上进行的预训练,所以泛化能力不足很正常,,,可以在更大数据集或者更加细分的领域进行微调,具体微调方式和[Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)差别不大,只需要将`DataCollactor`替换为`CustomBertModel.py`中的`DataCollatorForMultiMask`。 133 | 134 | Q: 不能直接检测句子中所存在的拼音缩写或汉字谐音进行翻译吗? 135 | 136 | A: 正在做,对于拼音缩写来说,模型检测会很容易将其与句中的英文单词进行误判,导致准确率很低。对于汉字谐音来说,有些句子中的谐音,比如`你木琴没了`,这句话是不存在语病的,模型很难检测出`木琴`是`母亲`的谐音。 137 | 138 | ### 引用 139 | 如果您对CNMBERT的具体实现感兴趣的话,可以参考 140 | ``` 141 | @misc{feng2024cnmbertmodelhanyupinyin, 142 | title={CNMBert: A Model For Hanyu Pinyin Abbreviation to Character Conversion Task}, 143 | author={Zishuo Feng and Feng Cao}, 144 | year={2024}, 145 | eprint={2411.11770}, 146 | archivePrefix={arXiv}, 147 | primaryClass={cs.CL}, 148 | url={https://arxiv.org/abs/2411.11770}, 149 | } 150 | ``` 151 | 152 | -------------------------------------------------------------------------------- /run_chinese_ref.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import List 4 | 5 | from ltp import LTP 6 | 7 | from transformers.models.bert.tokenization_bert import BertTokenizer 8 | 9 | 10 | def _is_chinese_char(cp): 11 | """Checks whether CP is the codepoint of a CJK character.""" 12 | # This defines a "chinese character" as anything in the CJK Unicode block: 13 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 14 | # 15 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 16 | # despite its name. The modern Korean Hangul alphabet is a different block, 17 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 18 | # space-separated words, so they are not treated specially and handled 19 | # like the all of the other languages. 20 | if ( 21 | (cp >= 0x4E00 and cp <= 0x9FFF) 22 | or (cp >= 0x3400 and cp <= 0x4DBF) # 23 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 24 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 25 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 26 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 27 | or (cp >= 0xF900 and cp <= 0xFAFF) 28 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 29 | ): # 30 | return True 31 | 32 | return False 33 | 34 | 35 | def is_chinese(word: str): 36 | # word like '180' or '身高' or '神' 37 | for char in word: 38 | char = ord(char) 39 | if not _is_chinese_char(char): 40 | return 0 41 | return 1 42 | 43 | 44 | def get_chinese_word(tokens: List[str]): 45 | word_set = set() 46 | 47 | for token in tokens: 48 | chinese_word = len(token) > 1 and is_chinese(token) 49 | if chinese_word: 50 | word_set.add(token) 51 | word_list = list(word_set) 52 | return word_list 53 | 54 | 55 | def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()): 56 | if not chinese_word_set: 57 | return bert_tokens 58 | max_word_len = max([len(w) for w in chinese_word_set]) 59 | 60 | bert_word = bert_tokens 61 | start, end = 0, len(bert_word) 62 | while start < end: 63 | single_word = True 64 | if is_chinese(bert_word[start]): 65 | l = min(end - start, max_word_len) 66 | for i in range(l, 1, -1): 67 | whole_word = "".join(bert_word[start : start + i]) 68 | if whole_word in chinese_word_set: 69 | for j in range(start + 1, start + i): 70 | bert_word[j] = "##" + bert_word[j] 71 | start = start + i 72 | single_word = False 73 | break 74 | if single_word: 75 | start += 1 76 | return bert_word 77 | 78 | 79 | def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer): 80 | ltp_res = [] 81 | 82 | for i in range(0, len(lines), 100): 83 | res = ltp_tokenizer.pipeline(lines[i : i + 100], tasks=["cws"]).cws 84 | res = [get_chinese_word(r) for r in res] 85 | ltp_res.extend(res) 86 | assert len(ltp_res) == len(lines) 87 | 88 | bert_res = [] 89 | for i in range(0, len(lines), 100): 90 | res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512) 91 | bert_res.extend(res["input_ids"]) 92 | assert len(bert_res) == len(lines) 93 | 94 | ref_ids = [] 95 | for input_ids, chinese_word in zip(bert_res, ltp_res): 96 | input_tokens = [] 97 | for id in input_ids: 98 | token = bert_tokenizer._convert_id_to_token(id) 99 | input_tokens.append(token) 100 | input_tokens = add_sub_symbol(input_tokens, chinese_word) 101 | ref_id = [] 102 | # We only save pos of chinese subwords start with ##, which mean is part of a whole word. 103 | for i, token in enumerate(input_tokens): 104 | if token[:2] == "##": 105 | clean_token = token[2:] 106 | # save chinese tokens' pos 107 | if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)): 108 | ref_id.append(i) 109 | ref_ids.append(ref_id) 110 | 111 | assert len(ref_ids) == len(bert_res) 112 | 113 | return ref_ids 114 | 115 | 116 | def main(args): 117 | # For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm) 118 | # If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp) 119 | with open(args.file_name, "r", encoding="utf-8") as f: 120 | data = f.readlines() 121 | data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029' 122 | ltp_tokenizer = LTP(args.ltp) # faster in GPU device 123 | bert_tokenizer = BertTokenizer.from_pretrained(args.bert) 124 | 125 | ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer) 126 | 127 | with open(args.save_path, "w", encoding="utf-8") as f: 128 | data = [json.dumps(ref) + "\n" for ref in ref_ids] 129 | f.writelines(data) 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser(description="prepare_chinese_ref") 134 | parser.add_argument( 135 | "--file_name", 136 | required=False, 137 | type=str, 138 | default="./resources/chinese-demo.txt", 139 | help="file need process, same as training data in lm", 140 | ) 141 | parser.add_argument( 142 | "--ltp", 143 | required=False, 144 | type=str, 145 | default="./resources/ltp", 146 | help="resources for LTP tokenizer, usually a path", 147 | ) 148 | parser.add_argument( 149 | "--bert", 150 | required=False, 151 | type=str, 152 | default="./resources/robert", 153 | help="resources for Bert tokenizer", 154 | ) 155 | parser.add_argument( 156 | "--save_path", 157 | required=False, 158 | type=str, 159 | default="./resources/ref.txt", 160 | help="path to save res", 161 | ) 162 | 163 | args = parser.parse_args() 164 | main(args) 165 | -------------------------------------------------------------------------------- /Test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "id": "initial_id", 6 | "metadata": { 7 | "collapsed": true, 8 | "ExecuteTime": { 9 | "end_time": "2025-03-29T02:38:37.862817Z", 10 | "start_time": "2025-03-29T02:38:33.921228Z" 11 | } 12 | }, 13 | "source": [ 14 | "from transformers import AutoTokenizer, BertConfig, BertForMaskedLM\n", 15 | "\n", 16 | "from CustomBertModel import predict, backtrack_predict, word_level_predict\n", 17 | "from MoELayer import BertWwmMoE" 18 | ], 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "E:\\Environment\\Anaconda\\envs\\speech\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 25 | " from .autonotebook import tqdm as notebook_tqdm\n" 26 | ] 27 | } 28 | ], 29 | "execution_count": 1 30 | }, 31 | { 32 | "metadata": { 33 | "ExecuteTime": { 34 | "end_time": "2025-03-29T02:38:39.810910Z", 35 | "start_time": "2025-03-29T02:38:37.866817Z" 36 | } 37 | }, 38 | "cell_type": "code", 39 | "source": [ 40 | "# use CNMBert with MoE\n", 41 | "# if you want to use CNMBert without MoE, please change all \"Midsummra/CNMBert-MoE\" to \"Midsummra/CNMBert\" and use BertForMaskedLM instead of using BertWwmMoE\n", 42 | "tokenizer = AutoTokenizer.from_pretrained(\"Midsummra/CNMBert-MoE\")\n", 43 | "config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')\n", 44 | "model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')" 45 | ], 46 | "id": "5d9191a45178cd39", 47 | "outputs": [ 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", 53 | " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n", 54 | " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", 55 | " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n", 56 | "Some weights of the model checkpoint at Midsummra/CNMBert-MoE were not used when initializing BertWwmMoE: ['bert.encoder.layer.0.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.10.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.12.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.14.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.2.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.4.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.6.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.8.intermediate.dense.sparse_moe.bias']\n", 57 | "- This IS expected if you are initializing BertWwmMoE from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 58 | "- This IS NOT expected if you are initializing BertWwmMoE from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 59 | "Some weights of BertWwmMoE were not initialized from the model checkpoint at Midsummra/CNMBert-MoE and are newly initialized: ['bert.encoder.layer.0.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.10.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.12.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.14.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.2.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.4.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.6.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.8.intermediate.dense.sparse_moe.beta']\n", 60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 61 | ] 62 | } 63 | ], 64 | "execution_count": 2 65 | }, 66 | { 67 | "metadata": { 68 | "ExecuteTime": { 69 | "end_time": "2025-03-29T02:47:24.841036Z", 70 | "start_time": "2025-03-29T02:47:23.974884Z" 71 | } 72 | }, 73 | "cell_type": "code", 74 | "source": [ 75 | "print(word_level_predict(\"将军是一支柱\", \"支柱\", model, tokenizer, fast_mode=False, strict_mode=False)[:10])\n", 76 | "print(predict(\"快去给魔理沙看b吧\", \"b\", model, tokenizer)[:5])" 77 | ], 78 | "id": "5f8789a21534588c", 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "[['只猪', 0.013427094615127833, 1.0], ['支主', 0.012690062437477466, 1.0], ['支州', 0.012477088056586812, 0.9230769230769231], ['支战', 0.01260267308151233, 0.7692307692307692], ['侄子', 0.012531780478518316, 0.7272727272727273], ['支子', 0.012490831659057209, 0.7272727272727273], ['种主', 0.012726939569656787, 0.7142857142857143], ['支长', 0.012681785355615857, 0.7142857142857143], ['支中', 0.012518818992013665, 0.7142857142857143], ['直在', 0.013213451601798279, 0.6666666666666667]]\n", 85 | "[['病', 0.19057052400975555], ['吧', 0.09717965056220808], ['包', 0.08986218881784686], ['呗', 0.08982954684417713], ['报', 0.08949582422272462]]\n" 86 | ] 87 | } 88 | ], 89 | "execution_count": 14 90 | }, 91 | { 92 | "metadata": {}, 93 | "cell_type": "code", 94 | "outputs": [], 95 | "execution_count": null, 96 | "source": "", 97 | "id": "51e5db5ca1807903" 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 2 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython2", 116 | "version": "2.7.6" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 5 121 | } 122 | -------------------------------------------------------------------------------- /TrainExample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "id": "initial_id", 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "source": [ 10 | "from transformers import AutoTokenizer, BertConfig, TrainingArguments, Trainer\n", 11 | "from CustomBertModel import DataCollatorForMultiMask\n", 12 | "from MoELayer import BertWwmMoE\n", 13 | "from datasets import Dataset\n", 14 | "from ltp import LTP\n", 15 | "\n", 16 | "# https://github.com/huggingface/transformers/blob/main/examples/research_projects/mlm_wwm/run_chinese_ref.py\n", 17 | "from run_chinese_ref import prepare_ref\n", 18 | "\n", 19 | "import random\n", 20 | "import torch\n" 21 | ], 22 | "outputs": [], 23 | "execution_count": null 24 | }, 25 | { 26 | "metadata": {}, 27 | "cell_type": "code", 28 | "source": [ 29 | "random.seed(123)\n", 30 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 31 | "ltp = LTP().to(device=device)\n", 32 | "\n", 33 | "tokenizer = AutoTokenizer.from_pretrained(\"Midsummra/CNMBert-MoE\")\n", 34 | "config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')\n", 35 | "model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')" 36 | ], 37 | "id": "497437f079c218e0", 38 | "outputs": [], 39 | "execution_count": null 40 | }, 41 | { 42 | "metadata": {}, 43 | "cell_type": "code", 44 | "source": [ 45 | "# 数据预处理\n", 46 | "\n", 47 | "text = set()\n", 48 | "bilibili = set()\n", 49 | "with open('../webtext/train.csv', mode='r', encoding='utf-8') as file:\n", 50 | " line = file.readline()\n", 51 | " while True:\n", 52 | " if not line:\n", 53 | " break\n", 54 | " text.add(line)\n", 55 | " line = file.readline()\n", 56 | "with open('../webtext/bilibili.csv', mode='r', encoding='utf-8') as file:\n", 57 | " line = file.readline()\n", 58 | " while True:\n", 59 | " if not line:\n", 60 | " break\n", 61 | " bilibili.add(line)\n", 62 | " line = file.readline()\n", 63 | "\n", 64 | "text = [t.replace('\\n', '') for t in list(text)]\n", 65 | "bilibili = [t.replace('\\n', '') for t in list(bilibili)]\n", 66 | "random.shuffle(text)\n", 67 | "random.shuffle(bilibili)\n", 68 | "\n", 69 | "train_data = {'text': text[:750000] + bilibili[:750000]}\n", 70 | "eval_data = {'text': text[len(text) - 20000:] + bilibili[len(bilibili) - 20000:]}\n", 71 | "\n", 72 | "train_data = Dataset.from_dict(train_data)\n", 73 | "eval_data = Dataset.from_dict(eval_data)" 74 | ], 75 | "id": "577b4745dde4c3ad", 76 | "outputs": [], 77 | "execution_count": null 78 | }, 79 | { 80 | "metadata": {}, 81 | "cell_type": "code", 82 | "source": [ 83 | "def tokenize_func(dataset):\n", 84 | " tokens = tokenizer(dataset['text'],\n", 85 | " max_length=64,\n", 86 | " padding='max_length',\n", 87 | " truncation=True,\n", 88 | " return_tensors='pt'\n", 89 | " )\n", 90 | " ref = prepare_ref(dataset['text'], ltp, tokenizer)\n", 91 | " features = {'input_ids': tokens['input_ids'], 'chinese_ref': ref, 'attention_mask': tokens['attention_mask']}\n", 92 | " return features\n", 93 | "\n", 94 | "data_collator = DataCollatorForMultiMask(tokenizer,\n", 95 | " mlm_probability=0.15,\n", 96 | " mlm=True,\n", 97 | " pad_to_multiple_of=64)\n", 98 | "\n", 99 | "train_dataset = train_data.map(tokenize_func, batched=True, remove_columns=[\"text\"])\n", 100 | "eval_dataset = eval_data.map(tokenize_func, batched=True, remove_columns=[\"text\"])" 101 | ], 102 | "id": "5d68fd569a379e40", 103 | "outputs": [], 104 | "execution_count": null 105 | }, 106 | { 107 | "metadata": {}, 108 | "cell_type": "code", 109 | "source": [ 110 | "# 可选,只训练embeddings\n", 111 | "for name, param in model.named_parameters():\n", 112 | " if name.startswith('bert.embeddings.'):\n", 113 | " param.requires_grad = True\n", 114 | " else:\n", 115 | " param.requires_grad = False\n", 116 | " if param.requires_grad:\n", 117 | " print(name)" 118 | ], 119 | "id": "bcc91db91d15d0bb", 120 | "outputs": [], 121 | "execution_count": null 122 | }, 123 | { 124 | "metadata": {}, 125 | "cell_type": "code", 126 | "source": [ 127 | "# 训练\n", 128 | "\n", 129 | "torch.manual_seed(42)\n", 130 | "\n", 131 | "model = model.to(device)\n", 132 | "for name, param in model.named_parameters():\n", 133 | " if param.requires_grad:\n", 134 | " print(f\"Trainable layer: {name}\")\n", 135 | " param.data = param.data.contiguous()\n", 136 | "\n", 137 | "\n", 138 | "training_args = TrainingArguments(\n", 139 | " output_dir='./model/checkpoints/',\n", 140 | " num_train_epochs=20,\n", 141 | " per_device_train_batch_size=128,\n", 142 | " eval_strategy='steps',\n", 143 | " eval_steps=500,\n", 144 | " learning_rate=1e-5, #学习率建议给1e-5~2e-5\n", 145 | " weight_decay=1e-5,\n", 146 | " logging_dir='./model/logs/',\n", 147 | " logging_steps=100,\n", 148 | " logging_first_step=True,\n", 149 | " save_strategy='steps',\n", 150 | " save_steps=100,\n", 151 | " save_total_limit=4,\n", 152 | " max_grad_norm=1.0,\n", 153 | " warmup_ratio=1 / 20,\n", 154 | " disable_tqdm=True\n", 155 | ")\n", 156 | "\n", 157 | "trainer = Trainer(\n", 158 | " model=model,\n", 159 | " args=training_args,\n", 160 | " train_dataset=train_dataset,\n", 161 | " eval_dataset=eval_dataset,\n", 162 | " data_collator=data_collator,\n", 163 | ")\n" 164 | ], 165 | "id": "2c9bf9ebb808cf34", 166 | "outputs": [], 167 | "execution_count": null 168 | }, 169 | { 170 | "metadata": {}, 171 | "cell_type": "code", 172 | "source": [ 173 | "trainer.train()\n", 174 | "trainer.save_model('./model/cnmbert-ft')\n", 175 | "eval_results = trainer.evaluate()\n", 176 | "print(f\"Evaluation cnmbert-ft: {eval_results}\")" 177 | ], 178 | "id": "b4004f82f39462d9", 179 | "outputs": [], 180 | "execution_count": null 181 | }, 182 | { 183 | "metadata": {}, 184 | "cell_type": "code", 185 | "source": "", 186 | "id": "d62062c3aa4e90fa", 187 | "outputs": [], 188 | "execution_count": null 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 2 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython2", 207 | "version": "2.7.6" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 5 212 | } 213 | -------------------------------------------------------------------------------- /MoELayer.py: -------------------------------------------------------------------------------- 1 | from transformers import BertForMaskedLM 2 | from transformers import TrainerCallback 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | temp_indices = [] 7 | 8 | class Expert(nn.Module): 9 | def __init__(self, input_dim, hidden_dim, dropout=0.15): 10 | super(Expert, self).__init__() 11 | self.net = nn.Sequential( 12 | nn.Linear(input_dim, hidden_dim), 13 | nn.LeakyReLU(negative_slope=0.01), 14 | nn.Dropout(dropout), 15 | nn.Linear(hidden_dim, input_dim), 16 | ) 17 | 18 | def forward(self, x): 19 | return self.net(x) 20 | 21 | 22 | class DynamicRouter(nn.Module): 23 | def __init__(self, input_dim, 24 | num_experts=8, 25 | top_k=2, 26 | noise_std=0.1): 27 | super(DynamicRouter, self).__init__() 28 | self.top_k = top_k 29 | self.linear = nn.Linear(input_dim, num_experts) 30 | self.noise_std = noise_std 31 | def forward(self, x): 32 | logits = self.linear(x) 33 | noise = torch.randn_like(logits) * self.noise_std 34 | logits += noise 35 | top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) 36 | 37 | zeros = torch.full_like(logits, float('-inf')).to(logits.device) 38 | sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits) 39 | global temp_indices 40 | temp_indices.append([logits, top_k_indices]) 41 | router_output = F.softmax(sparse_logits, dim=-1) 42 | 43 | return router_output, top_k_indices 44 | 45 | 46 | class SparseMoE(nn.Module): 47 | def __init__(self, input_dim, output_dim, num_experts=8, top_k=2, dropout=0.15): 48 | super(SparseMoE, self).__init__() 49 | self.router = DynamicRouter(input_dim, num_experts, top_k) 50 | self.experts = nn.ModuleList([Expert(input_dim, output_dim, dropout) for _ in range(num_experts)]) 51 | self.shared_expert = Expert(input_dim, output_dim, dropout) 52 | self.top_k = top_k 53 | self.beta = nn.Parameter(torch.tensor(0.7), requires_grad=False) 54 | self.alpha = nn.Parameter(torch.tensor(0.3), requires_grad=True) 55 | 56 | def forward(self, x): 57 | # 1. 输入进入router得到两个输出 58 | gating_output, indices = self.router(x) 59 | # 2.初始化全零矩阵,后续叠加为最终结果 60 | final_output = torch.zeros_like(x) 61 | 62 | # 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平 63 | flat_x = x.view(-1, x.size(-1)) 64 | flat_gating_output = gating_output.view(-1, gating_output.size(-1)) 65 | 66 | # 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权 67 | for i, expert in enumerate(self.experts): 68 | # 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2 69 | expert_mask = (indices == i).any(dim=-1) 70 | # 5. 展平操作 71 | flat_mask = expert_mask.view(-1) 72 | # 如果当前专家是任意一个token的前top2 73 | if flat_mask.any(): 74 | # 6. 得到该专家对哪几个token起作用后,选取token的维度表示 75 | expert_input = flat_x[flat_mask] 76 | # 7. 将token输入expert得到输出 77 | expert_output = expert(expert_input) 78 | 79 | # 8. 计算当前专家对于有作用的token的权重分数 80 | gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1) 81 | # 9. 将expert输出乘上权重分数 82 | weighted_output = expert_output * gating_scores 83 | 84 | # 10. 循环进行做种的结果叠加 85 | final_output[expert_mask] += weighted_output.squeeze(1) 86 | 87 | weights = F.softmax(torch.stack([self.alpha, self.beta]), dim=0) 88 | a, b = weights[0], weights[1] 89 | final_output = self.shared_expert(x) * a + final_output * b 90 | # global temp_indices 91 | # temp_indices.append([final_output, indices]) 92 | return final_output 93 | 94 | 95 | class SparseMoEFFN(nn.Module): 96 | def __init__(self, config, num_experts=8, top_k=2, dropout=0.15): 97 | super(SparseMoEFFN, self).__init__() 98 | self.sparse_moe = SparseMoE(input_dim=768, 99 | output_dim=3072, 100 | num_experts=num_experts, 101 | top_k=top_k, 102 | dropout=dropout) 103 | 104 | def forward(self, x): 105 | return self.sparse_moe(x) 106 | 107 | 108 | class BertWwmMoE(BertForMaskedLM): 109 | def __init__(self, config, num_experts=8, top_k=2, dropout=0.05): 110 | super(BertWwmMoE, self).__init__(config) 111 | for index, layer in enumerate(self.bert.encoder.layer): 112 | if 8 <= index <= 15: 113 | if index % 2 == 1: 114 | continue 115 | moe_ffn = SparseMoEFFN(config=config, 116 | num_experts=8, 117 | top_k=2, 118 | dropout=dropout) 119 | 120 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net): 121 | if index == 0: 122 | expert.weight.data = layer.intermediate.dense.weight.data.clone() 123 | expert.bias.data = layer.intermediate.dense.bias.data.clone() 124 | if index == 3: 125 | expert.weight.data = layer.output.dense.weight.data.clone() 126 | expert.bias.data = layer.output.dense.bias.data.clone() 127 | 128 | layer.intermediate.dense = moe_ffn 129 | layer.output.dense = nn.Identity() 130 | 131 | if 5 <= index <= 7: 132 | if index % 2 == 1: 133 | continue 134 | moe_ffn = SparseMoEFFN(config=config, 135 | num_experts=4, 136 | top_k=1, 137 | dropout=dropout) 138 | 139 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net): 140 | if index == 0: 141 | expert.weight.data = layer.intermediate.dense.weight.data.clone() 142 | expert.bias.data = layer.intermediate.dense.bias.data.clone() 143 | if index == 3: 144 | expert.weight.data = layer.output.dense.weight.data.clone() 145 | expert.bias.data = layer.output.dense.bias.data.clone() 146 | 147 | layer.intermediate.dense = moe_ffn 148 | layer.output.dense = nn.Identity() 149 | if 0 <= index <= 4: 150 | if index % 2 == 1: 151 | continue 152 | moe_ffn = SparseMoEFFN(config=config, 153 | num_experts=2, 154 | top_k=1, 155 | dropout=0.1) 156 | 157 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net): 158 | if index == 0: 159 | expert.weight.data = layer.intermediate.dense.weight.data.clone() 160 | expert.bias.data = layer.intermediate.dense.bias.data.clone() 161 | if index == 3: 162 | expert.weight.data = layer.output.dense.weight.data.clone() 163 | expert.bias.data = layer.output.dense.bias.data.clone() 164 | 165 | layer.intermediate.dense = moe_ffn 166 | layer.output.dense = nn.Identity() 167 | if 0 <= index <= 0: 168 | if index % 2 == 1: 169 | continue 170 | moe_ffn = SparseMoEFFN(config=config, 171 | num_experts=2, 172 | top_k=1, 173 | dropout=0.1) 174 | 175 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net): 176 | if index == 0: 177 | expert.weight.data = layer.intermediate.dense.weight.data.clone() 178 | expert.bias.data = layer.intermediate.dense.bias.data.clone() 179 | if index == 3: 180 | expert.weight.data = layer.output.dense.weight.data.clone() 181 | expert.bias.data = layer.output.dense.bias.data.clone() 182 | 183 | layer.intermediate.dense = moe_ffn 184 | layer.output.dense = nn.Identity() 185 | class EvaluationCallback(TrainerCallback): 186 | def __init__(self, model): 187 | super().__init__() 188 | self.model = model 189 | 190 | def on_evaluate(self, args, state, control, **kwargs): 191 | # 获取并打印指定层的权重 192 | for index, layer in enumerate(self.model.bert.encoder.layer): 193 | if 0 <= index < 6: 194 | print(layer.intermediate.dense.sparse_moe.alpha.data) 195 | print(layer.intermediate.dense.sparse_moe.beta.data) 196 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | -------------------------------------------------------------------------------- /CustomBertModel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import random 4 | from copy import deepcopy 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import re 9 | import copy 10 | import numpy as np 11 | import pkuseg 12 | 13 | from collections import defaultdict 14 | from typing import Union, Tuple, List, Mapping, Any, Dict 15 | 16 | from numpy.random import shuffle 17 | from transformers.data.data_collator import _torch_collate_batch, tolist, _numpy_collate_batch, \ 18 | DataCollatorForLanguageModeling 19 | from transformers import BertTokenizer, BertTokenizerFast 20 | from pypinyin import pinyin, Style 21 | 22 | 23 | class DataCollatorForMultiMask(DataCollatorForLanguageModeling): 24 | """ 25 | Data collator used for language modeling that masks entire words. 26 | 27 | - collates batches of tensors, honoring their tokenizer's pad_token 28 | - preprocesses batches for masked language modeling 29 | 30 | 31 | 32 | This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically 33 | that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will 34 | produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`]. 35 | 36 | """ 37 | 38 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 39 | if isinstance(examples[0], Mapping): 40 | input_ids = [e["input_ids"] for e in examples] 41 | else: 42 | input_ids = examples 43 | examples = [{"input_ids": e} for e in examples] 44 | 45 | batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 46 | 47 | mask_labels = [] 48 | for e in examples: 49 | ref_tokens = [] 50 | for id in tolist(e["input_ids"]): 51 | token = self.tokenizer._convert_id_to_token(id) 52 | if id == 0: 53 | continue 54 | ref_tokens.append(token) 55 | 56 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] 57 | if "chinese_ref" in e: 58 | ref_pos = tolist(e["chinese_ref"]) 59 | len_seq = len(e["input_ids"]) 60 | for i in range(len_seq): 61 | if i in ref_pos: 62 | ref_tokens[i] = "##" + ref_tokens[i] 63 | mask_labels.append(self._whole_word_mask(ref_tokens)) 64 | batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 65 | inputs, labels = self.torch_mask_tokens(batch_input, batch_mask) 66 | attention_mask = (inputs != self.tokenizer.pad_token_id).long() 67 | return {"input_ids": inputs, "labels": labels, 'attention_mask': attention_mask} 68 | 69 | def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 70 | if isinstance(examples[0], Mapping): 71 | input_ids = [e["input_ids"] for e in examples] 72 | else: 73 | input_ids = examples 74 | examples = [{"input_ids": e} for e in examples] 75 | 76 | batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 77 | 78 | mask_labels = [] 79 | for e in examples: 80 | ref_tokens = [] 81 | for id in tolist(e["input_ids"]): 82 | token = self.tokenizer._convert_id_to_token(id) 83 | ref_tokens.append(token) 84 | 85 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] 86 | if "chinese_ref" in e: 87 | ref_pos = tolist(e["chinese_ref"]) 88 | len_seq = len(e["input_ids"]) 89 | for i in range(len_seq): 90 | if i in ref_pos: 91 | ref_tokens[i] = "##" + ref_tokens[i] 92 | mask_labels.append(self._whole_word_mask(ref_tokens)) 93 | batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 94 | inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask) 95 | return {"input_ids": inputs, "labels": labels} 96 | 97 | def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): 98 | """ 99 | Get 0/1 labels for masked tokens with whole word mask proxy 100 | """ 101 | if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)): 102 | warnings.warn( 103 | "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. " 104 | "Please refer to the documentation for more information." 105 | ) 106 | 107 | cand_indexes = [] 108 | for i, token in enumerate(input_tokens): 109 | if token == "[CLS]" or token == "[SEP]": 110 | continue 111 | 112 | if len(cand_indexes) >= 1 and token.startswith("##"): 113 | cand_indexes[-1].append(i) 114 | else: 115 | cand_indexes.append([i]) 116 | 117 | weighted_cand_indexes = cand_indexes.copy() 118 | # 我们想让模型在进行mask操作时,更多的选择词组,而不是单个汉字,因此增加词组的权重 119 | for val in cand_indexes: 120 | # 如果是一个词组,则增加一个该词组的拷贝进入列表以增加权重 121 | if 2 <= len(val) <= 3: 122 | weighted_cand_indexes.append(val.copy()) 123 | 124 | random.shuffle(weighted_cand_indexes) 125 | num_to_predict = min(max_predictions, max(2, int(round(len(input_tokens) * self.mlm_probability)))) 126 | masked_lms = [] 127 | covered_indexes = set() 128 | for index_set in weighted_cand_indexes: 129 | if len(masked_lms) >= num_to_predict: 130 | break 131 | # If adding a whole-word mask would exceed the maximum number of 132 | # predictions, then just skip this candidate. 133 | if len(masked_lms) + len(index_set) > num_to_predict: 134 | continue 135 | if index_set[0] - 1 in masked_lms or index_set[-1] + 1 in masked_lms: 136 | flag = random.random() > 0.5 137 | if flag: 138 | continue 139 | is_any_index_covered = False 140 | for index in index_set: 141 | if index in covered_indexes: 142 | is_any_index_covered = True 143 | break 144 | if is_any_index_covered: 145 | continue 146 | for index in index_set: 147 | covered_indexes.add(index) 148 | masked_lms.append(index) 149 | 150 | if len(covered_indexes) != len(masked_lms): 151 | raise ValueError("Length of covered_indexes is not equal to length of masked_lms.") 152 | mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))] 153 | return mask_labels 154 | 155 | def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: 156 | # inputs [batch_size, max_len] 157 | # mask_labels [batch_size, max_len] 158 | """ 159 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 160 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. 161 | """ 162 | import torch 163 | 164 | if self.tokenizer.mask_token is None: 165 | raise ValueError( 166 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" 167 | " --mlm flag if you want to use this tokenizer." 168 | ) 169 | labels = inputs.clone() 170 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 171 | 172 | probability_matrix = mask_labels 173 | 174 | special_tokens_mask = [ 175 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 176 | ] 177 | try: 178 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 179 | except Exception: 180 | print(special_tokens_mask) 181 | if self.tokenizer._pad_token is not None: 182 | padding_mask = labels.eq(self.tokenizer.pad_token_id) 183 | probability_matrix.masked_fill_(padding_mask, value=0.0) 184 | 185 | masked_indices = probability_matrix.bool() 186 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 187 | 188 | # 90% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 189 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 1.0)).bool() & masked_indices 190 | 191 | # print(replaced_index) 192 | for index in range(len(inputs)): 193 | replaced_index = [i for i, v in enumerate(indices_replaced[index]) if v] 194 | # print(replaced_index) 195 | # 修改mask逻辑,使用多种mask替代原有的单一[MASK]标签,并剔除无效[MASK]如标点符号 196 | indices_total = torch.nonzero(labels[index] != -100, as_tuple=True)[0].tolist() 197 | values = labels[index][indices_total] 198 | # temp = indices_total.copy() 199 | indices_total = [x for i, x in enumerate(indices_total) if not ((200 <= values[i] <= 209) 200 | or (345 <= values[i] <= 532) 201 | or (106 <= values[i] <= 120) 202 | or (131 <= values[i] <= 142) 203 | or values[i] == 8024)] 204 | values = labels[index][indices_total] 205 | indices_total = torch.tensor(indices_total) 206 | words = [] 207 | pattern = re.compile(r'[\u4e00-\u9fff]') 208 | for i, val in enumerate(indices_total): 209 | if i == 0: 210 | words.append([values[i]]) 211 | continue 212 | # print(self.tokenizer.decode(values[i]).replace('#', '').replace(' ', '')) 213 | temp_word = self.tokenizer.decode(values[i]).replace('#', '').replace(' ', '') 214 | if str.isdigit(temp_word) or temp_word == '°' or str.isascii(temp_word): 215 | words.append([values[i]]) 216 | continue 217 | if indices_total[i] == indices_total[i - 1] + 1 and pattern.search(temp_word): 218 | words[-1].append(values[i]) 219 | else: 220 | words.append([values[i]]) 221 | 222 | # [LETTER_A] ~ [LETTER_Z] = 1 ~ 26 223 | # [NUM] = 28 224 | # [SPECIAL] = 29 225 | mask_ids = [] 226 | for word in words: 227 | word = self.tokenizer.decode(word).replace(' ', '').replace('#', '') 228 | first_letter = pinyin(word, style=Style.FIRST_LETTER) 229 | for letter in first_letter: 230 | if str.islower(letter[0][0]): 231 | if 1 <= (ord(letter[0][0]) - 96) <= 26: 232 | mask_ids.append(ord(letter[0][0]) - 96) 233 | else: 234 | mask_ids.append(29) 235 | else: 236 | mask_ids.append(28) 237 | 238 | try: 239 | replaced_mask = [] 240 | replaced_mask_index = [] 241 | for i, val in enumerate(indices_total): 242 | if val in replaced_index: 243 | replaced_mask_index.append(val.tolist()) 244 | # if i >= len(mask_ids): 245 | # continue 246 | replaced_mask.append(mask_ids[i]) 247 | 248 | if len(replaced_mask_index) != 0: 249 | inputs[index][replaced_mask_index] = torch.tensor(replaced_mask) 250 | except Exception: 251 | print(mask_ids) 252 | print(indices_total) 253 | 254 | return inputs, labels 255 | 256 | def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: 257 | """ 258 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 259 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. 260 | """ 261 | if self.tokenizer.mask_token is None: 262 | raise ValueError( 263 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" 264 | " --mlm flag if you want to use this tokenizer." 265 | ) 266 | labels = np.copy(inputs) 267 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 268 | 269 | masked_indices = mask_labels.astype(bool) 270 | 271 | special_tokens_mask = [ 272 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 273 | ] 274 | masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0 275 | if self.tokenizer._pad_token is not None: 276 | padding_mask = labels == self.tokenizer.pad_token_id 277 | masked_indices[padding_mask] = 0 278 | 279 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 280 | 281 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 282 | indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices 283 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 284 | 285 | # 10% of the time, we replace masked input tokens with random word 286 | # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 287 | indices_random = ( 288 | np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced 289 | ) 290 | random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64) 291 | inputs[indices_random] = random_words[indices_random] 292 | 293 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 294 | return inputs, labels 295 | 296 | 297 | class DataCollatorForDetector(DataCollatorForLanguageModeling): 298 | """ 299 | Data collator used for language modeling that masks entire words. 300 | 301 | - collates batches of tensors, honoring their tokenizer's pad_token 302 | - preprocesses batches for masked language modeling 303 | 304 | 305 | 306 | This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically 307 | that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will 308 | produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`]. 309 | 310 | """ 311 | 312 | def __random_token(self, current_token: int): 313 | random_token = current_token 314 | while random_token == current_token or random_token == 27: 315 | random_token = random.randint(1, 29) 316 | return random_token 317 | 318 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 319 | if isinstance(examples[0], Mapping): 320 | input_ids = [e["input_ids"] for e in examples] 321 | else: 322 | input_ids = examples 323 | examples = [{"input_ids": e} for e in examples] 324 | 325 | batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 326 | 327 | mask_labels = [] 328 | for e in examples: 329 | ref_tokens = [] 330 | for id in tolist(e["input_ids"]): 331 | token = self.tokenizer._convert_id_to_token(id) 332 | if id == 0: 333 | continue 334 | ref_tokens.append(token) 335 | 336 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] 337 | if "chinese_ref" in e: 338 | ref_pos = tolist(e["chinese_ref"]) 339 | len_seq = len(e["input_ids"]) 340 | for i in range(len_seq): 341 | if i in ref_pos: 342 | ref_tokens[i] = "##" + ref_tokens[i] 343 | mask_labels.append(self._whole_word_mask(ref_tokens)) 344 | batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 345 | inputs, labels = self.torch_mask_tokens(batch_input, batch_mask) 346 | attention_mask = (inputs != self.tokenizer.pad_token_id).long() 347 | return {"input_ids": inputs, "labels": labels, 'attention_mask': attention_mask} 348 | 349 | def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 350 | if isinstance(examples[0], Mapping): 351 | input_ids = [e["input_ids"] for e in examples] 352 | else: 353 | input_ids = examples 354 | examples = [{"input_ids": e} for e in examples] 355 | 356 | batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 357 | 358 | mask_labels = [] 359 | for e in examples: 360 | ref_tokens = [] 361 | for id in tolist(e["input_ids"]): 362 | token = self.tokenizer._convert_id_to_token(id) 363 | ref_tokens.append(token) 364 | 365 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] 366 | if "chinese_ref" in e: 367 | ref_pos = tolist(e["chinese_ref"]) 368 | len_seq = len(e["input_ids"]) 369 | for i in range(len_seq): 370 | if i in ref_pos: 371 | ref_tokens[i] = "##" + ref_tokens[i] 372 | mask_labels.append(self._whole_word_mask(ref_tokens)) 373 | batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 374 | inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask) 375 | return {"input_ids": inputs, "labels": labels} 376 | 377 | def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): 378 | """ 379 | Get 0/1 labels for masked tokens with whole word mask proxy 380 | """ 381 | if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)): 382 | warnings.warn( 383 | "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. " 384 | "Please refer to the documentation for more information." 385 | ) 386 | 387 | cand_indexes = [] 388 | for i, token in enumerate(input_tokens): 389 | if token == "[CLS]" or token == "[SEP]": 390 | continue 391 | 392 | if len(cand_indexes) >= 1 and token.startswith("##"): 393 | cand_indexes[-1].append(i) 394 | else: 395 | cand_indexes.append([i]) 396 | 397 | weighted_cand_indexes = cand_indexes.copy() 398 | # 我们想让模型在进行mask操作时,更多的选择词组,而不是单个汉字,因此增加词组的权重 399 | for val in cand_indexes: 400 | # 如果是一个词组,则增加一个该词组的拷贝进入列表以增加权重 401 | if 2 <= len(val) <= 3: 402 | weighted_cand_indexes.append(val.copy()) 403 | weighted_cand_indexes.append(val.copy()) 404 | 405 | random.shuffle(weighted_cand_indexes) 406 | num_to_predict = min(max_predictions, max(2, int(round(len(input_tokens) * self.mlm_probability)))) 407 | masked_lms = [] 408 | covered_indexes = set() 409 | for index_set in weighted_cand_indexes: 410 | if len(masked_lms) >= num_to_predict: 411 | break 412 | # If adding a whole-word mask would exceed the maximum number of 413 | # predictions, then just skip this candidate. 414 | if len(masked_lms) + len(index_set) > num_to_predict: 415 | continue 416 | if index_set[0] - 1 in masked_lms or index_set[-1] + 1 in masked_lms: 417 | flag = random.random() > 0.5 418 | if flag: 419 | continue 420 | is_any_index_covered = False 421 | for index in index_set: 422 | if index in covered_indexes: 423 | is_any_index_covered = True 424 | break 425 | if is_any_index_covered: 426 | continue 427 | for index in index_set: 428 | covered_indexes.add(index) 429 | masked_lms.append(index) 430 | 431 | if len(covered_indexes) != len(masked_lms): 432 | raise ValueError("Length of covered_indexes is not equal to length of masked_lms.") 433 | mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))] 434 | return mask_labels 435 | 436 | def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: 437 | # inputs [batch_size, max_len] 438 | # mask_labels [batch_size, max_len] 439 | """ 440 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 441 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. 442 | """ 443 | import torch 444 | 445 | if self.tokenizer.mask_token is None: 446 | raise ValueError( 447 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" 448 | " --mlm flag if you want to use this tokenizer." 449 | ) 450 | labels = inputs.clone() 451 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 452 | 453 | probability_matrix = mask_labels 454 | 455 | special_tokens_mask = [ 456 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 457 | ] 458 | try: 459 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 460 | except Exception: 461 | print(special_tokens_mask) 462 | if self.tokenizer._pad_token is not None: 463 | padding_mask = labels.eq(self.tokenizer.pad_token_id) 464 | probability_matrix.masked_fill_(padding_mask, value=0.0) 465 | 466 | masked_indices = probability_matrix.bool() 467 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 468 | 469 | # 100% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 470 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 1.0)).bool() & masked_indices 471 | 472 | # print(replaced_index) 473 | for index in range(len(inputs)): 474 | replaced_index = [i for i, v in enumerate(indices_replaced[index]) if v] 475 | # print(replaced_index) 476 | # 修改mask逻辑,使用多种mask替代原有的单一[MASK]标签,并剔除无效[MASK]如标点符号 477 | indices_total = torch.nonzero(labels[index] != -100, as_tuple=True)[0].tolist() 478 | values = labels[index][indices_total] 479 | # temp = indices_total.copy() 480 | indices_total = [x for i, x in enumerate(indices_total) if not ((200 <= values[i] <= 209) 481 | or (345 <= values[i] <= 532) 482 | or (106 <= values[i] <= 120) 483 | or (131 <= values[i] <= 142) 484 | or values[i] == 8024)] 485 | values = labels[index][indices_total] 486 | indices_total = torch.tensor(indices_total) 487 | words = [] 488 | pattern = re.compile(r'[\u4e00-\u9fff]') 489 | for i, val in enumerate(indices_total): 490 | if i == 0: 491 | words.append([values[i]]) 492 | continue 493 | # print(self.tokenizer.decode(values[i]).replace('#', '').replace(' ', '')) 494 | temp_word = self.tokenizer.decode(values[i]).replace('#', '').replace(' ', '') 495 | if str.isdigit(temp_word) or temp_word == '°' or str.isascii(temp_word): 496 | words.append([values[i]]) 497 | continue 498 | if indices_total[i] == indices_total[i - 1] + 1 and pattern.search(temp_word): 499 | words[-1].append(values[i]) 500 | else: 501 | words.append([values[i]]) 502 | 503 | # [LETTER_A] ~ [LETTER_Z] = 1 ~ 26 504 | # [NUM] = 28 505 | # [SPECIAL] = 29 506 | mask_ids = [] 507 | for word in words: 508 | word = self.tokenizer.decode(word).replace(' ', '').replace('#', '') 509 | first_letter = pinyin(word, style=Style.FIRST_LETTER) 510 | for letter in first_letter: 511 | if str.islower(letter[0][0]): 512 | if 1 <= (ord(letter[0][0]) - 96) <= 26: 513 | mask_ids.append(ord(letter[0][0]) - 96) 514 | else: 515 | mask_ids.append(29) 516 | else: 517 | mask_ids.append(28) 518 | 519 | try: 520 | replaced_mask = [] 521 | replaced_mask_index = [] 522 | for i, val in enumerate(indices_total): 523 | if val in replaced_index: 524 | replaced_mask_index.append(val.tolist()) 525 | # if i >= len(mask_ids): 526 | # continue 527 | replaced_mask.append(mask_ids[i]) 528 | 529 | labels[index] = torch.full(labels[index].shape, 2) 530 | if len(replaced_mask_index) != 0: 531 | word_indices = [] 532 | for i, val in enumerate(replaced_mask_index): 533 | if i == 0: 534 | word_indices.append([val]) 535 | continue 536 | if val == word_indices[-1][-1] + 1: 537 | word_indices[-1].append(val) 538 | else: 539 | word_indices.append([val]) 540 | 541 | random_replaced_mask = deepcopy(replaced_mask) 542 | for word in word_indices: 543 | if random.random() < 0.5: 544 | for i, val in enumerate(word): 545 | random_replaced_mask[i] = self.__random_token(random_replaced_mask[i]) 546 | labels[index][word] = torch.ones_like(torch.tensor(word)) 547 | else: 548 | labels[index][word] = torch.zeros_like(torch.tensor(word)) 549 | inputs[index][replaced_mask_index] = torch.tensor(random_replaced_mask) 550 | except Exception: 551 | print("error") 552 | print(mask_ids) 553 | print(indices_total) 554 | # print(inputs, labels) 555 | return inputs, labels 556 | 557 | 558 | def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: 559 | """ 560 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 561 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. 562 | """ 563 | if self.tokenizer.mask_token is None: 564 | raise ValueError( 565 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the" 566 | " --mlm flag if you want to use this tokenizer." 567 | ) 568 | labels = np.copy(inputs) 569 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 570 | 571 | masked_indices = mask_labels.astype(bool) 572 | 573 | special_tokens_mask = [ 574 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 575 | ] 576 | masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0 577 | if self.tokenizer._pad_token is not None: 578 | padding_mask = labels == self.tokenizer.pad_token_id 579 | masked_indices[padding_mask] = 0 580 | 581 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 582 | 583 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 584 | indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices 585 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 586 | 587 | # 10% of the time, we replace masked input tokens with random word 588 | # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 589 | indices_random = ( 590 | np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced 591 | ) 592 | random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64) 593 | inputs[indices_random] = random_words[indices_random] 594 | 595 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 596 | return inputs, labels 597 | 598 | 599 | __seg = pkuseg.pkuseg() 600 | 601 | def predict(sentence: str, 602 | predict_word: str, 603 | model, 604 | tokenizer, 605 | top_k=10, 606 | beam_size=16, 607 | threshold=0.000, 608 | fast_mode=True, 609 | strict_mode=True): 610 | replaced_word = [] 611 | for letter in predict_word: 612 | if str.isdigit(letter): 613 | replaced_word.append(27) 614 | continue 615 | id = ord(letter) - 96 616 | if 1 <= id <= 26: 617 | replaced_word.append(id) 618 | else: 619 | replaced_word.append(28) 620 | 621 | inputs = tokenizer(sentence, max_length=64, 622 | padding='max_length', 623 | truncation=True, 624 | return_tensors='pt').to('cuda') 625 | index = sentence.find(predict_word) 626 | length = len(predict_word) 627 | 628 | try: 629 | inputs['input_ids'][0][index + 1:index + 1 + length] = torch.tensor(replaced_word).to('cuda') 630 | except Exception: 631 | print('error') 632 | return [] 633 | 634 | results_a = [] 635 | results_b = [] 636 | __beam_search(results_a, 637 | beam_size=beam_size, 638 | max_depth=length, 639 | tokenizer=tokenizer, 640 | inputs=copy.deepcopy(inputs), 641 | model=model, 642 | top_k=top_k, 643 | threshold=threshold, 644 | index=index, 645 | strict_mode=strict_mode) 646 | 647 | if not fast_mode: 648 | __beam_search_back(results_b, 649 | beam_size=beam_size, 650 | max_depth=length, 651 | tokenizer=tokenizer, 652 | inputs=copy.deepcopy(inputs), 653 | model=model, 654 | top_k=top_k, 655 | threshold=threshold, 656 | index=index + length, 657 | strict_mode=strict_mode) 658 | result_dict = defaultdict(int) 659 | for val in results_a + results_b: 660 | key = val[0] 661 | value = val[1] 662 | result_dict[key] += value 663 | results = [[key, val] for key, val in result_dict.items()] 664 | results.sort(key=lambda x: x[1], reverse=True) 665 | return results 666 | 667 | 668 | def backtrack_predict(sentence: str, 669 | predict_word: str, 670 | model, 671 | tokenizer, 672 | top_k=10, 673 | fast_mode=True, 674 | strict_mode=True): 675 | replaced_word = [] 676 | for letter in predict_word: 677 | if str.isdigit(letter): 678 | replaced_word.append(27) 679 | continue 680 | id = ord(letter) - 96 681 | if 1 <= id <= 26: 682 | replaced_word.append(id) 683 | else: 684 | replaced_word.append(28) 685 | 686 | inputs = tokenizer(sentence, max_length=64, 687 | padding='max_length', 688 | truncation=True, 689 | return_tensors='pt').to('cuda') 690 | index = sentence.find(predict_word) 691 | length = len(predict_word) 692 | 693 | try: 694 | inputs['input_ids'][0][index + 1:index + 1 + length] = torch.tensor(replaced_word).to('cuda') 695 | except Exception: 696 | print('error') 697 | return [] 698 | 699 | results_a = [] 700 | results_b = [] 701 | __fixed_dfs(results=results_a, 702 | depth=length-1, 703 | probability=1.0, 704 | tokenizer=tokenizer, 705 | sentence=[], 706 | inputs=copy.deepcopy(inputs), 707 | model=model, 708 | index=index, 709 | top_k=top_k, 710 | strict_mode=strict_mode) 711 | 712 | if not fast_mode: 713 | __fixed_dfs_back(results=results_b, 714 | depth=length-1, 715 | probability=1.0, 716 | tokenizer=tokenizer, 717 | sentence=[], 718 | inputs=copy.deepcopy(inputs), 719 | model=model, 720 | index=index + length, 721 | top_k=top_k, 722 | strict_mode=strict_mode) 723 | result_dict = defaultdict(int) 724 | for val in results_a + results_b: 725 | key = val[0] 726 | value = val[1] 727 | result_dict[key] += value 728 | results = [[key, val] for key, val in result_dict.items()] 729 | results.sort(key=lambda x: x[1], reverse=True) 730 | return results 731 | 732 | 733 | def __fixed_dfs(results: list, 734 | depth: int, 735 | probability: float, 736 | tokenizer, 737 | sentence: list, 738 | inputs: str, 739 | model, 740 | index: int, 741 | top_k: int=5, 742 | strict_mode=True): 743 | 744 | with torch.no_grad(): 745 | logits = model(**inputs).logits 746 | 747 | # retrieve index of [MASK] 748 | mask_token_index = torch.where((inputs['input_ids'] == 103) | 749 | ((inputs['input_ids'] >= 1) & 750 | (inputs['input_ids'] <= 28)))[1].tolist() 751 | mask_token_logits = logits[0, mask_token_index, :] 752 | mask_token_probs = F.softmax(mask_token_logits, dim=-1) 753 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1) 754 | token = tokenizer.decode(top_k_tokens[0, 0:top_k]).split(' ') 755 | for i in range(len(token)): 756 | sentence.append(token[i]) 757 | prob = top_k_probs[0, i].item() 758 | new_probability = probability * prob 759 | 760 | if depth == 0: 761 | if new_probability >= 0.00: 762 | if not strict_mode or len(__seg.cut(''.join(sentence))) <= max(len(sentence) - 1, 1): 763 | results.append([''.join(sentence), new_probability]) 764 | else: 765 | original_value = torch.clone(inputs['input_ids'][0][index + 1]) 766 | inputs['input_ids'][0][index + 1] = top_k_tokens[0, i] 767 | __fixed_dfs(results=results, 768 | depth=depth-1, 769 | probability=new_probability, 770 | tokenizer=tokenizer, 771 | sentence=sentence, 772 | inputs=inputs, 773 | model=model, 774 | index=index+1, 775 | top_k=top_k-2) 776 | inputs['input_ids'][0][index + 1] = original_value 777 | sentence.pop() 778 | 779 | 780 | 781 | def __fixed_dfs_back(results: list, 782 | depth: int, 783 | probability: float, 784 | tokenizer, 785 | sentence: list, 786 | inputs: str, 787 | model, 788 | index: int, 789 | top_k: int=5, 790 | strict_mode=True): 791 | 792 | with torch.no_grad(): 793 | logits = model(**inputs).logits 794 | 795 | # retrieve index of [MASK] 796 | mask_token_index = [] 797 | temp = torch.nonzero(inputs['input_ids'] == 103, as_tuple=True) 798 | if len(temp[0]) > 0: 799 | mask_token_index.extend(temp[1].tolist()) 800 | for i in range(28): 801 | temp = torch.nonzero(inputs['input_ids'] == (i + 1), as_tuple=True) 802 | if len(temp[0]) > 0: 803 | mask_token_index.extend(temp[1].tolist()) 804 | mask_token_index = sorted(mask_token_index) 805 | mask_token_logits = logits[0, mask_token_index, :] 806 | mask_token_probs = F.softmax(mask_token_logits, dim=-1) 807 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1) 808 | token = tokenizer.decode(top_k_tokens[-1, 0:top_k]).split(' ') 809 | 810 | for i in range(len(token)): 811 | sentence.insert(0, token[i]) 812 | prob = top_k_probs[-1, i].item() 813 | new_probability = probability * prob 814 | if depth == 0: 815 | if new_probability >= 0.00: 816 | if not strict_mode or len(__seg.cut(''.join(sentence))) <= max(len(sentence) - 1, 1): 817 | results.append([''.join(sentence), new_probability]) 818 | else: 819 | inputs['input_ids'][0][index] = top_k_tokens[-1, i] 820 | __fixed_dfs_back(results=results, 821 | depth=depth-1, 822 | probability=new_probability, 823 | tokenizer=tokenizer, 824 | sentence=sentence, 825 | inputs=copy.deepcopy(inputs), 826 | model=model, 827 | index=index-1, 828 | top_k=top_k-2) 829 | sentence.pop(0) 830 | 831 | def __beam_search_back(results: list, 832 | beam_size: int, 833 | max_depth: int, 834 | tokenizer, 835 | inputs: str, 836 | model, 837 | top_k: int=5, 838 | threshold: float=0.01, 839 | index: int=0, 840 | strict_mode=True): 841 | beams = [[inputs, [], 1.0, index]] 842 | 843 | for depth in range(max_depth): 844 | new_beams = [] 845 | 846 | for inputs, sentence, probability, index in beams: 847 | with torch.no_grad(): 848 | logits = model(**inputs).logits 849 | 850 | mask_token_index = torch.where( 851 | (inputs['input_ids'] == 103) | 852 | ((inputs['input_ids'] >= 1) & (inputs['input_ids'] <= 28)) 853 | )[1].tolist() 854 | 855 | if not mask_token_index: 856 | continue 857 | 858 | mask_token_logits = logits[0, mask_token_index, :] 859 | mask_token_probs = F.softmax(mask_token_logits, dim=-1) 860 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1) 861 | 862 | tokens = tokenizer.decode(top_k_tokens[-1, 0:top_k]).split(' ') 863 | for i in range(len(tokens)): 864 | new_probability = probability * top_k_probs[-1, i].item() 865 | 866 | if new_probability < threshold: 867 | continue 868 | 869 | new_inputs = copy.deepcopy(inputs) 870 | new_inputs['input_ids'][0][index] = top_k_tokens[-1, i] 871 | 872 | new_sentence = sentence.copy() 873 | new_sentence.insert(0, tokens[i]) 874 | new_beams.append([new_inputs, new_sentence, new_probability, index - 1]) 875 | if len(new_beams) > 0: 876 | softmax(new_beams) 877 | new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:beam_size] 878 | 879 | if not new_beams: 880 | break 881 | 882 | if top_k > 2: 883 | top_k -= 2 884 | beams = new_beams 885 | 886 | for _, sentence, probability, _ in beams: 887 | sentence = ''.join(sentence) 888 | if not strict_mode or len(__seg.cut(sentence)) <= max(len(sentence) - 1, 1): 889 | results.append((sentence, probability)) 890 | 891 | return results 892 | 893 | def __beam_search(results: list, 894 | beam_size: int, 895 | max_depth: int, 896 | tokenizer, 897 | inputs: str, 898 | model, 899 | top_k: int=5, 900 | threshold: float=0.01, 901 | index: int=0, 902 | strict_mode=True): 903 | beams = [[inputs, [], 1.0, index]] 904 | 905 | for depth in range(max_depth): 906 | new_beams = [] 907 | 908 | for inputs, sentence, probability, index in beams: 909 | with torch.no_grad(): 910 | logits = model(**inputs).logits 911 | 912 | mask_token_index = torch.where( 913 | (inputs['input_ids'] == 103) | 914 | ((inputs['input_ids'] >= 1) & (inputs['input_ids'] <= 28)) 915 | )[1].tolist() 916 | 917 | if not mask_token_index: 918 | continue 919 | 920 | mask_token_logits = logits[0, mask_token_index, :] 921 | mask_token_probs = F.softmax(mask_token_logits, dim=-1) 922 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1) 923 | 924 | tokens = tokenizer.decode(top_k_tokens[0, 0:top_k]).split(' ') 925 | for i in range(len(tokens)): 926 | new_probability = probability * top_k_probs[0, i].item() 927 | 928 | if new_probability < threshold: 929 | continue 930 | 931 | new_inputs = copy.deepcopy(inputs) 932 | new_inputs['input_ids'][0][index + 1] = top_k_tokens[0, i] 933 | 934 | new_sentence = sentence + [tokens[i]] 935 | new_beams.append([new_inputs, new_sentence, new_probability, index + 1]) 936 | if len(new_beams) > 0: 937 | softmax(new_beams) 938 | new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:beam_size] 939 | 940 | if not new_beams: 941 | break 942 | 943 | if top_k > 2: 944 | top_k -= 2 945 | beams = new_beams 946 | 947 | for _, sentence, probability, _ in beams: 948 | sentence = ''.join(sentence) 949 | if not strict_mode or len(__seg.cut(sentence)) <= max(len(sentence) - 1, 1): 950 | results.append((sentence, probability)) 951 | 952 | return results 953 | 954 | def softmax(beams: List): 955 | column_2 = [row[2] for row in beams] 956 | 957 | max_val = max(column_2) 958 | exp_values = [math.exp(v - max_val) for v in column_2] 959 | sum_exp = sum(exp_values) 960 | softmax_column = [v / sum_exp for v in exp_values] 961 | 962 | for i, row in enumerate(beams): 963 | row[2] = softmax_column[i] 964 | 965 | import Levenshtein 966 | 967 | def word_level_predict(sentence: str, 968 | predict_word: str, 969 | model, 970 | tokenizer, 971 | top_k=10, 972 | beam_size=16, 973 | threshold=0.000, 974 | fast_mode=True, 975 | strict_mode=True): 976 | if predict_word.isascii(): 977 | return predict(sentence=sentence, predict_word=predict_word, model=model, tokenizer=tokenizer, top_k=top_k, threshold=threshold, fast_mode=fast_mode, beam_size=beam_size, strict_mode=strict_mode) 978 | abbr_pinyin = [] 979 | full_pinyin = [] 980 | for character in predict_word: 981 | if character.isascii(): 982 | abbr_pinyin.append(character) 983 | full_pinyin.append(character) 984 | else: 985 | abbr_pinyin.append(pinyin(character, Style.FIRST_LETTER)[0][0]) 986 | full_pinyin.append(pinyin(character, Style.NORMAL)[0][0]) 987 | abbr_pinyin = ''.join(abbr_pinyin) 988 | full_pinyin = ''.join(full_pinyin) 989 | sentence = sentence.replace(predict_word, abbr_pinyin) 990 | result = predict(sentence=sentence, predict_word=abbr_pinyin, model=model, tokenizer=tokenizer, top_k=top_k, threshold=threshold, fast_mode=fast_mode, beam_size=beam_size, strict_mode=strict_mode) 991 | 992 | if fast_mode: 993 | return result 994 | else: 995 | for val in result: 996 | word_pinyin = [] 997 | for temp in pinyin(val[0], Style.NORMAL): 998 | word_pinyin.append(temp[0]) 999 | word_pinyin = ''.join(word_pinyin) 1000 | val.append(Levenshtein.ratio(word_pinyin, full_pinyin)) 1001 | result.sort(key=lambda x: (x[2], x[1]), reverse=True) 1002 | return result 1003 | --------------------------------------------------------------------------------