├── .gitignore ├── README.md ├── configs ├── __init__.py ├── bert_config.yml ├── bilstm_config.yml ├── ccks2017.yml ├── ccks2019.yml ├── config.yml └── confighelper.py ├── dataset ├── conll.py ├── embedding.py ├── preprocess │ ├── __init__.py │ ├── preprocess_ccks2017.py │ └── preprocess_ccks2019.py ├── processor.py └── utils.py ├── db.sqlite3 ├── main.py ├── manage.py ├── model ├── base.py ├── bert.py ├── bilstm.py └── crf.py ├── port ├── settings.py ├── urls.py ├── view.py └── wsgi.py ├── requirements.txt ├── train.sh ├── train ├── device.py ├── eval.py ├── optimizer.py ├── plot.py ├── pretrain.py └── trainer.py └── utils └── datautils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | datasets/ 3 | logs/ 4 | output/ 5 | papers/ 6 | *.pdf 7 | 8 | *.pyc 9 | .ipynb_checkpoints 10 | .idea/ 11 | runs/ 12 | .vscode/ 13 | tmp/ 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PytorchNER_zh 2 | 这是我ehr-journey项目的一个命名实体识别的子项目,主要实现基于中文预训练字向量finetune的Bert与BiLSTM模型的网络。演示使用了CCKS2019task1数据集,并实现了django接口。 3 | 4 | 本人只是个NLP的新手,目前只是个玩具,欢迎大佬们指正。 5 | 6 | 7 | 8 | ## 1 项目依赖 9 | 10 | - numpy==1.16.4 11 | - gensim==3.8.0 12 | - pytorch-transformers==1.1.0 13 | - torch==1.1.0 14 | - TorchSnooper==0.7 15 | - Django==2.0.5 16 | - scikit-learn==0.21.3 17 | - tqdm==4.23.4 18 | 19 | 20 | 21 | ## 2 训练数据集 22 | 23 | CCKS2019任务一:面向中文电子病历的命名实体识别 [数据](http://openkg.cn/dataset/yiducloud-ccks2019task1) 24 | 25 | 26 | 27 | ## 3 模型说明 28 | 29 | 共实现了BiLSTM, BiLSTMCRF, Bert, BertCRF, BertBiLSTMCRF 30 | 31 | - Bert部分参考了[pytorch_transformers](https://github.com/huggingface/pytorch-transformers),预训练模型为[中文预训练BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) 32 | - BiLSTM的预训练词向量使用的是 word2vec的预训练词向量(Baidu Encyclopedia 百度百科 + Word + Character + Ngram 300d ) 可在 [Chinese Word Vectors 中文词向量](https://github.com/Embedding/Chinese-Word-Vectors) 下载 33 | 34 | - CRF模型部分参考了[SLTK](https://github.com/liu-nlper/SLTK) 35 | 36 | 37 | 38 | ### 模型训练 39 | 40 | 参数配置在`configs`下,下载数据集和预训练模型和词向量,放在指定位置,修改参数,运行 `train.sh` 41 | 42 | 43 | 44 | ### 模型预测 45 | 46 | 训练完成后,配置`config.yml`内的`model_class`参数,运行 47 | 48 | ```shell 49 | python main.py --task eval 50 | ``` 51 | 52 | 53 | 54 | ## 4 模型效果 55 | 56 | | labels/F1-score | BiLSTM | BiLSTMCRF | Bert | BertCRF | BertBiLSTMCRF | support | 57 | | --------------- | ------ | --------- | ---- | -------- | ------------- | ------- | 58 | | O | 1.00 | 1.00 | 0.95 | 1.00 | **1.00** | 386687 | 59 | | B-LABCHECK | 0.86 | 0.86 | 0.40 | 0.86 | **0.90** | 227 | 60 | | I-LABCHECK | 0.87 | 0.87 | 0.45 | 0.91 | **0.93** | 692 | 61 | | B-PICCHECK | 0.83 | 0.83 | 0.35 | 0.86 | **0.87** | 185 | 62 | | I-PICCHECK | 0.84 | 0.85 | 0.32 | **0.90** | **0.90** | 525 | 63 | | B-SURGERY | 0.79 | **0.86** | 0.32 | **0.86** | **0.86** | 225 | 64 | | I-SURGERY | 0.93 | 0.93 | 0.43 | 0.95 | **0.96** | 2386 | 65 | | B-DISEASE | 0.81 | 0.83 | 0.34 | 0.85 | **0.86** | 814 | 66 | | I-DISEASE | 0.83 | 0.84 | 0.36 | 0.85 | **0.87** | 5306 | 67 | | B-DRUGS | 0.88 | 0.90 | 0.31 | 0.94 | **0.95** | 354 | 68 | | I-DRUGS | 0.89 | 0.91 | 0.40 | **0.94** | **0.94** | 954 | 69 | | B-ANABODY | 0.88 | 0.88 | 0.73 | **0.91** | 0.90 | 1636 | 70 | | I-ANABODY | 0.82 | 0.85 | 0.44 | **0.86** | **0.86** | 2697 | 71 | | macro avg | 0.86 | 0.88 | 0.44 | 0.90 | **0.91** | 402688 | 72 | 73 | 74 | 75 | ## 5 接口 76 | 77 | 运行方式 78 | 79 | ```shell 80 | python manage.py runserver 0.0.0.0:8000 81 | ``` 82 | 83 | 传入数据 84 | 85 | ```json 86 | [{"sentence":"入院后完善相关辅助检查,给予口服活血止痛、调节血压药物及物理治疗,患者血脂异常,补充诊断:混合性高脂血症,给予调节血脂药物治疗;患者诉心慌、无力,急查心电图提示:心房颤动,ST段改变。急请内科会诊,考虑为:1.冠心病 不稳定型心绞痛 心律失常 室性期前收缩 房性期前收缩 心房颤动;2.高血压病3级 极高危组。给予处理:1.急查心肌酶学、离子,定期复查心电图;2.给予持续心电、血压、血氧监测3.给予吸氧、西地兰0.2mg加5%葡萄糖注射液15ml稀释后缓慢静推,给予硝酸甘油10mg加入5%葡萄糖注射液500ml以5~10ugmin缓慢静点,继续口服阿司匹林100mg日一次,辛伐他汀20mg日一次,硝酸异山梨酯10mg日三次口服,稳心颗粒1袋日三次,美托洛尔12.5mg日二次,非洛地平5mg日一次治疗,患者病情好转出院。","model_class":["BertBiLSTMCRF"],"dataset": "CCKS2019"}] 87 | ``` 88 | 89 | 输出结果 90 | 91 | ```json 92 | '{"手术": [], "影像检查": ["心电图", "心电图"], "解剖部位": ["心"], "疾病和诊断": ["混合性高脂血症", "心房颤动", "t段", "冠心病", "不稳定型心绞痛", "心律失常", "室性期前收缩", "房性期前收缩", "心房颤动", "高血压病3级极高危组"], "实验室检验": [], "药物": ["西地兰", "葡萄糖", "硝酸甘油", "葡萄糖", "阿司匹林", "辛伐他汀", "硝酸异山梨酯", "稳心颗粒", "美托洛尔", "非洛地平"]}' 93 | ``` 94 | 95 | > 本段数据sentence来源于CCKS2017 96 | 97 | 98 | 99 | ## 6 总结 100 | 101 | 1. 本项目针对中文电子病例命名实体任务,实现了一个基于Bert和Bilstm的命名实体识别模型并实现了对应接口 102 | 103 | 104 | 105 | ### 后续工作 106 | 107 | 1. 观察接口部分的输出结果,可以看出由于数据与模型训练数据来源于不同医院所以效果上有所下降,后续可以扩大训练数据量,使用自适应的训练看是否能提升模型表现。 108 | 2. 观察Bert模型的结果,训练过程中可能还存在某些问题,后续将逐步修改。 109 | 3. 考虑引入医疗知识图谱嵌入的思路,并使用同义词、同类词对目前的数据做数据增强观察效果。 -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kenshinpg/PytorchNER_zh/85c46c40a5482029130f65e2fdf8e003729dac24/configs/__init__.py -------------------------------------------------------------------------------- /configs/bert_config.yml: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # coding: utf-8 5 | 6 | 7 | # global 8 | name: bert-pytorch 9 | model_class: [Bert, BertCRF, BertBiLSTMCRF] 10 | 11 | # path 12 | data_dir: ./data/ 13 | finetune_model_dir: ./tmp/models/ 14 | pretrained_model_dir: ../pretrained/chinese_wwm_ext_pytorch 15 | output_dir: ./tmp/outputs/ 16 | 17 | # pretrain 18 | use_pretrained_embedding: True 19 | pretrain_embed_file: ../pretrained/sgns.baidubaike.bigram-char/sgns.baidubaike.bigram-char 20 | pretrain_embed_pkl: ./tmp/outputs/pretrain_word_embeddings.pkl 21 | requires_grad: True 22 | 23 | # device 24 | use_cuda: True 25 | gpu_memory: 8192 26 | 27 | # model 28 | dropout_rate: 0.2 29 | random_seed: 1301 30 | max_seq_length: 256 31 | lower_case: True 32 | 33 | # optimizer 34 | optimizer_type: adamw 35 | learning_rate: !!float 5e-5 36 | warmup_proportion: 0.1 37 | max_grad_norm: 1.0 38 | weight_decay: 0.01 39 | l2_rate: 1.0e-8 40 | momentum: 0. 41 | lr_decay: 0.05 42 | 43 | # epoch 44 | batch_size: 12 45 | nb_epoch: 15 46 | save_checkpoint: False 47 | average_batch: False 48 | 49 | # early stopping 50 | max_patience: 5 51 | 52 | # rnn 53 | bert_embedding: 768 54 | rnn_hidden: 100 55 | rnn_layers: 1 56 | 57 | # eval 58 | evalset: test -------------------------------------------------------------------------------- /configs/bilstm_config.yml: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # global 5 | name: bert-pytorch 6 | model_class: [BiLSTM, BiLSTMCRF] 7 | 8 | # path 9 | data_dir: ./data/ 10 | finetune_model_dir: ./tmp/models/ 11 | pretrained_model_dir: ../pretrained/chinese_wwm_ext_pytorch 12 | output_dir: ./tmp/outputs/ 13 | 14 | # pretrain 15 | use_pretrained_embedding: True 16 | pretrain_embed_file: ../pretrained/sgns.baidubaike.bigram-char/sgns.baidubaike.bigram-char 17 | pretrain_embed_pkl: ./tmp/outputs/pretrain_word_embeddings.pkl 18 | requires_grad: True 19 | 20 | # device 21 | use_cuda: True 22 | gpu_memory: 8192 23 | 24 | # model 25 | dropout_rate: 0.2 26 | random_seed: 1301 27 | max_seq_length: 256 28 | lower_case: True 29 | 30 | # optimizer 31 | optimizer_type: adamw 32 | learning_rate: !!float 1e-3 33 | warmup_proportion: 0.1 34 | max_grad_norm: 1.0 35 | weight_decay: 0.01 36 | l2_rate: 1.0e-8 37 | momentum: 0. 38 | lr_decay: 0.05 39 | 40 | # epoch 41 | # batch_size: 12 42 | batch_size: 128 43 | nb_epoch: 100 44 | save_checkpoint: False 45 | average_batch: False 46 | 47 | # early stopping 48 | max_patience: 5 49 | 50 | # rnn 51 | bert_embedding: 768 52 | rnn_hidden: 100 53 | rnn_layers: 1 54 | 55 | # eval 56 | evalset: test -------------------------------------------------------------------------------- /configs/ccks2017.yml: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | dataset: ccks2017 5 | 6 | # 数据相关参数 7 | data_path: ./data/ccks2017/ 8 | txt_path: ../datasets/ccks2017/ 9 | label_file: ./data/ccks2017/conll.txt 10 | 11 | label: 12 | class2label: 13 | 检查和检验: CHECK 14 | 症状和体征: SIGNS 15 | 疾病和诊断: DISEASE 16 | 治疗: TREATMENT 17 | 身体部位: BODY 18 | label2id: 19 | O: 1 20 | B-CHECK: 2 21 | I-CHECK: 3 22 | B-SIGNS: 4 23 | I-SIGNS: 5 24 | B-DISEASE: 6 25 | I-DISEASE: 7 26 | B-TREATMENT: 8 27 | I-TREATMENT: 9 28 | B-BODY: 10 29 | I-BODY: 11 -------------------------------------------------------------------------------- /configs/ccks2019.yml: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | dataset: ccks2019 5 | 6 | # 数据相关参数 7 | data_path: ./data/ccks2019/ 8 | txt_path: ../datasets/ccks2019/ 9 | label_file: ./data/ccks2019/conll.txt 10 | 11 | label: 12 | class2label: 13 | 实验室检验: LABCHECK 14 | 影像检查: PICCHECK 15 | 手术: SURGERY 16 | 疾病和诊断: DISEASE 17 | 药物: DRUGS 18 | 解剖部位: ANABODY 19 | label2id: 20 | O: 1 21 | B-LABCHECK: 2 22 | I-LABCHECK: 3 23 | B-PICCHECK: 4 24 | I-PICCHECK: 5 25 | B-SURGERY: 6 26 | I-SURGERY: 7 27 | B-DISEASE: 8 28 | I-DISEASE: 9 29 | B-DRUGS: 10 30 | I-DRUGS: 11 31 | B-ANABODY: 12 32 | I-ANABODY: 13 -------------------------------------------------------------------------------- /configs/config.yml: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | 5 | # global 6 | name: bert-pytorch 7 | model_class: [Bert, BertCRF, BertBiLSTMCRF] 8 | 9 | # path 10 | data_dir: ./data/ 11 | finetune_model_dir: ./tmp/models/ 12 | pretrained_model_dir: D:/DataWarehouse/pretrain/chinese_wwm_ext_pytorch 13 | output_dir: ./tmp/outputs/ 14 | 15 | # pretrain 16 | use_pretrained_embedding: True 17 | pretrain_embed_file: D:/DataWarehouse/pretrain/sgns.baidubaike.bigram-char/sgns.baidubaike.bigram-char 18 | pretrain_embed_pkl: ./tmp/outputs/pretrain_word_embeddings.pkl 19 | requires_grad: True 20 | 21 | # device 22 | use_cuda: True 23 | gpu_memory: 8192 24 | 25 | # model 26 | dropout_rate: 0.2 27 | random_seed: 1301 28 | max_seq_length: 256 29 | lower_case: True 30 | 31 | # optimizer 32 | optimizer_type: adamw 33 | learning_rate: !!float 5e-5 34 | # learning_rate: !!float 1e-3 35 | warmup_proportion: 0.1 36 | max_grad_norm: 1.0 37 | weight_decay: 0.01 38 | l2_rate: 1.0e-8 39 | momentum: 0. 40 | lr_decay: 0.05 41 | 42 | # epoch 43 | batch_size: 12 44 | # batch_size: 128 45 | nb_epoch: 15 46 | save_checkpoint: False 47 | average_batch: False 48 | 49 | # early stopping 50 | max_patience: 5 51 | 52 | # rnn 53 | bert_embedding: 768 54 | rnn_hidden: 100 55 | rnn_layers: 1 56 | 57 | # eval 58 | evalset: test -------------------------------------------------------------------------------- /configs/confighelper.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import yaml 6 | import codecs 7 | import argparse 8 | 9 | def read_yml(yml_file): 10 | return yaml.load(codecs.open(yml_file, encoding = 'utf-8')) 11 | 12 | def config_loader(config_file = './configs/cner.yml'): 13 | return read_yml(config_file) 14 | 15 | # def config_update(args, configs): 16 | 17 | 18 | 19 | def args_parser(): 20 | # start parser 21 | parser = argparse.ArgumentParser() 22 | # required parameters 23 | parser.add_argument("--config_path", default='./configs/config.yml', type=str) 24 | parser.add_argument("--dataset", default='CCKS2019', type=str, help="dataset name") 25 | parser.add_argument("--task", default='eval', type=str, help = "task type, train/eval/conll") 26 | # parser.add_argument("--output_dir", default=None, 27 | # type=str, required=True, help="the outptu directory where the model predictions and checkpoints will") 28 | 29 | # # other parameters 30 | # parser.add_argument("--use_cuda", type=bool, default=True) 31 | # parser.add_argument("--max_len_limit", default=100, 32 | # type=int, help="the maximum total input sequence length after ") 33 | 34 | # parser.add_argument("--hidden_dim", default=100, type=int) 35 | # parser.add_argument("--num_rnn_layers", default=1, type=int) 36 | # parser.add_argument("--bi_flag", default = True, type = bool) 37 | 38 | # parser.add_argument("--batch_size", default=32, type=int) 39 | # parser.add_argument("--average_batch", default=False, type=bool) 40 | # parser.add_argument("--optimizer", default='sgd', type=str) 41 | # parser.add_argument("--use_pretrained_embedding", default=True, type=bool) 42 | # parser.add_argument("--max_patience", default=50, type=int) 43 | 44 | # parser.add_argument("--test_batch_size", default=8, type=int) 45 | # parser.add_argument("--learning_rate", default=0.015, type=float) 46 | # parser.add_argument("--nb_epoch", default=1000, type=float) 47 | # parser.add_argument("--random_seed", type=int, default=1301) 48 | # parser.add_argument("--export_model", type=bool, default=True) 49 | # parser.add_argument("--output_dir", type=str, default='./output/') 50 | 51 | args = parser.parse_args() 52 | 53 | # os.makedirs(args.output_dir, exist_ok=True) 54 | 55 | return args 56 | -------------------------------------------------------------------------------- /dataset/conll.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import math 6 | import random 7 | import numpy as np 8 | 9 | def get_conll_data(data_file): 10 | data = [] 11 | with open(data_file, 'r', encoding = 'utf-8') as f: 12 | chars = [] 13 | labels = [] 14 | for line in f: 15 | line = line.rstrip().split('\t') 16 | if not line: 17 | continue 18 | char = line[0] 19 | if not char: 20 | continue 21 | label = line[-1] 22 | chars.append(char) 23 | labels.append(label) 24 | if char in ['。','?','!','!','?']: 25 | data.append([chars, labels]) 26 | chars = [] 27 | labels = [] 28 | return data 29 | 30 | def shuffle(data_ids, random_seed = 1301): 31 | random.seed(random_seed) 32 | random.shuffle(data_ids) 33 | 34 | def conll_to_train_test_dev(data_file, output_dir, random_seed = 1301, validation = 0.2): 35 | data = get_conll_data(data_file) 36 | data_count = len(data) 37 | data_ids = list(range(data_count)) 38 | 39 | dev_div = validation 40 | test_div = validation 41 | train_div = 1 - dev_div - test_div 42 | 43 | # 因为数据集可能会出现按顺序分布不均的情况,所以需要对data_ids进行shuffle 44 | # 比如ccks2019数据集就有病史特点,出院情况,一般项目,诊疗经过 45 | shuffle(data_ids, random_seed) 46 | 47 | train_ids = data_ids[:math.ceil(data_count * train_div)] 48 | dev_ids = data_ids[math.ceil(data_count * train_div):math.ceil(data_count * (1 - test_div))] 49 | test_ids = data_ids[math.ceil(data_count * (1 - test_div)):] 50 | train_data = np.array(data)[train_ids] 51 | dev_data = np.array(data)[dev_ids] 52 | test_data = np.array(data)[test_ids] 53 | 54 | with open(os.path.join(output_dir, 'train.txt'), 'w', encoding = 'utf-8') as f: 55 | for index, (chars, labels) in enumerate(train_data): 56 | for char, label in zip(chars, labels): 57 | f.write(char + '\t' + label + '\n') 58 | f.close() 59 | with open(os.path.join(output_dir, 'dev.txt'), 'w', encoding = 'utf-8') as f: 60 | for index, (chars, labels) in enumerate(dev_data): 61 | for char, label in zip(chars, labels): 62 | f.write(char + '\t' + label + '\n') 63 | f.close() 64 | with open(os.path.join(output_dir, 'test.txt'), 'w', encoding = 'utf-8') as f: 65 | for index, (chars, labels) in enumerate(test_data): 66 | for char, label in zip(chars, labels): 67 | f.write(char + '\t' + label + '\n') 68 | f.close() 69 | 70 | def oov_to_vocab(tok, tokenizer): 71 | if tok in tokenizer.vocab: 72 | return tok 73 | elif tok.lower() in tokenizer.vocab: 74 | return tok.lower() 75 | elif tok == '“': 76 | return '"' 77 | elif tok == '”': 78 | return '"' 79 | else: 80 | return '[UNK]' -------------------------------------------------------------------------------- /dataset/embedding.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import numpy as np 6 | from gensim.test.utils import datapath, get_tmpfile 7 | from gensim.models import KeyedVectors 8 | from gensim.scripts.glove2word2vec import glove2word2vec 9 | 10 | from configs.confighelper import config_loader 11 | from utils.datautils import pickle_load, pickle_save 12 | 13 | def glove_to_gensim(glove_embed, gensim_embed): 14 | """ 15 | 将glove的embedding转化为用gensim可读取的格式 16 | glove_embed: stanford官网下载的glove.6B.txt 17 | new_word2vec_embed: 可用gensim读取的embedding_file 18 | """ 19 | # 输入文件 20 | glove_file = datapath(glove_embed) 21 | # 输出文件 22 | tmp_file = get_tmpfile(gensim_embed) 23 | # 开始转换 24 | glove2word2vec(glove_file, tmp_file) 25 | 26 | # # 加载转化后的文件 27 | # model = KeyedVectors.load_word2vec_format(tmp_file) 28 | 29 | def load_embed_with_gensim(path_embed, binary = False): 30 | """ 31 | 读取预训练的embedding 32 | binary = True 二进制embedding 33 | """ 34 | return KeyedVectors.load_word2vec_format(path_embed, binary=binary) 35 | 36 | def get_pretrained_embedding(pretrain_embed_file, pretrain_embed_pkl): 37 | if os.path.exists(pretrain_embed_pkl): 38 | word_vectors = pickle_load(pretrain_embed_pkl) 39 | else: 40 | word_vectors = load_embed_with_gensim(pretrain_embed_file) 41 | pickle_save(word_vectors, pretrain_embed_pkl) 42 | return word_vectors 43 | 44 | # def get_stoi_from_tokenizer(tokenizer): 45 | # """ 46 | # 从BertTokenizer得到word2id_dict 47 | # """ 48 | 49 | 50 | # return word2id_dict 51 | 52 | def build_word_embed(tokenizer, pretrain_embed_file, pretrain_embed_pkl, seed=1301): 53 | """ 54 | 从预训练的文件中构建word embedding表 55 | Args: 56 | tokenizer: BertTokenizer 57 | Returns: 58 | word_embed_table: np.array, shape=[word_count, embed_dim] 59 | match_count: int, 匹配的词数 60 | unknown_count: int, 未匹配的词数 61 | """ 62 | word_vectors = get_pretrained_embedding(pretrain_embed_file, pretrain_embed_pkl) 63 | word_dim = word_vectors.vector_size 64 | word_count = tokenizer.vocab_size # 0 is for padding value 65 | np.random.seed(seed) 66 | scope = np.sqrt(3. / word_dim) 67 | word_embed_table = np.random.uniform( 68 | -scope, scope, size=(word_count, word_dim)).astype('float32') 69 | # match_count, unknown_count = 0, 0 70 | # for word in word2id_dict: 71 | # if word in word_vectors.vocab: 72 | # word_embed_table[word2id_dict[word]] = word_vectors[word] 73 | # match_count += 1 74 | # else: 75 | # unknown_count += 1 76 | # total_count = match_count + unknown_count 77 | # print('\tmatch: {0} / {1}'.format(match_count, total_count)) 78 | # print('\tOOV: {0} / {1}'.format(unknown_count, total_count)) 79 | return word_embed_table 80 | -------------------------------------------------------------------------------- /dataset/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | from .preprocess_ccks2019 import CCKS2019NER 4 | from .preprocess_ccks2017 import CCKS2017NER 5 | -------------------------------------------------------------------------------- /dataset/preprocess/preprocess_ccks2017.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import re 5 | import os 6 | import math 7 | import random 8 | import numpy as np 9 | 10 | from collections import Counter 11 | from utils.datautils import check_dir 12 | 13 | class CCKS2017NER(object): 14 | 15 | def __init__(self, configs, vocab_size = None, min_freq = 1, random_seed = 1301): 16 | 17 | self.vocab_size = vocab_size 18 | self.min_freq = min_freq 19 | self.random_seed = random_seed 20 | self.trainfile = configs['label_file'] 21 | self.configs = configs 22 | self.origin_datapath = os.path.join(configs['txt_path'], 'origin_data') 23 | check_dir(configs['data_path']) 24 | 25 | self.txtloader() 26 | self.label_dict = configs['label']['label2id'] 27 | self.class_dict = configs['label']['class2label'] 28 | self.txtnerlabel() 29 | self.nervocab() 30 | self.get_raw_data() 31 | self.get_train_data() 32 | 33 | self.data_count = len(self.data) 34 | self.data_ids = list(range(self.data_count)) 35 | self.train_test_split() 36 | 37 | def txtloader(self): 38 | self.originalText = [] 39 | self.entities = [] 40 | for root,dirs,files in os.walk(self.origin_datapath): 41 | for file in files: 42 | filepath = os.path.join(root, file) 43 | if 'txtoriginal' not in filepath: 44 | with open(filepath, 'r', encoding = 'utf-8') as f: 45 | entities = [] 46 | for line in f: 47 | res = line.strip().split(' ') 48 | entity = {} 49 | entity['start_pos'] = int(res[1]) 50 | entity['end_pos'] = int(res[2])+1 51 | entity['label_type'] = res[3] 52 | entities.append(entity) 53 | f.close() 54 | self.entities.append(entities) 55 | else: 56 | text = re.sub('\n', '', open(filepath, encoding = 'UTF-8').read()) 57 | self.originalText.append(text) 58 | 59 | def txtnerlabel(self): 60 | if not os.path.exists(self.trainfile): 61 | with open(self.trainfile, 'w', encoding = 'utf-8') as f: 62 | for i in range(len(self.originalText)): 63 | text = self.originalText[i] 64 | res_dict = {} 65 | for e in self.entities[i]: 66 | start = e['start_pos'] 67 | end = e['end_pos'] 68 | label = self.configs['label']['class2label'][e['label_type']] 69 | for i in range(start, end): 70 | if i == start: 71 | label_cate = 'B-' + label 72 | else: 73 | label_cate = 'I-' + label 74 | res_dict[i] = label_cate 75 | for indx, char in enumerate(text): 76 | char_label = res_dict.get(indx, 'O') 77 | f.write(char + '\t' + char_label + '\n') 78 | # 保证每条文本末尾都以中文句号结尾 79 | if indx == len(text)-1 and char not in ['。','?','!','!','?']: 80 | f.write('。' + '\t' + 'O' + '\n') 81 | f.close() 82 | 83 | def nervocab(self): 84 | """ 85 | 获得NER所需要的字特征 86 | """ 87 | if not os.path.exists(self.trainfile): 88 | self.txtnerlabel() 89 | words = [] 90 | counter = Counter() 91 | with open(self.trainfile, 'r', encoding = 'utf-8') as f: 92 | for line in f: 93 | words.append(line[0]) 94 | f.close() 95 | for word in words: 96 | counter[word] += 1 97 | 98 | # 将词token按词频freq排序,方便用vocab_size限制词表大小 99 | self.token_freqs = sorted(counter.items(), key = lambda tup: tup[0]) 100 | self.token_freqs.sort(key = lambda tup: tup[1], reverse = True) 101 | 102 | self.itos = [] 103 | 104 | # 剔除低频词 105 | for tok, freq in self.token_freqs: 106 | if freq < self.min_freq or len(self.itos) == self.vocab_size: 107 | break 108 | self.itos.append(tok) 109 | 110 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 111 | 112 | def get_raw_data(self): 113 | if not os.path.exists(self.trainfile): 114 | self.txtnerlabel() 115 | self.raw_data = [] 116 | with open(self.trainfile, 'r', encoding = 'utf-8') as f: 117 | chars = [] 118 | labels = [] 119 | for line in f: 120 | line = line.rstrip().split('\t') 121 | if not line: 122 | continue 123 | char = line[0] 124 | if not char: 125 | continue 126 | label = line[-1] 127 | chars.append(char) 128 | labels.append(label) 129 | if char in ['。','?','!','!','?']: 130 | self.raw_data.append([chars, labels]) 131 | chars = [] 132 | labels = [] 133 | f.close() 134 | 135 | def get_train_data(self): 136 | label2id = self.configs['label']['label2id'] 137 | self.data = [] 138 | for i, item in enumerate(self.raw_data): 139 | sentence = item[0] 140 | label = item[1] 141 | s2id = [self.stoi[tok] for tok in sentence] 142 | l2id = [label2id[la] for la in label] 143 | self.data.append([s2id, l2id]) 144 | 145 | def shuffle(self): 146 | random.seed(self.random_seed) 147 | random.shuffle(self.data_ids) 148 | 149 | def train_test_split(self, validation = 0.3, random_seed = 1301): 150 | self.shuffle() 151 | train_ids = self.data_ids[:math.ceil(self.data_count * (1 - validation))] 152 | test_ids = self.data_ids[math.ceil(self.data_count * (1 - validation)):] 153 | self.train_data = np.array(self.data)[train_ids] 154 | self.test_data = np.array(self.data)[test_ids] -------------------------------------------------------------------------------- /dataset/preprocess/preprocess_ccks2019.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import math 6 | import random 7 | import numpy as np 8 | 9 | from collections import Counter 10 | 11 | class CCKS2019NER(object): 12 | 13 | def __init__(self, configs, vocab_size = None, min_freq = 1, random_seed = 1301): 14 | 15 | self.vocab_size = vocab_size 16 | self.min_freq = min_freq 17 | self.random_seed = random_seed 18 | self.trainfile = configs['label_file'] 19 | self.configs = configs 20 | check_dir(configs['data_path']) 21 | 22 | self.txtloader() 23 | self.label_dict = configs['label']['label2id'] 24 | self.class_dict = configs['label']['class2label'] 25 | self.txtnerlabel() 26 | self.nervocab() 27 | self.get_raw_data() 28 | self.get_train_data() 29 | 30 | self.data_count = len(self.data) 31 | self.data_ids = list(range(self.data_count)) 32 | self.train_test_split() 33 | 34 | def txtloader(self): 35 | self.originalText = {} 36 | self.entities = {} 37 | with open(os.path.join(self.configs['txt_path'], 'subtask1_training_part1.txt'), 'r', encoding = 'utf-8') as f: 38 | i = 0 39 | for line in f: 40 | self.originalText[i] = eval(line)['originalText'] 41 | self.entities[i] = eval(line)['entities'] 42 | i += 1 43 | f.close() 44 | with open(os.path.join(self.configs['txt_path'], 'subtask1_training_part2.txt'), 'r', encoding = 'utf-8') as f: 45 | for line in f: 46 | self.originalText[i] = eval(line)['originalText'] 47 | self.entities[i] = eval(line)['entities'] 48 | i += 1 49 | f.close() 50 | 51 | def txtnerlabel(self): 52 | if not os.path.exists(self.trainfile): 53 | with open(self.trainfile, 'w', encoding = 'utf-8') as f: 54 | for i in range(len(self.originalText)): 55 | text = self.originalText[i] 56 | res_dict = {} 57 | for e in self.entities[i]: 58 | start = e['start_pos'] 59 | end = e['end_pos'] 60 | label = self.configs['label']['class2label'][e['label_type']] 61 | for i in range(start, end): 62 | if i == start: 63 | label_cate = 'B-' + label 64 | else: 65 | label_cate = 'I-' + label 66 | res_dict[i] = label_cate 67 | for indx, char in enumerate(text): 68 | char_label = res_dict.get(indx, 'O') 69 | f.write(char + '\t' + char_label + '\n') 70 | # 保证每条文本末尾都以中文句号结尾 71 | if indx == len(text)-1 and char not in ['。','?','!','!','?']: 72 | f.write('。' + '\t' + 'O' + '\n') 73 | f.close() 74 | 75 | def nervocab(self): 76 | """ 77 | 获得NER所需要的字特征 78 | """ 79 | if not os.path.exists(self.trainfile): 80 | self.txtnerlabel() 81 | words = [] 82 | counter = Counter() 83 | with open(self.trainfile, 'r', encoding = 'utf-8') as f: 84 | for line in f: 85 | words.append(line[0]) 86 | f.close() 87 | for word in words: 88 | counter[word] += 1 89 | 90 | # 将词token按词频freq排序,方便用vocab_size限制词表大小 91 | self.token_freqs = sorted(counter.items(), key = lambda tup: tup[0]) 92 | self.token_freqs.sort(key = lambda tup: tup[1], reverse = True) 93 | 94 | self.itos = [] 95 | 96 | # 剔除低频词 97 | for tok, freq in self.token_freqs: 98 | if freq < self.min_freq or len(self.itos) == self.vocab_size: 99 | break 100 | self.itos.append(tok) 101 | 102 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 103 | 104 | def get_raw_data(self): 105 | if not os.path.exists(self.trainfile): 106 | self.txtnerlabel() 107 | self.raw_data = [] 108 | with open(self.trainfile, 'r', encoding = 'utf-8') as f: 109 | chars = [] 110 | labels = [] 111 | for line in f: 112 | line = line.rstrip().split('\t') 113 | if not line: 114 | continue 115 | char = line[0] 116 | if not char: 117 | continue 118 | label = line[-1] 119 | chars.append(char) 120 | labels.append(label) 121 | if char in ['。','?','!','!','?']: 122 | self.raw_data.append([chars, labels]) 123 | chars = [] 124 | labels = [] 125 | f.close() 126 | 127 | def get_train_data(self): 128 | label2id = self.configs['label']['label2id'] 129 | self.data = [] 130 | for i, item in enumerate(self.raw_data): 131 | sentence = item[0] 132 | label = item[1] 133 | s2id = [self.stoi[tok] for tok in sentence] 134 | l2id = [label2id[la] for la in label] 135 | self.data.append([s2id, l2id]) 136 | 137 | def shuffle(self): 138 | random.seed(self.random_seed) 139 | random.shuffle(self.data_ids) 140 | 141 | def train_test_split(self, validation = 0.3, random_seed = 1301): 142 | self.shuffle() 143 | train_ids = self.data_ids[:math.ceil(self.data_count * (1 - validation))] 144 | test_ids = self.data_ids[math.ceil(self.data_count * (1 - validation)):] 145 | self.train_data = np.array(self.data)[train_ids] 146 | self.test_data = np.array(self.data)[test_ids] -------------------------------------------------------------------------------- /dataset/processor.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import codecs 6 | from dataset.conll import oov_to_vocab 7 | 8 | class InputExample(object): 9 | 10 | def __init__(self, uid, tokens, labels): 11 | self.uid = uid 12 | self.tokens = tokens 13 | self.labels = labels 14 | 15 | class DataProcessor(object): 16 | 17 | def get_train_examples(self, data_dir, tokenizer = None): 18 | raise NotImplementedError 19 | 20 | def get_dev_examples(self, data_dir, tokenizer = None): 21 | raise NotImplementedError 22 | 23 | def get_test_examples(self, data_dir, tokenizer = None): 24 | raise NotImplementedError 25 | 26 | @staticmethod 27 | def create_examples_from_conll_format_file(data_file, set_type, tokenizer = None, all_Os = False): 28 | """ 29 | input file format 为conll类型 30 | all_Os: 是否把全是标签'O'的句子当做训练集 31 | """ 32 | examples = [] 33 | index = 0 34 | with open(data_file, 'r', encoding = 'utf-8') as f: 35 | chars = [] 36 | labels = [] 37 | for line in f: 38 | line = line.rstrip().split('\t') 39 | if not line: 40 | continue 41 | char = line[0] 42 | if not char: 43 | continue 44 | label = line[-1] 45 | if tokenizer: 46 | chars.append(oov_to_vocab(char, tokenizer)) 47 | else: 48 | chars.append(char) 49 | labels.append(label) 50 | if char in ['。','?','!','!','?']: 51 | uid = "%s-%d" %(set_type, index) 52 | # 如果单句标签全为O则不作为训练样本 53 | if labels != ['O'] * len(labels) or all_Os: 54 | examples.append(InputExample(uid = uid, tokens = chars, labels = labels)) 55 | chars = [] 56 | labels = [] 57 | return examples 58 | 59 | @staticmethod 60 | def create_examples_from_zhsentence(sentence, tokenizer = None): 61 | """ 62 | 对单句做chars-examples的转化 63 | """ 64 | examples = [] 65 | chars = [] 66 | for char in sentence: 67 | if tokenizer: 68 | chars.append(oov_to_vocab(char, tokenizer)) 69 | else: 70 | chars.append(char) 71 | if char in ['。','?','!','!','?']: 72 | uid = None 73 | labels = ['O'] * len(chars) 74 | examples.append(InputExample(uid = uid, tokens = chars, labels = labels)) 75 | chars = [] 76 | labels = [] 77 | return examples 78 | 79 | @staticmethod 80 | def get_labels(): 81 | raise NotImplementedError() 82 | 83 | @staticmethod 84 | def get_labels_to_entities(): 85 | raise NotImplementedError() 86 | 87 | class CCKS2019Processor(DataProcessor): 88 | def get_train_examples(self, data_dir, tokenizer = None): 89 | return DataProcessor.create_examples_from_conll_format_file(os.path.join(data_dir, 'train.txt'), 'train', tokenizer = tokenizer, all_Os = True) 90 | 91 | def get_dev_examples(self, data_dir, tokenizer = None): 92 | return DataProcessor.create_examples_from_conll_format_file(os.path.join(data_dir, 'dev.txt'), 'dev', tokenizer = tokenizer, all_Os = True) 93 | 94 | def get_test_examples(self, data_dir, tokenizer = None): 95 | return DataProcessor.create_examples_from_conll_format_file(os.path.join(data_dir, 'test.txt'), 'test', tokenizer = tokenizer, all_Os = True) 96 | 97 | @staticmethod 98 | def get_labels(): 99 | label_type = ['O', 'B-LABCHECK', 'I-LABCHECK','B-PICCHECK','I-PICCHECK','B-SURGERY','I-SURGERY', 100 | 'B-DISEASE','I-DISEASE','B-DRUGS','I-DRUGS','B-ANABODY','I-ANABODY'] 101 | return label_type 102 | 103 | @staticmethod 104 | def get_labels_to_entities(): 105 | label_entities_map = { 106 | 'LABCHECK': '实验室检验', 107 | 'PICCHECK': '影像检查', 108 | 'SURGERY': '手术', 109 | 'DISEASE': '疾病和诊断', 110 | 'DRUGS': '药物', 111 | 'ANABODY': '解剖部位' 112 | } 113 | return label_entities_map 114 | 115 | class CCKS2017Processor(DataProcessor): 116 | def get_train_examples(self, data_dir, tokenizer = None): 117 | return DataProcessor.create_examples_from_conll_format_file(os.path.join(data_dir, 'train.txt'), 'train', tokenizer = tokenizer, all_Os = True) 118 | 119 | def get_dev_examples(self, data_dir, tokenizer = None): 120 | return DataProcessor.create_examples_from_conll_format_file(os.path.join(data_dir, 'dev.txt'), 'dev', tokenizer = tokenizer, all_Os = True) 121 | 122 | def get_test_examples(self, data_dir, tokenizer = None): 123 | return DataProcessor.create_examples_from_conll_format_file(os.path.join(data_dir, 'test.txt'), 'test', tokenizer = tokenizer, all_Os = True) 124 | 125 | @staticmethod 126 | def get_labels(): 127 | label_type = ['O', 'B-CHECK','I-CHECK','B-SIGNS','I-SIGNS','B-DISEASE','I-DISEASE','B-TREATMENT','I-TREATMENT', 128 | 'B-BODY','I-BODY'] 129 | return label_type 130 | 131 | @staticmethod 132 | def get_labels_to_entities(): 133 | label_entities_map = { 134 | 'CHECK': '检查和检验', 135 | 'SIGNS': '症状和体征', 136 | 'DISEASE': '疾病和诊断', 137 | 'TREATMENT': '治疗', 138 | 'BODY': '身体部位', 139 | } 140 | return label_entities_map -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import logging 6 | 7 | import torch 8 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler 9 | 10 | class InputFeatures(object): 11 | """A single set of features of data.""" 12 | 13 | def __init__(self, input_ids, input_mask, segment_ids, label_ids): 14 | self.input_ids = input_ids 15 | self.input_mask = input_mask 16 | self.segment_ids = segment_ids 17 | self.label_ids = label_ids 18 | # self.output_mask = output_mask 19 | 20 | def convert_examples_to_features(examples, 21 | max_seq_length, 22 | tokenizer, 23 | label_list, 24 | pad_type = 'head+tail'): 25 | # 标签转换为数字 26 | label_map = {label: i for i, label in enumerate(label_list)} 27 | 28 | # load sub_vocab 29 | # sub_vocab = {} 30 | # with open(vocab_file, 'r', encoding = 'utf-8') as fr: 31 | # for line in fr: 32 | # _line = line.strip('\n') 33 | # if "##" in _line and sub_vocab.get(_line) is None: 34 | # sub_vocab[_line] = 1 35 | 36 | features = [] 37 | for ex_index, example in enumerate(examples): 38 | tokens = example.tokens 39 | labels = example.labels 40 | 41 | if len(tokens)==0 or len(labels)==0: 42 | continue 43 | 44 | if len(tokens) > max_seq_length - 2: 45 | if pad_type == 'head-only': 46 | tokens = tokens[:(max_seq_length-2)] 47 | labels = labels[:(max_seq_length-2)] 48 | elif pad_type == 'tail-only': 49 | tokens = tokens[(max_seq_length-2):] 50 | labels = labels[(max_seq_length-2):] 51 | elif pad_type == 'head+tail': 52 | tokens = tokens[:round((max_seq_length-2)/4)] + tokens[-round(3 * (max_seq_length-2)/4):] 53 | labels = labels[:round((max_seq_length-2)/4)] + labels[-round(3 * (max_seq_length-2)/4):] 54 | else: 55 | raise ValueError('Unknown `pad_type`: ' + str(pad_type)) 56 | # ----------------处理source-------------- 57 | ## 句子首尾加入标示符 58 | tokens = ["[CLS]"] + tokens + ["[SEP]"] 59 | 60 | # segment_ids = [sequence_a_segment_id] * len(tokens) 61 | # sequence_a_segment_id = 0, sequence_b_segment_id = 1 62 | segment_ids = [0] * len(tokens) 63 | ## 词转换成数字 64 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 65 | 66 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 67 | # tokens are attended to. 68 | input_mask = [1] * len(input_ids) 69 | 70 | padding = [0] * (max_seq_length - len(input_ids)) 71 | 72 | input_ids += padding 73 | input_mask += padding 74 | segment_ids += padding 75 | 76 | # ---------------处理target---------------- 77 | ## Notes: label_id中不包括[CLS]和[SEP] 78 | label_ids = [label_map[l] for l in labels] 79 | label_ids = [0] + label_ids + [0] 80 | label_padding = [0] * (max_seq_length-len(label_ids)) 81 | label_ids += label_padding 82 | 83 | assert len(input_ids) == max_seq_length 84 | assert len(input_mask) == max_seq_length 85 | assert len(segment_ids) == max_seq_length 86 | assert len(label_ids) == max_seq_length 87 | 88 | # label_ids = to_categorical(label_ids.cpu(), num_classes = len(label_list)) 89 | 90 | ## output_mask用来过滤bert输出中sub_word的输出,只保留单词的第一个输出(As recommended by jocob in his paper) 91 | ## 此外,也是为了适应crf 92 | # output_mask = [0 if sub_vocab.get(t) is not None else 1 for t in tokens] 93 | # output_mask = [0] + output_mask + [0] 94 | # output_mask += padding 95 | 96 | # ----------------处理后结果------------------------- 97 | # for example, in the case of max_seq_length=10: 98 | # raw_data: 春 秋 忽 代 谢le 99 | # token: [CLS] 春 秋 忽 代 谢 ##le [SEP] 100 | # input_ids: 101 2 12 13 16 14 15 102 0 0 0 101 | # input_mask: 1 1 1 1 1 1 1 1 0 0 0 102 | # label_id: T T O O O 103 | # output_mask: 0 1 1 1 1 1 0 0 0 0 0 104 | # --------------看结果是否合理------------------------ 105 | 106 | 107 | feature = InputFeatures(input_ids=input_ids, 108 | input_mask=input_mask, 109 | segment_ids=segment_ids, 110 | # output_mask=output_mask, 111 | label_ids=label_ids) 112 | features.append(feature) 113 | 114 | return features 115 | 116 | def convert_features_to_dataloader(features, batch_size): 117 | 118 | """ 119 | input_ids: size=(batch_size, max_seq_length) 120 | input_mask: size=(batch_size, max_seq_length) 121 | segment_ids: size=(batch_size, max_seq_length) 122 | label_ids: size=(batch_size, max_seq_length) 123 | """ 124 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 125 | segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 126 | input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 127 | label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) 128 | # output_mask = torch.tensor([f.output_mask for f in features], dtype=torch.long) 129 | 130 | data = TensorDataset(input_ids, segment_ids, input_mask, label_ids) 131 | sampler = RandomSampler(data) 132 | dataloader = DataLoader(data, sampler = sampler, batch_size= batch_size) 133 | return dataloader -------------------------------------------------------------------------------- /db.sqlite3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kenshinpg/PytorchNER_zh/85c46c40a5482029130f65e2fdf8e003729dac24/db.sqlite3 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | from configs.confighelper import config_loader, args_parser 6 | from dataset.preprocess import CCKS2019NER, CCKS2017NER 7 | from dataset.conll import conll_to_train_test_dev 8 | from dataset.processor import CCKS2019Processor, CCKS2017Processor 9 | from train.trainer import Trainer 10 | from train.eval import Predictor 11 | from utils.datautils import check_dir 12 | 13 | dataset_name_to_class = { 14 | 'CCKS2019': (CCKS2019NER, CCKS2019Processor, './configs/ccks2019.yml'), 15 | 'CCKS2017': (CCKS2017NER, CCKS2017Processor, './configs/ccks2017.yml') 16 | } 17 | 18 | def main(): 19 | 20 | args = args_parser() 21 | if args.task == 'train': 22 | # conll process 23 | data_vocab_class, processor_class, conll_config_path = dataset_name_to_class[args.dataset] 24 | conll_configs = config_loader(conll_config_path) 25 | if not os.path.exists(os.path.join(conll_configs['data_path'], 'train.txt')): 26 | data_vocab = data_vocab_class(conll_configs) 27 | conll_to_train_test_dev(conll_configs['label_file'], conll_configs['data_path']) 28 | 29 | # config 30 | configs = config_loader(args.config_path) 31 | configs['data_dir'] = os.path.join(configs['data_dir'], args.dataset.lower()) 32 | configs['finetune_model_dir'] = os.path.join(configs['finetune_model_dir'], args.dataset.lower()) 33 | configs['output_dir'] = os.path.join(configs['output_dir'], args.dataset.lower()) 34 | check_dir(configs['data_dir']) 35 | check_dir(configs['finetune_model_dir']) 36 | check_dir(configs['output_dir']) 37 | 38 | # train 39 | processor = processor_class() 40 | for model_class in configs['model_class']: 41 | print('Begin Training %s Model on corpus %s' %(model_class, args.dataset)) 42 | trainer = Trainer(configs, model_class, processor) 43 | trainer.train() 44 | 45 | if args.task == 'eval': 46 | data_vocab_class, processor_class, conll_config_path = dataset_name_to_class[args.dataset] 47 | conll_configs = config_loader(conll_config_path) 48 | if not os.path.exists(os.path.join(conll_configs['data_path'], 'test.txt')): 49 | data_vocab = data_vocab_class(conll_configs) 50 | conll_to_train_test_dev(conll_configs['label_file'], conll_configs['data_path']) 51 | 52 | configs = config_loader(args.config_path) 53 | configs['data_dir'] = os.path.join(configs['data_dir'], args.dataset.lower()) 54 | configs['finetune_model_dir'] = os.path.join(configs['finetune_model_dir'], args.dataset.lower()) 55 | configs['output_dir'] = os.path.join(configs['output_dir'], args.dataset.lower()) 56 | check_dir(configs['data_dir']) 57 | check_dir(configs['finetune_model_dir']) 58 | check_dir(configs['output_dir']) 59 | 60 | processor = processor_class() 61 | for model_class in configs['model_class']: 62 | print('Begin Evaluate %s Model on corpus %s' %(model_class, args.dataset)) 63 | predicter = Predictor(configs, model_class, processor) 64 | predicter.eval() 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == '__main__': 6 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'port.settings') 7 | try: 8 | from django.core.management import execute_from_command_line 9 | except ImportError as exc: 10 | raise ImportError( 11 | "Couldn't import Django. Are you sure it's installed and " 12 | "available on your PYTHONPATH environment variable? Did you " 13 | "forget to activate a virtual environment?" 14 | ) from exc 15 | execute_from_command_line(sys.argv) 16 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from utils.datautils import json_write, json_load 10 | 11 | CONFIG_NAME = "config.json" 12 | WEIGHTS_NAME = "pytorch_model.bin" 13 | TF_WEIGHTS_NAME = 'model.ckpt' 14 | 15 | class BaseModel(nn.Module): 16 | 17 | def save_pretrained(self, save_directory): 18 | 19 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 20 | 21 | # Only save the model it-self if we are using distributed training 22 | model_to_save = self.module if hasattr(self, 'module') else self 23 | # Save configuration file 24 | json_write(model_to_save.configs, os.path.join(save_directory, CONFIG_NAME)) 25 | # If we save using the predefined names, we can load using `from_pretrained` 26 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 27 | torch.save(model_to_save.state_dict(), output_model_file) 28 | 29 | @classmethod 30 | def from_pretrained(cls, pretrained_model_path, pretrained_word_embed = None): 31 | 32 | configs = json_load(os.path.join(pretrained_model_path, CONFIG_NAME)) 33 | model = cls(configs, pretrained_word_embed = pretrained_word_embed) 34 | archive_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 35 | state_dict = torch.load(archive_file, map_location='cpu') 36 | # Load from a PyTorch state_dict 37 | missing_keys = [] 38 | unexpected_keys = [] 39 | error_msgs = [] 40 | # copy state_dict so _load_from_state_dict can modify it 41 | metadata = getattr(state_dict, '_metadata', None) 42 | state_dict = state_dict.copy() 43 | if metadata is not None: 44 | state_dict._metadata = metadata 45 | def load(module, prefix=''): 46 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 47 | module._load_from_state_dict( 48 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 49 | for name, child in module._modules.items(): 50 | if child is not None: 51 | load(child, prefix + name + '.') 52 | load(model) 53 | # Set model in evaluation mode to desactivate DropOut modules by default 54 | model.eval() 55 | return model 56 | -------------------------------------------------------------------------------- /model/bert.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import CrossEntropyLoss, MSELoss 7 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel 8 | from model.crf import CRF 9 | from configs.confighelper import config_loader 10 | 11 | class Bert(BertPreTrainedModel): 12 | 13 | def __init__(self, config, model_configs): 14 | super(Bert, self).__init__(config) 15 | self.num_labels = config.num_labels 16 | self.bert = BertModel(config) 17 | self.hidden_dim = config.hidden_size 18 | self.use_cuda = model_configs['use_cuda'] and torch.cuda.is_available() 19 | self.dropout = nn.Dropout(model_configs['dropout_rate']) 20 | self.hidden2label = nn.Linear(self.hidden_dim, self.num_labels) 21 | self.loss_function = CrossEntropyLoss() 22 | self.max_seq_length = model_configs['max_seq_length'] 23 | 24 | self.apply(self.init_weights) 25 | 26 | def forward(self, input_ids, segment_ids, input_mask): 27 | 28 | # outputs = sequence_output, pooled_output, (hidden_states), (attentions) 29 | # sequence_output = encoder_outputs[0] 30 | # pooled_output = pooler(sequence_output) 31 | outputs = self.bert(input_ids = input_ids, 32 | position_ids = None, 33 | token_type_ids = segment_ids, 34 | attention_mask = input_mask, 35 | head_mask = None) 36 | # bert_embeds: shape = [batch_size, max_seq_length, bert_embedding] 37 | bert_embeds = outputs[0].contiguous().view(-1, self.hidden_dim) 38 | 39 | bert_embeds = self.dropout(bert_embeds) 40 | logits = self.hidden2label(bert_embeds) 41 | 42 | return logits.view(-1, self.max_seq_length, self.num_labels) 43 | 44 | def loss_fn(self, feats, mask, labels): 45 | """ 46 | feats: size=(batch_size, max_seq_length, num_labels) 47 | mask: size=(batch_size, max_seq_length) 48 | labels: size=(batch_size, max_seq_length) 49 | """ 50 | 51 | # Only keep active parts of the loss 52 | if mask is not None: 53 | active_loss = mask.view(-1) == 1 # size=(batch_size * max_seq_length) 54 | active_logits = feats.view(-1, self.num_labels)[active_loss] # size=(batch_size * max_seq_length, num_labels) 55 | active_labels = labels.view(-1)[active_loss] 56 | loss_value = self.loss_function(active_logits, active_labels) 57 | else: 58 | loss_value = self.loss_function(feats.view(-1, self.num_labels), labels.view(-1)) 59 | return loss_value 60 | 61 | def predict(self, feats, mask = None): 62 | return feats.argmax(-1) 63 | 64 | class BertCRF(BertPreTrainedModel): 65 | 66 | def __init__(self, config, model_configs): 67 | super(BertCRF, self).__init__(config) 68 | self.num_labels = config.num_labels 69 | self.max_seq_length = model_configs['max_seq_length'] 70 | self.bert = BertModel(config) 71 | self.use_cuda = model_configs['use_cuda'] and torch.cuda.is_available() 72 | self.crf = CRF(target_size = self.num_labels, 73 | use_cuda = self.use_cuda, 74 | average_batch = False) 75 | bert_embedding = config.hidden_size 76 | # hidden_dim即输出维度 77 | # lstm的hidden_dim和init_hidden的hidden_dim是一致的 78 | # 是输出层hidden_dim的1/2 79 | self.hidden_dim = config.hidden_size 80 | self.dropout = nn.Dropout(model_configs['dropout_rate']) 81 | self.hidden2label = nn.Linear(self.hidden_dim, self.num_labels + 2) 82 | self.apply(self.init_weights) 83 | 84 | def forward(self, input_ids, segment_ids, input_mask): 85 | 86 | # outputs = sequence_output, pooled_output, (hidden_states), (attentions) 87 | # sequence_output = encoder_outputs[0] 88 | # pooled_output = pooler(sequence_output) 89 | outputs = self.bert(input_ids = input_ids, 90 | position_ids = None, 91 | token_type_ids = segment_ids, 92 | attention_mask = input_mask, 93 | head_mask = None) 94 | # bert_embeds: shape = [batch_size, max_seq_length, bert_embedding] 95 | bert_embeds = outputs[0].contiguous().view(-1, self.hidden_dim) 96 | bert_embeds = self.dropout(bert_embeds) 97 | logits = self.hidden2label(bert_embeds) 98 | 99 | return logits.view(-1, self.max_seq_length, self.num_labels +2) 100 | 101 | def loss_fn(self, feats, mask, labels): 102 | batch_size = feats.size(0) 103 | loss_value = self.crf.neg_log_likelihood_loss(feats, mask, labels)/float(batch_size) 104 | return loss_value 105 | 106 | def predict(self, feats, mask): 107 | 108 | path_score, best_path = self.crf(feats, mask.byte()) 109 | return best_path 110 | 111 | class BertBiLSTMCRF(BertPreTrainedModel): 112 | 113 | def __init__(self, config, model_configs): 114 | super(BertBiLSTMCRF, self).__init__(config) 115 | self.num_labels = config.num_labels 116 | self.max_seq_length = model_configs['max_seq_length'] 117 | self.bert = BertModel(config) 118 | self.use_cuda = model_configs['use_cuda'] and torch.cuda.is_available() 119 | self.crf = CRF(target_size = self.num_labels, 120 | use_cuda = self.use_cuda, 121 | average_batch = False) 122 | bert_embedding = config.hidden_size 123 | # hidden_dim即输出维度 124 | # lstm的hidden_dim和init_hidden的hidden_dim是一致的 125 | # 是输出层hidden_dim的1/2 126 | self.hidden_dim = config.hidden_size 127 | self.rnn_layers = model_configs['rnn_layers'] 128 | self.lstm = nn.LSTM(input_size = bert_embedding, # bert embedding 129 | hidden_size = self.hidden_dim, 130 | num_layers = self.rnn_layers, 131 | batch_first = True, 132 | # dropout = model_configs['train']['dropout_rate'], 133 | bidirectional = True) 134 | self.dropout = nn.Dropout(model_configs['dropout_rate']) 135 | self.hidden2label = nn.Linear(self.hidden_dim * 2, self.num_labels + 2) 136 | self.apply(self.init_weights) 137 | 138 | def rand_init_hidden(self, batch_size): 139 | """ 140 | random initialize hidden variable 141 | 双向是2,单向是1 142 | """ 143 | if self.use_cuda: 144 | return (torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim).cuda(), 145 | torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim).cuda()) 146 | else: 147 | return (torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim), 148 | torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim)) 149 | 150 | def forward(self, input_ids, segment_ids, input_mask): 151 | 152 | # outputs = sequence_output, pooled_output, (hidden_states), (attentions) 153 | # sequence_output = encoder_outputs[0] 154 | # pooled_output = pooler(sequence_output) 155 | outputs = self.bert(input_ids = input_ids, 156 | position_ids = None, 157 | token_type_ids = segment_ids, 158 | attention_mask = input_mask, 159 | head_mask = None) 160 | # bert_embeds: shape = [batch_size, max_seq_length, bert_embedding] 161 | bert_embeds = outputs[0] 162 | 163 | batch_size = input_ids.size(0) 164 | 165 | hidden = self.rand_init_hidden(batch_size) 166 | 167 | lstm_output, hidden = self.lstm(bert_embeds, hidden) 168 | lstm_output = lstm_output.contiguous().view(-1, self.hidden_dim * 2) 169 | # lstm_output = self.dropout(lstm_output) 170 | logits = self.hidden2label(lstm_output) 171 | 172 | return logits.view(-1, self.max_seq_length, self.num_labels +2) 173 | 174 | def loss_fn(self, feats, mask, labels): 175 | batch_size = feats.size(0) 176 | loss_value = self.crf.neg_log_likelihood_loss(feats, mask, labels)/float(batch_size) 177 | return loss_value 178 | 179 | def predict(self, feats, mask): 180 | 181 | path_score, best_path = self.crf(feats, mask.byte()) 182 | return best_path -------------------------------------------------------------------------------- /model/bilstm.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import CrossEntropyLoss 7 | from model.crf import CRF 8 | from model.base import BaseModel 9 | # from pytorch_transformers.modeling_utils import PretrainedConfig 10 | 11 | # class BiLSTMConfig(PretrainedConfig): 12 | 13 | class BiLSTM(BaseModel): 14 | 15 | def __init__(self, configs, pretrained_word_embed = None): 16 | super(BiLSTM, self).__init__() 17 | 18 | self.configs = configs 19 | self.num_labels = configs['num_labels'] 20 | self.max_seq_length = configs['max_seq_length'] 21 | self.hidden_dim = configs['rnn_hidden'] 22 | self.use_cuda = configs['use_cuda'] and torch.cuda.is_available() 23 | 24 | # word embedding layer 25 | self.word_embedding = nn.Embedding(num_embeddings = configs['word_vocab_size'], 26 | embedding_dim = configs['word_embedding_dim']) 27 | if configs['use_pretrained_embedding']: 28 | self.word_embedding.weight.data.copy_(torch.from_numpy(pretrained_word_embed)) 29 | self.word_embedding.weight.requires_grad = configs['requires_grad'] 30 | # dropout layer 31 | self.dropout_embed = nn.Dropout(configs['dropout_rate']) 32 | self.dropout_rnn = nn.Dropout(configs['dropout_rate']) 33 | # rnn layer 34 | self.rnn_layers = configs['rnn_layers'] 35 | self.lstm = nn.LSTM(input_size = configs['word_embedding_dim'], # bert embedding 36 | hidden_size = self.hidden_dim, 37 | num_layers = self.rnn_layers, 38 | batch_first = True, 39 | bidirectional = True) 40 | self.hidden2label = nn.Linear(self.hidden_dim * 2, self.num_labels) 41 | self.loss_function = CrossEntropyLoss() 42 | 43 | def rand_init_hidden(self, batch_size): 44 | """ 45 | random initialize hidden variable 46 | 双向是2,单向是1 47 | """ 48 | if self.use_cuda: 49 | return (torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim).cuda(), 50 | torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim).cuda()) 51 | else: 52 | return (torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim), 53 | torch.zeros(2 * self.rnn_layers, batch_size, self.hidden_dim)) 54 | 55 | def get_lstm_outputs(self, input_ids): 56 | 57 | word_embeds = self.word_embedding(input_ids) 58 | word_embeds = self.dropout_embed(word_embeds) 59 | 60 | batch_size = input_ids.size(0) 61 | hidden = self.rand_init_hidden(batch_size) 62 | 63 | lstm_outputs, hidden = self.lstm(word_embeds, hidden) 64 | lstm_outputs = lstm_outputs.contiguous().view(-1, self.hidden_dim * 2) 65 | return lstm_outputs 66 | 67 | def forward(self, input_ids, segment_ids, input_mask): 68 | lstm_outputs = self.get_lstm_outputs(input_ids) 69 | logits = self.hidden2label(lstm_outputs) 70 | return logits.view(-1, self.max_seq_length, self.num_labels) 71 | 72 | def loss_fn(self, feats, mask, labels): 73 | loss_value = self.loss_function(feats.view(-1, self.num_labels), labels.view(-1)) 74 | return loss_value 75 | 76 | def predict(self, feats, mask = None): 77 | return feats.argmax(-1) 78 | 79 | 80 | class BiLSTMCRF(BaseModel): 81 | 82 | def __init__(self, configs, pretrained_word_embed = None): 83 | super(BiLSTMCRF, self).__init__() 84 | 85 | self.configs = configs 86 | self.num_labels = configs['num_labels'] 87 | self.max_seq_length = configs['max_seq_length'] 88 | self.use_cuda = configs['use_cuda'] and torch.cuda.is_available() 89 | 90 | self.bilstm = BiLSTM(configs, pretrained_word_embed) 91 | self.crf = CRF(target_size = self.num_labels, 92 | use_cuda = self.use_cuda, 93 | average_batch = False) 94 | self.hidden2label = nn.Linear(self.bilstm.hidden_dim * 2, self.num_labels + 2) 95 | 96 | def forward(self, input_ids, segment_ids, input_mask): 97 | lstm_outputs = self.bilstm.get_lstm_outputs(input_ids) 98 | logits = self.hidden2label(lstm_outputs) 99 | return logits.view(-1, self.max_seq_length, self.num_labels + 2) 100 | 101 | def loss_fn(self, feats, mask, labels): 102 | """ 103 | Args: 104 | feats: size=(batch_size, seq_len, tag_size) 105 | mask: size=(batch_size, seq_len) 106 | tags: size=(batch_size, seq_len) 107 | """ 108 | batch_size = feats.size(0) 109 | loss_value = self.crf.neg_log_likelihood_loss(feats, mask, labels)/float(batch_size) 110 | return loss_value 111 | 112 | def predict(self, feats, mask): 113 | 114 | path_score, best_path = self.crf(feats, mask.byte()) 115 | return best_path -------------------------------------------------------------------------------- /model/crf.py: -------------------------------------------------------------------------------- 1 | 2 | # coding=utf-8 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def log_sum_exp(vec, m_size): 9 | """ 10 | Args: 11 | vec: size=(batch_size, vanishing_dim, hidden_dim) 12 | m_size: hidden_dim 13 | 14 | Returns: 15 | size=(batch_size, hidden_dim) 16 | """ 17 | _, idx = torch.max(vec, 1) # B * 1 * M 18 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M 19 | return max_score.view(-1, m_size) + torch.log(torch.sum( 20 | torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) 21 | 22 | 23 | class CRF(nn.Module): 24 | 25 | def __init__(self, **kwargs): 26 | """ 27 | Args: 28 | target_size: int, target size 29 | use_cuda: bool, 是否使用gpu, default is True 30 | average_batch: bool, loss是否作平均, default is True 31 | """ 32 | super(CRF, self).__init__() 33 | for k in kwargs: 34 | self.__setattr__(k, kwargs[k]) 35 | self.START_TAG_IDX, self.END_TAG_IDX = -2, -1 36 | # +2 是因为tag比初始的Label_tag多了start_tag和end_tag 37 | init_transitions = torch.zeros(self.target_size+2, self.target_size+2) 38 | # These two statements enforce the constraint that we never transfer 39 | # to the start tag and we never transfer from the stop tag 40 | init_transitions[:, self.START_TAG_IDX] = -1000. 41 | init_transitions[self.END_TAG_IDX, :] = -1000. 42 | if self.use_cuda: 43 | init_transitions = init_transitions.cuda() 44 | self.transitions = nn.Parameter(init_transitions) 45 | 46 | def _forward_alg(self, feats, mask=None): 47 | """ 48 | Do the forward algorithm to compute the partition function (batched). 49 | 50 | Args: 51 | feats: size=(batch_size, seq_len, self.target_size+2) 52 | mask: size=(batch_size, seq_len) 53 | 54 | Returns: 55 | xxx 56 | """ 57 | batch_size = feats.size(0) 58 | seq_len = feats.size(1) 59 | tag_size = feats.size(-1) 60 | 61 | mask = mask.transpose(1, 0).contiguous() 62 | ins_num = batch_size * seq_len 63 | feats = feats.transpose(1, 0).contiguous().view( 64 | ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 65 | 66 | scores = feats + self.transitions.view( 67 | 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 68 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 69 | seq_iter = enumerate(scores) 70 | try: 71 | _, inivalues = seq_iter.__next__() 72 | except: 73 | _, inivalues = seq_iter.next() 74 | 75 | partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1) 76 | for idx, cur_values in seq_iter: 77 | cur_values = cur_values + partition.contiguous().view( 78 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 79 | cur_partition = log_sum_exp(cur_values, tag_size) 80 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) 81 | masked_cur_partition = cur_partition.masked_select(mask_idx.byte()) 82 | if masked_cur_partition.dim() != 0: 83 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 84 | partition.masked_scatter_(mask_idx.byte(), masked_cur_partition) 85 | cur_values = self.transitions.view(1, tag_size, tag_size).expand( 86 | batch_size, tag_size, tag_size) + partition.contiguous().view( 87 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 88 | cur_partition = log_sum_exp(cur_values, tag_size) 89 | final_partition = cur_partition[:, self.END_TAG_IDX] 90 | return final_partition.sum(), scores 91 | 92 | def _viterbi_decode(self, feats, mask=None): 93 | """ 94 | Args: 95 | feats: size=(batch_size, seq_len, self.target_size+2) 96 | mask: size=(batch_size, seq_len) 97 | 98 | Returns: 99 | decode_idx: (batch_size, seq_len), viterbi decode结果 100 | path_score: size=(batch_size, 1), 每个句子的得分 101 | """ 102 | batch_size = feats.size(0) 103 | seq_len = feats.size(1) 104 | tag_size = feats.size(-1) 105 | 106 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() 107 | mask = mask.transpose(1, 0).contiguous() 108 | ins_num = seq_len * batch_size 109 | feats = feats.transpose(1, 0).contiguous().view( 110 | ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 111 | 112 | scores = feats + self.transitions.view( 113 | 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 114 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 115 | 116 | seq_iter = enumerate(scores) 117 | # record the position of the best score 118 | back_points = list() 119 | partition_history = list() 120 | mask = (1 - mask.long()).byte() 121 | try: 122 | _, inivalues = seq_iter.__next__() 123 | except: 124 | _, inivalues = seq_iter.next() 125 | partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1) 126 | partition_history.append(partition) 127 | 128 | for idx, cur_values in seq_iter: 129 | cur_values = cur_values + partition.contiguous().view( 130 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 131 | partition, cur_bp = torch.max(cur_values, 1) 132 | partition_history.append(partition.unsqueeze(-1)) 133 | 134 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 135 | back_points.append(cur_bp) 136 | 137 | partition_history = torch.cat(partition_history).view( 138 | seq_len, batch_size, -1).transpose(1, 0).contiguous() 139 | 140 | last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1 141 | last_partition = torch.gather( 142 | partition_history, 1, last_position).view(batch_size, tag_size, 1) 143 | 144 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + \ 145 | self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) 146 | _, last_bp = torch.max(last_values, 1) 147 | pad_zero = Variable(torch.zeros(batch_size, tag_size)).long() 148 | if self.use_cuda: 149 | pad_zero = pad_zero.cuda() 150 | back_points.append(pad_zero) 151 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 152 | 153 | pointer = last_bp[:, self.END_TAG_IDX] 154 | insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size) 155 | back_points = back_points.transpose(1, 0).contiguous() 156 | 157 | back_points.scatter_(1, last_position, insert_last) 158 | 159 | back_points = back_points.transpose(1, 0).contiguous() 160 | 161 | decode_idx = Variable(torch.LongTensor(seq_len, batch_size)) 162 | if self.use_cuda: 163 | decode_idx = decode_idx.cuda() 164 | decode_idx[-1] = pointer.data 165 | for idx in range(len(back_points)-2, -1, -1): 166 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 167 | decode_idx[idx] = pointer.view(-1).data 168 | path_score = None 169 | decode_idx = decode_idx.transpose(1, 0) 170 | return path_score, decode_idx 171 | 172 | def forward(self, feats, mask=None): 173 | path_score, best_path = self._viterbi_decode(feats, mask) 174 | return path_score, best_path 175 | 176 | def _score_sentence(self, scores, mask, tags): 177 | """ 178 | Args: 179 | scores: size=(seq_len, batch_size, tag_size, tag_size) 180 | mask: size=(batch_size, seq_len) 181 | tags: size=(batch_size, seq_len) 182 | 183 | Returns: 184 | score: 185 | """ 186 | batch_size = scores.size(1) 187 | seq_len = scores.size(0) 188 | tag_size = scores.size(-1) 189 | 190 | new_tags = Variable(torch.LongTensor(batch_size, seq_len)) 191 | if self.use_cuda: 192 | new_tags = new_tags.cuda() 193 | for idx in range(seq_len): 194 | if idx == 0: 195 | new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] 196 | else: 197 | new_tags[:, idx] = tags[:, idx-1] * tag_size + tags[:, idx] 198 | 199 | end_transition = self.transitions[:, self.END_TAG_IDX].contiguous().view( 200 | 1, tag_size).expand(batch_size, tag_size) 201 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() 202 | end_ids = torch.gather(tags, 1, length_mask-1) 203 | 204 | end_energy = torch.gather(end_transition, 1, end_ids) 205 | 206 | new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) 207 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view( 208 | seq_len, batch_size) 209 | tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) 210 | 211 | gold_score = tg_energy.sum() + end_energy.sum() 212 | 213 | return gold_score 214 | 215 | def neg_log_likelihood_loss(self, feats, mask, tags): 216 | """ 217 | Args: 218 | feats: size=(batch_size, seq_len, tag_size) 219 | mask: size=(batch_size, seq_len) 220 | tags: size=(batch_size, seq_len) 221 | """ 222 | batch_size = feats.size(0) 223 | mask = mask.byte() 224 | forward_score, scores = self._forward_alg(feats, mask) 225 | gold_score = self._score_sentence(scores, mask, tags) 226 | if self.average_batch: 227 | return (forward_score - gold_score) / batch_size 228 | return forward_score - gold_score -------------------------------------------------------------------------------- /port/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for pydemo project. 3 | 4 | Generated by 'django-admin startproject' using Django 2.1. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/dev/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/dev/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | # Quick-start development settings - unsuitable for production 19 | # See https://docs.djangoproject.com/en/dev/howto/deployment/checklist/ 20 | 21 | # SECURITY WARNING: keep the secret key used in production secret! 22 | SECRET_KEY = 'o)24s543aa@6l329wp+*)-$10vbzx*l7lbz56mz3c(veyz)o=_' 23 | 24 | # SECURITY WARNING: don't run with debug turned on in production! 25 | DEBUG = True 26 | 27 | ALLOWED_HOSTS = ['*'] 28 | 29 | # Application definition 30 | 31 | INSTALLED_APPS = [ 32 | 'django.contrib.admin', 33 | 'django.contrib.auth', 34 | 'django.contrib.contenttypes', 35 | 'django.contrib.sessions', 36 | 'django.contrib.messages', 37 | 'django.contrib.staticfiles', 38 | ] 39 | 40 | MIDDLEWARE = [ 41 | 'django.middleware.security.SecurityMiddleware', 42 | 'django.contrib.sessions.middleware.SessionMiddleware', 43 | 'django.middleware.common.CommonMiddleware', 44 | # 'django.middleware.csrf.CsrfViewMiddleware', 45 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 46 | 'django.contrib.messages.middleware.MessageMiddleware', 47 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 48 | ] 49 | 50 | ROOT_URLCONF = 'port.urls' 51 | 52 | TEMPLATES = [ 53 | { 54 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 55 | 'DIRS': [], 56 | 'APP_DIRS': True, 57 | 'OPTIONS': { 58 | 'context_processors': [ 59 | 'django.template.context_processors.debug', 60 | 'django.template.context_processors.request', 61 | 'django.contrib.auth.context_processors.auth', 62 | 'django.contrib.messages.context_processors.messages', 63 | ], 64 | }, 65 | }, 66 | ] 67 | 68 | WSGI_APPLICATION = 'port.wsgi.application' 69 | 70 | 71 | # Database 72 | # https://docs.djangoproject.com/en/dev/ref/settings/#databases 73 | 74 | DATABASES = { 75 | 'default': { 76 | 'ENGINE': 'django.db.backends.sqlite3', 77 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 78 | } 79 | } 80 | 81 | 82 | # Password validation 83 | # https://docs.djangoproject.com/en/dev/ref/settings/#auth-password-validators 84 | 85 | AUTH_PASSWORD_VALIDATORS = [ 86 | { 87 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 88 | }, 89 | { 90 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 91 | }, 92 | { 93 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 94 | }, 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 97 | }, 98 | ] 99 | 100 | 101 | # Internationalization 102 | # https://docs.djangoproject.com/en/dev/topics/i18n/ 103 | 104 | LANGUAGE_CODE = 'en-us' 105 | 106 | TIME_ZONE = 'UTC' 107 | 108 | USE_I18N = True 109 | 110 | USE_L10N = True 111 | 112 | USE_TZ = True 113 | 114 | 115 | # Static files (CSS, JavaScript, Images) 116 | # https://docs.djangoproject.com/en/dev/howto/static-files/ 117 | 118 | STATIC_URL = '/static/' 119 | -------------------------------------------------------------------------------- /port/urls.py: -------------------------------------------------------------------------------- 1 | """pydemo URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/dev/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path 18 | from . import view 19 | 20 | urlpatterns = [ 21 | path("ner", view.get_NER_result), 22 | path('admin/', admin.site.urls), 23 | ] -------------------------------------------------------------------------------- /port/view.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | import os 4 | import sys 5 | 6 | from django.http import HttpResponse 7 | import json 8 | from utils.datautils import json_dump 9 | from configs.confighelper import config_loader, args_parser 10 | from dataset.preprocess import CCKS2019NER, CCKS2017NER 11 | from dataset.conll import conll_to_train_test_dev 12 | from dataset.processor import CCKS2019Processor, CCKS2017Processor 13 | from train.trainer import Trainer 14 | from train.eval import Predictor 15 | from utils.datautils import check_dir 16 | 17 | dataset_name_to_class = { 18 | 'CCKS2019': (CCKS2019NER, CCKS2019Processor, './configs/ccks2019.yml'), 19 | 'CCKS2017': (CCKS2017NER, CCKS2017Processor, './configs/ccks2017.yml') 20 | } 21 | 22 | def get_NER_result(request): 23 | """ 24 | request['data']: json 25 | example:[{"sentence":"入院后完善相关辅助检查,给予口服活血止痛、调节血压药物及物理治疗,患者血脂异常,补充诊断:混合性高脂血症,给予调节血脂药物治疗;患者诉心慌、无力,急查心电图提示:心房颤动,ST段改变。急请内科会诊,考虑为:1.冠心病 不稳定型心绞痛 心律失常 室性期前收缩 房性期前收缩 心房颤动;2.高血压病3级 极高危组。给予处理:1.急查心肌酶学、离子,定期复查心电图;2.给予持续心电、血压、血氧监测3.给予吸氧、西地兰0.2mg加5%葡萄糖注射液15ml稀释后缓慢静推,给予硝酸甘油10mg加入5%葡萄糖注射液500ml以5~10ugmin缓慢静点,继续口服阿司匹林100mg日一次,辛伐他汀20mg日一次,硝酸异山梨酯10mg日三次口服,稳心颗粒1袋日三次,美托洛尔12.5mg日二次,非洛地平5mg日一次治疗,患者病情好转出院。","model_class":["BertBiLSTMCRF"],"dataset": "CCKS2019"}] 26 | """ 27 | if request.method == 'POST': 28 | json_data = json.loads(request.POST['data'], encoding = 'utf-8'); 29 | sentence = json_data[0]['sentence'] 30 | model_class = json_data[0]['model_class'] 31 | dataset = json_data[0]['dataset'] 32 | 33 | data_vocab_class, processor_class, conll_config_path = dataset_name_to_class[dataset] 34 | 35 | configs = config_loader('./configs/config.yml') 36 | configs['finetune_model_dir'] = os.path.join(configs['finetune_model_dir'], dataset.lower()) 37 | configs['output_dir'] = os.path.join(configs['output_dir'], dataset.lower()) 38 | 39 | result = {} 40 | 41 | processor = processor_class() 42 | for model_class in model_class: 43 | print('%s Model Outputs:') 44 | predicter = Predictor(configs, model_class, processor) 45 | entities_, result_ = predicter.predict_one(sentence) 46 | print(entities_) 47 | 48 | result[model_class] = result_ 49 | 50 | output = json_dump(result) 51 | 52 | return HttpResponse(json.dumps({ 53 | "status": 200, 54 | "errMsg": "", 55 | "data": output 56 | })) 57 | else: 58 | return HttpResponse(json.dumps({ 59 | "status": 400, 60 | "errMsg": "ValueError", 61 | "data": "" 62 | })) -------------------------------------------------------------------------------- /port/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for pydemo project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/dev/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'port.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.4 2 | gensim==3.8.0 3 | pytorch-transformers==1.1.0 4 | torch==1.1.0 5 | TorchSnooper==0.7 6 | Django==2.0.5 7 | scikit-learn==0.21.3 8 | tqdm==4.23.4 9 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python main.py --task train --config_path ./configs/bilstm_config.yml 2 | python main.py --task train --config_path ./configs/bert_config.yml -------------------------------------------------------------------------------- /train/device.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import torch 5 | import pynvml 6 | 7 | def check_cuda(configs): 8 | 9 | if configs['use_cuda'] and torch.cuda.is_available(): 10 | device = torch.device("cuda", torch.cuda.current_device()) 11 | use_gpu = True 12 | else: 13 | device = torch.device("cpu") 14 | use_gpu = False 15 | return device, use_gpu 16 | 17 | def gpu_memory_occupancy(configs): 18 | 19 | total_memory = configs['gpu_memory'] 20 | pynvml.nvmlInit() 21 | # 这里的0是GPU id 22 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 23 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) 24 | return '%.3fMiB//%dMiB' %(float(meminfo.used/1024/1024), total_memory) -------------------------------------------------------------------------------- /train/eval.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import torch 6 | from pytorch_transformers.tokenization_bert import BertTokenizer 7 | from model.bert import BertBiLSTMCRF 8 | from train.pretrain import load_pretrain 9 | from dataset.utils import convert_examples_to_features, convert_features_to_dataloader 10 | from train.device import check_cuda 11 | from sklearn.metrics import f1_score, classification_report 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | # import torchsnooper 16 | 17 | class Predictor(object): 18 | 19 | def __init__(self, configs, model_class, processor): 20 | 21 | self.configs = configs 22 | self.processor = processor 23 | self.label_list = self.processor.get_labels() 24 | self.label2entities = self.processor.get_labels_to_entities() 25 | self.device, self.use_cuda = check_cuda(configs) 26 | self.fine_tune_dir = os.path.join(self.configs['finetune_model_dir'], model_class) 27 | self.model, self.tokenizer = load_pretrain(configs, model_class, self.fine_tune_dir, processor, eval = True) 28 | 29 | if self.use_cuda: 30 | self.model = self.model.cuda() 31 | 32 | self.entities_list = set(label.split('-')[1] for label in self.label_list if label != 'O') 33 | 34 | @staticmethod 35 | def class_report(y_pred, y_true): 36 | y_true = y_true.numpy() 37 | y_pred = y_pred.numpy() 38 | classify_report = classification_report(y_true, y_pred) 39 | print('\n\nclassify_report:\n', classify_report) 40 | 41 | def eval(self): 42 | 43 | test_examples = self.processor.get_test_examples(self.configs['data_dir'], tokenizer = self.tokenizer) 44 | test_features = convert_examples_to_features(examples = test_examples, 45 | max_seq_length = self.configs['max_seq_length'], 46 | tokenizer = self.tokenizer, 47 | label_list = self.label_list) 48 | self.test_dataloader = convert_features_to_dataloader(test_features, batch_size = self.configs['batch_size']) 49 | 50 | self.model.eval() 51 | count = 0 52 | y_preds, y_labels = [], [] 53 | 54 | # with torchsnooper.snoop(): 55 | with torch.no_grad(): 56 | for batch in tqdm(self.test_dataloader, ncols=75): 57 | input_ids, segment_ids, input_mask, label_ids = tuple(t.to(self.device) for t in batch) 58 | feats = self.model(input_ids, segment_ids, input_mask) 59 | predicts = self.model.predict(feats, input_mask) 60 | y_preds.append(predicts) 61 | y_labels.append(label_ids) 62 | 63 | self.y_preds = y_preds 64 | self.y_labels = y_labels 65 | 66 | eval_predict = torch.cat(y_preds, dim = 0).view(-1).cpu() 67 | eval_label = torch.cat(y_labels, dim = 0).view(-1).cpu() 68 | self.class_report(eval_predict, eval_label) 69 | 70 | @staticmethod 71 | def convert_ids_to_labels(label_ids, label_list): 72 | id_label_map = {i: label for i, label in enumerate(label_list)} 73 | return [id_label_map[i] for i in label_ids] 74 | 75 | @staticmethod 76 | def get_entity(tag_seq, char_seq, entity): 77 | length = len(char_seq) 78 | entities = [] 79 | begin_tag = 'B-%s' %entity 80 | inter_tag = 'I-%s' %entity 81 | entity_ = [] 82 | for i, (char, tag) in enumerate(zip(char_seq, tag_seq)): 83 | if tag == begin_tag or tag == inter_tag: 84 | entity_.append(char) 85 | if tag_seq[i+1] == 'O' or (i+1 == length): 86 | entities.append(''.join(entity_)) 87 | entity_= [] 88 | else: 89 | continue 90 | return entities 91 | 92 | def predict_one(self, sentence, special_tokens = ['[CLS]', '[PAD]', '[SEP]']): 93 | 94 | predict_examples = self.processor.create_examples_from_zhsentence(sentence, self.tokenizer) 95 | predict_features = convert_examples_to_features(examples = predict_examples, 96 | max_seq_length = self.configs['max_seq_length'], 97 | tokenizer = self.tokenizer, 98 | label_list = self.label_list) 99 | input_ids = torch.tensor([f.input_ids for f in predict_features], dtype=torch.long) 100 | segment_ids = torch.tensor([f.segment_ids for f in predict_features], dtype=torch.long) 101 | input_mask = torch.tensor([f.input_mask for f in predict_features], dtype=torch.long) 102 | label_ids = torch.tensor([f.label_ids for f in predict_features], dtype=torch.long) 103 | if self.use_cuda: 104 | input_ids = input_ids.cuda() 105 | segment_ids = segment_ids.cuda() 106 | input_mask = input_mask.cuda() 107 | label_ids = label_ids.cuda() 108 | with torch.no_grad(): 109 | feats = self.model(input_ids, segment_ids, input_mask) 110 | predicts = self.model.predict(feats, input_mask) 111 | 112 | tokens = self.tokenizer.convert_ids_to_tokens(input_ids.contiguous().view(-1).cpu().numpy()) 113 | preds = self.convert_ids_to_labels(predicts.contiguous().view(-1).cpu().numpy(), self.label_list) 114 | 115 | result = [(tok, pred) for (tok, pred) in zip(tokens, preds) if tok not in special_tokens] 116 | 117 | entities = {} 118 | for tag in self.entities_list: 119 | entities[self.label2entities[tag]] = self.get_entity(preds, tokens, tag) 120 | return entities, result -------------------------------------------------------------------------------- /train/optimizer.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule 5 | import torch.optim as optim 6 | 7 | def load_optimizer(configs, model): 8 | 9 | # Prepare optimizer 10 | optimizer_type = configs['optimizer_type'] 11 | 12 | lr_decay = configs['lr_decay'] 13 | learning_rate = configs['learning_rate'] 14 | momentum = configs['momentum'] 15 | l2_rate = configs['l2_rate'] 16 | 17 | if optimizer_type.lower() == 'adamw': 18 | param_optimizer = list(model.named_parameters()) 19 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 20 | optimizer_grouped_parameters = [ 21 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': configs['weight_decay']}, 22 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 23 | ] 24 | optimizer = AdamW(optimizer_grouped_parameters, 25 | lr=configs['learning_rate'], 26 | correct_bias = False) 27 | else: 28 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 29 | if optimizer_type.lower() == "sgd": 30 | optimizer = optim.SGD(parameters, lr=learning_rate, momentum=momentum, weight_decay=l2_rate) 31 | elif optimizer_type.lower() == "adagrad": 32 | optimizer = optim.Adagrad(parameters, lr = learning_rate, weight_decay=l2_rate) 33 | elif optimizer_type.lower() == "adadelta": 34 | optimizer = optim.Adadelta(parameters, lr=learning_rate, weight_decay=l2_rate) 35 | elif optimizer_type.lower() == "rmsprop": 36 | optimizer = optim.RMSprop(parameters, lr=learning_rate, weight_decay=l2_rate) 37 | elif optimizer_type.lower() == "adam": 38 | optimizer = optim.Adam(parameters, lr=learning_rate, weight_decay=l2_rate) 39 | else: 40 | print('请选择正确的optimizer: {0}'.format(optimizer_type)) 41 | 42 | # warmup_proportion:warm up 步数的比例,比如说总共学习100步, 43 | # warmup_proportion=0.1表示前10步用来warm up,warm up时以较低的学习率进行学习 44 | # (lr = global_step/num_warmup_steps * init_lr),10步之后以正常(或衰减)的学习率来学习。 45 | schedular = WarmupLinearSchedule(optimizer, 46 | warmup_steps = int(configs['num_train_steps'] * configs['warmup_proportion']), 47 | t_total = configs['num_train_steps']) 48 | return optimizer, schedular -------------------------------------------------------------------------------- /train/plot.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | color = sns.color_palette() 9 | # %matplotlib inline 10 | sns.set_style("whitegrid") 11 | 12 | def eval_plot(configs, train_loss_log, dev_loss_log): 13 | 14 | plt.figure(figsize=(12,7)) 15 | epochs = list(range(len(train_loss_log))) 16 | 17 | plt.plot(epochs, train_loss_log, color = 'r') 18 | plt.plot(epochs, dev_loss_log, color = 'b') 19 | plt.xlabel('Epochs', fontsize = 15) 20 | plt.ylabel('Loss', fontsize = 15) 21 | plt.savefig(os.path.join(configs['test']['output_dir'], 'loss_eval.png')) -------------------------------------------------------------------------------- /train/pretrain.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import torch 6 | import re 7 | 8 | from utils.datautils import check_dir 9 | from model.bert import Bert, BertCRF, BertBiLSTMCRF 10 | from model.bilstm import BiLSTM, BiLSTMCRF 11 | from pytorch_transformers.modeling_bert import BertConfig 12 | from pytorch_transformers.tokenization_bert import BertTokenizer 13 | from dataset.embedding import build_word_embed 14 | 15 | def load_pretrain(configs, model_class, fine_tune_dir, processor, eval = False): 16 | """ 17 | configs: 配置文件 18 | model_class: 模型名称 19 | fine_tune_dir: 微调模型保存路径 20 | processor: DataProcessor 21 | eval: 是否验证 22 | """ 23 | 24 | model_class_map = { 25 | 'Bert': Bert, 26 | 'BertCRF': BertCRF, 27 | 'BertBiLSTMCRF': BertBiLSTMCRF, 28 | 'BiLSTM': BiLSTM, 29 | 'BiLSTMCRF': BiLSTMCRF 30 | } 31 | model_class_ = model_class_map[model_class] 32 | label_list = processor.get_labels() 33 | 34 | check_dir(fine_tune_dir) 35 | if eval: 36 | model_pretrained_path = fine_tune_dir 37 | else: 38 | model_pretrained_path = configs['pretrained_model_dir'] 39 | tokenizer = BertTokenizer.from_pretrained(model_pretrained_path, do_lower_case = configs['lower_case']) 40 | 41 | if model_class in ['Bert', 'BertCRF', 'BertBiLSTMCRF']: 42 | bert_config = BertConfig.from_pretrained(model_pretrained_path, 43 | num_labels = len(label_list), 44 | finetuning_task="ner") 45 | model = model_class_.from_pretrained(model_pretrained_path, config = bert_config, model_configs = configs) 46 | 47 | elif model_class in ['BiLSTM', 'BiLSTMCRF']: 48 | configs['num_labels'] = len(label_list) 49 | if configs['use_pretrained_embedding']: 50 | pretrained_word_embed = build_word_embed(tokenizer, 51 | pretrain_embed_file = configs['pretrain_embed_file'], 52 | pretrain_embed_pkl = configs['pretrain_embed_pkl']) 53 | configs['word_vocab_size'] = pretrained_word_embed.shape[0] 54 | configs['word_embedding_dim'] = pretrained_word_embed.shape[1] 55 | else: 56 | pretrained_word_embed = None 57 | if eval: 58 | model_pretrained_path = fine_tune_dir 59 | model = model_class_.from_pretrained(model_pretrained_path, pretrained_word_embed) 60 | else: 61 | model = model_class_(configs, pretrained_word_embed) 62 | else: 63 | raise ValueError("Invalid Model Class") 64 | return model, tokenizer -------------------------------------------------------------------------------- /train/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import sys 6 | import torch 7 | 8 | import logging 9 | logging.basicConfig() 10 | logger = logging.getLogger(__name__) 11 | 12 | import random 13 | import numpy as np 14 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler 15 | 16 | from train.device import check_cuda 17 | from train.pretrain import load_pretrain 18 | from train.optimizer import load_optimizer 19 | from train.plot import eval_plot 20 | from dataset.utils import convert_examples_to_features, convert_features_to_dataloader 21 | from utils.datautils import check_dir 22 | 23 | # import torchsnooper 24 | 25 | class Trainer(object): 26 | 27 | def __init__(self, configs, model_class, processor): 28 | """ 29 | configs: 配置 30 | model_clas: str, 模型名称, 'BertBiLSTMCRF' 31 | processor: DataProcessor 32 | """ 33 | self.model_class = model_class 34 | self.device, use_gpu = check_cuda(configs) 35 | self.configs = configs 36 | self.fine_tune_dir = os.path.join(self.configs['finetune_model_dir'], model_class) 37 | self.model, self.tokenizer = load_pretrain(configs, model_class, self.fine_tune_dir, processor, eval = False) 38 | self.model.to(self.device) 39 | 40 | self.configs = configs 41 | self.batch_size = configs['batch_size'] 42 | self.nb_epoch = configs['nb_epoch'] 43 | self.max_seq_length = configs['max_seq_length'] 44 | # 设置随机数 45 | self.random_seed = configs['random_seed'] 46 | self.set_seed(use_gpu) 47 | train_examples = processor.get_train_examples(configs['data_dir'], tokenizer = self.tokenizer) 48 | dev_examples = processor.get_dev_examples(configs['data_dir'], tokenizer = self.tokenizer) 49 | 50 | self.configs['num_train_steps'] = int(len(train_examples)/self.batch_size) * self.nb_epoch 51 | self.optimizer, self.scheduler = load_optimizer(self.configs, self.model) 52 | self.max_grad_norm = configs['max_grad_norm'] 53 | self.label_list = processor.get_labels() 54 | 55 | train_features = convert_examples_to_features(examples = train_examples, 56 | max_seq_length = self.max_seq_length, 57 | tokenizer = self.tokenizer, 58 | label_list = self.label_list) 59 | dev_features = convert_examples_to_features(examples = dev_examples, 60 | max_seq_length = self.max_seq_length, 61 | tokenizer = self.tokenizer, 62 | label_list = self.label_list) 63 | self.train_count = len(train_examples) 64 | self.train_dataloader = convert_features_to_dataloader(train_features, batch_size = configs['batch_size']) 65 | self.dev_dataloader = convert_features_to_dataloader(dev_features, batch_size = configs['batch_size']) 66 | self.max_patience = configs['max_patience'] 67 | 68 | def set_seed(self, use_gpu): 69 | random.seed(self.random_seed) 70 | np.random.seed(self.random_seed) 71 | torch.manual_seed(self.random_seed) 72 | if use_gpu: 73 | torch.cuda.manual_seed(self.random_seed) 74 | torch.cuda.manual_seed_all(self.random_seed) 75 | 76 | def train(self): 77 | 78 | best_dev_loss = 1.e8 79 | current_patience = 0 80 | for epoch in range(self.nb_epoch): 81 | 82 | torch.cuda.empty_cache() 83 | 84 | train_loss, dev_loss = 0., 0. 85 | train_loss_log, dev_loss_log = [], [] 86 | self.model.train() 87 | iter_variable = 0 88 | # with torchsnooper.snoop(): 89 | for i, train_batch in enumerate(self.train_dataloader): 90 | 91 | torch.cuda.empty_cache() 92 | 93 | self.optimizer.zero_grad() 94 | input_ids, segment_ids, input_mask, label_ids= tuple(t.to(self.device) for t in train_batch) 95 | feats = self.model(input_ids, segment_ids, input_mask) 96 | loss = self.model.loss_fn(feats, input_mask, label_ids) 97 | loss.backward() 98 | with torch.no_grad(): 99 | train_loss += float(loss.item()) 100 | train_loss_log.append(float(loss.item())) 101 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 102 | self.optimizer.step() 103 | self.scheduler.step() 104 | iter_variable += self.batch_size 105 | if iter_variable > self.train_count: 106 | iter_variable = self.train_count 107 | sys.stdout.write('Epoch {0}/{1}: {2}/{3}\r'.format(epoch+1, self.nb_epoch, 108 | iter_variable, self.train_count)) 109 | # early stopping 110 | self.model.eval() 111 | with torch.no_grad(): 112 | for dev_batch in self.dev_dataloader: 113 | 114 | torch.cuda.empty_cache() 115 | 116 | input_ids, segment_ids, input_mask, label_ids = tuple(t.to(self.device) for t in dev_batch) 117 | feats = self.model(input_ids, segment_ids, input_mask) 118 | loss = self.model.loss_fn(feats, input_mask, label_ids) 119 | dev_loss += float(loss.item()) 120 | dev_loss_log.append(float(loss.item())) 121 | 122 | print('\ttrain loss: {0}, dev loss: {1}'.format(train_loss, dev_loss)) 123 | # early stopping 124 | if dev_loss < best_dev_loss: 125 | current_patience = 0 126 | best_dev_loss = dev_loss 127 | self.save_model() 128 | else: 129 | current_patience += 1 130 | if self.max_patience <= current_patience: 131 | print('finished training! (early stopping, max_patience: {0})'.format(self.max_patience)) 132 | return 133 | 134 | # eval_plot(self.configs, train_loss_log, dev_loss_log) 135 | print('finished training!') 136 | 137 | def save_model(self): 138 | self.model.save_pretrained(self.fine_tune_dir) 139 | self.tokenizer.save_pretrained(self.fine_tune_dir) -------------------------------------------------------------------------------- /utils/datautils.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import os 5 | import json 6 | import numpy as np 7 | import pickle 8 | 9 | def check_dir(dir_path): 10 | if not os.path.exists(dir_path): 11 | os.makedirs(dir_path, exist_ok = True) 12 | 13 | class MyEncoder(json.JSONEncoder): 14 | def default(self, obj): 15 | if isinstance(obj, np.integer): 16 | return int(obj) 17 | elif isinstance(obj, np.floating): 18 | return float(obj) 19 | elif isinstance(obj, np.ndarray): 20 | return obj.tolist() 21 | else: 22 | return super(MyEncoder, self).default(obj) 23 | 24 | def json_dump(dict_): 25 | return json.dumps(dict_, ensure_ascii=False, cls=MyEncoder) 26 | 27 | def json_load(file_path, encoding = 'utf-8'): 28 | with open(file_path, encoding = encoding) as json_file: 29 | return json.load(json_file) 30 | 31 | def json_write(file, file_path, encoding = 'utf-8'): 32 | with open(file_path, 'w', encoding = encoding) as json_file: 33 | json_file.write(json.dumps(file)) 34 | 35 | def pickle_save(file, file_path): 36 | with open(file_path, 'wb') as f: 37 | pickle.dump(file, f) 38 | 39 | def pickle_load(file_path): 40 | return pickle.load(open(file_path, 'rb')) --------------------------------------------------------------------------------