├── model
├── ernie_onnx
│ └── pun_model.onnx
└── pun_models_pytorch
│ └── config.json
├── data
└── iwslt2012_zh.rar
├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── punctuation_prediction.iml
├── README.md
└── onnx_infer.py
/model/ernie_onnx/pun_model.onnx:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/iwslt2012_zh.rar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/punctuation_prediction/HEAD/data/iwslt2012_zh.rar
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/punctuation_prediction.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/model/pun_models_pytorch/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "intermediate_size": 3072,
7 | "initializer_range": 0.02,
8 | "max_position_embeddings": 2048,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 6,
11 | "task_type_vocab_size": 16,
12 | "type_vocab_size": 4,
13 | "use_task_id": true,
14 | "vocab_size": 40000,
15 | "layer_norm_eps": 1e-05,
16 | "model_type": "ernie",
17 | "architectures": [
18 | "ErnieForMaskedLM"
19 | ]
20 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #### 中文句子标点符号预测
2 |
3 | 对一个没有标点符号的句子预测标点,主要预测逗号、句号以及问号(,。?)
4 |
5 | ###### 给句子添加标点符号
6 |
7 | 请下载模型 [pun_model.onnx],将模型放入model/ernie_onnx目录下。
8 |
9 | 链接:https://pan.baidu.com/s/1l62YmuU3giNPkT2TonZRKA
10 | 提取码:sy12
11 |
12 | ```
13 | def onnx_infer(sess, tokenizer, sent):
14 | tokenized_tokens = tokenizer(sent)
15 | input_ids = np.array([tokenized_tokens['input_ids']], dtype=np.int64)
16 | token_type_ids = np.array([tokenized_tokens['token_type_ids']], dtype=np.int64)
17 | result = sess.run(
18 | output_names=None,
19 | input_feed={"input_ids": input_ids,
20 | "token_type_ids": token_type_ids}
21 | )[0]
22 | return result, input_ids
23 | ```
24 |
25 | 输出结果:
26 | ```
27 | sent: 从小我有个梦想这个梦想是我想当一个科学家 -> result: 从小我有个梦想,这个梦想是我想当一个科学家。
28 | ------------------------------------------------
29 | sent: 中国的首都是北京我爱我的祖国 -> result: 中国的首都是北京。我爱我的祖国。
30 | ------------------------------------------------
31 | sent: 早上起来穿衣吃饭后我就上学了在路上碰见了许久不见的一个朋友 -> result: 早上起来,穿衣吃饭后,我就上学了,在路上碰见了许久不见的一个朋友。
32 |
33 | ```
--------------------------------------------------------------------------------
/onnx_infer.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from transformers import BertTokenizer
4 | import numpy as np
5 | import onnx
6 | import onnxruntime as ort
7 |
8 | def onnx_infer(sess, tokenizer, sent):
9 | tokenized_tokens = tokenizer(sent)
10 | input_ids = np.array([tokenized_tokens['input_ids']], dtype=np.int64)
11 | token_type_ids = np.array([tokenized_tokens['token_type_ids']], dtype=np.int64)
12 | result = sess.run(
13 | output_names=None,
14 | input_feed={"input_ids": input_ids,
15 | "token_type_ids": token_type_ids}
16 | )[0]
17 | return result, input_ids
18 |
19 | def clean_text(text, punc_list):
20 | text = text.lower()
21 | text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
22 | text = re.sub(f'[{"".join([p for p in punc_list][1:])}]', '', text)
23 | return text
24 |
25 | # 后处理识别结果
26 | def post_process(tokenizer, input_ids, result, punc_list):
27 | seq_len = len(input_ids[0])
28 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:seq_len - 1])
29 | labels = result[1:seq_len - 1].tolist()
30 | assert len(tokens) == len(labels)
31 | text = ''
32 | for t, l in zip(tokens, labels):
33 | text += t
34 | if l != 0:
35 | text += punc_list[l]
36 | return text
37 |
38 | if __name__ == '__main__':
39 | # load onnx model
40 | onnx_path = 'D:\\project\\pycharm_workspace\\punctuation_prediction\\model\\ernie_onnx\\pun_model.onnx'
41 | model = onnx.load(onnx_path)
42 | sess = ort.InferenceSession(bytes(model.SerializeToString()))
43 | # load tokenizer
44 | tokenizer = BertTokenizer.from_pretrained("D:\\project\\pycharm_workspace\\punctuation_prediction\\model\\pun_models_pytorch")
45 | # punc
46 | punc_list = []
47 | punc_list.append('')
48 | punc_list.append(',')
49 | punc_list.append('。')
50 | punc_list.append('?')
51 |
52 | sent = '从小我有个梦想这个梦想是我想当一个科学家'
53 | sent = '中国的首都是北京我爱我的祖国'
54 | sent = '早上起来穿衣吃饭后我就上学了在路上碰见了许久不见的一个朋友'
55 | sent = clean_text(sent, punc_list)
56 | result, input_ids = onnx_infer(sess, tokenizer, sent)
57 | text = post_process(tokenizer, input_ids, result, punc_list)
58 | print(text)
59 |
60 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
66 |
67 |
68 |
--------------------------------------------------------------------------------