├── model ├── ernie_onnx │ └── pun_model.onnx └── pun_models_pytorch │ └── config.json ├── data └── iwslt2012_zh.rar ├── .idea ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── punctuation_prediction.iml ├── README.md └── onnx_infer.py /model/ernie_onnx/pun_model.onnx: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/iwslt2012_zh.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/punctuation_prediction/HEAD/data/iwslt2012_zh.rar -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/punctuation_prediction.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /model/pun_models_pytorch/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "intermediate_size": 3072, 7 | "initializer_range": 0.02, 8 | "max_position_embeddings": 2048, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 6, 11 | "task_type_vocab_size": 16, 12 | "type_vocab_size": 4, 13 | "use_task_id": true, 14 | "vocab_size": 40000, 15 | "layer_norm_eps": 1e-05, 16 | "model_type": "ernie", 17 | "architectures": [ 18 | "ErnieForMaskedLM" 19 | ] 20 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #### 中文句子标点符号预测 2 | 3 | 对一个没有标点符号的句子预测标点,主要预测逗号、句号以及问号(,。?) 4 | 5 | ###### 给句子添加标点符号 6 | 7 | 请下载模型 [pun_model.onnx],将模型放入model/ernie_onnx目录下。 8 | 9 | 链接:https://pan.baidu.com/s/1l62YmuU3giNPkT2TonZRKA 10 | 提取码:sy12 11 | 12 | ``` 13 | def onnx_infer(sess, tokenizer, sent): 14 | tokenized_tokens = tokenizer(sent) 15 | input_ids = np.array([tokenized_tokens['input_ids']], dtype=np.int64) 16 | token_type_ids = np.array([tokenized_tokens['token_type_ids']], dtype=np.int64) 17 | result = sess.run( 18 | output_names=None, 19 | input_feed={"input_ids": input_ids, 20 | "token_type_ids": token_type_ids} 21 | )[0] 22 | return result, input_ids 23 | ``` 24 | 25 | 输出结果: 26 | ``` 27 | sent: 从小我有个梦想这个梦想是我想当一个科学家 -> result: 从小我有个梦想,这个梦想是我想当一个科学家。 28 | ------------------------------------------------ 29 | sent: 中国的首都是北京我爱我的祖国 -> result: 中国的首都是北京。我爱我的祖国。 30 | ------------------------------------------------ 31 | sent: 早上起来穿衣吃饭后我就上学了在路上碰见了许久不见的一个朋友 -> result: 早上起来,穿衣吃饭后,我就上学了,在路上碰见了许久不见的一个朋友。 32 | 33 | ``` -------------------------------------------------------------------------------- /onnx_infer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from transformers import BertTokenizer 4 | import numpy as np 5 | import onnx 6 | import onnxruntime as ort 7 | 8 | def onnx_infer(sess, tokenizer, sent): 9 | tokenized_tokens = tokenizer(sent) 10 | input_ids = np.array([tokenized_tokens['input_ids']], dtype=np.int64) 11 | token_type_ids = np.array([tokenized_tokens['token_type_ids']], dtype=np.int64) 12 | result = sess.run( 13 | output_names=None, 14 | input_feed={"input_ids": input_ids, 15 | "token_type_ids": token_type_ids} 16 | )[0] 17 | return result, input_ids 18 | 19 | def clean_text(text, punc_list): 20 | text = text.lower() 21 | text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text) 22 | text = re.sub(f'[{"".join([p for p in punc_list][1:])}]', '', text) 23 | return text 24 | 25 | # 后处理识别结果 26 | def post_process(tokenizer, input_ids, result, punc_list): 27 | seq_len = len(input_ids[0]) 28 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0][1:seq_len - 1]) 29 | labels = result[1:seq_len - 1].tolist() 30 | assert len(tokens) == len(labels) 31 | text = '' 32 | for t, l in zip(tokens, labels): 33 | text += t 34 | if l != 0: 35 | text += punc_list[l] 36 | return text 37 | 38 | if __name__ == '__main__': 39 | # load onnx model 40 | onnx_path = 'D:\\project\\pycharm_workspace\\punctuation_prediction\\model\\ernie_onnx\\pun_model.onnx' 41 | model = onnx.load(onnx_path) 42 | sess = ort.InferenceSession(bytes(model.SerializeToString())) 43 | # load tokenizer 44 | tokenizer = BertTokenizer.from_pretrained("D:\\project\\pycharm_workspace\\punctuation_prediction\\model\\pun_models_pytorch") 45 | # punc 46 | punc_list = [] 47 | punc_list.append('') 48 | punc_list.append(',') 49 | punc_list.append('。') 50 | punc_list.append('?') 51 | 52 | sent = '从小我有个梦想这个梦想是我想当一个科学家' 53 | sent = '中国的首都是北京我爱我的祖国' 54 | sent = '早上起来穿衣吃饭后我就上学了在路上碰见了许久不见的一个朋友' 55 | sent = clean_text(sent, punc_list) 56 | result, input_ids = onnx_infer(sess, tokenizer, sent) 57 | text = post_process(tokenizer, input_ids, result, punc_list) 58 | print(text) 59 | 60 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 68 | --------------------------------------------------------------------------------