├── requirements.txt
├── README.md
├── run_chinese_ref.py
├── Test.ipynb
├── TrainExample.ipynb
├── MoELayer.py
├── LICENSE
└── CustomBertModel.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | jupyter>=7.2.0
2 | pytorch>=2.4.0
3 | numpy>=2.1.1
4 | transformers>=4.45.0
5 | pypinyin>=0.53.0
6 | pkuseg==0.0.25
7 | ltp>=4.2.14
8 | datasets>=3.4.1
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # zh-CN-Multi-Mask-BERT (CNMBERT🍋)
2 | ~~吃柠檬Bert~~
3 | Official repository of paper "CNMBERT: A Model for Converting Hanyu Pinyin Abbreviations to Chinese Characters" published on IJCNN 2025
4 | 
5 |
6 | ---
7 |
8 | 一个用来翻译拼音缩写/汉字谐音的模型
9 |
10 | 此模型基于[Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)训练而来,通过修改其预训练任务来使其适配拼音缩写/汉字谐音翻译任务,相较于微调过的GPT模型以及GPT-4o达到了sota
11 |
12 | ---
13 |
14 | ## 什么是拼音缩写/汉字谐音
15 |
16 | 形如:
17 |
18 | > "bhys" -> "不好意思"
19 | >
20 | > "kb" -> "看病"
21 | >
22 | > "将军是一 **支柱** " -> "将军是一 **只猪** "
23 | >
24 | > "想 **紫砂** 了" -> "想 **自杀** 了"
25 |
26 |
27 | 如果对拼音缩写感兴趣可以看看这个↓
28 |
29 | [大家为什么会讨厌缩写? - 远方青木的回答 - 知乎](https://www.zhihu.com/question/269016377/answer/2654824753)
30 |
31 | ### CNMBERT
32 |
33 | | Model | 模型权重 | Memory Usage (FP16) | Model Size | QPS | MRR | Acc |
34 | | --------------- | ----------------------------------------------------------- | ------------------- | ---------- | ----- | ----- | ----- |
35 | | CNMBERT-Default* | [Huggingface](https://huggingface.co/Midsummra/CNMBert) | 0.4GB | 131M | 12.56 | 59.70 | 49.74 |
36 | | CNMBERT-MoE | [Huggingface](https://huggingface.co/Midsummra/CNMBert-MoE) | 0.8GB | 329M | 3.20 | 61.53 | 51.86 |
37 |
38 | * 所有模型均在相同的200万条wiki,知乎以及b站评论语料下训练
39 | * B站评论语料:[仓库](https://github.com/IgarashiAkatuki/BilibiliDatasets)
40 | * QPS 为 queries per second
41 | * MRR 为平均倒数排名(mean reciprocal rank)
42 | * Acc 为准确率(accuracy)
43 | * CNMBERT-Default 存在[量化版本](https://huggingface.co/mradermacher/CNMBert-GGUF)
44 |
45 | 模型架构&性能对比:
46 | 
47 | 
48 |
49 |
50 | ### Usage
51 |
52 | ```python
53 | from transformers import AutoTokenizer, BertConfig
54 |
55 | from CustomBertModel import predict, word_level_predict
56 | from MoELayer import BertWwmMoE
57 | ```
58 |
59 | 加载模型
60 |
61 | ```python
62 | # use CNMBert with MoE
63 | # To use CNMBert without MoE, replace all "Midsummra/CNMBert-MoE" with "Midsummra/CNMBert" and use BertForMaskedLM instead of using BertWwmMoE
64 | tokenizer = AutoTokenizer.from_pretrained("Midsummra/CNMBert-MoE")
65 | config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')
66 | model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')
67 |
68 | # model = BertForMaskedLM.from_pretrained('Midsummra/CNMBert').to('cuda')
69 | ```
70 |
71 | 预测词语
72 |
73 | ```python
74 | print(word_level_predict("将军是一支柱", "支柱", model, tokenizer)[:5])
75 | print(predict("我有两千kq", "kq", model, tokenizer)[:5])
76 | print(predict("快去给魔理沙看b吧", "b", model, tokenizer)[:5])
77 | ```
78 | > ['只猪', 0.013427094615127833, 1.0], ['支主', 0.012690062437477466, 1.0], ['支州', 0.012477088056586812, 0.9230769230769231], ['支战', 0.01260267308151233, 0.7692307692307692], ['侄子', 0.012531780478518316, 0.7272727272727273]
79 |
80 | > ['块钱', 1.2056937473156175], ['块前', 0.05837443749364857], ['开千', 0.0483869208528063], ['可千', 0.03996622172280445], ['口气', 0.037183335575008414]
81 |
82 | > ['病', 1.6893256306648254], ['吧', 0.1642467901110649], ['呗', 0.026976384222507477], ['包', 0.021441461518406868], ['报', 0.01396679226309061]
83 |
84 | ---
85 |
86 | ```python
87 | # 默认的predict函数使用束搜索
88 | def predict(sentence: str,
89 | predict_word: str,
90 | model,
91 | tokenizer,
92 | top_k=10,
93 | beam_size=24, # 束宽
94 | threshold=0.005, # 阈值
95 | fast_mode=True, # 是否使用快速模式
96 | strict_mode=True): # 是否对输出结果进行检查
97 |
98 | # 使用回溯的无剪枝暴力搜索
99 | def backtrack_predict(sentence: str,
100 | predict_word: str,
101 | model,
102 | tokenizer,
103 | top_k=10,
104 | fast_mode=True,
105 | strict_mode=True):
106 |
107 | # 如果要翻译汉字谐音,则使用word_level_predict
108 | def word_level_predict(sentence: str,
109 | predict_word: str,
110 | model,
111 | tokenizer,
112 | top_k=10,
113 | beam_size=24, # 束宽
114 | threshold=0.005, # 阈值
115 | fast_mode=True, # 是否使用快速模式
116 | strict_mode=True): # 是否对输出结果进行检查并使用Levenshtein Distance进行排序
117 | ```
118 |
119 | > 由于BERT的自编码特性,导致其在预测MASK时,顺序不同会导致预测结果不同,如果启用`fast_mode`,则会正向和反向分别对输入进行预测,可以提升一点准确率(2%左右),但是会带来更大的性能开销。
120 |
121 | > `strict_mode`会对输入进行检查,以判断其是否为一个真实存在的汉语词汇。
122 |
123 | ### 如何微调模型
124 |
125 | 请参考[TrainExample.ipynb](https://github.com/IgarashiAkatuki/CNMBert/blob/main/TrainExample.ipynb),在数据集的格式上,只要保证csv的第一列为要训练的语料即可。
126 | > 感觉冻结其他层只训练embedding也可以(? 有空会试一下的
127 |
128 | ### Q&A
129 |
130 | Q: 感觉这个东西准确度有点低啊
131 |
132 | A: 可以尝试设置`fast_mode`和`strict_mode`为`False`。 模型是在很小的数据集(200w)上进行的预训练,所以泛化能力不足很正常,,,可以在更大数据集或者更加细分的领域进行微调,具体微调方式和[Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)差别不大,只需要将`DataCollactor`替换为`CustomBertModel.py`中的`DataCollatorForMultiMask`。
133 |
134 | Q: 不能直接检测句子中所存在的拼音缩写或汉字谐音进行翻译吗?
135 |
136 | A: 正在做,对于拼音缩写来说,模型检测会很容易将其与句中的英文单词进行误判,导致准确率很低。对于汉字谐音来说,有些句子中的谐音,比如`你木琴没了`,这句话是不存在语病的,模型很难检测出`木琴`是`母亲`的谐音。
137 |
138 | ### 引用
139 | 如果您对CNMBERT的具体实现感兴趣的话,可以参考
140 | ```
141 | @misc{feng2024cnmbertmodelhanyupinyin,
142 | title={CNMBert: A Model For Hanyu Pinyin Abbreviation to Character Conversion Task},
143 | author={Zishuo Feng and Feng Cao},
144 | year={2024},
145 | eprint={2411.11770},
146 | archivePrefix={arXiv},
147 | primaryClass={cs.CL},
148 | url={https://arxiv.org/abs/2411.11770},
149 | }
150 | ```
151 |
152 |
--------------------------------------------------------------------------------
/run_chinese_ref.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from typing import List
4 |
5 | from ltp import LTP
6 |
7 | from transformers.models.bert.tokenization_bert import BertTokenizer
8 |
9 |
10 | def _is_chinese_char(cp):
11 | """Checks whether CP is the codepoint of a CJK character."""
12 | # This defines a "chinese character" as anything in the CJK Unicode block:
13 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
14 | #
15 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
16 | # despite its name. The modern Korean Hangul alphabet is a different block,
17 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
18 | # space-separated words, so they are not treated specially and handled
19 | # like the all of the other languages.
20 | if (
21 | (cp >= 0x4E00 and cp <= 0x9FFF)
22 | or (cp >= 0x3400 and cp <= 0x4DBF) #
23 | or (cp >= 0x20000 and cp <= 0x2A6DF) #
24 | or (cp >= 0x2A700 and cp <= 0x2B73F) #
25 | or (cp >= 0x2B740 and cp <= 0x2B81F) #
26 | or (cp >= 0x2B820 and cp <= 0x2CEAF) #
27 | or (cp >= 0xF900 and cp <= 0xFAFF)
28 | or (cp >= 0x2F800 and cp <= 0x2FA1F) #
29 | ): #
30 | return True
31 |
32 | return False
33 |
34 |
35 | def is_chinese(word: str):
36 | # word like '180' or '身高' or '神'
37 | for char in word:
38 | char = ord(char)
39 | if not _is_chinese_char(char):
40 | return 0
41 | return 1
42 |
43 |
44 | def get_chinese_word(tokens: List[str]):
45 | word_set = set()
46 |
47 | for token in tokens:
48 | chinese_word = len(token) > 1 and is_chinese(token)
49 | if chinese_word:
50 | word_set.add(token)
51 | word_list = list(word_set)
52 | return word_list
53 |
54 |
55 | def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
56 | if not chinese_word_set:
57 | return bert_tokens
58 | max_word_len = max([len(w) for w in chinese_word_set])
59 |
60 | bert_word = bert_tokens
61 | start, end = 0, len(bert_word)
62 | while start < end:
63 | single_word = True
64 | if is_chinese(bert_word[start]):
65 | l = min(end - start, max_word_len)
66 | for i in range(l, 1, -1):
67 | whole_word = "".join(bert_word[start : start + i])
68 | if whole_word in chinese_word_set:
69 | for j in range(start + 1, start + i):
70 | bert_word[j] = "##" + bert_word[j]
71 | start = start + i
72 | single_word = False
73 | break
74 | if single_word:
75 | start += 1
76 | return bert_word
77 |
78 |
79 | def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
80 | ltp_res = []
81 |
82 | for i in range(0, len(lines), 100):
83 | res = ltp_tokenizer.pipeline(lines[i : i + 100], tasks=["cws"]).cws
84 | res = [get_chinese_word(r) for r in res]
85 | ltp_res.extend(res)
86 | assert len(ltp_res) == len(lines)
87 |
88 | bert_res = []
89 | for i in range(0, len(lines), 100):
90 | res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
91 | bert_res.extend(res["input_ids"])
92 | assert len(bert_res) == len(lines)
93 |
94 | ref_ids = []
95 | for input_ids, chinese_word in zip(bert_res, ltp_res):
96 | input_tokens = []
97 | for id in input_ids:
98 | token = bert_tokenizer._convert_id_to_token(id)
99 | input_tokens.append(token)
100 | input_tokens = add_sub_symbol(input_tokens, chinese_word)
101 | ref_id = []
102 | # We only save pos of chinese subwords start with ##, which mean is part of a whole word.
103 | for i, token in enumerate(input_tokens):
104 | if token[:2] == "##":
105 | clean_token = token[2:]
106 | # save chinese tokens' pos
107 | if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
108 | ref_id.append(i)
109 | ref_ids.append(ref_id)
110 |
111 | assert len(ref_ids) == len(bert_res)
112 |
113 | return ref_ids
114 |
115 |
116 | def main(args):
117 | # For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
118 | # If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
119 | with open(args.file_name, "r", encoding="utf-8") as f:
120 | data = f.readlines()
121 | data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029'
122 | ltp_tokenizer = LTP(args.ltp) # faster in GPU device
123 | bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
124 |
125 | ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer)
126 |
127 | with open(args.save_path, "w", encoding="utf-8") as f:
128 | data = [json.dumps(ref) + "\n" for ref in ref_ids]
129 | f.writelines(data)
130 |
131 |
132 | if __name__ == "__main__":
133 | parser = argparse.ArgumentParser(description="prepare_chinese_ref")
134 | parser.add_argument(
135 | "--file_name",
136 | required=False,
137 | type=str,
138 | default="./resources/chinese-demo.txt",
139 | help="file need process, same as training data in lm",
140 | )
141 | parser.add_argument(
142 | "--ltp",
143 | required=False,
144 | type=str,
145 | default="./resources/ltp",
146 | help="resources for LTP tokenizer, usually a path",
147 | )
148 | parser.add_argument(
149 | "--bert",
150 | required=False,
151 | type=str,
152 | default="./resources/robert",
153 | help="resources for Bert tokenizer",
154 | )
155 | parser.add_argument(
156 | "--save_path",
157 | required=False,
158 | type=str,
159 | default="./resources/ref.txt",
160 | help="path to save res",
161 | )
162 |
163 | args = parser.parse_args()
164 | main(args)
165 |
--------------------------------------------------------------------------------
/Test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "id": "initial_id",
6 | "metadata": {
7 | "collapsed": true,
8 | "ExecuteTime": {
9 | "end_time": "2025-03-29T02:38:37.862817Z",
10 | "start_time": "2025-03-29T02:38:33.921228Z"
11 | }
12 | },
13 | "source": [
14 | "from transformers import AutoTokenizer, BertConfig, BertForMaskedLM\n",
15 | "\n",
16 | "from CustomBertModel import predict, backtrack_predict, word_level_predict\n",
17 | "from MoELayer import BertWwmMoE"
18 | ],
19 | "outputs": [
20 | {
21 | "name": "stderr",
22 | "output_type": "stream",
23 | "text": [
24 | "E:\\Environment\\Anaconda\\envs\\speech\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
25 | " from .autonotebook import tqdm as notebook_tqdm\n"
26 | ]
27 | }
28 | ],
29 | "execution_count": 1
30 | },
31 | {
32 | "metadata": {
33 | "ExecuteTime": {
34 | "end_time": "2025-03-29T02:38:39.810910Z",
35 | "start_time": "2025-03-29T02:38:37.866817Z"
36 | }
37 | },
38 | "cell_type": "code",
39 | "source": [
40 | "# use CNMBert with MoE\n",
41 | "# if you want to use CNMBert without MoE, please change all \"Midsummra/CNMBert-MoE\" to \"Midsummra/CNMBert\" and use BertForMaskedLM instead of using BertWwmMoE\n",
42 | "tokenizer = AutoTokenizer.from_pretrained(\"Midsummra/CNMBert-MoE\")\n",
43 | "config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')\n",
44 | "model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')"
45 | ],
46 | "id": "5d9191a45178cd39",
47 | "outputs": [
48 | {
49 | "name": "stderr",
50 | "output_type": "stream",
51 | "text": [
52 | "BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
53 | " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n",
54 | " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
55 | " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n",
56 | "Some weights of the model checkpoint at Midsummra/CNMBert-MoE were not used when initializing BertWwmMoE: ['bert.encoder.layer.0.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.10.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.12.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.14.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.2.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.4.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.6.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.8.intermediate.dense.sparse_moe.bias']\n",
57 | "- This IS expected if you are initializing BertWwmMoE from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
58 | "- This IS NOT expected if you are initializing BertWwmMoE from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
59 | "Some weights of BertWwmMoE were not initialized from the model checkpoint at Midsummra/CNMBert-MoE and are newly initialized: ['bert.encoder.layer.0.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.10.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.12.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.14.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.2.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.4.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.6.intermediate.dense.sparse_moe.beta', 'bert.encoder.layer.8.intermediate.dense.sparse_moe.beta']\n",
60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
61 | ]
62 | }
63 | ],
64 | "execution_count": 2
65 | },
66 | {
67 | "metadata": {
68 | "ExecuteTime": {
69 | "end_time": "2025-03-29T02:47:24.841036Z",
70 | "start_time": "2025-03-29T02:47:23.974884Z"
71 | }
72 | },
73 | "cell_type": "code",
74 | "source": [
75 | "print(word_level_predict(\"将军是一支柱\", \"支柱\", model, tokenizer, fast_mode=False, strict_mode=False)[:10])\n",
76 | "print(predict(\"快去给魔理沙看b吧\", \"b\", model, tokenizer)[:5])"
77 | ],
78 | "id": "5f8789a21534588c",
79 | "outputs": [
80 | {
81 | "name": "stdout",
82 | "output_type": "stream",
83 | "text": [
84 | "[['只猪', 0.013427094615127833, 1.0], ['支主', 0.012690062437477466, 1.0], ['支州', 0.012477088056586812, 0.9230769230769231], ['支战', 0.01260267308151233, 0.7692307692307692], ['侄子', 0.012531780478518316, 0.7272727272727273], ['支子', 0.012490831659057209, 0.7272727272727273], ['种主', 0.012726939569656787, 0.7142857142857143], ['支长', 0.012681785355615857, 0.7142857142857143], ['支中', 0.012518818992013665, 0.7142857142857143], ['直在', 0.013213451601798279, 0.6666666666666667]]\n",
85 | "[['病', 0.19057052400975555], ['吧', 0.09717965056220808], ['包', 0.08986218881784686], ['呗', 0.08982954684417713], ['报', 0.08949582422272462]]\n"
86 | ]
87 | }
88 | ],
89 | "execution_count": 14
90 | },
91 | {
92 | "metadata": {},
93 | "cell_type": "code",
94 | "outputs": [],
95 | "execution_count": null,
96 | "source": "",
97 | "id": "51e5db5ca1807903"
98 | }
99 | ],
100 | "metadata": {
101 | "kernelspec": {
102 | "display_name": "Python 3",
103 | "language": "python",
104 | "name": "python3"
105 | },
106 | "language_info": {
107 | "codemirror_mode": {
108 | "name": "ipython",
109 | "version": 2
110 | },
111 | "file_extension": ".py",
112 | "mimetype": "text/x-python",
113 | "name": "python",
114 | "nbconvert_exporter": "python",
115 | "pygments_lexer": "ipython2",
116 | "version": "2.7.6"
117 | }
118 | },
119 | "nbformat": 4,
120 | "nbformat_minor": 5
121 | }
122 |
--------------------------------------------------------------------------------
/TrainExample.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "id": "initial_id",
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "source": [
10 | "from transformers import AutoTokenizer, BertConfig, TrainingArguments, Trainer\n",
11 | "from CustomBertModel import DataCollatorForMultiMask\n",
12 | "from MoELayer import BertWwmMoE\n",
13 | "from datasets import Dataset\n",
14 | "from ltp import LTP\n",
15 | "\n",
16 | "# https://github.com/huggingface/transformers/blob/main/examples/research_projects/mlm_wwm/run_chinese_ref.py\n",
17 | "from run_chinese_ref import prepare_ref\n",
18 | "\n",
19 | "import random\n",
20 | "import torch\n"
21 | ],
22 | "outputs": [],
23 | "execution_count": null
24 | },
25 | {
26 | "metadata": {},
27 | "cell_type": "code",
28 | "source": [
29 | "random.seed(123)\n",
30 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
31 | "ltp = LTP().to(device=device)\n",
32 | "\n",
33 | "tokenizer = AutoTokenizer.from_pretrained(\"Midsummra/CNMBert-MoE\")\n",
34 | "config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')\n",
35 | "model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')"
36 | ],
37 | "id": "497437f079c218e0",
38 | "outputs": [],
39 | "execution_count": null
40 | },
41 | {
42 | "metadata": {},
43 | "cell_type": "code",
44 | "source": [
45 | "# 数据预处理\n",
46 | "\n",
47 | "text = set()\n",
48 | "bilibili = set()\n",
49 | "with open('../webtext/train.csv', mode='r', encoding='utf-8') as file:\n",
50 | " line = file.readline()\n",
51 | " while True:\n",
52 | " if not line:\n",
53 | " break\n",
54 | " text.add(line)\n",
55 | " line = file.readline()\n",
56 | "with open('../webtext/bilibili.csv', mode='r', encoding='utf-8') as file:\n",
57 | " line = file.readline()\n",
58 | " while True:\n",
59 | " if not line:\n",
60 | " break\n",
61 | " bilibili.add(line)\n",
62 | " line = file.readline()\n",
63 | "\n",
64 | "text = [t.replace('\\n', '') for t in list(text)]\n",
65 | "bilibili = [t.replace('\\n', '') for t in list(bilibili)]\n",
66 | "random.shuffle(text)\n",
67 | "random.shuffle(bilibili)\n",
68 | "\n",
69 | "train_data = {'text': text[:750000] + bilibili[:750000]}\n",
70 | "eval_data = {'text': text[len(text) - 20000:] + bilibili[len(bilibili) - 20000:]}\n",
71 | "\n",
72 | "train_data = Dataset.from_dict(train_data)\n",
73 | "eval_data = Dataset.from_dict(eval_data)"
74 | ],
75 | "id": "577b4745dde4c3ad",
76 | "outputs": [],
77 | "execution_count": null
78 | },
79 | {
80 | "metadata": {},
81 | "cell_type": "code",
82 | "source": [
83 | "def tokenize_func(dataset):\n",
84 | " tokens = tokenizer(dataset['text'],\n",
85 | " max_length=64,\n",
86 | " padding='max_length',\n",
87 | " truncation=True,\n",
88 | " return_tensors='pt'\n",
89 | " )\n",
90 | " ref = prepare_ref(dataset['text'], ltp, tokenizer)\n",
91 | " features = {'input_ids': tokens['input_ids'], 'chinese_ref': ref, 'attention_mask': tokens['attention_mask']}\n",
92 | " return features\n",
93 | "\n",
94 | "data_collator = DataCollatorForMultiMask(tokenizer,\n",
95 | " mlm_probability=0.15,\n",
96 | " mlm=True,\n",
97 | " pad_to_multiple_of=64)\n",
98 | "\n",
99 | "train_dataset = train_data.map(tokenize_func, batched=True, remove_columns=[\"text\"])\n",
100 | "eval_dataset = eval_data.map(tokenize_func, batched=True, remove_columns=[\"text\"])"
101 | ],
102 | "id": "5d68fd569a379e40",
103 | "outputs": [],
104 | "execution_count": null
105 | },
106 | {
107 | "metadata": {},
108 | "cell_type": "code",
109 | "source": [
110 | "# 可选,只训练embeddings\n",
111 | "for name, param in model.named_parameters():\n",
112 | " if name.startswith('bert.embeddings.'):\n",
113 | " param.requires_grad = True\n",
114 | " else:\n",
115 | " param.requires_grad = False\n",
116 | " if param.requires_grad:\n",
117 | " print(name)"
118 | ],
119 | "id": "bcc91db91d15d0bb",
120 | "outputs": [],
121 | "execution_count": null
122 | },
123 | {
124 | "metadata": {},
125 | "cell_type": "code",
126 | "source": [
127 | "# 训练\n",
128 | "\n",
129 | "torch.manual_seed(42)\n",
130 | "\n",
131 | "model = model.to(device)\n",
132 | "for name, param in model.named_parameters():\n",
133 | " if param.requires_grad:\n",
134 | " print(f\"Trainable layer: {name}\")\n",
135 | " param.data = param.data.contiguous()\n",
136 | "\n",
137 | "\n",
138 | "training_args = TrainingArguments(\n",
139 | " output_dir='./model/checkpoints/',\n",
140 | " num_train_epochs=20,\n",
141 | " per_device_train_batch_size=128,\n",
142 | " eval_strategy='steps',\n",
143 | " eval_steps=500,\n",
144 | " learning_rate=1e-5, #学习率建议给1e-5~2e-5\n",
145 | " weight_decay=1e-5,\n",
146 | " logging_dir='./model/logs/',\n",
147 | " logging_steps=100,\n",
148 | " logging_first_step=True,\n",
149 | " save_strategy='steps',\n",
150 | " save_steps=100,\n",
151 | " save_total_limit=4,\n",
152 | " max_grad_norm=1.0,\n",
153 | " warmup_ratio=1 / 20,\n",
154 | " disable_tqdm=True\n",
155 | ")\n",
156 | "\n",
157 | "trainer = Trainer(\n",
158 | " model=model,\n",
159 | " args=training_args,\n",
160 | " train_dataset=train_dataset,\n",
161 | " eval_dataset=eval_dataset,\n",
162 | " data_collator=data_collator,\n",
163 | ")\n"
164 | ],
165 | "id": "2c9bf9ebb808cf34",
166 | "outputs": [],
167 | "execution_count": null
168 | },
169 | {
170 | "metadata": {},
171 | "cell_type": "code",
172 | "source": [
173 | "trainer.train()\n",
174 | "trainer.save_model('./model/cnmbert-ft')\n",
175 | "eval_results = trainer.evaluate()\n",
176 | "print(f\"Evaluation cnmbert-ft: {eval_results}\")"
177 | ],
178 | "id": "b4004f82f39462d9",
179 | "outputs": [],
180 | "execution_count": null
181 | },
182 | {
183 | "metadata": {},
184 | "cell_type": "code",
185 | "source": "",
186 | "id": "d62062c3aa4e90fa",
187 | "outputs": [],
188 | "execution_count": null
189 | }
190 | ],
191 | "metadata": {
192 | "kernelspec": {
193 | "display_name": "Python 3",
194 | "language": "python",
195 | "name": "python3"
196 | },
197 | "language_info": {
198 | "codemirror_mode": {
199 | "name": "ipython",
200 | "version": 2
201 | },
202 | "file_extension": ".py",
203 | "mimetype": "text/x-python",
204 | "name": "python",
205 | "nbconvert_exporter": "python",
206 | "pygments_lexer": "ipython2",
207 | "version": "2.7.6"
208 | }
209 | },
210 | "nbformat": 4,
211 | "nbformat_minor": 5
212 | }
213 |
--------------------------------------------------------------------------------
/MoELayer.py:
--------------------------------------------------------------------------------
1 | from transformers import BertForMaskedLM
2 | from transformers import TrainerCallback
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | temp_indices = []
7 |
8 | class Expert(nn.Module):
9 | def __init__(self, input_dim, hidden_dim, dropout=0.15):
10 | super(Expert, self).__init__()
11 | self.net = nn.Sequential(
12 | nn.Linear(input_dim, hidden_dim),
13 | nn.LeakyReLU(negative_slope=0.01),
14 | nn.Dropout(dropout),
15 | nn.Linear(hidden_dim, input_dim),
16 | )
17 |
18 | def forward(self, x):
19 | return self.net(x)
20 |
21 |
22 | class DynamicRouter(nn.Module):
23 | def __init__(self, input_dim,
24 | num_experts=8,
25 | top_k=2,
26 | noise_std=0.1):
27 | super(DynamicRouter, self).__init__()
28 | self.top_k = top_k
29 | self.linear = nn.Linear(input_dim, num_experts)
30 | self.noise_std = noise_std
31 | def forward(self, x):
32 | logits = self.linear(x)
33 | noise = torch.randn_like(logits) * self.noise_std
34 | logits += noise
35 | top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
36 |
37 | zeros = torch.full_like(logits, float('-inf')).to(logits.device)
38 | sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
39 | global temp_indices
40 | temp_indices.append([logits, top_k_indices])
41 | router_output = F.softmax(sparse_logits, dim=-1)
42 |
43 | return router_output, top_k_indices
44 |
45 |
46 | class SparseMoE(nn.Module):
47 | def __init__(self, input_dim, output_dim, num_experts=8, top_k=2, dropout=0.15):
48 | super(SparseMoE, self).__init__()
49 | self.router = DynamicRouter(input_dim, num_experts, top_k)
50 | self.experts = nn.ModuleList([Expert(input_dim, output_dim, dropout) for _ in range(num_experts)])
51 | self.shared_expert = Expert(input_dim, output_dim, dropout)
52 | self.top_k = top_k
53 | self.beta = nn.Parameter(torch.tensor(0.7), requires_grad=False)
54 | self.alpha = nn.Parameter(torch.tensor(0.3), requires_grad=True)
55 |
56 | def forward(self, x):
57 | # 1. 输入进入router得到两个输出
58 | gating_output, indices = self.router(x)
59 | # 2.初始化全零矩阵,后续叠加为最终结果
60 | final_output = torch.zeros_like(x)
61 |
62 | # 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
63 | flat_x = x.view(-1, x.size(-1))
64 | flat_gating_output = gating_output.view(-1, gating_output.size(-1))
65 |
66 | # 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权
67 | for i, expert in enumerate(self.experts):
68 | # 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2
69 | expert_mask = (indices == i).any(dim=-1)
70 | # 5. 展平操作
71 | flat_mask = expert_mask.view(-1)
72 | # 如果当前专家是任意一个token的前top2
73 | if flat_mask.any():
74 | # 6. 得到该专家对哪几个token起作用后,选取token的维度表示
75 | expert_input = flat_x[flat_mask]
76 | # 7. 将token输入expert得到输出
77 | expert_output = expert(expert_input)
78 |
79 | # 8. 计算当前专家对于有作用的token的权重分数
80 | gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
81 | # 9. 将expert输出乘上权重分数
82 | weighted_output = expert_output * gating_scores
83 |
84 | # 10. 循环进行做种的结果叠加
85 | final_output[expert_mask] += weighted_output.squeeze(1)
86 |
87 | weights = F.softmax(torch.stack([self.alpha, self.beta]), dim=0)
88 | a, b = weights[0], weights[1]
89 | final_output = self.shared_expert(x) * a + final_output * b
90 | # global temp_indices
91 | # temp_indices.append([final_output, indices])
92 | return final_output
93 |
94 |
95 | class SparseMoEFFN(nn.Module):
96 | def __init__(self, config, num_experts=8, top_k=2, dropout=0.15):
97 | super(SparseMoEFFN, self).__init__()
98 | self.sparse_moe = SparseMoE(input_dim=768,
99 | output_dim=3072,
100 | num_experts=num_experts,
101 | top_k=top_k,
102 | dropout=dropout)
103 |
104 | def forward(self, x):
105 | return self.sparse_moe(x)
106 |
107 |
108 | class BertWwmMoE(BertForMaskedLM):
109 | def __init__(self, config, num_experts=8, top_k=2, dropout=0.05):
110 | super(BertWwmMoE, self).__init__(config)
111 | for index, layer in enumerate(self.bert.encoder.layer):
112 | if 8 <= index <= 15:
113 | if index % 2 == 1:
114 | continue
115 | moe_ffn = SparseMoEFFN(config=config,
116 | num_experts=8,
117 | top_k=2,
118 | dropout=dropout)
119 |
120 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net):
121 | if index == 0:
122 | expert.weight.data = layer.intermediate.dense.weight.data.clone()
123 | expert.bias.data = layer.intermediate.dense.bias.data.clone()
124 | if index == 3:
125 | expert.weight.data = layer.output.dense.weight.data.clone()
126 | expert.bias.data = layer.output.dense.bias.data.clone()
127 |
128 | layer.intermediate.dense = moe_ffn
129 | layer.output.dense = nn.Identity()
130 |
131 | if 5 <= index <= 7:
132 | if index % 2 == 1:
133 | continue
134 | moe_ffn = SparseMoEFFN(config=config,
135 | num_experts=4,
136 | top_k=1,
137 | dropout=dropout)
138 |
139 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net):
140 | if index == 0:
141 | expert.weight.data = layer.intermediate.dense.weight.data.clone()
142 | expert.bias.data = layer.intermediate.dense.bias.data.clone()
143 | if index == 3:
144 | expert.weight.data = layer.output.dense.weight.data.clone()
145 | expert.bias.data = layer.output.dense.bias.data.clone()
146 |
147 | layer.intermediate.dense = moe_ffn
148 | layer.output.dense = nn.Identity()
149 | if 0 <= index <= 4:
150 | if index % 2 == 1:
151 | continue
152 | moe_ffn = SparseMoEFFN(config=config,
153 | num_experts=2,
154 | top_k=1,
155 | dropout=0.1)
156 |
157 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net):
158 | if index == 0:
159 | expert.weight.data = layer.intermediate.dense.weight.data.clone()
160 | expert.bias.data = layer.intermediate.dense.bias.data.clone()
161 | if index == 3:
162 | expert.weight.data = layer.output.dense.weight.data.clone()
163 | expert.bias.data = layer.output.dense.bias.data.clone()
164 |
165 | layer.intermediate.dense = moe_ffn
166 | layer.output.dense = nn.Identity()
167 | if 0 <= index <= 0:
168 | if index % 2 == 1:
169 | continue
170 | moe_ffn = SparseMoEFFN(config=config,
171 | num_experts=2,
172 | top_k=1,
173 | dropout=0.1)
174 |
175 | for index, expert in enumerate(moe_ffn.sparse_moe.shared_expert.net):
176 | if index == 0:
177 | expert.weight.data = layer.intermediate.dense.weight.data.clone()
178 | expert.bias.data = layer.intermediate.dense.bias.data.clone()
179 | if index == 3:
180 | expert.weight.data = layer.output.dense.weight.data.clone()
181 | expert.bias.data = layer.output.dense.bias.data.clone()
182 |
183 | layer.intermediate.dense = moe_ffn
184 | layer.output.dense = nn.Identity()
185 | class EvaluationCallback(TrainerCallback):
186 | def __init__(self, model):
187 | super().__init__()
188 | self.model = model
189 |
190 | def on_evaluate(self, args, state, control, **kwargs):
191 | # 获取并打印指定层的权重
192 | for index, layer in enumerate(self.model.bert.encoder.layer):
193 | if 0 <= index < 6:
194 | print(layer.intermediate.dense.sparse_moe.alpha.data)
195 | print(layer.intermediate.dense.sparse_moe.beta.data)
196 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU AFFERO GENERAL PUBLIC LICENSE
2 | Version 3, 19 November 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU Affero General Public License is a free, copyleft license for
11 | software and other kinds of works, specifically designed to ensure
12 | cooperation with the community in the case of network server software.
13 |
14 | The licenses for most software and other practical works are designed
15 | to take away your freedom to share and change the works. By contrast,
16 | our General Public Licenses are intended to guarantee your freedom to
17 | share and change all versions of a program--to make sure it remains free
18 | software for all its users.
19 |
20 | When we speak of free software, we are referring to freedom, not
21 | price. Our General Public Licenses are designed to make sure that you
22 | have the freedom to distribute copies of free software (and charge for
23 | them if you wish), that you receive source code or can get it if you
24 | want it, that you can change the software or use pieces of it in new
25 | free programs, and that you know you can do these things.
26 |
27 | Developers that use our General Public Licenses protect your rights
28 | with two steps: (1) assert copyright on the software, and (2) offer
29 | you this License which gives you legal permission to copy, distribute
30 | and/or modify the software.
31 |
32 | A secondary benefit of defending all users' freedom is that
33 | improvements made in alternate versions of the program, if they
34 | receive widespread use, become available for other developers to
35 | incorporate. Many developers of free software are heartened and
36 | encouraged by the resulting cooperation. However, in the case of
37 | software used on network servers, this result may fail to come about.
38 | The GNU General Public License permits making a modified version and
39 | letting the public access it on a server without ever releasing its
40 | source code to the public.
41 |
42 | The GNU Affero General Public License is designed specifically to
43 | ensure that, in such cases, the modified source code becomes available
44 | to the community. It requires the operator of a network server to
45 | provide the source code of the modified version running there to the
46 | users of that server. Therefore, public use of a modified version, on
47 | a publicly accessible server, gives the public access to the source
48 | code of the modified version.
49 |
50 | An older license, called the Affero General Public License and
51 | published by Affero, was designed to accomplish similar goals. This is
52 | a different license, not a version of the Affero GPL, but Affero has
53 | released a new version of the Affero GPL which permits relicensing under
54 | this license.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | TERMS AND CONDITIONS
60 |
61 | 0. Definitions.
62 |
63 | "This License" refers to version 3 of the GNU Affero General Public License.
64 |
65 | "Copyright" also means copyright-like laws that apply to other kinds of
66 | works, such as semiconductor masks.
67 |
68 | "The Program" refers to any copyrightable work licensed under this
69 | License. Each licensee is addressed as "you". "Licensees" and
70 | "recipients" may be individuals or organizations.
71 |
72 | To "modify" a work means to copy from or adapt all or part of the work
73 | in a fashion requiring copyright permission, other than the making of an
74 | exact copy. The resulting work is called a "modified version" of the
75 | earlier work or a work "based on" the earlier work.
76 |
77 | A "covered work" means either the unmodified Program or a work based
78 | on the Program.
79 |
80 | To "propagate" a work means to do anything with it that, without
81 | permission, would make you directly or secondarily liable for
82 | infringement under applicable copyright law, except executing it on a
83 | computer or modifying a private copy. Propagation includes copying,
84 | distribution (with or without modification), making available to the
85 | public, and in some countries other activities as well.
86 |
87 | To "convey" a work means any kind of propagation that enables other
88 | parties to make or receive copies. Mere interaction with a user through
89 | a computer network, with no transfer of a copy, is not conveying.
90 |
91 | An interactive user interface displays "Appropriate Legal Notices"
92 | to the extent that it includes a convenient and prominently visible
93 | feature that (1) displays an appropriate copyright notice, and (2)
94 | tells the user that there is no warranty for the work (except to the
95 | extent that warranties are provided), that licensees may convey the
96 | work under this License, and how to view a copy of this License. If
97 | the interface presents a list of user commands or options, such as a
98 | menu, a prominent item in the list meets this criterion.
99 |
100 | 1. Source Code.
101 |
102 | The "source code" for a work means the preferred form of the work
103 | for making modifications to it. "Object code" means any non-source
104 | form of a work.
105 |
106 | A "Standard Interface" means an interface that either is an official
107 | standard defined by a recognized standards body, or, in the case of
108 | interfaces specified for a particular programming language, one that
109 | is widely used among developers working in that language.
110 |
111 | The "System Libraries" of an executable work include anything, other
112 | than the work as a whole, that (a) is included in the normal form of
113 | packaging a Major Component, but which is not part of that Major
114 | Component, and (b) serves only to enable use of the work with that
115 | Major Component, or to implement a Standard Interface for which an
116 | implementation is available to the public in source code form. A
117 | "Major Component", in this context, means a major essential component
118 | (kernel, window system, and so on) of the specific operating system
119 | (if any) on which the executable work runs, or a compiler used to
120 | produce the work, or an object code interpreter used to run it.
121 |
122 | The "Corresponding Source" for a work in object code form means all
123 | the source code needed to generate, install, and (for an executable
124 | work) run the object code and to modify the work, including scripts to
125 | control those activities. However, it does not include the work's
126 | System Libraries, or general-purpose tools or generally available free
127 | programs which are used unmodified in performing those activities but
128 | which are not part of the work. For example, Corresponding Source
129 | includes interface definition files associated with source files for
130 | the work, and the source code for shared libraries and dynamically
131 | linked subprograms that the work is specifically designed to require,
132 | such as by intimate data communication or control flow between those
133 | subprograms and other parts of the work.
134 |
135 | The Corresponding Source need not include anything that users
136 | can regenerate automatically from other parts of the Corresponding
137 | Source.
138 |
139 | The Corresponding Source for a work in source code form is that
140 | same work.
141 |
142 | 2. Basic Permissions.
143 |
144 | All rights granted under this License are granted for the term of
145 | copyright on the Program, and are irrevocable provided the stated
146 | conditions are met. This License explicitly affirms your unlimited
147 | permission to run the unmodified Program. The output from running a
148 | covered work is covered by this License only if the output, given its
149 | content, constitutes a covered work. This License acknowledges your
150 | rights of fair use or other equivalent, as provided by copyright law.
151 |
152 | You may make, run and propagate covered works that you do not
153 | convey, without conditions so long as your license otherwise remains
154 | in force. You may convey covered works to others for the sole purpose
155 | of having them make modifications exclusively for you, or provide you
156 | with facilities for running those works, provided that you comply with
157 | the terms of this License in conveying all material for which you do
158 | not control copyright. Those thus making or running the covered works
159 | for you must do so exclusively on your behalf, under your direction
160 | and control, on terms that prohibit them from making any copies of
161 | your copyrighted material outside their relationship with you.
162 |
163 | Conveying under any other circumstances is permitted solely under
164 | the conditions stated below. Sublicensing is not allowed; section 10
165 | makes it unnecessary.
166 |
167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168 |
169 | No covered work shall be deemed part of an effective technological
170 | measure under any applicable law fulfilling obligations under article
171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172 | similar laws prohibiting or restricting circumvention of such
173 | measures.
174 |
175 | When you convey a covered work, you waive any legal power to forbid
176 | circumvention of technological measures to the extent such circumvention
177 | is effected by exercising rights under this License with respect to
178 | the covered work, and you disclaim any intention to limit operation or
179 | modification of the work as a means of enforcing, against the work's
180 | users, your or third parties' legal rights to forbid circumvention of
181 | technological measures.
182 |
183 | 4. Conveying Verbatim Copies.
184 |
185 | You may convey verbatim copies of the Program's source code as you
186 | receive it, in any medium, provided that you conspicuously and
187 | appropriately publish on each copy an appropriate copyright notice;
188 | keep intact all notices stating that this License and any
189 | non-permissive terms added in accord with section 7 apply to the code;
190 | keep intact all notices of the absence of any warranty; and give all
191 | recipients a copy of this License along with the Program.
192 |
193 | You may charge any price or no price for each copy that you convey,
194 | and you may offer support or warranty protection for a fee.
195 |
196 | 5. Conveying Modified Source Versions.
197 |
198 | You may convey a work based on the Program, or the modifications to
199 | produce it from the Program, in the form of source code under the
200 | terms of section 4, provided that you also meet all of these conditions:
201 |
202 | a) The work must carry prominent notices stating that you modified
203 | it, and giving a relevant date.
204 |
205 | b) The work must carry prominent notices stating that it is
206 | released under this License and any conditions added under section
207 | 7. This requirement modifies the requirement in section 4 to
208 | "keep intact all notices".
209 |
210 | c) You must license the entire work, as a whole, under this
211 | License to anyone who comes into possession of a copy. This
212 | License will therefore apply, along with any applicable section 7
213 | additional terms, to the whole of the work, and all its parts,
214 | regardless of how they are packaged. This License gives no
215 | permission to license the work in any other way, but it does not
216 | invalidate such permission if you have separately received it.
217 |
218 | d) If the work has interactive user interfaces, each must display
219 | Appropriate Legal Notices; however, if the Program has interactive
220 | interfaces that do not display Appropriate Legal Notices, your
221 | work need not make them do so.
222 |
223 | A compilation of a covered work with other separate and independent
224 | works, which are not by their nature extensions of the covered work,
225 | and which are not combined with it such as to form a larger program,
226 | in or on a volume of a storage or distribution medium, is called an
227 | "aggregate" if the compilation and its resulting copyright are not
228 | used to limit the access or legal rights of the compilation's users
229 | beyond what the individual works permit. Inclusion of a covered work
230 | in an aggregate does not cause this License to apply to the other
231 | parts of the aggregate.
232 |
233 | 6. Conveying Non-Source Forms.
234 |
235 | You may convey a covered work in object code form under the terms
236 | of sections 4 and 5, provided that you also convey the
237 | machine-readable Corresponding Source under the terms of this License,
238 | in one of these ways:
239 |
240 | a) Convey the object code in, or embodied in, a physical product
241 | (including a physical distribution medium), accompanied by the
242 | Corresponding Source fixed on a durable physical medium
243 | customarily used for software interchange.
244 |
245 | b) Convey the object code in, or embodied in, a physical product
246 | (including a physical distribution medium), accompanied by a
247 | written offer, valid for at least three years and valid for as
248 | long as you offer spare parts or customer support for that product
249 | model, to give anyone who possesses the object code either (1) a
250 | copy of the Corresponding Source for all the software in the
251 | product that is covered by this License, on a durable physical
252 | medium customarily used for software interchange, for a price no
253 | more than your reasonable cost of physically performing this
254 | conveying of source, or (2) access to copy the
255 | Corresponding Source from a network server at no charge.
256 |
257 | c) Convey individual copies of the object code with a copy of the
258 | written offer to provide the Corresponding Source. This
259 | alternative is allowed only occasionally and noncommercially, and
260 | only if you received the object code with such an offer, in accord
261 | with subsection 6b.
262 |
263 | d) Convey the object code by offering access from a designated
264 | place (gratis or for a charge), and offer equivalent access to the
265 | Corresponding Source in the same way through the same place at no
266 | further charge. You need not require recipients to copy the
267 | Corresponding Source along with the object code. If the place to
268 | copy the object code is a network server, the Corresponding Source
269 | may be on a different server (operated by you or a third party)
270 | that supports equivalent copying facilities, provided you maintain
271 | clear directions next to the object code saying where to find the
272 | Corresponding Source. Regardless of what server hosts the
273 | Corresponding Source, you remain obligated to ensure that it is
274 | available for as long as needed to satisfy these requirements.
275 |
276 | e) Convey the object code using peer-to-peer transmission, provided
277 | you inform other peers where the object code and Corresponding
278 | Source of the work are being offered to the general public at no
279 | charge under subsection 6d.
280 |
281 | A separable portion of the object code, whose source code is excluded
282 | from the Corresponding Source as a System Library, need not be
283 | included in conveying the object code work.
284 |
285 | A "User Product" is either (1) a "consumer product", which means any
286 | tangible personal property which is normally used for personal, family,
287 | or household purposes, or (2) anything designed or sold for incorporation
288 | into a dwelling. In determining whether a product is a consumer product,
289 | doubtful cases shall be resolved in favor of coverage. For a particular
290 | product received by a particular user, "normally used" refers to a
291 | typical or common use of that class of product, regardless of the status
292 | of the particular user or of the way in which the particular user
293 | actually uses, or expects or is expected to use, the product. A product
294 | is a consumer product regardless of whether the product has substantial
295 | commercial, industrial or non-consumer uses, unless such uses represent
296 | the only significant mode of use of the product.
297 |
298 | "Installation Information" for a User Product means any methods,
299 | procedures, authorization keys, or other information required to install
300 | and execute modified versions of a covered work in that User Product from
301 | a modified version of its Corresponding Source. The information must
302 | suffice to ensure that the continued functioning of the modified object
303 | code is in no case prevented or interfered with solely because
304 | modification has been made.
305 |
306 | If you convey an object code work under this section in, or with, or
307 | specifically for use in, a User Product, and the conveying occurs as
308 | part of a transaction in which the right of possession and use of the
309 | User Product is transferred to the recipient in perpetuity or for a
310 | fixed term (regardless of how the transaction is characterized), the
311 | Corresponding Source conveyed under this section must be accompanied
312 | by the Installation Information. But this requirement does not apply
313 | if neither you nor any third party retains the ability to install
314 | modified object code on the User Product (for example, the work has
315 | been installed in ROM).
316 |
317 | The requirement to provide Installation Information does not include a
318 | requirement to continue to provide support service, warranty, or updates
319 | for a work that has been modified or installed by the recipient, or for
320 | the User Product in which it has been modified or installed. Access to a
321 | network may be denied when the modification itself materially and
322 | adversely affects the operation of the network or violates the rules and
323 | protocols for communication across the network.
324 |
325 | Corresponding Source conveyed, and Installation Information provided,
326 | in accord with this section must be in a format that is publicly
327 | documented (and with an implementation available to the public in
328 | source code form), and must require no special password or key for
329 | unpacking, reading or copying.
330 |
331 | 7. Additional Terms.
332 |
333 | "Additional permissions" are terms that supplement the terms of this
334 | License by making exceptions from one or more of its conditions.
335 | Additional permissions that are applicable to the entire Program shall
336 | be treated as though they were included in this License, to the extent
337 | that they are valid under applicable law. If additional permissions
338 | apply only to part of the Program, that part may be used separately
339 | under those permissions, but the entire Program remains governed by
340 | this License without regard to the additional permissions.
341 |
342 | When you convey a copy of a covered work, you may at your option
343 | remove any additional permissions from that copy, or from any part of
344 | it. (Additional permissions may be written to require their own
345 | removal in certain cases when you modify the work.) You may place
346 | additional permissions on material, added by you to a covered work,
347 | for which you have or can give appropriate copyright permission.
348 |
349 | Notwithstanding any other provision of this License, for material you
350 | add to a covered work, you may (if authorized by the copyright holders of
351 | that material) supplement the terms of this License with terms:
352 |
353 | a) Disclaiming warranty or limiting liability differently from the
354 | terms of sections 15 and 16 of this License; or
355 |
356 | b) Requiring preservation of specified reasonable legal notices or
357 | author attributions in that material or in the Appropriate Legal
358 | Notices displayed by works containing it; or
359 |
360 | c) Prohibiting misrepresentation of the origin of that material, or
361 | requiring that modified versions of such material be marked in
362 | reasonable ways as different from the original version; or
363 |
364 | d) Limiting the use for publicity purposes of names of licensors or
365 | authors of the material; or
366 |
367 | e) Declining to grant rights under trademark law for use of some
368 | trade names, trademarks, or service marks; or
369 |
370 | f) Requiring indemnification of licensors and authors of that
371 | material by anyone who conveys the material (or modified versions of
372 | it) with contractual assumptions of liability to the recipient, for
373 | any liability that these contractual assumptions directly impose on
374 | those licensors and authors.
375 |
376 | All other non-permissive additional terms are considered "further
377 | restrictions" within the meaning of section 10. If the Program as you
378 | received it, or any part of it, contains a notice stating that it is
379 | governed by this License along with a term that is a further
380 | restriction, you may remove that term. If a license document contains
381 | a further restriction but permits relicensing or conveying under this
382 | License, you may add to a covered work material governed by the terms
383 | of that license document, provided that the further restriction does
384 | not survive such relicensing or conveying.
385 |
386 | If you add terms to a covered work in accord with this section, you
387 | must place, in the relevant source files, a statement of the
388 | additional terms that apply to those files, or a notice indicating
389 | where to find the applicable terms.
390 |
391 | Additional terms, permissive or non-permissive, may be stated in the
392 | form of a separately written license, or stated as exceptions;
393 | the above requirements apply either way.
394 |
395 | 8. Termination.
396 |
397 | You may not propagate or modify a covered work except as expressly
398 | provided under this License. Any attempt otherwise to propagate or
399 | modify it is void, and will automatically terminate your rights under
400 | this License (including any patent licenses granted under the third
401 | paragraph of section 11).
402 |
403 | However, if you cease all violation of this License, then your
404 | license from a particular copyright holder is reinstated (a)
405 | provisionally, unless and until the copyright holder explicitly and
406 | finally terminates your license, and (b) permanently, if the copyright
407 | holder fails to notify you of the violation by some reasonable means
408 | prior to 60 days after the cessation.
409 |
410 | Moreover, your license from a particular copyright holder is
411 | reinstated permanently if the copyright holder notifies you of the
412 | violation by some reasonable means, this is the first time you have
413 | received notice of violation of this License (for any work) from that
414 | copyright holder, and you cure the violation prior to 30 days after
415 | your receipt of the notice.
416 |
417 | Termination of your rights under this section does not terminate the
418 | licenses of parties who have received copies or rights from you under
419 | this License. If your rights have been terminated and not permanently
420 | reinstated, you do not qualify to receive new licenses for the same
421 | material under section 10.
422 |
423 | 9. Acceptance Not Required for Having Copies.
424 |
425 | You are not required to accept this License in order to receive or
426 | run a copy of the Program. Ancillary propagation of a covered work
427 | occurring solely as a consequence of using peer-to-peer transmission
428 | to receive a copy likewise does not require acceptance. However,
429 | nothing other than this License grants you permission to propagate or
430 | modify any covered work. These actions infringe copyright if you do
431 | not accept this License. Therefore, by modifying or propagating a
432 | covered work, you indicate your acceptance of this License to do so.
433 |
434 | 10. Automatic Licensing of Downstream Recipients.
435 |
436 | Each time you convey a covered work, the recipient automatically
437 | receives a license from the original licensors, to run, modify and
438 | propagate that work, subject to this License. You are not responsible
439 | for enforcing compliance by third parties with this License.
440 |
441 | An "entity transaction" is a transaction transferring control of an
442 | organization, or substantially all assets of one, or subdividing an
443 | organization, or merging organizations. If propagation of a covered
444 | work results from an entity transaction, each party to that
445 | transaction who receives a copy of the work also receives whatever
446 | licenses to the work the party's predecessor in interest had or could
447 | give under the previous paragraph, plus a right to possession of the
448 | Corresponding Source of the work from the predecessor in interest, if
449 | the predecessor has it or can get it with reasonable efforts.
450 |
451 | You may not impose any further restrictions on the exercise of the
452 | rights granted or affirmed under this License. For example, you may
453 | not impose a license fee, royalty, or other charge for exercise of
454 | rights granted under this License, and you may not initiate litigation
455 | (including a cross-claim or counterclaim in a lawsuit) alleging that
456 | any patent claim is infringed by making, using, selling, offering for
457 | sale, or importing the Program or any portion of it.
458 |
459 | 11. Patents.
460 |
461 | A "contributor" is a copyright holder who authorizes use under this
462 | License of the Program or a work on which the Program is based. The
463 | work thus licensed is called the contributor's "contributor version".
464 |
465 | A contributor's "essential patent claims" are all patent claims
466 | owned or controlled by the contributor, whether already acquired or
467 | hereafter acquired, that would be infringed by some manner, permitted
468 | by this License, of making, using, or selling its contributor version,
469 | but do not include claims that would be infringed only as a
470 | consequence of further modification of the contributor version. For
471 | purposes of this definition, "control" includes the right to grant
472 | patent sublicenses in a manner consistent with the requirements of
473 | this License.
474 |
475 | Each contributor grants you a non-exclusive, worldwide, royalty-free
476 | patent license under the contributor's essential patent claims, to
477 | make, use, sell, offer for sale, import and otherwise run, modify and
478 | propagate the contents of its contributor version.
479 |
480 | In the following three paragraphs, a "patent license" is any express
481 | agreement or commitment, however denominated, not to enforce a patent
482 | (such as an express permission to practice a patent or covenant not to
483 | sue for patent infringement). To "grant" such a patent license to a
484 | party means to make such an agreement or commitment not to enforce a
485 | patent against the party.
486 |
487 | If you convey a covered work, knowingly relying on a patent license,
488 | and the Corresponding Source of the work is not available for anyone
489 | to copy, free of charge and under the terms of this License, through a
490 | publicly available network server or other readily accessible means,
491 | then you must either (1) cause the Corresponding Source to be so
492 | available, or (2) arrange to deprive yourself of the benefit of the
493 | patent license for this particular work, or (3) arrange, in a manner
494 | consistent with the requirements of this License, to extend the patent
495 | license to downstream recipients. "Knowingly relying" means you have
496 | actual knowledge that, but for the patent license, your conveying the
497 | covered work in a country, or your recipient's use of the covered work
498 | in a country, would infringe one or more identifiable patents in that
499 | country that you have reason to believe are valid.
500 |
501 | If, pursuant to or in connection with a single transaction or
502 | arrangement, you convey, or propagate by procuring conveyance of, a
503 | covered work, and grant a patent license to some of the parties
504 | receiving the covered work authorizing them to use, propagate, modify
505 | or convey a specific copy of the covered work, then the patent license
506 | you grant is automatically extended to all recipients of the covered
507 | work and works based on it.
508 |
509 | A patent license is "discriminatory" if it does not include within
510 | the scope of its coverage, prohibits the exercise of, or is
511 | conditioned on the non-exercise of one or more of the rights that are
512 | specifically granted under this License. You may not convey a covered
513 | work if you are a party to an arrangement with a third party that is
514 | in the business of distributing software, under which you make payment
515 | to the third party based on the extent of your activity of conveying
516 | the work, and under which the third party grants, to any of the
517 | parties who would receive the covered work from you, a discriminatory
518 | patent license (a) in connection with copies of the covered work
519 | conveyed by you (or copies made from those copies), or (b) primarily
520 | for and in connection with specific products or compilations that
521 | contain the covered work, unless you entered into that arrangement,
522 | or that patent license was granted, prior to 28 March 2007.
523 |
524 | Nothing in this License shall be construed as excluding or limiting
525 | any implied license or other defenses to infringement that may
526 | otherwise be available to you under applicable patent law.
527 |
528 | 12. No Surrender of Others' Freedom.
529 |
530 | If conditions are imposed on you (whether by court order, agreement or
531 | otherwise) that contradict the conditions of this License, they do not
532 | excuse you from the conditions of this License. If you cannot convey a
533 | covered work so as to satisfy simultaneously your obligations under this
534 | License and any other pertinent obligations, then as a consequence you may
535 | not convey it at all. For example, if you agree to terms that obligate you
536 | to collect a royalty for further conveying from those to whom you convey
537 | the Program, the only way you could satisfy both those terms and this
538 | License would be to refrain entirely from conveying the Program.
539 |
540 | 13. Remote Network Interaction; Use with the GNU General Public License.
541 |
542 | Notwithstanding any other provision of this License, if you modify the
543 | Program, your modified version must prominently offer all users
544 | interacting with it remotely through a computer network (if your version
545 | supports such interaction) an opportunity to receive the Corresponding
546 | Source of your version by providing access to the Corresponding Source
547 | from a network server at no charge, through some standard or customary
548 | means of facilitating copying of software. This Corresponding Source
549 | shall include the Corresponding Source for any work covered by version 3
550 | of the GNU General Public License that is incorporated pursuant to the
551 | following paragraph.
552 |
553 | Notwithstanding any other provision of this License, you have
554 | permission to link or combine any covered work with a work licensed
555 | under version 3 of the GNU General Public License into a single
556 | combined work, and to convey the resulting work. The terms of this
557 | License will continue to apply to the part which is the covered work,
558 | but the work with which it is combined will remain governed by version
559 | 3 of the GNU General Public License.
560 |
561 | 14. Revised Versions of this License.
562 |
563 | The Free Software Foundation may publish revised and/or new versions of
564 | the GNU Affero General Public License from time to time. Such new versions
565 | will be similar in spirit to the present version, but may differ in detail to
566 | address new problems or concerns.
567 |
568 | Each version is given a distinguishing version number. If the
569 | Program specifies that a certain numbered version of the GNU Affero General
570 | Public License "or any later version" applies to it, you have the
571 | option of following the terms and conditions either of that numbered
572 | version or of any later version published by the Free Software
573 | Foundation. If the Program does not specify a version number of the
574 | GNU Affero General Public License, you may choose any version ever published
575 | by the Free Software Foundation.
576 |
577 | If the Program specifies that a proxy can decide which future
578 | versions of the GNU Affero General Public License can be used, that proxy's
579 | public statement of acceptance of a version permanently authorizes you
580 | to choose that version for the Program.
581 |
582 | Later license versions may give you additional or different
583 | permissions. However, no additional obligations are imposed on any
584 | author or copyright holder as a result of your choosing to follow a
585 | later version.
586 |
587 | 15. Disclaimer of Warranty.
588 |
589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597 |
598 | 16. Limitation of Liability.
599 |
600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608 | SUCH DAMAGES.
609 |
610 | 17. Interpretation of Sections 15 and 16.
611 |
612 | If the disclaimer of warranty and limitation of liability provided
613 | above cannot be given local legal effect according to their terms,
614 | reviewing courts shall apply local law that most closely approximates
615 | an absolute waiver of all civil liability in connection with the
616 | Program, unless a warranty or assumption of liability accompanies a
617 | copy of the Program in return for a fee.
618 |
619 | END OF TERMS AND CONDITIONS
620 |
621 | How to Apply These Terms to Your New Programs
622 |
623 | If you develop a new program, and you want it to be of the greatest
624 | possible use to the public, the best way to achieve this is to make it
625 | free software which everyone can redistribute and change under these terms.
626 |
627 | To do so, attach the following notices to the program. It is safest
628 | to attach them to the start of each source file to most effectively
629 | state the exclusion of warranty; and each file should have at least
630 | the "copyright" line and a pointer to where the full notice is found.
631 |
632 |
633 | Copyright (C)
634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/CustomBertModel.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | import random
4 | from copy import deepcopy
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import re
9 | import copy
10 | import numpy as np
11 | import pkuseg
12 |
13 | from collections import defaultdict
14 | from typing import Union, Tuple, List, Mapping, Any, Dict
15 |
16 | from numpy.random import shuffle
17 | from transformers.data.data_collator import _torch_collate_batch, tolist, _numpy_collate_batch, \
18 | DataCollatorForLanguageModeling
19 | from transformers import BertTokenizer, BertTokenizerFast
20 | from pypinyin import pinyin, Style
21 |
22 |
23 | class DataCollatorForMultiMask(DataCollatorForLanguageModeling):
24 | """
25 | Data collator used for language modeling that masks entire words.
26 |
27 | - collates batches of tensors, honoring their tokenizer's pad_token
28 | - preprocesses batches for masked language modeling
29 |
30 |
31 |
32 | This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
33 | that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
34 | produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
35 |
36 | """
37 |
38 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
39 | if isinstance(examples[0], Mapping):
40 | input_ids = [e["input_ids"] for e in examples]
41 | else:
42 | input_ids = examples
43 | examples = [{"input_ids": e} for e in examples]
44 |
45 | batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
46 |
47 | mask_labels = []
48 | for e in examples:
49 | ref_tokens = []
50 | for id in tolist(e["input_ids"]):
51 | token = self.tokenizer._convert_id_to_token(id)
52 | if id == 0:
53 | continue
54 | ref_tokens.append(token)
55 |
56 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
57 | if "chinese_ref" in e:
58 | ref_pos = tolist(e["chinese_ref"])
59 | len_seq = len(e["input_ids"])
60 | for i in range(len_seq):
61 | if i in ref_pos:
62 | ref_tokens[i] = "##" + ref_tokens[i]
63 | mask_labels.append(self._whole_word_mask(ref_tokens))
64 | batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
65 | inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
66 | attention_mask = (inputs != self.tokenizer.pad_token_id).long()
67 | return {"input_ids": inputs, "labels": labels, 'attention_mask': attention_mask}
68 |
69 | def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
70 | if isinstance(examples[0], Mapping):
71 | input_ids = [e["input_ids"] for e in examples]
72 | else:
73 | input_ids = examples
74 | examples = [{"input_ids": e} for e in examples]
75 |
76 | batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
77 |
78 | mask_labels = []
79 | for e in examples:
80 | ref_tokens = []
81 | for id in tolist(e["input_ids"]):
82 | token = self.tokenizer._convert_id_to_token(id)
83 | ref_tokens.append(token)
84 |
85 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
86 | if "chinese_ref" in e:
87 | ref_pos = tolist(e["chinese_ref"])
88 | len_seq = len(e["input_ids"])
89 | for i in range(len_seq):
90 | if i in ref_pos:
91 | ref_tokens[i] = "##" + ref_tokens[i]
92 | mask_labels.append(self._whole_word_mask(ref_tokens))
93 | batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
94 | inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
95 | return {"input_ids": inputs, "labels": labels}
96 |
97 | def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
98 | """
99 | Get 0/1 labels for masked tokens with whole word mask proxy
100 | """
101 | if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
102 | warnings.warn(
103 | "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
104 | "Please refer to the documentation for more information."
105 | )
106 |
107 | cand_indexes = []
108 | for i, token in enumerate(input_tokens):
109 | if token == "[CLS]" or token == "[SEP]":
110 | continue
111 |
112 | if len(cand_indexes) >= 1 and token.startswith("##"):
113 | cand_indexes[-1].append(i)
114 | else:
115 | cand_indexes.append([i])
116 |
117 | weighted_cand_indexes = cand_indexes.copy()
118 | # 我们想让模型在进行mask操作时,更多的选择词组,而不是单个汉字,因此增加词组的权重
119 | for val in cand_indexes:
120 | # 如果是一个词组,则增加一个该词组的拷贝进入列表以增加权重
121 | if 2 <= len(val) <= 3:
122 | weighted_cand_indexes.append(val.copy())
123 |
124 | random.shuffle(weighted_cand_indexes)
125 | num_to_predict = min(max_predictions, max(2, int(round(len(input_tokens) * self.mlm_probability))))
126 | masked_lms = []
127 | covered_indexes = set()
128 | for index_set in weighted_cand_indexes:
129 | if len(masked_lms) >= num_to_predict:
130 | break
131 | # If adding a whole-word mask would exceed the maximum number of
132 | # predictions, then just skip this candidate.
133 | if len(masked_lms) + len(index_set) > num_to_predict:
134 | continue
135 | if index_set[0] - 1 in masked_lms or index_set[-1] + 1 in masked_lms:
136 | flag = random.random() > 0.5
137 | if flag:
138 | continue
139 | is_any_index_covered = False
140 | for index in index_set:
141 | if index in covered_indexes:
142 | is_any_index_covered = True
143 | break
144 | if is_any_index_covered:
145 | continue
146 | for index in index_set:
147 | covered_indexes.add(index)
148 | masked_lms.append(index)
149 |
150 | if len(covered_indexes) != len(masked_lms):
151 | raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
152 | mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
153 | return mask_labels
154 |
155 | def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
156 | # inputs [batch_size, max_len]
157 | # mask_labels [batch_size, max_len]
158 | """
159 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
160 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
161 | """
162 | import torch
163 |
164 | if self.tokenizer.mask_token is None:
165 | raise ValueError(
166 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
167 | " --mlm flag if you want to use this tokenizer."
168 | )
169 | labels = inputs.clone()
170 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
171 |
172 | probability_matrix = mask_labels
173 |
174 | special_tokens_mask = [
175 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
176 | ]
177 | try:
178 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
179 | except Exception:
180 | print(special_tokens_mask)
181 | if self.tokenizer._pad_token is not None:
182 | padding_mask = labels.eq(self.tokenizer.pad_token_id)
183 | probability_matrix.masked_fill_(padding_mask, value=0.0)
184 |
185 | masked_indices = probability_matrix.bool()
186 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
187 |
188 | # 90% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
189 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 1.0)).bool() & masked_indices
190 |
191 | # print(replaced_index)
192 | for index in range(len(inputs)):
193 | replaced_index = [i for i, v in enumerate(indices_replaced[index]) if v]
194 | # print(replaced_index)
195 | # 修改mask逻辑,使用多种mask替代原有的单一[MASK]标签,并剔除无效[MASK]如标点符号
196 | indices_total = torch.nonzero(labels[index] != -100, as_tuple=True)[0].tolist()
197 | values = labels[index][indices_total]
198 | # temp = indices_total.copy()
199 | indices_total = [x for i, x in enumerate(indices_total) if not ((200 <= values[i] <= 209)
200 | or (345 <= values[i] <= 532)
201 | or (106 <= values[i] <= 120)
202 | or (131 <= values[i] <= 142)
203 | or values[i] == 8024)]
204 | values = labels[index][indices_total]
205 | indices_total = torch.tensor(indices_total)
206 | words = []
207 | pattern = re.compile(r'[\u4e00-\u9fff]')
208 | for i, val in enumerate(indices_total):
209 | if i == 0:
210 | words.append([values[i]])
211 | continue
212 | # print(self.tokenizer.decode(values[i]).replace('#', '').replace(' ', ''))
213 | temp_word = self.tokenizer.decode(values[i]).replace('#', '').replace(' ', '')
214 | if str.isdigit(temp_word) or temp_word == '°' or str.isascii(temp_word):
215 | words.append([values[i]])
216 | continue
217 | if indices_total[i] == indices_total[i - 1] + 1 and pattern.search(temp_word):
218 | words[-1].append(values[i])
219 | else:
220 | words.append([values[i]])
221 |
222 | # [LETTER_A] ~ [LETTER_Z] = 1 ~ 26
223 | # [NUM] = 28
224 | # [SPECIAL] = 29
225 | mask_ids = []
226 | for word in words:
227 | word = self.tokenizer.decode(word).replace(' ', '').replace('#', '')
228 | first_letter = pinyin(word, style=Style.FIRST_LETTER)
229 | for letter in first_letter:
230 | if str.islower(letter[0][0]):
231 | if 1 <= (ord(letter[0][0]) - 96) <= 26:
232 | mask_ids.append(ord(letter[0][0]) - 96)
233 | else:
234 | mask_ids.append(29)
235 | else:
236 | mask_ids.append(28)
237 |
238 | try:
239 | replaced_mask = []
240 | replaced_mask_index = []
241 | for i, val in enumerate(indices_total):
242 | if val in replaced_index:
243 | replaced_mask_index.append(val.tolist())
244 | # if i >= len(mask_ids):
245 | # continue
246 | replaced_mask.append(mask_ids[i])
247 |
248 | if len(replaced_mask_index) != 0:
249 | inputs[index][replaced_mask_index] = torch.tensor(replaced_mask)
250 | except Exception:
251 | print(mask_ids)
252 | print(indices_total)
253 |
254 | return inputs, labels
255 |
256 | def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
257 | """
258 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
259 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
260 | """
261 | if self.tokenizer.mask_token is None:
262 | raise ValueError(
263 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
264 | " --mlm flag if you want to use this tokenizer."
265 | )
266 | labels = np.copy(inputs)
267 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
268 |
269 | masked_indices = mask_labels.astype(bool)
270 |
271 | special_tokens_mask = [
272 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
273 | ]
274 | masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
275 | if self.tokenizer._pad_token is not None:
276 | padding_mask = labels == self.tokenizer.pad_token_id
277 | masked_indices[padding_mask] = 0
278 |
279 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
280 |
281 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
282 | indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
283 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
284 |
285 | # 10% of the time, we replace masked input tokens with random word
286 | # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
287 | indices_random = (
288 | np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
289 | )
290 | random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
291 | inputs[indices_random] = random_words[indices_random]
292 |
293 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
294 | return inputs, labels
295 |
296 |
297 | class DataCollatorForDetector(DataCollatorForLanguageModeling):
298 | """
299 | Data collator used for language modeling that masks entire words.
300 |
301 | - collates batches of tensors, honoring their tokenizer's pad_token
302 | - preprocesses batches for masked language modeling
303 |
304 |
305 |
306 | This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
307 | that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
308 | produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
309 |
310 | """
311 |
312 | def __random_token(self, current_token: int):
313 | random_token = current_token
314 | while random_token == current_token or random_token == 27:
315 | random_token = random.randint(1, 29)
316 | return random_token
317 |
318 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
319 | if isinstance(examples[0], Mapping):
320 | input_ids = [e["input_ids"] for e in examples]
321 | else:
322 | input_ids = examples
323 | examples = [{"input_ids": e} for e in examples]
324 |
325 | batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
326 |
327 | mask_labels = []
328 | for e in examples:
329 | ref_tokens = []
330 | for id in tolist(e["input_ids"]):
331 | token = self.tokenizer._convert_id_to_token(id)
332 | if id == 0:
333 | continue
334 | ref_tokens.append(token)
335 |
336 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
337 | if "chinese_ref" in e:
338 | ref_pos = tolist(e["chinese_ref"])
339 | len_seq = len(e["input_ids"])
340 | for i in range(len_seq):
341 | if i in ref_pos:
342 | ref_tokens[i] = "##" + ref_tokens[i]
343 | mask_labels.append(self._whole_word_mask(ref_tokens))
344 | batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
345 | inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
346 | attention_mask = (inputs != self.tokenizer.pad_token_id).long()
347 | return {"input_ids": inputs, "labels": labels, 'attention_mask': attention_mask}
348 |
349 | def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
350 | if isinstance(examples[0], Mapping):
351 | input_ids = [e["input_ids"] for e in examples]
352 | else:
353 | input_ids = examples
354 | examples = [{"input_ids": e} for e in examples]
355 |
356 | batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
357 |
358 | mask_labels = []
359 | for e in examples:
360 | ref_tokens = []
361 | for id in tolist(e["input_ids"]):
362 | token = self.tokenizer._convert_id_to_token(id)
363 | ref_tokens.append(token)
364 |
365 | # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
366 | if "chinese_ref" in e:
367 | ref_pos = tolist(e["chinese_ref"])
368 | len_seq = len(e["input_ids"])
369 | for i in range(len_seq):
370 | if i in ref_pos:
371 | ref_tokens[i] = "##" + ref_tokens[i]
372 | mask_labels.append(self._whole_word_mask(ref_tokens))
373 | batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
374 | inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
375 | return {"input_ids": inputs, "labels": labels}
376 |
377 | def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
378 | """
379 | Get 0/1 labels for masked tokens with whole word mask proxy
380 | """
381 | if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
382 | warnings.warn(
383 | "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
384 | "Please refer to the documentation for more information."
385 | )
386 |
387 | cand_indexes = []
388 | for i, token in enumerate(input_tokens):
389 | if token == "[CLS]" or token == "[SEP]":
390 | continue
391 |
392 | if len(cand_indexes) >= 1 and token.startswith("##"):
393 | cand_indexes[-1].append(i)
394 | else:
395 | cand_indexes.append([i])
396 |
397 | weighted_cand_indexes = cand_indexes.copy()
398 | # 我们想让模型在进行mask操作时,更多的选择词组,而不是单个汉字,因此增加词组的权重
399 | for val in cand_indexes:
400 | # 如果是一个词组,则增加一个该词组的拷贝进入列表以增加权重
401 | if 2 <= len(val) <= 3:
402 | weighted_cand_indexes.append(val.copy())
403 | weighted_cand_indexes.append(val.copy())
404 |
405 | random.shuffle(weighted_cand_indexes)
406 | num_to_predict = min(max_predictions, max(2, int(round(len(input_tokens) * self.mlm_probability))))
407 | masked_lms = []
408 | covered_indexes = set()
409 | for index_set in weighted_cand_indexes:
410 | if len(masked_lms) >= num_to_predict:
411 | break
412 | # If adding a whole-word mask would exceed the maximum number of
413 | # predictions, then just skip this candidate.
414 | if len(masked_lms) + len(index_set) > num_to_predict:
415 | continue
416 | if index_set[0] - 1 in masked_lms or index_set[-1] + 1 in masked_lms:
417 | flag = random.random() > 0.5
418 | if flag:
419 | continue
420 | is_any_index_covered = False
421 | for index in index_set:
422 | if index in covered_indexes:
423 | is_any_index_covered = True
424 | break
425 | if is_any_index_covered:
426 | continue
427 | for index in index_set:
428 | covered_indexes.add(index)
429 | masked_lms.append(index)
430 |
431 | if len(covered_indexes) != len(masked_lms):
432 | raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
433 | mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
434 | return mask_labels
435 |
436 | def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
437 | # inputs [batch_size, max_len]
438 | # mask_labels [batch_size, max_len]
439 | """
440 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
441 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
442 | """
443 | import torch
444 |
445 | if self.tokenizer.mask_token is None:
446 | raise ValueError(
447 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
448 | " --mlm flag if you want to use this tokenizer."
449 | )
450 | labels = inputs.clone()
451 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
452 |
453 | probability_matrix = mask_labels
454 |
455 | special_tokens_mask = [
456 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
457 | ]
458 | try:
459 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
460 | except Exception:
461 | print(special_tokens_mask)
462 | if self.tokenizer._pad_token is not None:
463 | padding_mask = labels.eq(self.tokenizer.pad_token_id)
464 | probability_matrix.masked_fill_(padding_mask, value=0.0)
465 |
466 | masked_indices = probability_matrix.bool()
467 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
468 |
469 | # 100% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
470 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 1.0)).bool() & masked_indices
471 |
472 | # print(replaced_index)
473 | for index in range(len(inputs)):
474 | replaced_index = [i for i, v in enumerate(indices_replaced[index]) if v]
475 | # print(replaced_index)
476 | # 修改mask逻辑,使用多种mask替代原有的单一[MASK]标签,并剔除无效[MASK]如标点符号
477 | indices_total = torch.nonzero(labels[index] != -100, as_tuple=True)[0].tolist()
478 | values = labels[index][indices_total]
479 | # temp = indices_total.copy()
480 | indices_total = [x for i, x in enumerate(indices_total) if not ((200 <= values[i] <= 209)
481 | or (345 <= values[i] <= 532)
482 | or (106 <= values[i] <= 120)
483 | or (131 <= values[i] <= 142)
484 | or values[i] == 8024)]
485 | values = labels[index][indices_total]
486 | indices_total = torch.tensor(indices_total)
487 | words = []
488 | pattern = re.compile(r'[\u4e00-\u9fff]')
489 | for i, val in enumerate(indices_total):
490 | if i == 0:
491 | words.append([values[i]])
492 | continue
493 | # print(self.tokenizer.decode(values[i]).replace('#', '').replace(' ', ''))
494 | temp_word = self.tokenizer.decode(values[i]).replace('#', '').replace(' ', '')
495 | if str.isdigit(temp_word) or temp_word == '°' or str.isascii(temp_word):
496 | words.append([values[i]])
497 | continue
498 | if indices_total[i] == indices_total[i - 1] + 1 and pattern.search(temp_word):
499 | words[-1].append(values[i])
500 | else:
501 | words.append([values[i]])
502 |
503 | # [LETTER_A] ~ [LETTER_Z] = 1 ~ 26
504 | # [NUM] = 28
505 | # [SPECIAL] = 29
506 | mask_ids = []
507 | for word in words:
508 | word = self.tokenizer.decode(word).replace(' ', '').replace('#', '')
509 | first_letter = pinyin(word, style=Style.FIRST_LETTER)
510 | for letter in first_letter:
511 | if str.islower(letter[0][0]):
512 | if 1 <= (ord(letter[0][0]) - 96) <= 26:
513 | mask_ids.append(ord(letter[0][0]) - 96)
514 | else:
515 | mask_ids.append(29)
516 | else:
517 | mask_ids.append(28)
518 |
519 | try:
520 | replaced_mask = []
521 | replaced_mask_index = []
522 | for i, val in enumerate(indices_total):
523 | if val in replaced_index:
524 | replaced_mask_index.append(val.tolist())
525 | # if i >= len(mask_ids):
526 | # continue
527 | replaced_mask.append(mask_ids[i])
528 |
529 | labels[index] = torch.full(labels[index].shape, 2)
530 | if len(replaced_mask_index) != 0:
531 | word_indices = []
532 | for i, val in enumerate(replaced_mask_index):
533 | if i == 0:
534 | word_indices.append([val])
535 | continue
536 | if val == word_indices[-1][-1] + 1:
537 | word_indices[-1].append(val)
538 | else:
539 | word_indices.append([val])
540 |
541 | random_replaced_mask = deepcopy(replaced_mask)
542 | for word in word_indices:
543 | if random.random() < 0.5:
544 | for i, val in enumerate(word):
545 | random_replaced_mask[i] = self.__random_token(random_replaced_mask[i])
546 | labels[index][word] = torch.ones_like(torch.tensor(word))
547 | else:
548 | labels[index][word] = torch.zeros_like(torch.tensor(word))
549 | inputs[index][replaced_mask_index] = torch.tensor(random_replaced_mask)
550 | except Exception:
551 | print("error")
552 | print(mask_ids)
553 | print(indices_total)
554 | # print(inputs, labels)
555 | return inputs, labels
556 |
557 |
558 | def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
559 | """
560 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
561 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
562 | """
563 | if self.tokenizer.mask_token is None:
564 | raise ValueError(
565 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
566 | " --mlm flag if you want to use this tokenizer."
567 | )
568 | labels = np.copy(inputs)
569 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
570 |
571 | masked_indices = mask_labels.astype(bool)
572 |
573 | special_tokens_mask = [
574 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
575 | ]
576 | masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
577 | if self.tokenizer._pad_token is not None:
578 | padding_mask = labels == self.tokenizer.pad_token_id
579 | masked_indices[padding_mask] = 0
580 |
581 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
582 |
583 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
584 | indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
585 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
586 |
587 | # 10% of the time, we replace masked input tokens with random word
588 | # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
589 | indices_random = (
590 | np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
591 | )
592 | random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
593 | inputs[indices_random] = random_words[indices_random]
594 |
595 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
596 | return inputs, labels
597 |
598 |
599 | __seg = pkuseg.pkuseg()
600 |
601 | def predict(sentence: str,
602 | predict_word: str,
603 | model,
604 | tokenizer,
605 | top_k=10,
606 | beam_size=16,
607 | threshold=0.000,
608 | fast_mode=True,
609 | strict_mode=True):
610 | replaced_word = []
611 | for letter in predict_word:
612 | if str.isdigit(letter):
613 | replaced_word.append(27)
614 | continue
615 | id = ord(letter) - 96
616 | if 1 <= id <= 26:
617 | replaced_word.append(id)
618 | else:
619 | replaced_word.append(28)
620 |
621 | inputs = tokenizer(sentence, max_length=64,
622 | padding='max_length',
623 | truncation=True,
624 | return_tensors='pt').to('cuda')
625 | index = sentence.find(predict_word)
626 | length = len(predict_word)
627 |
628 | try:
629 | inputs['input_ids'][0][index + 1:index + 1 + length] = torch.tensor(replaced_word).to('cuda')
630 | except Exception:
631 | print('error')
632 | return []
633 |
634 | results_a = []
635 | results_b = []
636 | __beam_search(results_a,
637 | beam_size=beam_size,
638 | max_depth=length,
639 | tokenizer=tokenizer,
640 | inputs=copy.deepcopy(inputs),
641 | model=model,
642 | top_k=top_k,
643 | threshold=threshold,
644 | index=index,
645 | strict_mode=strict_mode)
646 |
647 | if not fast_mode:
648 | __beam_search_back(results_b,
649 | beam_size=beam_size,
650 | max_depth=length,
651 | tokenizer=tokenizer,
652 | inputs=copy.deepcopy(inputs),
653 | model=model,
654 | top_k=top_k,
655 | threshold=threshold,
656 | index=index + length,
657 | strict_mode=strict_mode)
658 | result_dict = defaultdict(int)
659 | for val in results_a + results_b:
660 | key = val[0]
661 | value = val[1]
662 | result_dict[key] += value
663 | results = [[key, val] for key, val in result_dict.items()]
664 | results.sort(key=lambda x: x[1], reverse=True)
665 | return results
666 |
667 |
668 | def backtrack_predict(sentence: str,
669 | predict_word: str,
670 | model,
671 | tokenizer,
672 | top_k=10,
673 | fast_mode=True,
674 | strict_mode=True):
675 | replaced_word = []
676 | for letter in predict_word:
677 | if str.isdigit(letter):
678 | replaced_word.append(27)
679 | continue
680 | id = ord(letter) - 96
681 | if 1 <= id <= 26:
682 | replaced_word.append(id)
683 | else:
684 | replaced_word.append(28)
685 |
686 | inputs = tokenizer(sentence, max_length=64,
687 | padding='max_length',
688 | truncation=True,
689 | return_tensors='pt').to('cuda')
690 | index = sentence.find(predict_word)
691 | length = len(predict_word)
692 |
693 | try:
694 | inputs['input_ids'][0][index + 1:index + 1 + length] = torch.tensor(replaced_word).to('cuda')
695 | except Exception:
696 | print('error')
697 | return []
698 |
699 | results_a = []
700 | results_b = []
701 | __fixed_dfs(results=results_a,
702 | depth=length-1,
703 | probability=1.0,
704 | tokenizer=tokenizer,
705 | sentence=[],
706 | inputs=copy.deepcopy(inputs),
707 | model=model,
708 | index=index,
709 | top_k=top_k,
710 | strict_mode=strict_mode)
711 |
712 | if not fast_mode:
713 | __fixed_dfs_back(results=results_b,
714 | depth=length-1,
715 | probability=1.0,
716 | tokenizer=tokenizer,
717 | sentence=[],
718 | inputs=copy.deepcopy(inputs),
719 | model=model,
720 | index=index + length,
721 | top_k=top_k,
722 | strict_mode=strict_mode)
723 | result_dict = defaultdict(int)
724 | for val in results_a + results_b:
725 | key = val[0]
726 | value = val[1]
727 | result_dict[key] += value
728 | results = [[key, val] for key, val in result_dict.items()]
729 | results.sort(key=lambda x: x[1], reverse=True)
730 | return results
731 |
732 |
733 | def __fixed_dfs(results: list,
734 | depth: int,
735 | probability: float,
736 | tokenizer,
737 | sentence: list,
738 | inputs: str,
739 | model,
740 | index: int,
741 | top_k: int=5,
742 | strict_mode=True):
743 |
744 | with torch.no_grad():
745 | logits = model(**inputs).logits
746 |
747 | # retrieve index of [MASK]
748 | mask_token_index = torch.where((inputs['input_ids'] == 103) |
749 | ((inputs['input_ids'] >= 1) &
750 | (inputs['input_ids'] <= 28)))[1].tolist()
751 | mask_token_logits = logits[0, mask_token_index, :]
752 | mask_token_probs = F.softmax(mask_token_logits, dim=-1)
753 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1)
754 | token = tokenizer.decode(top_k_tokens[0, 0:top_k]).split(' ')
755 | for i in range(len(token)):
756 | sentence.append(token[i])
757 | prob = top_k_probs[0, i].item()
758 | new_probability = probability * prob
759 |
760 | if depth == 0:
761 | if new_probability >= 0.00:
762 | if not strict_mode or len(__seg.cut(''.join(sentence))) <= max(len(sentence) - 1, 1):
763 | results.append([''.join(sentence), new_probability])
764 | else:
765 | original_value = torch.clone(inputs['input_ids'][0][index + 1])
766 | inputs['input_ids'][0][index + 1] = top_k_tokens[0, i]
767 | __fixed_dfs(results=results,
768 | depth=depth-1,
769 | probability=new_probability,
770 | tokenizer=tokenizer,
771 | sentence=sentence,
772 | inputs=inputs,
773 | model=model,
774 | index=index+1,
775 | top_k=top_k-2)
776 | inputs['input_ids'][0][index + 1] = original_value
777 | sentence.pop()
778 |
779 |
780 |
781 | def __fixed_dfs_back(results: list,
782 | depth: int,
783 | probability: float,
784 | tokenizer,
785 | sentence: list,
786 | inputs: str,
787 | model,
788 | index: int,
789 | top_k: int=5,
790 | strict_mode=True):
791 |
792 | with torch.no_grad():
793 | logits = model(**inputs).logits
794 |
795 | # retrieve index of [MASK]
796 | mask_token_index = []
797 | temp = torch.nonzero(inputs['input_ids'] == 103, as_tuple=True)
798 | if len(temp[0]) > 0:
799 | mask_token_index.extend(temp[1].tolist())
800 | for i in range(28):
801 | temp = torch.nonzero(inputs['input_ids'] == (i + 1), as_tuple=True)
802 | if len(temp[0]) > 0:
803 | mask_token_index.extend(temp[1].tolist())
804 | mask_token_index = sorted(mask_token_index)
805 | mask_token_logits = logits[0, mask_token_index, :]
806 | mask_token_probs = F.softmax(mask_token_logits, dim=-1)
807 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1)
808 | token = tokenizer.decode(top_k_tokens[-1, 0:top_k]).split(' ')
809 |
810 | for i in range(len(token)):
811 | sentence.insert(0, token[i])
812 | prob = top_k_probs[-1, i].item()
813 | new_probability = probability * prob
814 | if depth == 0:
815 | if new_probability >= 0.00:
816 | if not strict_mode or len(__seg.cut(''.join(sentence))) <= max(len(sentence) - 1, 1):
817 | results.append([''.join(sentence), new_probability])
818 | else:
819 | inputs['input_ids'][0][index] = top_k_tokens[-1, i]
820 | __fixed_dfs_back(results=results,
821 | depth=depth-1,
822 | probability=new_probability,
823 | tokenizer=tokenizer,
824 | sentence=sentence,
825 | inputs=copy.deepcopy(inputs),
826 | model=model,
827 | index=index-1,
828 | top_k=top_k-2)
829 | sentence.pop(0)
830 |
831 | def __beam_search_back(results: list,
832 | beam_size: int,
833 | max_depth: int,
834 | tokenizer,
835 | inputs: str,
836 | model,
837 | top_k: int=5,
838 | threshold: float=0.01,
839 | index: int=0,
840 | strict_mode=True):
841 | beams = [[inputs, [], 1.0, index]]
842 |
843 | for depth in range(max_depth):
844 | new_beams = []
845 |
846 | for inputs, sentence, probability, index in beams:
847 | with torch.no_grad():
848 | logits = model(**inputs).logits
849 |
850 | mask_token_index = torch.where(
851 | (inputs['input_ids'] == 103) |
852 | ((inputs['input_ids'] >= 1) & (inputs['input_ids'] <= 28))
853 | )[1].tolist()
854 |
855 | if not mask_token_index:
856 | continue
857 |
858 | mask_token_logits = logits[0, mask_token_index, :]
859 | mask_token_probs = F.softmax(mask_token_logits, dim=-1)
860 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1)
861 |
862 | tokens = tokenizer.decode(top_k_tokens[-1, 0:top_k]).split(' ')
863 | for i in range(len(tokens)):
864 | new_probability = probability * top_k_probs[-1, i].item()
865 |
866 | if new_probability < threshold:
867 | continue
868 |
869 | new_inputs = copy.deepcopy(inputs)
870 | new_inputs['input_ids'][0][index] = top_k_tokens[-1, i]
871 |
872 | new_sentence = sentence.copy()
873 | new_sentence.insert(0, tokens[i])
874 | new_beams.append([new_inputs, new_sentence, new_probability, index - 1])
875 | if len(new_beams) > 0:
876 | softmax(new_beams)
877 | new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:beam_size]
878 |
879 | if not new_beams:
880 | break
881 |
882 | if top_k > 2:
883 | top_k -= 2
884 | beams = new_beams
885 |
886 | for _, sentence, probability, _ in beams:
887 | sentence = ''.join(sentence)
888 | if not strict_mode or len(__seg.cut(sentence)) <= max(len(sentence) - 1, 1):
889 | results.append((sentence, probability))
890 |
891 | return results
892 |
893 | def __beam_search(results: list,
894 | beam_size: int,
895 | max_depth: int,
896 | tokenizer,
897 | inputs: str,
898 | model,
899 | top_k: int=5,
900 | threshold: float=0.01,
901 | index: int=0,
902 | strict_mode=True):
903 | beams = [[inputs, [], 1.0, index]]
904 |
905 | for depth in range(max_depth):
906 | new_beams = []
907 |
908 | for inputs, sentence, probability, index in beams:
909 | with torch.no_grad():
910 | logits = model(**inputs).logits
911 |
912 | mask_token_index = torch.where(
913 | (inputs['input_ids'] == 103) |
914 | ((inputs['input_ids'] >= 1) & (inputs['input_ids'] <= 28))
915 | )[1].tolist()
916 |
917 | if not mask_token_index:
918 | continue
919 |
920 | mask_token_logits = logits[0, mask_token_index, :]
921 | mask_token_probs = F.softmax(mask_token_logits, dim=-1)
922 | top_k_probs, top_k_tokens = torch.topk(mask_token_probs, top_k, dim=1)
923 |
924 | tokens = tokenizer.decode(top_k_tokens[0, 0:top_k]).split(' ')
925 | for i in range(len(tokens)):
926 | new_probability = probability * top_k_probs[0, i].item()
927 |
928 | if new_probability < threshold:
929 | continue
930 |
931 | new_inputs = copy.deepcopy(inputs)
932 | new_inputs['input_ids'][0][index + 1] = top_k_tokens[0, i]
933 |
934 | new_sentence = sentence + [tokens[i]]
935 | new_beams.append([new_inputs, new_sentence, new_probability, index + 1])
936 | if len(new_beams) > 0:
937 | softmax(new_beams)
938 | new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:beam_size]
939 |
940 | if not new_beams:
941 | break
942 |
943 | if top_k > 2:
944 | top_k -= 2
945 | beams = new_beams
946 |
947 | for _, sentence, probability, _ in beams:
948 | sentence = ''.join(sentence)
949 | if not strict_mode or len(__seg.cut(sentence)) <= max(len(sentence) - 1, 1):
950 | results.append((sentence, probability))
951 |
952 | return results
953 |
954 | def softmax(beams: List):
955 | column_2 = [row[2] for row in beams]
956 |
957 | max_val = max(column_2)
958 | exp_values = [math.exp(v - max_val) for v in column_2]
959 | sum_exp = sum(exp_values)
960 | softmax_column = [v / sum_exp for v in exp_values]
961 |
962 | for i, row in enumerate(beams):
963 | row[2] = softmax_column[i]
964 |
965 | import Levenshtein
966 |
967 | def word_level_predict(sentence: str,
968 | predict_word: str,
969 | model,
970 | tokenizer,
971 | top_k=10,
972 | beam_size=16,
973 | threshold=0.000,
974 | fast_mode=True,
975 | strict_mode=True):
976 | if predict_word.isascii():
977 | return predict(sentence=sentence, predict_word=predict_word, model=model, tokenizer=tokenizer, top_k=top_k, threshold=threshold, fast_mode=fast_mode, beam_size=beam_size, strict_mode=strict_mode)
978 | abbr_pinyin = []
979 | full_pinyin = []
980 | for character in predict_word:
981 | if character.isascii():
982 | abbr_pinyin.append(character)
983 | full_pinyin.append(character)
984 | else:
985 | abbr_pinyin.append(pinyin(character, Style.FIRST_LETTER)[0][0])
986 | full_pinyin.append(pinyin(character, Style.NORMAL)[0][0])
987 | abbr_pinyin = ''.join(abbr_pinyin)
988 | full_pinyin = ''.join(full_pinyin)
989 | sentence = sentence.replace(predict_word, abbr_pinyin)
990 | result = predict(sentence=sentence, predict_word=abbr_pinyin, model=model, tokenizer=tokenizer, top_k=top_k, threshold=threshold, fast_mode=fast_mode, beam_size=beam_size, strict_mode=strict_mode)
991 |
992 | if fast_mode:
993 | return result
994 | else:
995 | for val in result:
996 | word_pinyin = []
997 | for temp in pinyin(val[0], Style.NORMAL):
998 | word_pinyin.append(temp[0])
999 | word_pinyin = ''.join(word_pinyin)
1000 | val.append(Levenshtein.ratio(word_pinyin, full_pinyin))
1001 | result.sort(key=lambda x: (x[2], x[1]), reverse=True)
1002 | return result
1003 |
--------------------------------------------------------------------------------