├── LICENSE ├── README.md ├── models.py ├── utils.py ├── Preprocess.ipynb ├── BERT+Bi_LSTM+CRF.ipynb └── main.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xiang WU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Chinese Medical Entity Recognition Based on BERT+Bi-LSTM+CRF 2 | 3 | ### Step 1 4 | 5 | I share the dataset on my google drive, please download the whole 'CCKS_2019_Task1' folder to the current working path. 6 | 7 | https://drive.google.com/drive/folders/1Z81nYCnHTvqlzQ0RnO-mFI9xRfJvCf5X?usp=sharing 8 | 9 | (Note: There are three empty folds (data, data_test and preprocessed_data) under 'CCKS_2019_Task1' folder) 10 | 11 | 12 | ### Step 2 13 | 14 | Please open 'Preprocess.ipynb' to process raw data. 15 | 16 | The processed train data are saved into './CCKS_2019_Task1/data/' by default; 17 | 18 | The processed test data are saved into './CCKS_2019_Task1/data_test/' by default. 19 | 20 | 21 | ### Step 3 22 | 23 | Please open 'BERT+Bi_LSTM+CRF.ipynb' to run codes. 24 | 25 | You can see I re-process preprocessed data to three '.txt' files for training, validating and testing; 26 | 27 | The three '.txt' files are saved into './CCKS_2019_Task1/processed_data/' by default. And then you can follow my codes to train and test. 28 | 29 | 30 | ### Hope you guys can play my codes successfully and enjoy them ! 31 | 32 | You also can see a Chinese explanatory article I shared on Zhihu: https://zhuanlan.zhihu.com/p/453350271 33 | 34 | If you have any problems, please feel free to contact me via email: [xavier.wu@connect.ust.hk]. 35 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @Author: Xavier WU 4 | @Date: 2021-11-30 5 | @LastEditTime: 2022-1-6 6 | @Description: This file is for building model. 7 | @All Right Reserve 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | from transformers import BertModel 13 | from torchcrf import CRF 14 | 15 | class Bert_BiLSTM_CRF(nn.Module): 16 | 17 | def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256): 18 | super(Bert_BiLSTM_CRF, self).__init__() 19 | self.tag_to_ix = tag_to_ix 20 | self.tagset_size = len(tag_to_ix) 21 | self.hidden_dim = hidden_dim 22 | self.embedding_dim = embedding_dim 23 | 24 | self.bert = BertModel.from_pretrained('bert-base-chinese') 25 | self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2, 26 | num_layers=2, bidirectional=True, batch_first=True) 27 | self.dropout = nn.Dropout(p=0.1) 28 | self.linear = nn.Linear(hidden_dim, self.tagset_size) 29 | self.crf = CRF(self.tagset_size, batch_first=True) 30 | 31 | def _get_features(self, sentence): 32 | with torch.no_grad(): 33 | embeds, _ = self.bert(sentence) 34 | enc, _ = self.lstm(embeds) 35 | enc = self.dropout(enc) 36 | feats = self.linear(enc) 37 | return feats 38 | 39 | def forward(self, sentence, tags, mask, is_test=False): 40 | emissions = self._get_features(sentence) 41 | if not is_test: # Training,return loss 42 | loss=-self.crf.forward(emissions, tags, mask, reduction='mean') 43 | return loss 44 | else: # Testing,return decoding 45 | decode=self.crf.decode(emissions, mask) 46 | return decode 47 | 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @Author: Xavier WU 4 | @Date: 2021-11-30 5 | @LastEditTime: 2022-1-6 6 | @Description: This file is for implementing Dataset. 7 | @All Right Reserve 8 | ''' 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | from transformers import BertTokenizer 13 | 14 | bert_model = 'bert-base-chinese' 15 | tokenizer = BertTokenizer.from_pretrained(bert_model) 16 | VOCAB = ('', '[CLS]', '[SEP]', 'O', 'B-BODY','I-TEST', 'I-EXAMINATIONS', 17 | 'I-TREATMENT', 'B-DRUG', 'B-TREATMENT', 'I-DISEASES', 'B-EXAMINATIONS', 18 | 'I-BODY', 'B-TEST', 'B-DISEASES', 'I-DRUG') 19 | 20 | tag2idx = {tag: idx for idx, tag in enumerate(VOCAB)} 21 | idx2tag = {idx: tag for idx, tag in enumerate(VOCAB)} 22 | MAX_LEN = 256 - 2 23 | 24 | class NerDataset(Dataset): 25 | ''' Generate our dataset ''' 26 | def __init__(self, f_path): 27 | self.sents = [] 28 | self.tags_li = [] 29 | 30 | with open(f_path, 'r', encoding='utf-8') as f: 31 | lines = [line.split('\n')[0] for line in f.readlines() if len(line.strip())!=0] 32 | 33 | tags = [line.split('\t')[1] for line in lines] 34 | words = [line.split('\t')[0] for line in lines] 35 | 36 | word, tag = [], [] 37 | for char, t in zip(words, tags): 38 | if char != '。': 39 | word.append(char) 40 | tag.append(t) 41 | else: 42 | if len(word) > MAX_LEN: 43 | self.sents.append(['[CLS]'] + word[:MAX_LEN] + ['[SEP]']) 44 | self.tags_li.append(['[CLS]'] + tag[:MAX_LEN] + ['[SEP]']) 45 | else: 46 | self.sents.append(['[CLS]'] + word + ['[SEP]']) 47 | self.tags_li.append(['[CLS]'] + tag + ['[SEP]']) 48 | word, tag = [], [] 49 | 50 | def __getitem__(self, idx): 51 | words, tags = self.sents[idx], self.tags_li[idx] 52 | token_ids = tokenizer.convert_tokens_to_ids(words) 53 | laebl_ids = [tag2idx[tag] for tag in tags] 54 | seqlen = len(laebl_ids) 55 | return token_ids, laebl_ids, seqlen 56 | 57 | def __len__(self): 58 | return len(self.sents) 59 | 60 | def PadBatch(batch): 61 | maxlen = max([i[2] for i in batch]) 62 | token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch]) 63 | label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch]) 64 | mask = (token_tensors > 0) 65 | return token_tensors, label_tensors, mask 66 | -------------------------------------------------------------------------------- /Preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Preprocess.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "id": "ZSK6_URGzKHd" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "cd drive/MyDrive/" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "source": [ 33 | "import json\n", 34 | "\n", 35 | "FILE1 = './CCKS_2019_Task1/subtask1_training_part1.txt'\n", 36 | "FILE2 = './CCKS_2019_Task1/subtask1_training_part2.txt'\n", 37 | "FILE3 = './CCKS_2019_Task1/subtask1_test_set_with_answer.json'\n", 38 | "\n", 39 | "PATH1 = './CCKS_2019_Task1/data/data1-'\n", 40 | "PATH2 = './CCKS_2019_Task1/data/data2-'\n", 41 | "PATH3 = './CCKS_2019_Task1/data_test/data-test-'\n", 42 | "\n", 43 | "def Process_File(FILENAME, PATH, enc):\n", 44 | " with open(FILENAME, 'r', encoding=enc) as f:\n", 45 | " i = 0\n", 46 | " while True:\n", 47 | " txt = f.readline()\n", 48 | " if not txt: break # end loop\n", 49 | " i+=1\n", 50 | " j = json.loads(txt)\n", 51 | " orig = j['originalText'] # original text\n", 52 | " entities = j['entities'] # entity part\n", 53 | " pathO = PATH + str(i) + '-original.txt'\n", 54 | " pathE = PATH + str(i) + '.txt'\n", 55 | "\n", 56 | " with open(pathO, 'w', encoding='utf-8') as o1: # write the original text\n", 57 | " o1.write(orig)\n", 58 | " o1.flush\n", 59 | "\n", 60 | " with open(pathE, 'w', encoding='utf-8') as o2: # wirte entity file\n", 61 | " for e in entities:\n", 62 | " start = e['start_pos'] # extract start position\n", 63 | " end = e['end_pos'] # extract end position\n", 64 | " name = orig[start:end] # entity content\n", 65 | " ty = e['label_type'] # entity label type\n", 66 | " label = '{0}\\t{1}\\t{2}\\t{3}\\n'.format(name, start, end, ty)\n", 67 | " o2.write(label) \n", 68 | " o2.flush" 69 | ], 70 | "metadata": { 71 | "id": "2XGlZ1RwzQdj" 72 | }, 73 | "execution_count": 2, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "source": [ 79 | "Process_File(FILE1, PATH1, 'utf-8-sig')\n", 80 | "Process_File(FILE2, PATH2, 'utf-8-sig')\n", 81 | "Process_File(FILE3, PATH3, 'utf-8')" 82 | ], 83 | "metadata": { 84 | "id": "AIChvi6azqwV" 85 | }, 86 | "execution_count": 3, 87 | "outputs": [] 88 | } 89 | ] 90 | } -------------------------------------------------------------------------------- /BERT+Bi_LSTM+CRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "github.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "id": "O5s4Uvtx0AnL" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "cd drive/MyDrive/" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "source": [ 33 | "import numpy as np\n", 34 | "import os\n", 35 | "from sklearn.model_selection import train_test_split" 36 | ], 37 | "metadata": { 38 | "id": "LqyYaqfZnIZ_" 39 | }, 40 | "execution_count": null, 41 | "outputs": [] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "label_dict = {'药物':'DRUG',\n", 47 | " '解剖部位':'BODY',\n", 48 | " '疾病和诊断':'DISEASES',\n", 49 | " '影像检查':'EXAMINATIONS',\n", 50 | " '实验室检验':'TEST',\n", 51 | " '手术':'TREATMENT'}\n", 52 | "\n", 53 | "TRAIN = './CCKS_2019_Task1/processed_data/train_dataset.txt'\n", 54 | "VALID = './CCKS_2019_Task1/processed_data/val_dataset.txt'\n", 55 | "TEST = './CCKS_2019_Task1/processed_data/test_dataset.txt'\n", 56 | "\n", 57 | "def sentence2BIOlabel(sentence, label_from_file):\n", 58 | " \"\"\" BIO Tagging \"\"\"\n", 59 | " sentence_label = ['O']*len(sentence)\n", 60 | " if label_from_file=='':\n", 61 | " return sentence_label\n", 62 | " \n", 63 | " for line in label_from_file.split('\\n'):\n", 64 | " \n", 65 | " entity_info = line.strip().split('\\t')\n", 66 | " start_index = int(entity_info[1]) \n", 67 | " end_index = int(entity_info[2]) \n", 68 | " entity_label = label_dict[entity_info[3]] \n", 69 | " # Frist entity: B-xx\n", 70 | " sentence_label[start_index] = 'B-'+entity_label\n", 71 | " # Other: I-xx\n", 72 | " for i in range(start_index+1, end_index):\n", 73 | " sentence_label[i] = 'I-'+entity_label\n", 74 | " return sentence_label\n", 75 | "\n", 76 | "def loadRawData(fileName):\n", 77 | " \"\"\" Loading raw data and tagging \"\"\"\n", 78 | " sentence_list = []\n", 79 | " label_list = []\n", 80 | "\n", 81 | " for file_name in os.listdir(fileName):\n", 82 | " \n", 83 | " if '.DS_Store' == file_name:\n", 84 | " continue\n", 85 | "\n", 86 | " if 'original' in file_name:\n", 87 | " org_file = fileName + file_name\n", 88 | " lab_file = fileName + file_name.replace('-original', '')\n", 89 | "\n", 90 | " with open(org_file, encoding='utf-8') as f:\n", 91 | " content = f.read().strip()\n", 92 | "\n", 93 | " with open(lab_file, encoding='utf-8') as f:\n", 94 | " content_label = f.read().strip()\n", 95 | "\n", 96 | " sentence_label = sentence2BIOlabel(content, content_label)\n", 97 | " sentence_list.append(content)\n", 98 | " label_list.append(sentence_label)\n", 99 | "\n", 100 | " return sentence_list, label_list\n", 101 | "\n", 102 | "def Save_data(filename, texts, tags):\n", 103 | " \"\"\" Processing to files in neeed format \"\"\"\n", 104 | " with open(filename, 'w') as f:\n", 105 | " for sent, tag in zip(texts, tags):\n", 106 | " size = len(sent)\n", 107 | " for i in range(size):\n", 108 | " f.write(sent[i])\n", 109 | " f.write('\\t')\n", 110 | " f.write(tag[i])\n", 111 | " f.write('\\n')" 112 | ], 113 | "metadata": { 114 | "id": "1yw17zB0nmzb" 115 | }, 116 | "execution_count": null, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "source": [ 122 | "# Training data\n", 123 | "sentence_list, label_list = loadRawData('./CCKS_2019_Task1/data/')\n", 124 | "# Test data\n", 125 | "sentence_list_test, label_list_test = loadRawData('./CCKS_2019_Task1/data_test/')\n", 126 | "\n", 127 | "# Split dataset\n", 128 | "words = [list(sent) for sent in sentence_list]\n", 129 | "t_words = [list(sent) for sent in sentence_list_test]\n", 130 | "tags = label_list\n", 131 | "t_tags = label_list_test\n", 132 | "train_texts, val_texts, train_tags, val_tags = train_test_split(words, tags, test_size=.2)\n", 133 | "test_texts, test_tags = t_words, t_tags\n", 134 | "\n", 135 | "# Obtain training, validating and testing files\n", 136 | "Save_data(TRAIN, train_texts, train_tags)\n", 137 | "Save_data(VALID, val_texts, val_tags)\n", 138 | "Save_data(TEST, test_texts, test_tags)" 139 | ], 140 | "metadata": { 141 | "id": "fGZUEmdBo0BK" 142 | }, 143 | "execution_count": 11, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "source": [ 149 | "pip install transformers==3.4" 150 | ], 151 | "metadata": { 152 | "id": "EvcXRCPxo3Ln" 153 | }, 154 | "execution_count": null, 155 | "outputs": [] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "source": [ 160 | "pip install pytorch-crf" 161 | ], 162 | "metadata": { 163 | "id": "kSi5VYRNo3Nx" 164 | }, 165 | "execution_count": null, 166 | "outputs": [] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "source": [ 171 | "!python main.py --n_epochs 30" 172 | ], 173 | "metadata": { 174 | "id": "FonoUHo1o3Pf" 175 | }, 176 | "execution_count": null, 177 | "outputs": [] 178 | } 179 | ] 180 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @Author: Xavier WU 4 | @Date: 2021-11-30 5 | @LastEditTime: 2022-1-6 6 | @Description: This file is for training, validating and testing. 7 | @All Right Reserve 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.utils import data 14 | import os 15 | import warnings 16 | import argparse 17 | import numpy as np 18 | from sklearn import metrics 19 | from models import Bert_BiLSTM_CRF 20 | from transformers import AdamW, get_linear_schedule_with_warmup 21 | from utils import NerDataset, PadBatch, VOCAB, tokenizer, tag2idx, idx2tag 22 | 23 | warnings.filterwarnings("ignore", category=DeprecationWarning) 24 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 25 | 26 | def train(e, model, iterator, optimizer, scheduler, device): 27 | model.train() 28 | losses = 0.0 29 | step = 0 30 | for i, batch in enumerate(iterator): 31 | step += 1 32 | x, y, z = batch 33 | x = x.to(device) 34 | y = y.to(device) 35 | z = z.to(device) 36 | 37 | loss = model(x, y, z) 38 | losses += loss.item() 39 | """ Gradient Accumulation """ 40 | ''' 41 | full_loss = loss / 2 # normalize loss 42 | full_loss.backward() # backward and accumulate gradient 43 | if step % 2 == 0: 44 | optimizer.step() # update optimizer 45 | scheduler.step() # update scheduler 46 | optimizer.zero_grad() # clear gradient 47 | ''' 48 | loss.backward() 49 | optimizer.step() 50 | scheduler.step() 51 | optimizer.zero_grad() 52 | 53 | print("Epoch: {}, Loss:{:.4f}".format(e, losses/step)) 54 | 55 | def validate(e, model, iterator, device): 56 | model.eval() 57 | Y, Y_hat = [], [] 58 | losses = 0 59 | step = 0 60 | with torch.no_grad(): 61 | for i, batch in enumerate(iterator): 62 | step += 1 63 | 64 | x, y, z = batch 65 | x = x.to(device) 66 | y = y.to(device) 67 | z = z.to(device) 68 | 69 | y_hat = model(x, y, z, is_test=True) 70 | 71 | loss = model(x, y, z) 72 | losses += loss.item() 73 | # Save prediction 74 | for j in y_hat: 75 | Y_hat.extend(j) 76 | # Save labels 77 | mask = (z==1) 78 | y_orig = torch.masked_select(y, mask) 79 | Y.append(y_orig.cpu()) 80 | 81 | Y = torch.cat(Y, dim=0).numpy() 82 | Y_hat = np.array(Y_hat) 83 | acc = (Y_hat == Y).mean()*100 84 | 85 | print("Epoch: {}, Val Loss:{:.4f}, Val Acc:{:.3f}%".format(e, losses/step, acc)) 86 | return model, losses/step, acc 87 | 88 | def test(model, iterator, device): 89 | model.eval() 90 | Y, Y_hat = [], [] 91 | with torch.no_grad(): 92 | for i, batch in enumerate(iterator): 93 | x, y, z = batch 94 | x = x.to(device) 95 | z = z.to(device) 96 | y_hat = model(x, y, z, is_test=True) 97 | # Save prediction 98 | for j in y_hat: 99 | Y_hat.extend(j) 100 | # Save labels 101 | mask = (z==1).cpu() 102 | y_orig = torch.masked_select(y, mask) 103 | Y.append(y_orig) 104 | 105 | Y = torch.cat(Y, dim=0).numpy() 106 | y_true = [idx2tag[i] for i in Y] 107 | y_pred = [idx2tag[i] for i in Y_hat] 108 | 109 | return y_true, y_pred 110 | 111 | if __name__=="__main__": 112 | 113 | labels = ['B-BODY', 114 | 'B-DISEASES', 115 | 'B-DRUG', 116 | 'B-EXAMINATIONS', 117 | 'B-TEST', 118 | 'B-TREATMENT', 119 | 'I-BODY', 120 | 'I-DISEASES', 121 | 'I-DRUG', 122 | 'I-EXAMINATIONS', 123 | 'I-TEST', 124 | 'I-TREATMENT'] 125 | 126 | best_model = None 127 | _best_val_loss = 1e18 128 | _best_val_acc = 1e-18 129 | 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--batch_size", type=int, default=64) 132 | parser.add_argument("--lr", type=float, default=0.001) 133 | parser.add_argument("--n_epochs", type=int, default=40) 134 | parser.add_argument("--trainset", type=str, default="./CCKS_2019_Task1/processed_data/train_dataset.txt") 135 | parser.add_argument("--validset", type=str, default="./CCKS_2019_Task1/processed_data/val_dataset.txt") 136 | parser.add_argument("--testset", type=str, default="./CCKS_2019_Task1/processed_data/test_dataset.txt") 137 | 138 | ner = parser.parse_args() 139 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 140 | model = Bert_BiLSTM_CRF(tag2idx).cuda() 141 | 142 | print('Initial model Done.') 143 | train_dataset = NerDataset(ner.trainset) 144 | eval_dataset = NerDataset(ner.validset) 145 | test_dataset = NerDataset(ner.testset) 146 | print('Load Data Done.') 147 | 148 | train_iter = data.DataLoader(dataset=train_dataset, 149 | batch_size=ner.batch_size, 150 | shuffle=True, 151 | num_workers=4, 152 | collate_fn=PadBatch) 153 | 154 | eval_iter = data.DataLoader(dataset=eval_dataset, 155 | batch_size=(ner.batch_size)//2, 156 | shuffle=False, 157 | num_workers=4, 158 | collate_fn=PadBatch) 159 | 160 | test_iter = data.DataLoader(dataset=test_dataset, 161 | batch_size=(ner.batch_size)//2, 162 | shuffle=False, 163 | num_workers=4, 164 | collate_fn=PadBatch) 165 | 166 | #optimizer = optim.Adam(self.model.parameters(), lr=ner.lr, weight_decay=0.01) 167 | optimizer = AdamW(model.parameters(), lr=ner.lr, eps=1e-6) 168 | 169 | # Warmup 170 | len_dataset = len(train_dataset) 171 | epoch = ner.n_epochs 172 | batch_size = ner.batch_size 173 | total_steps = (len_dataset // batch_size) * epoch if len_dataset % batch_size == 0 else (len_dataset // batch_size + 1) * epoch 174 | 175 | warm_up_ratio = 0.1 # Define 10% steps 176 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps) 177 | 178 | print('Start Train...,') 179 | for epoch in range(1, ner.n_epochs+1): 180 | 181 | train(epoch, model, train_iter, optimizer, scheduler, device) 182 | candidate_model, loss, acc = validate(epoch, model, eval_iter, device) 183 | 184 | if loss < _best_val_loss and acc > _best_val_acc: 185 | best_model = candidate_model 186 | _best_val_loss = loss 187 | _best_val_acc = acc 188 | 189 | print("=============================================") 190 | 191 | y_test, y_pred = test(best_model, test_iter, device) 192 | print(metrics.classification_report(y_test, y_pred, labels=labels, digits=3)) 193 | --------------------------------------------------------------------------------