├── requirements.txt ├── data └── README.md ├── README.md ├── text-classifizer.py ├── sequence-label.py ├── process.py ├── utils.py ├── .gitignore ├── config.py ├── test.py ├── LICENSE ├── trainer.py └── model.py /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wireless911/bert-text/HEAD/requirements.txt -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # 数据文件夹 2 | 文本多分类 3 | 4 | ``` 5 | 文件 csv 6 | ···· data/ 7 | text-classifizer/ 8 | train.csv 9 | eval.csv 10 | 格式 11 | label,text 12 | 0,这个东西真不错 13 | 1,用了再来评,感觉一般 14 | 2,这个产品一点儿都不好用 15 | 16 | ``` 17 | 18 | 序列标注任务 19 | 20 | ``` 21 | 文件 csv 22 | ···· data/ 23 | squence-label/ 24 | train.csv 25 | eval.csv 26 | 27 | 格式 28 | text,label 29 | 小 明 是 个 好 人,B-person I-person O O O O 30 | ``` 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bert文本多分类、bert-bilstm-crf序列标注任务 2 | 基于BERT的中文情感分类任务 3 | 基于BERT、LSTM、CRF 的中文序列标注任务 4 | 5 | #### 文本分类 6 | bert-dense 7 | ``` 8 | 这里的文本分类主要是多分类,如果是二分类任务可以自己替换损失函数 9 | ``` 10 | 11 | 12 | #### 序列标注 13 | bert-bilstm-crf 序列标注任务 14 | 15 | pytorch 微调 bert 模型 应用于下游分类、序列标注任务, 16 | bert模块使用的是hugging face 发布的第三方库[transformers](https://huggingface.co/transformers/) 17 | crf模块参考了[pytorch-crf](https://pytorch-crf.readthedocs.io/en/stable/)的内容,做了部分修改,方便计算准确率 18 | 19 | 20 | ##### 环境配置 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ##### 训练参数配置 26 | ``` 27 | config.py 28 | 29 | 序列标注需要修改自己的标签 SequenceLabelConfig.TAG_TO_ID 30 | ``` 31 | 32 | 33 | ##### 数据准备 34 | ``` 35 | 参考 data/README.md 文件 36 | 37 | ``` 38 | 39 | ##### 训练模型 40 | ``` 41 | 文本分类 python text-classifizer.py 42 | 序列标注 python sequence-label.py 43 | ``` 44 | 45 | ##### 查看训练过程日志记录(tensorboard) 46 | ``` 47 | tensorboard.exe --logdir=logs 48 | ``` 49 | 50 | 51 | ##### 模型训练,验证结果: 52 | | | training_acc | training_loss | eval_acc | eval_loss | 53 | | ---- | ---- | ----| ----| ----| 54 | | 文本分类 | 0.9766 |0.07909 |0.9922 | 0.0868| 55 | | 序列标注 | 0.9838 |19.706 | 0.9175| 38.77| 56 | 57 | 58 | -------------------------------------------------------------------------------- /text-classifizer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from typing import Optional, Text 4 | from torch.utils.data import DataLoader 5 | from transformers import BertTokenizer 6 | from utils import CustomTextClassifizerDataset 7 | from trainer import TextClassifizerTrainer 8 | from model import TextClassificationModel 9 | from config import TextClassifizerConfig 10 | 11 | # load config from object 12 | config = TextClassifizerConfig() 13 | 14 | # tokenizer 15 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 16 | 17 | # load dataset 18 | train_datasets = CustomTextClassifizerDataset(config.train_data, tokenizer, config.max_sequence_length) 19 | eval_datasets = CustomTextClassifizerDataset(config.eval_data, tokenizer, config.max_sequence_length) 20 | 21 | # dataloader 22 | train_dataloader = DataLoader(train_datasets, batch_size=config.batch_size, shuffle=True) 23 | eval_dataloader = DataLoader(eval_datasets, batch_size=config.batch_size, shuffle=True) 24 | 25 | # create model 26 | model = TextClassificationModel(config.max_sequence_length, config.num_classes) 27 | # create trainer 28 | trainer = TextClassifizerTrainer( 29 | model=model, 30 | args=None, 31 | train_dataloader=train_dataloader, 32 | eval_dataloader=eval_dataloader, 33 | epochs=config.epochs, 34 | learning_rate=config.learning_rate, 35 | device=config.device 36 | 37 | ) 38 | 39 | # train model 40 | trainer.train() 41 | -------------------------------------------------------------------------------- /sequence-label.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import time 3 | import torch 4 | from typing import Optional, Text 5 | from torch.utils.data import DataLoader 6 | from torch.utils.tensorboard import SummaryWriter 7 | from transformers import BertTokenizer 8 | 9 | from trainer import SequenceLabelTrainer 10 | from config import SequenceLabelConfig 11 | from utils import CustomSequenceLabelDataset 12 | from model import BiLSTM_CRF 13 | 14 | config = SequenceLabelConfig() 15 | tag_to_ix = SequenceLabelConfig.TAG_TO_ID 16 | # tokenizer 17 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 18 | 19 | train_datasets = CustomSequenceLabelDataset(config.train_data, tokenizer, config) 20 | eval_datasets = CustomSequenceLabelDataset(config.eval_data, tokenizer, config) 21 | 22 | # create model 23 | model = BiLSTM_CRF(tag_to_ix, config.max_sequence_length, config.hidden_dim,config.device) 24 | model.summuary() 25 | 26 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=0.1) 27 | 28 | # dataloader 29 | train_dataloader = DataLoader(train_datasets, batch_size=config.batch_size, shuffle=True) 30 | eval_dataloader = DataLoader(eval_datasets, batch_size=config.batch_size, shuffle=True) 31 | 32 | # create trainer 33 | trainer = SequenceLabelTrainer( 34 | model=model, 35 | args=None, 36 | train_dataloader=train_dataloader, 37 | eval_dataloader=eval_dataloader, 38 | epochs=config.epochs, 39 | learning_rate=config.learning_rate, 40 | device=config.device, 41 | padding_tag=config.TAG_TO_ID[config.PAD_TAG] 42 | ) 43 | 44 | # train model 45 | trainer.train() 46 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | dataframe = pd.read_excel("data/addressInfo.xlsx") 4 | dataframe = dataframe.where(dataframe.notnull(), None) 5 | data = dataframe.itertuples(index=False) 6 | 7 | data_list = [] 8 | 9 | key_mapping = { 10 | "person":"person", 11 | "mobile":"mobile", 12 | "city":"cities", 13 | "province":"provin", 14 | "county":"county", 15 | "street":"street", 16 | "detail":"detail" 17 | } 18 | 19 | 20 | for r in data: 21 | mapping = dict( 22 | person=r.person.strip() if r.person else None, 23 | mobile=str(int(r.mobile)) if not pd.isnull(r.mobile) else None, 24 | city=r.city.strip() if r.city else None, 25 | province=r.province.strip() if r.province else None, 26 | county=r.county.strip() if r.county else None, 27 | street=r.street.strip() if r.street else None, 28 | detail=r.detail.strip() if r.detail else None 29 | ) 30 | address = r.address 31 | if address: 32 | address = address.replace(" ", "") 33 | address = address.replace(" ", "") 34 | text = address 35 | tags = [] 36 | 37 | log = 0 38 | for k, x in enumerate(text): 39 | change = False 40 | if k < log: 41 | continue 42 | for key, a in mapping.items(): 43 | if a is None: 44 | continue 45 | elif text[k:(k + len(a))] == a: 46 | start = f"B-{key_mapping[key]}" 47 | end = f"I-{key_mapping[key]}" 48 | arr = [start] + [end] * (len(a) - 1) 49 | tags.extend(arr) 50 | log = k + len(a) 51 | change = True 52 | break 53 | else: 54 | continue 55 | if not change: 56 | tags.append("O") 57 | 58 | text = " ".join([x for x in text if x != " "]) 59 | label = " ".join(tags) 60 | res = (text, label) 61 | data_list.append(res) 62 | import random 63 | random.shuffle(data_list) 64 | df = pd.DataFrame(data_list, columns=["text", "label"]) 65 | df.to_csv("data/train.csv",index=False) 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import re 4 | from typing import Text, Optional, Dict, Set 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from config import SequenceLabelConfig 10 | 11 | 12 | class CustomTextClassifizerDataset(Dataset): 13 | """classifizer dataset""" 14 | 15 | def __init__(self, filepath, tokenizer, max_length): 16 | self.dataframe = pd.read_csv(filepath) 17 | self.text_dir = filepath 18 | self.tokenizer = tokenizer 19 | self.max_length = max_length 20 | 21 | def __len__(self): 22 | return len(self.dataframe) 23 | 24 | def __getitem__(self, idx): 25 | labels = self.dataframe.iloc[idx, 0] 26 | text = self.dataframe.iloc[idx, 1] 27 | token = self.tokenizer(text, return_tensors='pt', padding="max_length", max_length=self.max_length, 28 | truncation=True) 29 | item = {"labels": torch.tensor(labels, dtype=torch.long), "token": token} 30 | return item 31 | 32 | @property 33 | def num_classes(self) -> int: 34 | return len(set(self.dataframe["label"])) 35 | 36 | 37 | class CustomSequenceLabelDataset(Dataset): 38 | """sequence label dataset""" 39 | 40 | def __init__(self, filepath, tokenizer, config: SequenceLabelConfig): 41 | self.dataframe = pd.read_csv(filepath) 42 | self.text_dir = filepath 43 | self.max_length = config.max_length 44 | self.tag2idx = config.TAG_TO_ID 45 | self.pad_tag = config.PAD_TAG 46 | self.tokenizer = tokenizer 47 | 48 | def __len__(self): 49 | return len(self.dataframe) 50 | 51 | def __getitem__(self, idx): 52 | text = self.dataframe.iloc[idx, 0] 53 | token = self.tokenizer(text, return_tensors='pt', padding="max_length", max_length=self.max_length, 54 | truncation=True) 55 | 56 | labels = self.dataframe.iloc[idx, 1] 57 | labels = [self.tag2idx[t] for t in labels.split(" ")] 58 | padding_length = self.max_length - len(labels) 59 | padding_list = [self.tag2idx[self.pad_tag]] * padding_length 60 | pad = torch.LongTensor(padding_list) 61 | labels = torch.cat((torch.LongTensor(labels), pad), dim=-1) 62 | item = {"labels": labels, "token": token} 63 | return item 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | models/ 131 | logs/ 132 | .idea/ 133 | data/text-classifizer 134 | data/squence-label 135 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from typing import Text, Dict, Optional 2 | 3 | import torch 4 | 5 | 6 | class TextClassifizerConfig(object): 7 | """Configuration for `TextClassifizer`.""" 8 | 9 | def __init__( 10 | self, 11 | num_classes: int = 3, 12 | batch_size: int = 4, 13 | learning_rate: float = 1e-6, 14 | epochs: int = 20, 15 | max_sequence_length: int = 100, 16 | train_data: Text = "data/text-classifizer/train.csv", 17 | eval_data: Text = "data/text-classifizer/dev.csv" 18 | ): 19 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | print('Using {} device'.format(self.device)) 21 | if torch.cuda.is_available(): 22 | torch.cuda.empty_cache() 23 | self.num_classes = num_classes 24 | self.train_data = train_data 25 | self.eval_data = eval_data 26 | self.batch_size = batch_size 27 | self.learning_rate = learning_rate 28 | self.epochs = epochs 29 | self.max_sequence_length = max_sequence_length 30 | 31 | 32 | class SequenceLabelConfig(object): 33 | """Configuration for `SequenceLabelConfig`.""" 34 | PAD_TAG = "[PAD]" 35 | MAX_LENGTH = 250 36 | TAG_TO_ID = { 37 | PAD_TAG: 0, 38 | "B-person": 1, 39 | "I-person": 2, 40 | "B-mobile": 3, 41 | "I-mobile": 4, 42 | "B-provin": 5, 43 | "I-provin": 6, 44 | "B-cities": 7, 45 | "I-cities": 8, 46 | "B-county": 9, 47 | "I-county": 10, 48 | "B-street": 11, 49 | "I-street": 12, 50 | "B-detail": 13, 51 | "I-detail": 14, 52 | "O": 15, 53 | } 54 | 55 | def __init__( 56 | self, 57 | batch_size: int = 8, 58 | learning_rate: float = 5e-6, 59 | epochs: int = 50, 60 | max_length=MAX_LENGTH, 61 | hidden_dim=50, 62 | train_data: Text = "data/squence-label/train.csv", 63 | eval_data: Text = "data/squence-label/dev.csv", 64 | albert_vocab_file: Optional[Text] = "albert_base_zh/vocab_chinese.txt", 65 | albert_hidden_size: Optional[int] = 768, 66 | albert_pytorch_model_path: Optional[Text] = "models" 67 | 68 | ): 69 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 70 | print('Using {} device'.format(self.device)) 71 | if torch.cuda.is_available(): 72 | torch.cuda.empty_cache() 73 | self.train_data = train_data 74 | self.eval_data = eval_data 75 | self.batch_size = batch_size 76 | self.learning_rate = learning_rate 77 | self.epochs = epochs 78 | self.max_length = max_length 79 | self.hidden_dim = hidden_dim 80 | self.albert_vocab_file = albert_vocab_file 81 | self.albert_hidden_size = albert_hidden_size 82 | self.albert_pytorch_model_path = albert_pytorch_model_path 83 | self.tag_to_id = self.TAG_TO_ID -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Text 2 | 3 | import torch 4 | from transformers import BertTokenizer 5 | 6 | from config import TextClassifizerConfig, SequenceLabelConfig 7 | from model import TextClassificationModel, BiLSTM_CRF 8 | import torch.nn.functional as F 9 | import pandas as pd 10 | import re 11 | # load config from object 12 | # config = TextClassifizerConfig() 13 | # 14 | # model = TextClassificationModel(config.max_sequence_length, 3) 15 | # 16 | # model.load_state_dict(torch.load('models/model-B128-E20-L2e-05.pkl')) 17 | # 18 | # model.eval() 19 | # with torch.no_grad(): 20 | # dataframe = pd.read_csv("data/text-classifizer/test.csv") 21 | # text_list = dataframe["text"] 22 | # label = dataframe["label"] 23 | # for idx,text in enumerate(text_list): 24 | # pred = model(text) 25 | # res = pred.argmax(1).item() 26 | # scores = F.softmax(pred) 27 | # print(f'lable{label[idx]} {res},scores:{scores.squeeze(0)[res]}') 28 | 29 | # tokenizer 30 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 31 | config = SequenceLabelConfig() 32 | tag_to_ix = SequenceLabelConfig.TAG_TO_ID 33 | model = BiLSTM_CRF(tag_to_ix, config.max_length, config.hidden_dim,config.device) 34 | model.summuary() 35 | model.to(config.device) 36 | # 37 | model.load_state_dict(torch.load('models/sequence-label-checkpoint_7_epoch.pkl')["model_state_dict"]) 38 | # 39 | ix_to_tag = {v: k for k, v in tag_to_ix.items()} 40 | 41 | model.eval() 42 | # with torch.no_grad(): 43 | # dataframe = pd.read_csv("data/squence-label/dev.csv") 44 | # text_list = dataframe["text"] 45 | # label = dataframe["label"] 46 | # for idx, text in enumerate(text_list): 47 | # pred, padding_count = model(text) 48 | # res = [ix_to_tag[x] for x in pred.squeeze(0).tolist()[:config.max_length - padding_count]] 49 | # text = text.split(" ") 50 | # print([f"{a}/{b}" for a, b in zip(text, res)]) 51 | 52 | def predict(text:Optional[Text]): 53 | with torch.no_grad(): 54 | curr_text = text 55 | curr_text = curr_text.replace(" ","") 56 | text_list = [x for x in curr_text] 57 | text = " ".join(text_list) 58 | token = tokenizer(text, return_tensors='pt', padding="max_length", max_length=config.max_length, 59 | truncation=True) 60 | input_ids = token["input_ids"].squeeze(1).to(config.device) 61 | attention_mask = token["attention_mask"].squeeze(1).to(config.device) 62 | token_type_ids = token["token_type_ids"].squeeze(1).to(config.device) 63 | 64 | # Compute prediction and loss 65 | y_pred = model(input_ids, attention_mask, token_type_ids) 66 | res = [ix_to_tag[x] for x in y_pred.squeeze(0).tolist()] 67 | # print([f"{a}/{b}" for a, b in zip(text_list, res)]) 68 | label = "".join(res) 69 | label = label.replace("O", "U-meamea") 70 | pattern = re.compile( 71 | "B-person*(I-person)*|B-mobile*(I-mobile)|B-provin*(I-provin)*|B-cities*(I-cities)*|B-county*(I-county)*|B-street*(I-street)*|B-detail*(I-detail)*") 72 | resutlt = re.finditer(pattern, label) 73 | result_word = {"province": [], "city": [], "county": [], "street": [], 74 | "detail": [], "person": [], "cellphones": []} 75 | shiti_dict = {"province": [-1, -1], "city": [-1, -1], "county": [-1, -1], "street": [-1, -1], 76 | "detail": [-1, -1], "person": [-1, -1], "mobile": [-1, -1]} 77 | 78 | for i in resutlt: 79 | start_index = int(i.span(0)[0] / 8) 80 | end_index = int((i.span(0)[1] - i.span(0)[0]) / 8) + start_index 81 | if i.group(0)[0:8] == "B-provin": 82 | shiti_dict["province"][0] = start_index 83 | shiti_dict["province"][1] = end_index 84 | if i.group(0)[0:8] == "B-cities": 85 | shiti_dict["city"][0] = start_index 86 | shiti_dict["city"][1] = end_index 87 | if i.group(0)[0:8] == "B-county": 88 | shiti_dict["county"][0] = start_index 89 | shiti_dict["county"][1] = end_index 90 | if i.group(0)[0:8] == "B-street": 91 | shiti_dict["street"][0] = start_index 92 | shiti_dict["street"][1] = end_index 93 | if i.group(0)[0:8] == "B-detail": 94 | shiti_dict["detail"][0] = start_index 95 | shiti_dict["detail"][1] = end_index 96 | if i.group(0)[0:8] == "B-person": 97 | shiti_dict["person"][0] = start_index 98 | shiti_dict["person"][1] = end_index 99 | if i.group(0)[0:8] == "B-mobile": 100 | shiti_dict["mobile"][0] = start_index 101 | shiti_dict["mobile"][1] = end_index 102 | if shiti_dict["province"][0] != -1: 103 | pro = "".join(text_list[shiti_dict["province"][0]:shiti_dict["province"][1]]) 104 | result_word["province"].append(pro) 105 | if shiti_dict["city"][0] != -1: 106 | cit = "".join(text_list[shiti_dict["city"][0]:shiti_dict["city"][1]]) 107 | result_word["city"].append(cit) 108 | if shiti_dict["county"][0] != -1: 109 | cou = "".join(text_list[shiti_dict["county"][0]:shiti_dict["county"][1]]) 110 | result_word["county"].append(cou) 111 | if shiti_dict["street"][0] != -1: 112 | stre = "".join(text_list[shiti_dict["street"][0]:shiti_dict["street"][1]]) 113 | result_word["street"].append(stre) 114 | if shiti_dict["detail"][0] != -1: 115 | det = "".join(text_list[shiti_dict["detail"][0]:shiti_dict["detail"][1]]) 116 | result_word["detail"].append(det) 117 | if shiti_dict["person"][0] != -1: 118 | per = "".join(text_list[shiti_dict["person"][0]:shiti_dict["person"][1]]) 119 | result_word["person"].append(per) 120 | if shiti_dict["mobile"][0] != -1: 121 | mob = "".join(text_list[shiti_dict["mobile"][0]:shiti_dict["mobile"][1]]) 122 | res = re.compile('(13\d{9}|14[5|7]\d{8}|15\d{9}|166{\d{8}|17[3|6|7]{\d{8}|18\d{9})') 123 | s = re.findall(res, curr_text) 124 | try: 125 | mob = s[0] 126 | except: 127 | mob = mob[0:11] 128 | result_word["cellphones"].append(mob) 129 | print("".join(text_list)) 130 | print(result_word) 131 | 132 | 133 | dataframe = pd.read_csv("data/squence-label/dev.csv") 134 | text_list = dataframe["text"] 135 | label = dataframe["label"] 136 | 137 | for idx, text in enumerate(text_list): 138 | predict(text) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import partial 4 | import torch 5 | from typing import Union, Optional, Text, Tuple 6 | from torch.utils.data import Dataset, DataLoader 7 | from torch.utils.tensorboard import SummaryWriter 8 | import abc 9 | 10 | 11 | class Trainer(object): 12 | def name(self) -> Text: 13 | raise NotImplementedError 14 | 15 | def save(self, model: torch.nn.Module = None, optimizer=None, epoch: Optional[int] = None): 16 | """save model state dict into model path""" 17 | checkpoint = {"model_state_dict": model.state_dict(), 18 | "optimizer_state_dict": optimizer.state_dict(), 19 | "epoch": epoch} 20 | path_checkpoint = f"models/{self.name()}-checkpoint_{epoch}_epoch.pkl" 21 | torch.save(checkpoint, path_checkpoint) 22 | 23 | 24 | class TextClassifizerTrainer(Trainer): 25 | 26 | def name(self) -> Text: 27 | return "text_classifizer" 28 | 29 | def __init__( 30 | self, model: torch.nn.Module = None, 31 | args: Optional[Tuple] = None, 32 | train_dataloader: DataLoader = None, 33 | eval_dataloader: DataLoader = None, 34 | epochs: Optional[int] = 30, 35 | learning_rate: Optional[float] = 1e-5, 36 | device: Optional[Text] = "cpu" 37 | ): 38 | self.writer = SummaryWriter( 39 | f'logs/text-classifier-B-{train_dataloader.batch_size}-E{epochs}-L{learning_rate}-{time.time()}') 40 | self.writer.flush() 41 | 42 | if model is None: 43 | raise RuntimeError("`Trainer` requires a `model` ") 44 | self.epochs = epochs 45 | self.learning_rate = learning_rate 46 | self.model = model 47 | self.train_dataloader = train_dataloader 48 | self.eval_dataloader = eval_dataloader 49 | self.args = args 50 | self.device = device 51 | self.model.to(device) 52 | 53 | def train(self): 54 | loss_fn = torch.nn.CrossEntropyLoss() 55 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=0.1) 56 | 57 | for epoch in range(1, self.epochs + 1): 58 | epoch_start_time = time.time() 59 | self.train_loop(epoch, loss_fn, optimizer) 60 | accu_val, loss = self.eval_loop(epoch, loss_fn) 61 | print('-' * 59) 62 | print('| end of epoch {:3d} | time: {:5.2f}s | ' 63 | 'valid accuracy {:8.3f} ' 64 | 'valid loss {:8.3f} ' 65 | .format(epoch, 66 | time.time() - epoch_start_time, 67 | accu_val, loss)) 68 | print('-' * 59) 69 | self.save(model=self.model, optimizer=optimizer, epoch=epoch) 70 | 71 | def train_loop(self, epoch, loss_fn, optimizer): 72 | self.model.train() 73 | total_acc, total_count = 0, 0 74 | 75 | for batch, data in enumerate(self.train_dataloader): 76 | y = data["labels"].to(self.device) 77 | token = data["token"] 78 | input_ids = token["input_ids"].squeeze(1).to(self.device) 79 | attention_mask = token["attention_mask"].squeeze(1).to(self.device) 80 | token_type_ids = token["token_type_ids"].squeeze(1).to(self.device) 81 | 82 | # Compute prediction and loss 83 | pred = self.model(input_ids, attention_mask, token_type_ids) 84 | loss = loss_fn(pred, y) 85 | 86 | # Backpropagation 87 | optimizer.zero_grad() 88 | loss.backward() 89 | optimizer.step() 90 | 91 | current_acc = (pred.argmax(1) == y).sum().item() 92 | current_count = y.size(0) 93 | loss, current = loss.item(), batch * len(token) 94 | 95 | total_acc += current_acc 96 | total_count += current_count 97 | 98 | # ...log the running loss 99 | self.writer.add_scalar('training loss', 100 | loss, 101 | (epoch - 1) * len(self.train_dataloader) + batch) 102 | 103 | # ...log a Matplotlib Figure showing the model's predictions on a 104 | # random mini-batch 105 | # ...log the running loss 106 | self.writer.add_scalar('training acc', 107 | current_acc / current_count, 108 | (epoch - 1) * len(self.train_dataloader) + batch) 109 | 110 | print('| epoch {:3d} | {:5d}/{:5d} batches ' 111 | '| accuracy {:8.3f}' 112 | '| loss {:8.3f}' 113 | .format(epoch, batch, len(self.train_dataloader), 114 | current_acc / current_count, loss)) 115 | 116 | def eval_loop(self, epoch, loss_fn): 117 | self.model.eval() 118 | total_acc, total_count = 0, 0 119 | loss = 0 120 | 121 | with torch.no_grad(): 122 | for batch, data in enumerate(self.eval_dataloader): 123 | y = data["labels"].to(self.device) 124 | token = data["token"] 125 | input_ids = token["input_ids"].squeeze(1).to(self.device) 126 | attention_mask = token["attention_mask"].squeeze(1).to(self.device) 127 | token_type_ids = token["token_type_ids"].squeeze(1).to(self.device) 128 | 129 | # Compute prediction and loss 130 | pred = self.model(input_ids, attention_mask, token_type_ids) 131 | loss = loss_fn(pred, y) 132 | loss, current = loss.item(), batch * len(token) 133 | current_acc = (pred.argmax(1) == y).sum().item() 134 | current_count = y.size(0) 135 | 136 | total_acc += current_acc 137 | total_count += current_count 138 | 139 | # ...log the running loss 140 | self.writer.add_scalar('eval loss', 141 | loss, 142 | (epoch - 1) * len(self.eval_dataloader) + batch) 143 | 144 | self.writer.add_scalar('eval acc', 145 | current_acc / current_count, 146 | (epoch - 1) * len(self.eval_dataloader) + batch) 147 | 148 | return total_acc / total_count, loss 149 | 150 | 151 | class SequenceLabelTrainer(Trainer): 152 | 153 | def name(self) -> Text: 154 | return "sequence-label" 155 | 156 | def __init__( 157 | self, model: torch.nn.Module = None, 158 | args: Optional[Tuple] = None, 159 | train_dataloader: DataLoader = None, 160 | eval_dataloader: DataLoader = None, 161 | epochs: Optional[int] = 30, 162 | learning_rate: Optional[float] = 1e-5, 163 | device: Optional[Text] = "cpu", 164 | padding_tag:Optional[int]=0 165 | ): 166 | self.writer = SummaryWriter( 167 | f'logs/sequence-label-B-{train_dataloader.batch_size}-E{epochs}-L{learning_rate}-{time.time()}') 168 | self.writer.flush() 169 | 170 | if model is None: 171 | raise RuntimeError("`Trainer` requires a `model` ") 172 | self.epochs = epochs 173 | self.learning_rate = learning_rate 174 | self.model = model 175 | self.train_dataloader = train_dataloader 176 | self.eval_dataloader = eval_dataloader 177 | self.args = args 178 | self.device = device 179 | self.model.to(self.device) 180 | self.padding_tag = padding_tag 181 | 182 | def train(self): 183 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=0.1) 184 | 185 | for epoch in range(1, self.epochs + 1): 186 | epoch_start_time = time.time() 187 | self.train_loop(epoch, optimizer) 188 | accu_val, loss = self.eval_loop(epoch) 189 | print('-' * 59) 190 | print('| end of epoch {:3d} | time: {:5.2f}s | ' 191 | 'valid accuracy {:8.3f} ' 192 | 'valid loss {:8.3f} ' 193 | .format(epoch, 194 | time.time() - epoch_start_time, 195 | accu_val, loss)) 196 | self.save(model=self.model, optimizer=optimizer, epoch=epoch) 197 | print('-' * 59) 198 | 199 | def train_loop(self, epoch, optimizer): 200 | self.model.train() 201 | for batch, data in enumerate(self.train_dataloader): 202 | y = data["labels"].to(self.device) 203 | token = data["token"] 204 | input_ids = token["input_ids"].squeeze(1).to(self.device) 205 | attention_mask = token["attention_mask"].squeeze(1).to(self.device) 206 | token_type_ids = token["token_type_ids"].squeeze(1).to(self.device) 207 | 208 | # Compute prediction and loss 209 | y_pred = self.model(input_ids, attention_mask, token_type_ids) 210 | loss = self.model.loss(input_ids, attention_mask, token_type_ids, y) 211 | 212 | # Backpropagation 213 | optimizer.zero_grad() 214 | loss.backward() 215 | optimizer.step() 216 | padding_count = (y_pred == self.padding_tag).sum() 217 | current_acc = (y_pred == y).sum().item() - padding_count 218 | current = y.size(0) * y.size(1) - padding_count 219 | 220 | loss = loss.item() 221 | 222 | # ...log the running loss 223 | self.writer.add_scalar('training loss', 224 | loss, 225 | (epoch - 1) * len(self.train_dataloader) + batch) 226 | 227 | # ...log a Matplotlib Figure showing the model's predictions on a 228 | # random mini-batch 229 | # ...log the running loss 230 | self.writer.add_scalar('training acc', 231 | current_acc / current, 232 | (epoch - 1) * len(self.train_dataloader) + batch) 233 | 234 | print('| epoch {:3d} | {:5d}/{:5d} batches ' 235 | '| accuracy {:8.3f}' 236 | '| loss {:8.3f}' 237 | .format(epoch, batch, len(self.train_dataloader), 238 | current_acc / current, loss)) 239 | 240 | def eval_loop(self, epoch): 241 | self.model.eval() 242 | total_acc, total_count = 0, 0 243 | loss = 0 244 | 245 | with torch.no_grad(): 246 | for batch, data in enumerate(self.eval_dataloader): 247 | y = data["labels"].to(self.device) 248 | token = data["token"] 249 | input_ids = token["input_ids"].squeeze(1).to(self.device) 250 | attention_mask = token["attention_mask"].squeeze(1).to(self.device) 251 | token_type_ids = token["token_type_ids"].squeeze(1).to(self.device) 252 | 253 | # Compute prediction and loss 254 | y_pred = self.model(input_ids, attention_mask, token_type_ids) 255 | loss = self.model.loss(input_ids, attention_mask, token_type_ids, y) 256 | padding_count = (y_pred == self.padding_tag).sum() 257 | current_acc = (y_pred == y).sum().item() - padding_count 258 | current = y.size(0) * y.size(1) - padding_count 259 | 260 | loss = loss.item() 261 | 262 | total_acc += current_acc 263 | total_count += current 264 | 265 | # ...log the running loss 266 | self.writer.add_scalar('eval loss', 267 | loss, 268 | (epoch - 1) * len(self.eval_dataloader) + batch) 269 | 270 | self.writer.add_scalar('eval acc', 271 | current_acc / current, 272 | (epoch - 1) * len(self.eval_dataloader) + batch) 273 | 274 | return total_acc / total_count, loss 275 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from typing import Optional, Text, List, Tuple 5 | 6 | from torch import Tensor 7 | from torch.autograd import Variable 8 | from transformers import BertTokenizer, BertModel 9 | from config import SequenceLabelConfig 10 | 11 | 12 | class TextClassificationModel(torch.nn.Module): 13 | """ rnn text classification""" 14 | 15 | def __init__(self, max_length: Optional[int] = None, num_class: Optional[int] = None): 16 | super(TextClassificationModel, self).__init__() 17 | self.max_length = max_length 18 | self.num_class = num_class 19 | self.bert_dim = 768 20 | self.bert = BertModel.from_pretrained('bert-base-chinese') 21 | self.classifizer = torch.nn.Linear(self.bert_dim, num_class) 22 | 23 | def forward(self, input_ids: Optional[Tensor], attention_mask: Optional[Tensor], 24 | token_type_ids: Optional[Tensor]) -> Tensor: 25 | outputs = self.bert( 26 | input_ids, 27 | attention_mask=attention_mask, 28 | token_type_ids=token_type_ids 29 | ) 30 | return self.classifizer(outputs.pooler_output) 31 | 32 | def summuary(self): 33 | print("Model structure: ", self, "\n\n") 34 | 35 | for name, param in self.named_parameters(): 36 | print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") 37 | 38 | 39 | class CRF(torch.nn.Module): 40 | """Conditional random field. 41 | This module implements a conditional random field [LMP01]_. The forward computation 42 | of this class computes the log likelihood of the given sequence of tags and 43 | emission score tensor. This class also has `~CRF.decode` method which finds 44 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_. 45 | Args: 46 | num_tags: Number of tags. 47 | batch_first: Whether the first dimension corresponds to the size of a minibatch. 48 | Attributes: 49 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size 50 | ``(num_tags,)``. 51 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size 52 | ``(num_tags,)``. 53 | transitions (`~torch.nn.Parameter`): Transition score tensor of size 54 | ``(num_tags, num_tags)``. 55 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). 56 | "Conditional random fields: Probabilistic models for segmenting and 57 | labeling sequence data". *Proc. 18th International Conf. on Machine 58 | Learning*. Morgan Kaufmann. pp. 282–289. 59 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm 60 | """ 61 | 62 | def __init__(self, num_tags: int, tag_to_ix: dict, max_length: int = 100, batch_first: bool = False, 63 | device: Text = "cpu") -> None: 64 | if num_tags <= 0: 65 | raise ValueError(f'invalid number of tags: {num_tags}') 66 | super().__init__() 67 | self.num_tags = num_tags 68 | self.tag_to_ix = tag_to_ix 69 | self.max_length = max_length 70 | self.batch_first = batch_first 71 | self.start_transitions = torch.nn.Parameter(torch.empty(num_tags).to(device)) 72 | self.end_transitions = torch.nn.Parameter(torch.empty(num_tags).to(device)) 73 | self.transitions = torch.nn.Parameter(torch.empty(num_tags, num_tags).to(device)) 74 | self.device = device 75 | 76 | self.reset_parameters() 77 | 78 | def reset_parameters(self) -> None: 79 | """Initialize the transition parameters. 80 | The parameters will be initialized randomly from a uniform distribution 81 | between -0.1 and 0.1. 82 | """ 83 | torch.nn.init.uniform_(self.start_transitions, -0.1, 0.1) 84 | torch.nn.init.uniform_(self.end_transitions, -0.1, 0.1) 85 | torch.nn.init.uniform_(self.transitions, -0.1, 0.1) 86 | 87 | def __repr__(self) -> str: 88 | return f'{self.__class__.__name__}(num_tags={self.num_tags})' 89 | 90 | def neg_log_likelihood_loss( 91 | self, 92 | emissions: torch.Tensor, 93 | tags: torch.LongTensor, 94 | mask: Optional[torch.ByteTensor] = None, 95 | reduction: str = 'sum', 96 | ) -> torch.Tensor: 97 | """Compute the conditional log likelihood of a sequence of tags given emission scores. 98 | Args: 99 | emissions (`~torch.Tensor`): Emission score tensor of size 100 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 101 | ``(batch_size, seq_length, num_tags)`` otherwise. 102 | tags (`~torch.LongTensor`): Sequence of tags tensor of size 103 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, 104 | ``(batch_size, seq_length)`` otherwise. 105 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 106 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 107 | reduction: Specifies the reduction to apply to the output: 108 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. 109 | ``sum``: the output will be summed over batches. ``mean``: the output will be 110 | averaged over batches. ``token_mean``: the output will be averaged over tokens. 111 | Returns: 112 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if 113 | reduction is ``none``, ``()`` otherwise. 114 | """ 115 | self._validate(emissions, tags=tags, mask=mask) 116 | if reduction not in ('none', 'sum', 'mean', 'token_mean'): 117 | raise ValueError(f'invalid reduction: {reduction}') 118 | if mask is None: 119 | mask = torch.ones_like(tags, dtype=torch.uint8) 120 | 121 | if self.batch_first: 122 | emissions = emissions.transpose(0, 1) 123 | tags = tags.transpose(0, 1) 124 | mask = mask.transpose(0, 1) 125 | 126 | # shape: (batch_size,) 127 | numerator = self._compute_score(emissions, tags, mask) 128 | # shape: (batch_size,) 129 | denominator = self._compute_normalizer(emissions, mask) 130 | # shape: (batch_size,) 131 | llh = numerator - denominator 132 | 133 | if reduction == 'none': 134 | return llh 135 | if reduction == 'sum': 136 | return llh.sum() 137 | if reduction == 'mean': 138 | return llh.mean() 139 | assert reduction == 'token_mean' 140 | return llh.sum() / mask.type_as(emissions).sum() 141 | 142 | def forward(self, emissions: torch.Tensor, 143 | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: 144 | """Find the most likely tag sequence using Viterbi algorithm. 145 | Args: 146 | emissions (`~torch.Tensor`): Emission score tensor of size 147 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 148 | ``(batch_size, seq_length, num_tags)`` otherwise. 149 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 150 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 151 | Returns: 152 | List of list containing the best tag sequence for each batch. 153 | """ 154 | self._validate(emissions, mask=mask) 155 | if mask is None: 156 | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) 157 | 158 | if self.batch_first: 159 | emissions = emissions.transpose(0, 1) 160 | mask = mask.transpose(0, 1) 161 | 162 | return self._viterbi_decode(emissions, mask) 163 | 164 | def _validate( 165 | self, 166 | emissions: torch.Tensor, 167 | tags: Optional[torch.LongTensor] = None, 168 | mask: Optional[torch.ByteTensor] = None) -> None: 169 | if emissions.dim() != 3: 170 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') 171 | if emissions.size(2) != self.num_tags: 172 | raise ValueError( 173 | f'expected last dimension of emissions is {self.num_tags}, ' 174 | f'got {emissions.size(2)}') 175 | 176 | if tags is not None: 177 | if emissions.shape[:2] != tags.shape: 178 | raise ValueError( 179 | 'the first two dimensions of emissions and tags must match, ' 180 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') 181 | 182 | if mask is not None: 183 | if emissions.shape[:2] != mask.shape: 184 | raise ValueError( 185 | 'the first two dimensions of emissions and mask must match, ' 186 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') 187 | no_empty_seq = not self.batch_first and mask[0].all() 188 | no_empty_seq_bf = self.batch_first and mask[:, 0].all() 189 | if not no_empty_seq and not no_empty_seq_bf: 190 | raise ValueError('mask of the first timestep must all be on') 191 | 192 | def _compute_score( 193 | self, emissions: torch.Tensor, tags: torch.LongTensor, 194 | mask: torch.ByteTensor) -> torch.Tensor: 195 | # emissions: (seq_length, batch_size, num_tags) 196 | # tags: (seq_length, batch_size) 197 | # mask: (seq_length, batch_size) 198 | assert emissions.dim() == 3 and tags.dim() == 2 199 | assert emissions.shape[:2] == tags.shape 200 | assert emissions.size(2) == self.num_tags 201 | assert mask.shape == tags.shape 202 | assert mask[0].all() 203 | 204 | seq_length, batch_size = tags.shape 205 | mask = mask.type_as(emissions) 206 | 207 | # Start transition score and first emission 208 | # shape: (batch_size,) 209 | score = self.start_transitions[tags[0]] 210 | score += emissions[0, torch.arange(batch_size), tags[0]] 211 | 212 | for i in range(1, seq_length): 213 | # Transition score to next tag, only added if next timestep is valid (mask == 1) 214 | # shape: (batch_size,) 215 | score += self.transitions[tags[i - 1], tags[i]] * mask[i] 216 | 217 | # Emission score for next tag, only added if next timestep is valid (mask == 1) 218 | # shape: (batch_size,) 219 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] 220 | 221 | # End transition score 222 | # shape: (batch_size,) 223 | seq_ends = mask.long().sum(dim=0) - 1 224 | # shape: (batch_size,) 225 | last_tags = tags[seq_ends, torch.arange(batch_size)] 226 | # shape: (batch_size,) 227 | score += self.end_transitions[last_tags] 228 | 229 | return score 230 | 231 | def _compute_normalizer( 232 | self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: 233 | # emissions: (seq_length, batch_size, num_tags) 234 | # mask: (seq_length, batch_size) 235 | assert emissions.dim() == 3 and mask.dim() == 2 236 | assert emissions.shape[:2] == mask.shape 237 | assert emissions.size(2) == self.num_tags 238 | assert mask[0].all() 239 | 240 | seq_length = emissions.size(0) 241 | 242 | # Start transition score and first emission; score has size of 243 | # (batch_size, num_tags) where for each batch, the j-th column stores 244 | # the score that the first timestep has tag j 245 | # shape: (batch_size, num_tags) 246 | score = self.start_transitions + emissions[0] 247 | 248 | for i in range(1, seq_length): 249 | # Broadcast score for every possible next tag 250 | # shape: (batch_size, num_tags, 1) 251 | broadcast_score = score.unsqueeze(2) 252 | 253 | # Broadcast emission score for every possible current tag 254 | # shape: (batch_size, 1, num_tags) 255 | broadcast_emissions = emissions[i].unsqueeze(1) 256 | 257 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 258 | # for each sample, entry at row i and column j stores the sum of scores of all 259 | # possible tag sequences so far that end with transitioning from tag i to tag j 260 | # and emitting 261 | # shape: (batch_size, num_tags, num_tags) 262 | next_score = broadcast_score + self.transitions + broadcast_emissions 263 | 264 | # Sum over all possible current tags, but we're in score space, so a sum 265 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of 266 | # all possible tag sequences so far, that end in tag i 267 | # shape: (batch_size, num_tags) 268 | next_score = torch.logsumexp(next_score, dim=1) 269 | 270 | # Set score to the next score if this timestep is valid (mask == 1) 271 | # shape: (batch_size, num_tags) 272 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 273 | 274 | # End transition score 275 | # shape: (batch_size, num_tags) 276 | score += self.end_transitions 277 | 278 | # Sum (log-sum-exp) over all possible tags 279 | # shape: (batch_size,) 280 | return torch.logsumexp(score, dim=1) 281 | 282 | def _viterbi_decode(self, emissions: torch.FloatTensor, 283 | mask: torch.ByteTensor) -> List[List[int]]: 284 | # emissions: (seq_length, batch_size, num_tags) 285 | # mask: (seq_length, batch_size) 286 | assert emissions.dim() == 3 and mask.dim() == 2 287 | assert emissions.shape[:2] == mask.shape 288 | assert emissions.size(2) == self.num_tags 289 | assert mask[0].all() 290 | 291 | seq_length, batch_size = mask.shape 292 | 293 | # Start transition and first emission 294 | # shape: (batch_size, num_tags) 295 | score = self.start_transitions + emissions[0] 296 | history = [] 297 | 298 | # score is a tensor of size (batch_size, num_tags) where for every batch, 299 | # value at column j stores the score of the best tag sequence so far that ends 300 | # with tag j 301 | # history saves where the best tags candidate transitioned from; this is used 302 | # when we trace back the best tag sequence 303 | 304 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 305 | # for every possible next tag 306 | for i in range(1, seq_length): 307 | # Broadcast viterbi score for every possible next tag 308 | # shape: (batch_size, num_tags, 1) 309 | broadcast_score = score.unsqueeze(2) 310 | 311 | # Broadcast emission score for every possible current tag 312 | # shape: (batch_size, 1, num_tags) 313 | broadcast_emission = emissions[i].unsqueeze(1) 314 | 315 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 316 | # for each sample, entry at row i and column j stores the score of the best 317 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting 318 | # shape: (batch_size, num_tags, num_tags) 319 | next_score = broadcast_score + self.transitions + broadcast_emission 320 | 321 | # Find the maximum score over all possible current tag 322 | # shape: (batch_size, num_tags) 323 | next_score, indices = next_score.max(dim=1) 324 | 325 | # Set score to the next score if this timestep is valid (mask == 1) 326 | # and save the index that produces the next score 327 | # shape: (batch_size, num_tags) 328 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 329 | history.append(indices) 330 | 331 | # End transition score 332 | # shape: (batch_size, num_tags) 333 | score += self.end_transitions 334 | 335 | # Now, compute the best path for each sample 336 | 337 | # shape: (batch_size,) 338 | seq_ends = mask.long().sum(dim=0) - 1 339 | best_tags_list = [] 340 | 341 | for idx in range(batch_size): 342 | # Find the tag which maximizes the score at the last timestep; this is our best tag 343 | # for the last timestep 344 | _, best_last_tag = score[idx].max(dim=0) 345 | best_tags = [best_last_tag.item()] 346 | 347 | # We trace back where the best last tag comes from, append that to our best tag 348 | # sequence, and trace it back again, and so on 349 | for hist in reversed(history[:seq_ends[idx]]): 350 | best_last_tag = hist[idx][best_tags[-1]] 351 | best_tags.append(best_last_tag.item()) 352 | 353 | # Reverse the order because we start from the last timestep 354 | best_tags.reverse() 355 | best_tags_list.append(best_tags) 356 | 357 | for idx, best_tags in enumerate(best_tags_list): 358 | padding_length = self.max_length - len(best_tags) 359 | best_tags.extend([self.tag_to_ix[SequenceLabelConfig.PAD_TAG]] * padding_length) 360 | 361 | return best_tags_list 362 | 363 | 364 | class BiLSTM_CRF(torch.nn.Module): 365 | """bilstm crf model""" 366 | 367 | def __init__(self, tag_to_ix, max_length, hidden_dim, device): 368 | super(BiLSTM_CRF, self).__init__() 369 | self.embedding_dim = 768 370 | self.max_length = max_length 371 | self.hidden_dim = hidden_dim 372 | self.tag_to_ix = tag_to_ix 373 | self.tagset_size = len(tag_to_ix) 374 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 375 | self.bert = BertModel.from_pretrained('bert-base-chinese') 376 | self.lstm = torch.nn.LSTM(self.embedding_dim, hidden_dim // 2, 377 | num_layers=1, bidirectional=True, batch_first=True) 378 | # self.crf = CRF(self.tagset_size, tag_to_ix) 379 | self.crf = CRF(self.tagset_size, self.tag_to_ix, max_length=self.max_length, batch_first=True, device=device) 380 | 381 | # Maps the output of the LSTM into tag space. 382 | self.hidden2tag = torch.nn.Linear(hidden_dim, self.tagset_size) 383 | 384 | self.device = device 385 | 386 | def init_hidden(self, batch_size): 387 | return (torch.randn(2, batch_size, self.hidden_dim // 2).to(self.device), 388 | torch.randn(2, batch_size, self.hidden_dim // 2).to(self.device)) 389 | 390 | def _get_lstm_features(self, input_ids: Optional[Tensor], attention_mask: Optional[Tensor], 391 | token_type_ids: Optional[Tensor]): 392 | # Get the emission scores from the BiLSTM 393 | outputs = self.bert( 394 | input_ids, 395 | attention_mask=attention_mask, 396 | token_type_ids=token_type_ids 397 | ) 398 | embedding = outputs.last_hidden_state 399 | 400 | batch_size, sequece_length, embedding_dim = embedding.shape 401 | self.hidden = self.init_hidden(batch_size) 402 | lstm_out, self.hidden = self.lstm(embedding, self.hidden) 403 | lstm_feats = self.hidden2tag(lstm_out) 404 | return lstm_feats, attention_mask.byte() 405 | 406 | def forward(self, input_ids: Optional[Tensor], attention_mask: Optional[Tensor], 407 | token_type_ids: Optional[Tensor]): # dont confuse this with _forward_alg above. 408 | lstm_feats, mask = self._get_lstm_features(input_ids, attention_mask, token_type_ids) 409 | # Find the best path, given the features. 410 | tag_seq = self.crf(lstm_feats, mask=mask) 411 | return torch.tensor(tag_seq).to(self.device) 412 | 413 | def loss(self, input_ids: Optional[Tensor], attention_mask: Optional[Tensor], 414 | token_type_ids: Optional[Tensor], tags: Optional[Tensor]): 415 | """ 416 | feats: size=(batch_size, seq_len, tag_size) 417 | mask: size=(batch_size, seq_len) 418 | tags: size=(batch_size, seq_len) 419 | :return: 420 | """ 421 | lstm_feats, mask = self._get_lstm_features(input_ids, attention_mask, token_type_ids) 422 | loss_value = self.crf.neg_log_likelihood_loss(lstm_feats, tags, mask=mask) 423 | batch_size = lstm_feats.size(0) 424 | loss_value /= float(batch_size) 425 | return -loss_value 426 | 427 | def summuary(self): 428 | print("Model structure: ", self, "\n\n") 429 | 430 | for name, param in self.named_parameters(): 431 | print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") 432 | --------------------------------------------------------------------------------