├── ckps └── .gitkeep ├── data └── .gitkeep ├── datasets └── .gitkeep ├── preds └── .gitkeep ├── bert_pretrained └── .gitkeep ├── src ├── requirements.txt ├── scripts │ ├── readme.md │ ├── data.sh │ ├── sync_data.sh │ ├── inference.sh │ ├── kfold.sh │ ├── kfold_one.sh │ ├── train.sh │ └── trainx.sh ├── kfold │ ├── cvindex.py │ ├── test_avg.py │ ├── test_avg_all.py │ ├── test.py │ └── train.py ├── loss.py ├── data_raw_test.py ├── data_raw_trnval.py ├── utils.py ├── metric.py ├── get_result.py ├── test.py ├── get_sents.py ├── get_sents_fix.py ├── get_sents_fix_more.py ├── get_sents_fix_more_s.py ├── data_title_test.py ├── trainx.py └── data_title_trnval.py ├── TODO.md ├── IDEA.md ├── .gitignore ├── docs └── readme.txt └── README.md /ckps/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preds/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert_pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch>=1.0.1 2 | ignite==0.2.1 3 | h5py==2.7.1 4 | -------------------------------------------------------------------------------- /src/scripts/readme.md: -------------------------------------------------------------------------------- 1 | + sync_data.sh 目前是仅供参考的脚本,没能形成参数化调用 2 | + train.sh 更改参数 训练 lite 和 full 3 | + inference.sh 设置 model pred result 文件名 一键 inference! 4 | -------------------------------------------------------------------------------- /src/scripts/data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | 4 | cd ../. 5 | python -u data_raw_trnval.py &>> ../logs/data.log 6 | 7 | python -u data_raw_test.py &>> ../logs/data.log 8 | 9 | python -u data_title_trnval.py --lite &>> ../logs/data.log 10 | 11 | python -u data_title_trnval.py &>> ../logs/data.log 12 | 13 | python -u data_title_test.py &>> ../logs/data.log 14 | 15 | echo "over!" -------------------------------------------------------------------------------- /src/scripts/sync_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 服务器间同步数据集 3 | server_75="llouice@10.108.208.113" 4 | server_lzk="lzk@10.112.101.47" 5 | server_68="llouice@10.108.209.96" 6 | cd .. 7 | 8 | Sohu2019="/home/llouice/Projects/BERT/Sohu2019/datasets" 9 | target_dir="../." 10 | 11 | function init() { 12 | scp -r "${server_75}:${Sohu2019}" ${target_dir} 13 | echo "copy Sohu2019 over!" 14 | } 15 | 16 | init 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/scripts/inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | test_batch_size=500 5 | best_model="best_model.pt.bak" 6 | pred="pred_new.h5" 7 | 8 | 9 | cd ../. 10 | 11 | test(){ 12 | python -u test.py \ 13 | --test_batch_size=${test_batch_size} \ 14 | --best_model=${best_model} \ 15 | --pred=${pred} 16 | } 17 | 18 | 19 | 20 | res="result_new.txt" 21 | 22 | get_res(){ 23 | python -u get_result.py \ 24 | --pred=${pred} \ 25 | --res=${res} 26 | } 27 | 28 | test 29 | get_res 30 | echo "over!" 31 | 32 | -------------------------------------------------------------------------------- /src/kfold/cvindex.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.sys.path.append("../.") 3 | from sklearn.model_selection import KFold 4 | import h5py 5 | from utils import data_dump 6 | 7 | kfold = KFold(n_splits=5) 8 | 9 | full = "../../datasets/full.h5" 10 | f = h5py.File(full, "r") 11 | trn_size = f.get("train")["input_ids"].shape[0] 12 | val_size = f.get("val")["input_ids"].shape[0] 13 | all_size = trn_size + val_size 14 | 15 | cv = 0 16 | for train_indexs, dev_indexs in kfold.split([1] * all_size): 17 | cv = cv + 1 18 | index_t_d = (train_indexs, dev_indexs) 19 | index_file = '5cv_indexs_{}'.format(cv) 20 | data_dump(index_t_d, index_file) 21 | -------------------------------------------------------------------------------- /src/scripts/kfold.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | src=$(pwd)/.. 4 | export PYTHONPATH=:${src} 5 | export CUDA_VISIBLE_DEVICES=0,1 6 | 7 | cd ../kfold 8 | bs=124 9 | val_bs=150 10 | test_bs=2880 11 | pred="pred_5cv.h5" 12 | pred_avg="pred_avg.h5" 13 | 14 | train(){ 15 | python -u train.py \ 16 | --batch_size=${bs} \ 17 | --val_batch_size=${val_bs} \ 18 | --lite 19 | } 20 | 21 | 22 | test(){ 23 | python -u test.py \ 24 | --test_batch_size=${test_bs} \ 25 | --pred=${pred} \ 26 | --lite 27 | } 28 | 29 | avg(){ 30 | python -u test_avg.py \ 31 | --pred=${pred} \ 32 | --pred_avg=${pred_avg} 33 | } 34 | 35 | train 36 | test 37 | avg 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # 数据集构造 2 | - [x] 混合 example + train -> 所有的新闻 news.pkl 3 | - [x] 抽取部分 -> lite_trn.txt lite_val.txt lite_test.txt -> lite.h5 包括 train val test 4 | 5 | ~~lite.h5 -> model -> trn val -> best_model.pt -> get_result -> score.py~~ 6 | ~~- [ ] score.py~~ 7 | - [ ] BIESO 等其它形式的数据集构造 label id 转换部分 8 | > python data_title_trnval.py --label_method="BIESO" 9 | # 预测! 10 | - [ ] 取前三 11 | - [ ] 5 fold 12 | # 训练调参 13 | - [ ] lite 调参 14 | - [ ] ent 和 emo 单独训练观察收敛情况 15 | - [ ] batch 级别评估 16 | > 由于在中途就会过拟合 所以尝试取 n 个 batch(如 20) 评估一次 17 | 18 | ## ent loss 和 emo loss 的处理 19 | - [x] emo loss 权重组 20 | > 先给 emo 小权重 让 ent loss 先收敛 再加大 emo loss 权重 joint 一起收敛 21 | 22 | # inference 23 | - [x] 初赛模型 -> result.txt 24 | - [x] cv 模型 -> result.txt 25 | 26 | # 模型 27 | - [ ] freeze and unfreeze 28 | -------------------------------------------------------------------------------- /src/scripts/kfold_one.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | src=$(pwd)/.. 4 | export PYTHONPATH=:${src} 5 | gpu=$1 6 | cv=$2 7 | export CUDA_VISIBLE_DEVICES=${gpu} 8 | 9 | cd ../kfold 10 | bs=62 11 | val_bs=90 12 | test_bs=1400 13 | pred="pred_5cv_new.h5" 14 | pred_avg="pred_avg.h5" 15 | 16 | train(){ 17 | python -u train.py \ 18 | --batch_size=${bs} \ 19 | --val_batch_size=${val_bs} \ 20 | --cv=${cv} 21 | } 22 | 23 | 24 | test(){ 25 | python -u test.py \ 26 | --test_batch_size=${test_bs} \ 27 | --pred=${pred} \ 28 | --cv=${cv} 29 | } 30 | 31 | avg(){ 32 | python -u test_avg.py \ 33 | --pred=${pred} \ 34 | --pred_avg=${pred_avg} 35 | } 36 | 37 | train 38 | test 39 | #avg 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /src/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | 5 | bs=48 6 | val_bs=48 7 | epos=5 8 | lr=3e-5 9 | alpha=1 10 | warm=0.1 11 | dp=0.2 12 | wd=0.01 13 | 14 | cd .. 15 | 16 | run_lite(){ 17 | python -u train.py \ 18 | --lr=${lr} \ 19 | --batch_size=${bs} \ 20 | --val_batch_size=${val_bs} \ 21 | --epochs=${epos} \ 22 | --alpha=${alpha} \ 23 | --warmup_proportion=${warm} \ 24 | --dp=${dp} \ 25 | --wd=${wd} \ 26 | --hyper_cfg=a_${alpha}_lr_${lr}_dp_${dp}_wu_${warm}_wd_${wd} \ 27 | --lite \ 28 | &>> ../logs/lite.log 29 | } 30 | run_full(){ 31 | python -u train.py \ 32 | --lr=${lr} \ 33 | --batch_size=${bs} \ 34 | --val_batch_size=${val_bs} \ 35 | --epochs=${epos} \ 36 | --alpha=${alpha} \ 37 | --warmup_proportion=${warm} \ 38 | --dp=${dp} \ 39 | --wd=${wd} \ 40 | --hyper_cfg=a_${alpha}_lr_${lr}_dp_${dp}_wu_${warm}_wd_${wd} \ 41 | &>> ../logs/full.log 42 | } 43 | 44 | run_lite -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ignite.utils import to_onehot 5 | 6 | 7 | class FocalLoss0(nn.Module): 8 | 9 | def __init__(self, gamma=2, eps=1e-7): 10 | super(FocalLoss0, self).__init__() 11 | self.gamma = gamma 12 | self.eps = eps 13 | 14 | def forward(self, logit, target): 15 | prob = torch.sigmoid(logit) 16 | prob = prob.clamp(self.eps, 1. - self.eps) 17 | 18 | loss = -1 * target * torch.log(prob) 19 | loss = loss * (1 - logit) ** self.gamma 20 | 21 | return loss.sum() 22 | 23 | 24 | class FocalLoss(nn.Module): 25 | 26 | def __init__(self, gamma=0, eps=1e-7, size_average=True): 27 | super(FocalLoss, self).__init__() 28 | self.gamma = gamma 29 | self.eps = eps 30 | self.size_average=size_average 31 | 32 | def forward(self, input, target): 33 | y = to_onehot(target, input.size(-1)) 34 | logit = F.softmax(input, dim=-1) 35 | logit = logit.clamp(self.eps, 1. - self.eps) 36 | 37 | loss = -1 * y * torch.log(logit) # cross entropy 38 | loss = loss * (1 - logit) ** self.gamma # focal loss 39 | if self.size_average: 40 | return loss.mean() 41 | else: 42 | return loss.sum() 43 | -------------------------------------------------------------------------------- /src/scripts/trainx.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | 5 | bs=48 6 | val_bs=48 7 | epos=1 8 | lr=3e-5 9 | alpha=1 10 | warm=0.1 11 | dp=0.2 12 | wd=0.01 13 | 14 | cd .. 15 | 16 | run_lite(){ 17 | python -u trainx.py \ 18 | --lr=${lr} \ 19 | --batch_size=${bs} \ 20 | --val_batch_size=${val_bs} \ 21 | --epochs=${epos} \ 22 | --alpha=${alpha} \ 23 | --warmup_proportion=${warm} \ 24 | --dp=${dp} \ 25 | --wd=${wd} \ 26 | --hyper_cfg=a_${alpha}_lr_${lr}_dp_${dp}_wu_${warm}_wd_${wd} \ 27 | --lite \ 28 | &>> ../logs/lite.log 29 | } 30 | run_lite_emo(){ 31 | python -u trainx.py \ 32 | --lr=${lr} \ 33 | --batch_size=${bs} \ 34 | --val_batch_size=${val_bs} \ 35 | --epochs=${epos} \ 36 | --alpha=${alpha} \ 37 | --warmup_proportion=${warm} \ 38 | --dp=${dp} \ 39 | --wd=${wd} \ 40 | --hyper_cfg=a_${alpha}_lr_${lr}_dp_${dp}_wu_${warm}_wd_${wd} \ 41 | --lite \ 42 | --emo \ 43 | &>> ../logs/lite.log 44 | } 45 | run_full(){ 46 | python -u trainx.py \ 47 | --lr=${lr} \ 48 | --batch_size=${bs} \ 49 | --val_batch_size=${val_bs} \ 50 | --epochs=${epos} \ 51 | --alpha=${alpha} \ 52 | --warmup_proportion=${warm} \ 53 | --dp=${dp} \ 54 | --wd=${wd} \ 55 | --hyper_cfg=a_${alpha}_lr_${lr}_dp_${dp}_wu_${warm}_wd_${wd} \ 56 | &>> ../logs/full.log 57 | } 58 | 59 | run_lite 60 | run_lite_emo -------------------------------------------------------------------------------- /src/data_raw_test.py: -------------------------------------------------------------------------------- 1 | from utils import load_data, data_dump 2 | from sklearn.model_selection import train_test_split 3 | 4 | ''' 5 | 将新闻和标记写入txt文件 6 | 每一个句子空一行 每一则新闻空两行 7 | ''' 8 | file = "../datasets/news_test.pkl" 9 | data = load_data(file) 10 | print("all data: ", len(data)) 11 | 12 | 13 | def gen_news_map(data): 14 | D = {} 15 | for i, news in enumerate(data): 16 | ID = news["newsId"] 17 | D[i] = ID 18 | data_dump(D, "../datasets/news_map.pkl") 19 | 20 | 21 | def data2txt(data): 22 | print("cur data: ", len(data)) 23 | file = "../datasets/test.txt" 24 | f = open(file, "w") 25 | count = 0 26 | for news in data: 27 | title = news["title"] 28 | contents = news["content"] 29 | for t in title: 30 | line = t.strip() 31 | line = line.replace("\n", "") 32 | line = line.replace("\r", "") 33 | if len(line) > 0: 34 | f.write(line) 35 | f.write("\n") 36 | f.write("\n") 37 | for content in contents: 38 | # 避免空content 39 | if len(content)>0: 40 | for t in content: 41 | line = t.strip() 42 | line = line.replace("\n", "") 43 | line = line.replace("\r", "") 44 | if len(line)>0: 45 | f.write(line) 46 | f.write("\n") 47 | f.write("\n") 48 | # 下一新闻 49 | f.write("\n") 50 | count += 1 51 | # if count > 2: 52 | # break 53 | f.close() 54 | print(count) 55 | 56 | 57 | # all 58 | data2txt(data) 59 | gen_news_map(data) 60 | print("over") 61 | -------------------------------------------------------------------------------- /src/data_raw_trnval.py: -------------------------------------------------------------------------------- 1 | from utils import load_data 2 | from sklearn.model_selection import train_test_split 3 | 4 | ''' 5 | 将新闻和标记写入txt文件 6 | 每一个句子空一行 没一则新闻空两行 7 | ''' 8 | file = "../datasets/news.pkl" 9 | 10 | data = load_data(file) 11 | print(len(data)) 12 | 13 | # 分出lite (40000+1305) * 0.04 = 1653-> 1322:331 14 | _, lite = train_test_split(data, test_size=0.04) 15 | 16 | # 再各自分成 trn 和 val 17 | full_trn, full_val = train_test_split(data, test_size=0.2) 18 | lite_trn, lite_val = train_test_split(lite, test_size=0.2) 19 | 20 | 21 | def data2txt(data, mode="train", size="lite"): 22 | if size == "lite": 23 | trn_txt = "../datasets/lite_trn.txt" 24 | val_txt = "../datasets/lite_val.txt" 25 | else: 26 | trn_txt = "../datasets/train.txt" 27 | val_txt = "../datasets/val.txt" 28 | if mode == "train": 29 | file = trn_txt 30 | else: 31 | file = val_txt 32 | f = open(file, "w") 33 | count = 0 34 | for news in data: 35 | title = news["title"][0] 36 | title_O = news["title"][1] 37 | contents = news["content"] 38 | for (t, tO) in zip(title, title_O): 39 | line = " ".join((t, tO)) 40 | f.write(line) 41 | f.write("\n") 42 | f.write("\n") 43 | for content in contents: 44 | if len(content) > 0: 45 | for t, tO in zip(content[0], content[1]): 46 | line = " ".join((t, tO)) 47 | f.write(line) 48 | f.write("\n") 49 | f.write("\n") 50 | # 下一新闻 51 | f.write("\n") 52 | count += 1 53 | # if count > 2: 54 | # break 55 | f.close() 56 | print(count) 57 | 58 | 59 | # lite 60 | data2txt(lite_trn) 61 | data2txt(lite_val, "val") 62 | # full 63 | data2txt(full_trn, size="full") 64 | data2txt(full_val, mode="val", size="full") 65 | print("over") 66 | -------------------------------------------------------------------------------- /IDEA.md: -------------------------------------------------------------------------------- 1 | 1. 先 freeze BERT 的底层 让上层部分改变 再放开全部训练 2 | > 初定 freeze 一个 epoch ? 3 | 4 | 2. 多级重复 loss 5 | > 即 loss(A+B) + loss(A) 6 | 这样 loss(A) 重复了 相当于加了一倍权重 7 | 8 | 3. 交叉熵的代价矩阵 9 | > 对 B I O 的 交叉熵代价赋权重 10 | 11 | 4. 成词二分类 搜狐的训练集实体 + 百度实体 -> BERT vector -> 二分类器 12 | 13 | 5. 新闻分类 14 | > 采用分类模型,或者规则分类,将训练集分类(类别尽量少),然后构建五折或者验证集的时候就可以从每个类中提取相同比例的数据来进行测试和验证。 15 | > 同时还可以把类别作为aspect来训练模型 16 | > 自己训练新闻分类模型(全网新闻数据(SogouCA)http://www.sogou.com/labs/resource/ca.php) 17 | 18 | 6. 构建预测实体质量检测模型,采用CN-DBpedia(复旦大学GDM实验室中文知识图谱)或者训练集实体或者其他词汇库【采用规则或者深度学习】 19 | > 1)CN-DBpedia(复旦大学GDM实验室中文知识图谱) http://openkg.cn/dataset/cndbpedia 20 | > 2)THUOCL(THU Open Chinese Lexicon)是由清华大学自然语言处理与社会人文计算实验室整理推出的一套高质量的中文词库,词表来自主流网站的社会标签、搜索热词、输入法词库等。https 21 | ://github.com/thunlp/THUOCL 22 | > 3)训练集+example的实体,分词后使用网络训练,要自己生成负样本,然后用这个分类模型来判断生成的测试实体的好坏 23 | > 4)公司名语料库(Company-Names-Corpus)https://github.com/wainshine/Company-Names-Corpus 24 | > 5)中文词语搭配库(SogouR) http://www.sogou.com/labs/resource/r.php 25 | > 6)壹沓科技中文新词https://github.com/1data-inc/chinese_popular_new_words 26 | > 7)中文人名语料库(Chinese-Names-Corpus)https://github.com/wainshine/Chinese-Names-Corpus 27 | > 8)各种语料https://github.com/fighting41love/funNLP 28 | > 9)神策杯2018高校算法大师赛(中文关键词提取)第二名代码方案(带有字典)https://github.com/bigzhao/Keyword_Extraction 29 | 30 | 7. 训练多个不同想法,不同参数的bert,进行融合 31 | 32 | 8. 五折stacking融合,融合两次,第一次采用DNN(可以配合传统的特征,传统的模型进行融合),第二次采用深度融合+伪标签,具体看学长的github:https://github.com/zhanzecheng/SOHU_competition 33 | 34 | 9. 规则纠正预测实体,得提取一些特征看看是否有效。比如取长词,或者取短词,或者融合能包含的词(如互联网工业,工业,互联网————要“工业、互联网”还是“互联网工业”) 35 | 36 | 10. ~~看看标点是否影响分数,中英文标点可以分别去掉一次,然后提交,如果不影响,除分句标点可以全部去掉(甚至分句标点分完句后也能去掉)~~[均有不同程度下降,不可取] 37 | 38 | 11. 情感的模型可以另外做,在预测的实体的基础上,不过准确率要大于85%才有用 39 | 40 | 12. 《》“”中的实体可能在分句中被切断了,采用哈希mask的方式保留,经代码验证需要保留[('【', '】'), ('(', ')'), ('“', '”'), ('「', '」'), ('《', '》')]这些符号中的字符 41 | 42 | 13. html转移字符手动转换,经代码确认测试集和训练集只有['&', ' ', '"', '·', '>', '<']这几种转移字符 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | ## pycharm 127 | .idea/ 128 | 129 | ## vim 130 | *.swp 131 | *.swo 132 | 133 | ## my 134 | ckps/* 135 | datasets/* 136 | data/* 137 | bert_pretrained/* 138 | preds/* 139 | bak/ 140 | !*.gitkeep 141 | 142 | -------------------------------------------------------------------------------- /src/kfold/test_avg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/7 15:57 4 | # @Author : 邵明岩 5 | # @File : test_avg.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | import h5py 10 | import torch 11 | from argparse import ArgumentParser 12 | 13 | os.chdir("../.") 14 | if __name__ == '__main__': 15 | parser = ArgumentParser() 16 | parser.add_argument("--pred", 17 | default="pred_5cv.h5", 18 | type=str, required=False) 19 | parser.add_argument("--pred_avg", 20 | default="pred_avg.h5", 21 | type=str, required=False) 22 | args = parser.parse_args() 23 | 24 | fresult = h5py.File(f"../preds/{args.pred_avg}", "w") 25 | f = h5py.File(f"../preds/{args.pred}", "r") 26 | ent_raw = fresult.create_dataset("ent_raw", shape=(0, 128, 3), maxshape=(None, 128, 3), compression="gzip") 27 | emo_raw = fresult.create_dataset("emo_raw", shape=(0, 128, 4), maxshape=(None, 128, 4), compression="gzip") 28 | ent = fresult.create_dataset("ent", shape=(0, 128), maxshape=(None, 128), compression="gzip") 29 | emo = fresult.create_dataset("emo", shape=(0, 128), maxshape=(None, 128), compression="gzip") 30 | 31 | for cv in range(1, 6): 32 | print('cv : {}'.format(cv)) 33 | if cv == 1: 34 | pred_ent_conf = f.get(f"cv1/ent_raw")[()] 35 | pred_emo_conf = f.get("cv1/emo_raw")[()] 36 | else: 37 | pred_ent_cur = f.get(f"cv{cv}/ent_raw")[()] 38 | pred_emo_cur = f.get(f"cv{cv}/emo_raw")[()] 39 | pred_ent_conf += pred_ent_cur 40 | pred_emo_conf += pred_emo_cur 41 | 42 | pred_ent_conf /= 5.0 43 | pred_emo_conf /= 5.0 44 | 45 | pred_ent_t = torch.from_numpy(pred_ent_conf) 46 | pred_emo_t = torch.from_numpy(pred_emo_conf) 47 | pred_ent = torch.argmax(torch.softmax(pred_ent_t, dim=-1), dim=-1) # [-1, 128] 48 | pred_emo = torch.argmax(torch.softmax(pred_emo_t, dim=-1), dim=-1) # [-1, 128] 49 | size = pred_ent_t.shape[0] 50 | # ent.resize(size, axis=0) 51 | # emo.resize(size, axis=0) 52 | # ent[0: size] = pred_ent_t.numpy() 53 | # emo[0: size] = pred_emo_t.numpy() 54 | assert pred_ent_conf.shape[0] == pred_emo_conf.shape[0] == pred_ent.shape[0] == pred_emo.shape[0] 55 | ent_raw.resize(size, axis=0) 56 | emo_raw.resize(size, axis=0) 57 | ent.resize(size, axis=0) 58 | emo.resize(size, axis=0) 59 | ent_raw[...] = pred_ent_conf 60 | emo_raw[...] = pred_emo_conf 61 | ent[...] = pred_ent 62 | emo[...] = pred_emo 63 | f.close() 64 | fresult.close() 65 | print('over!') 66 | -------------------------------------------------------------------------------- /src/kfold/test_avg_all.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/7 15:57 4 | # @Author : 邵明岩 5 | # @File : test_avg.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | import h5py 10 | import torch 11 | from argparse import ArgumentParser 12 | 13 | os.chdir("../.") 14 | if __name__ == '__main__': 15 | parser = ArgumentParser() 16 | parser.add_argument("--pred_dir", 17 | default="pred_5cv.h5", 18 | type=str, required=False) 19 | parser.add_argument("--pred_avg", 20 | default="pred_avg.h5", 21 | type=str, required=False) 22 | args = parser.parse_args() 23 | 24 | all_preds = os.listdir(args.pred_dir) 25 | fresult = h5py.File(f"../preds/{args.pred_avg}", "w") 26 | ent_raw = fresult.create_dataset("ent_raw", shape=(0, 256, 5), maxshape=(None, 256, 5), compression="gzip") 27 | emo_raw = fresult.create_dataset("emo_raw", shape=(0, 256, 4), maxshape=(None, 256, 4), compression="gzip") 28 | ent = fresult.create_dataset("ent", shape=(0, 256), maxshape=(None, 256), compression="gzip") 29 | emo = fresult.create_dataset("emo", shape=(0, 256), maxshape=(None, 256), compression="gzip") 30 | for i, name in enumerate(all_preds): 31 | f = h5py.File(os.path.join(args.pred_dir, f"{name}"), "r") 32 | if i == 0: 33 | pred_ent_conf = f.get("ent_raw")[()] 34 | pred_emo_conf = f.get("emo_raw")[()] 35 | else: 36 | pred_ent_cur = f.get(f"ent_raw")[()] 37 | pred_emo_cur = f.get(f"emo_raw")[()] 38 | pred_ent_conf += pred_ent_cur 39 | pred_emo_conf += pred_emo_cur 40 | f.close() 41 | 42 | pred_ent_conf /= len(all_preds) 43 | pred_emo_conf /= len(all_preds) 44 | 45 | pred_ent_t = torch.from_numpy(pred_ent_conf) 46 | pred_emo_t = torch.from_numpy(pred_emo_conf) 47 | pred_ent = torch.argmax(torch.softmax(pred_ent_t, dim=-1), dim=-1) # [-1, 128] 48 | pred_emo = torch.argmax(torch.softmax(pred_emo_t, dim=-1), dim=-1) # [-1, 128] 49 | size = pred_ent_t.shape[0] 50 | # ent.resize(size, axis=0) 51 | # emo.resize(size, axis=0) 52 | # ent[0: size] = pred_ent_t.numpy() 53 | # emo[0: size] = pred_emo_t.numpy() 54 | assert pred_ent_conf.shape[0] == pred_emo_conf.shape[0] == pred_ent.shape[0] == pred_emo.shape[0] 55 | ent_raw.resize(size, axis=0) 56 | emo_raw.resize(size, axis=0) 57 | ent.resize(size, axis=0) 58 | emo.resize(size, axis=0) 59 | ent_raw[...] = pred_ent_conf 60 | emo_raw[...] = pred_emo_conf 61 | ent[...] = pred_ent 62 | emo[...] = pred_emo 63 | fresult.close() 64 | print('over!') 65 | -------------------------------------------------------------------------------- /docs/readme.txt: -------------------------------------------------------------------------------- 1 | # inference 运行 2 | ## 配置 3 | 需要安装 pytorch ignite h5py 4 | 如果需要的话 pip install requirements.txt 进行安装对应版本 5 | 6 | bash inference.sh 进行运行 7 | 结果保存在 ../datasets/result.txt中 8 | 说明: 9 | 脚本会先运行 test.py 将模型的预测写进 ../datasets/pred.h5 10 | 然后运行 get_result.py 读取 pred.h5 取出实体和情感写入到../datasets/result.txt当中 11 | 12 | 也可以直接运行 python get_result.py 读取pred.h5生成结果 省去模型预测环节 13 | 14 | 15 | # 项目依赖 16 | + pytorch >=0.1 17 | + pytorch-pretrained-BERT 18 | + ignite https://github.com/pytorch/ignite 19 | + h5py 20 | 21 | 22 | # 引用的开源数据和代码 23 | 使用了 pytorch-pretrained-BERT(https://github.com/huggingface/pytorch-pretrained-BERT) 框架运行 BERT 24 | 使用了适用于该框架的源自 google bert 的 bert-base-chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters 25 | 和其 token 模型 保存在 bert_pretrained 下 26 | 27 | # 文件结构说明 28 | src -- 所有代码源码 29 | --get_sent_and_word_list.py 从原始新闻中进行清洗 分字 标记标签 保存到 train_ner_has_emotion.pkl wenj 30 | --data_raw 从train_ner_has_emotion.pkl 读取分字和标签写入 train_full.txt dev_full.txt文件 一行一个字一个标签 31 | --data_title_trnval.py 读取 train_full.txt 和 dev_full.txt 进行 token 生成 bert 需要的输入 input_ids 32 | input_mask segment_ids, label_ent_ids label_emo_ids 写入 data_train.h5 33 | --data_title_test.py 类似上面,输入的 test 版本 写入 ../datasets/data_test.h5 34 | --models.py 网络模型定义文件 35 | --loss.py loss函数的定义文件 36 | --utils.py 用到的辅助函数汇总文件 37 | --metric.py F1分数定义文件 38 | --train.py 模型训练脚本 输出为checkpoint 39 | --test.py 进行 inference 写入 ../datasets/pred.h5 40 | --get_result.py 读取perd.h5 取出实体写入到 ../datasets/result.txt 41 | 42 | --bert_pretrained google预训练模型存放 43 | --bert-base-chinese 中文权重 44 | --bert_token_model tokenizer文件 45 | 46 | --datasets 所有的数据保存之处 47 | inference 用到的有 data_test.h5 ID2TOK_test.pkl(token是分为 UNK的字的记录) 48 | 49 | --pred_best.pth 最好的预测文件记录 50 | --result_best.txt 最好的提交结果 51 | 52 | --ckp 53 | --best_model.pth 最好的模型权重 54 | 55 | 56 | # 思路 57 | 58 | 因为新闻的标题是新闻的上下文概括,所以将每则新闻的标题拼接到该新闻的每个分句当中,形成 59 | [CLS] A句 [SEP] B句(=title),这样每个分句就能学到与整个新闻上下文相关的信息 60 | 61 | 经过 bert 得到每个字的 768 维 word embedding 和 [CLS] 即整个句子的 embeeding 拼接 62 | 到每个字的 embedding 上去 形成 768*2 维, 思路是进一步给予单字更多的上下文信息,帮助分类 63 | 64 | 然后分别使用两个一层的线性分类器做 实体的 O B I 分类 和 情感的 O POS NEG NORM 分类 65 | 66 | 且用 mask 的方法只对 A 句进行 loss_ent loss_emo 的 反传 67 | 由于 emo(情感) 是 基于 ent(实体) 的 ,所以 68 | loss = alpha * loss_ent + loss_emo, alpha 赋予大于1的值 69 | 70 | 71 | ## 待实验 72 | 我们小范围测试过拼接一则新闻的多个分句形成长句作为一个样本,同样是batch化和给予更多上下文的考量,虽然训练变慢, 73 | 但效果会有一些提升。 74 | 还有诸如给予B更多的loss权重的实验,由于时间原因没能测试 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2019搜狐校园算法大赛 llouice 2 | 3 | --- 4 | 5 | + TODO 工作推进状况 包括已经完成和 TODO List 6 | + IDEA 各种天马行空的想法~ 7 | 8 | --- 9 | ## 文件说明 10 | - src 11 | - data_raw 12 | > news.pkl (所有的新闻及其分字列表和标签列表的集合) 13 | >> 0.1 的 lite, full lite 对应的 14 | train.txt val.txt 和 lite_trn.txt lite_val.txt 15 | - data_title_trnval.py data_title_test.py 16 | > 从 train.txt val.txt 或者 lite_trn.txt lite_val.txt test.txt 17 | >> 生成 tensor 序列化文件 train.h5 lite.h5 和 ID2TOK.pkl(trn val test 所有的 ID2TOK) 18 | >>> full.h5 包括3个数据集 train val test 19 | >>> lite.h5 包括2个数据集 trina val 20 | 21 | ## 数据集生成流程 22 | + data_raw_trnval news.pkl -> lite_trn.txt lite_val.txt train.txt val.txt 23 | + data_raw_test news_test.pkl -> test.txt news_mapl.pkl 24 | + data_title_trnval -> lite.h5 full.h5 ID2TOK.pkl 25 | + data_title_test -> update full.h5 ID2TOK.pkl 26 | 27 | ## inference 运行 28 | ### 配置 29 | - 需要安装 pytorch ignite h5py 30 | - 如果需要的话 pip install requirements.txt 进行安装对应版本 31 | 32 | bash inference.sh 进行运行 33 | 结果保存在 ../datasets/result.txt中 34 | 35 | 说明: 36 | - 脚本会先运行 test.py 将模型的预测写进 ../datasets/pred.h5 37 | - 然后运行 get_result.py 读取 pred.h5 取出实体和情感写入到../datasets/result.txt当中 38 | - 也可以直接运行 python get_result.py 读取pred.h5生成结果 省去模型预测环节 39 | 40 | 41 | ## 项目依赖 42 | * pytorch >=0.1 43 | * pytorch-pretrained-BERT 44 | * ignite https://github.com/pytorch/ignite 45 | * h5py 46 | 47 | 48 | ## 引用的开源数据和代码 49 | 使用了 pytorch-pretrained-BERT(https://github.com/huggingface/pytorch-pretrained-BERT) 框架运行 BERT 50 | 使用了适用于该框架的源自 google bert 的 bert-base-chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters 51 | 和其 token 模型 保存在 bert_pretrained 下 52 | 53 | ## 文件结构说明 54 | - src 所有代码源码 55 | - kfold k折交叉验证 56 | - test_avg_all.py 平均所有结果 57 | - train.py 模型训练 58 | - test.py 模型测试输出结果 59 | 60 | - scripts 存放一键运行脚本 61 | 62 | - data_raw_test.py 将新闻和标记写入txt文件 63 | - data_raw_trnval.py 将新闻和标记写入txt文件 64 | - data_title_test.py 读取 train_full.txt 和 dev_full.txt 进行 token 生成 bert 需要的输入 input_ids 65 | input_mask segment_ids, label_ent_ids label_emo_ids 写入 data_train.h5 66 | - data_title_trnval.py 类似上面,输入的 test 版本 写入 ../datasets/data_test.h5 67 | - get_result.py 读取perd.h5 取出实体写入到 ../datasets/result.txt 68 | - get_sents*.py 分句 69 | - loss.py loss函数的定义文件 70 | - metric.py F1分数定义文件 71 | - models.py 网络模型定义文件,选择模型网络选项 --net 72 | - requirements.txt 项目依赖 73 | - test.py 添加选择模型网络选项 --net 2 days ago 74 | - train.py 模型训练脚本 输出为checkpoint 75 | - trainx.py 单独训练 ent 和 emo 76 | - utils.py 用到的辅助函数汇总文件 77 | 78 | - bert_pretrained google预训练模型存放 79 | - bert-base-chinese 中文权重 80 | - bert_token_model tokenizer文件 81 | 82 | - ckp 83 | - best_model.pth 最好的模型权重 84 | - datasets 所有数据保存之处 85 | 86 | - preds 所有预测输出的h5文件 87 | 88 | - pred_best.pth 最好的预测文件记录 89 | - result/best 存放最好的提交结果 90 | - result_best.txt 最好的提交结果 91 | 92 | ## 算法思路 93 | ### 数据预处理 94 | 清洗:保留数字、英文、中文、中英文标点、空格和\n并将网页的转义字符还原。去除连续多余的空格。去除链接。去除多余连续的符号(无意义的、颜文字表情)。 95 | 分句:对新闻的内容进行分句处理,由于后续输入网络的最大长度是256. 减去BERT需要的三个字符 CLS SEP CLS 长度设为256-3-标题长; 96 | 一级:|。|!|\!|?|\?分句,然后对分完的句子进行判断,不符合的进入二级截断 97 | 二级:;|、|,|,|﹔|、 98 | 如果还是不符合,就按最大长度截断。 2000左右条句子被截断。 99 | 分字:对内容和标题都分字。把中文当成截断字符来分,英文先保留整体单词。针对每个分开的字,再通过标点符号对英文分字。 100 | 打标签:采用BIE标签,先从长实体开始打起,BIIIIIIE,如果再打标签的时候遇到了BIE这种,则不进行标签处理。 101 | 遇到问题:某些实体当中含有分句的字符,就会被切开。 通过统计发现带有,的实体都在《》“”之内,于是我们采用哈希mask的方式,将实体进行等长度的哈希之后保护替换,然后进行分句还原。保护他们不被截断。 分句将\n\r删掉了 102 | 103 | ### 创新思路 104 | 因为新闻的标题是新闻的上下文概括,所以将每则新闻的标题拼接到该新闻的每个分句当中,形成[CLS] A句 [SEP] B句(=title),这样每个分句就能学到与整个新闻上下文相关的信息。经过 bert 得到每个字的 768 维 word embedding 和 [CLS] 即整个句子的 embeeding 拼接到每个字的 embedding 上去 形成 768*2 维, 思路是进一步给予单字更多的上下文信息,帮助分类。 105 | 106 | 分别使用两个一层的线性分类器做 实体的 O B I 分类 和 情感的 O POS NEG NORM 分类。且用 mask 的方法只对 A 句进行 loss_ent loss_emo 的 反传,由于 emo(情感) 是 基于 ent(实体) 的 ,所以loss = alpha * loss_ent + loss_emo, alpha 赋予大于1的值。 107 | 108 | 针对重叠实体问题,如预测“天然气 天然气田”,“个人信息 个人信息保护”等我们采用BIESO标志,重新生成输入,调整分类数目,训练模型 109 | 110 | 针对tokenize产生[UNK]造成无法还原的问题,我们在进行tokenize时,将最终token为[UNK]的实体使用正则表达式匹配到,添加到我们自己基于BERT Vocabulary的扩展字典当中。这样预测时就能还原出每一个字。 111 | 112 | ### 神经网络结构图 113 | ![神经网络结构图](http://wx1.sinaimg.cn/large/e8f43d1ely1g39w4lg4ebj20h90a3mxc.jpg) 114 | 115 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import re 4 | import os 5 | import h5py 6 | from collections import OrderedDict 7 | 8 | 9 | ################## pickle ################ 10 | def data_dump(data, file): 11 | with open(file, "wb") as f: 12 | pickle.dump(data, f) 13 | print("store {} successfully!".format(file)) 14 | 15 | 16 | def load_data(path): 17 | with open(path, "rb") as f: 18 | res = pickle.load(f) 19 | print("load data from {} successfully!".format(path)) 20 | return res 21 | 22 | 23 | ################## News ################ 24 | class News2(object): 25 | def __init__(self, idx, newsId, title, coreEE): 26 | self.newsID = newsId 27 | self.title = title 28 | self.sents = [] 29 | self.idx = idx 30 | self.coreEE = coreEE 31 | 32 | def add_sent(self, sent): 33 | self.sents.append(sent) 34 | 35 | 36 | def gen_news(file="/home/lzk/llouice/BERT/souhu/BERT-NER-master/att/coreEntityEmotion_example.txt"): 37 | with open(file, "rt", encoding="utf-8") as f: 38 | for line in f: 39 | news = json.loads(line) 40 | yield news 41 | 42 | def seg_char(sent): 43 | """ 44 | 把句子按字分开,不破坏英文结构 45 | """ 46 | pattern = re.compile(r'([\u4e00-\u9fa5])') 47 | chars = pattern.split(sent) 48 | chars = [w.strip() for w in chars if len(w.strip()) > 0] 49 | new_chars = [] 50 | for c in chars: 51 | if len(c) > 1: 52 | punctuation = '!?。,。"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’“”„‟…‧' 53 | punctuation = '|'.join([p for p in punctuation]) 54 | pattern = re.compile(' |({})'.format(punctuation)) 55 | cs = pattern.split(c) 56 | for w in cs: 57 | if w and len(w.strip()) > 0: 58 | new_chars.append(w.strip()) 59 | else: 60 | new_chars.append(c) 61 | return new_chars 62 | 63 | def get_sentences(content): 64 | sentences = re.split(r'(。|!|\!|?|\?)', content) # 保留分割符 65 | new_sents = [] 66 | for i in range(int(len(sentences) / 2)): 67 | sent = sentences[2 * i] + sentences[2 * i + 1] 68 | sent = sent.strip() 69 | sent.replace("\n", "") 70 | sent.replace("\r", "") 71 | sent.replace("\u200B", "") 72 | new_sents.append(sent.strip()) 73 | return new_sents 74 | 75 | 76 | ########################################## HDF5 ################################################## 77 | def save2hdf5(filename): 78 | file_path = os.path.join("/home/lzk/llouice/BERT/souhu/datasets", filename) 79 | f = h5py.File(file_path, "a", libver="latest") 80 | 81 | 82 | ########################################## my token ################################################ 83 | def covert_mytokens_to_myids(TOK2ID, mytokens): 84 | myids = [] 85 | for token in mytokens: 86 | myids.append(TOK2ID[token]) 87 | return myids 88 | 89 | def covert_myids_to_mytokens(ID2TOK, myids): 90 | mytokens = [] 91 | for i in myids: 92 | mytokens.append(ID2TOK[i]) 93 | return mytokens 94 | 95 | 96 | #####################################check sent_len############################ 97 | def check_sent_len(): 98 | sent_lens = dict() 99 | for fname in ['../data/coreEntityEmotion_test_stage2.txt','../data/coreEntityEmotion_train.txt', 100 | '../data/coreEntityEmotion_example.txt']: 101 | f = open(fname, 'r') 102 | f.seek(0) 103 | for index, line in enumerate(f.readlines()): 104 | data = json.loads(line) 105 | title = data['title'] 106 | content = data['content'] 107 | if sent_lens.get(len(title)): 108 | sent_lens[len(title)] = sent_lens[len(title)] + 1 109 | else: 110 | sent_lens[len(title)] = 1 111 | for sent in get_sentences(content): 112 | if sent_lens.get(len(sent)): 113 | sent_lens[len(sent)] = sent_lens[len(sent)] + 1 114 | else: 115 | sent_lens[len(sent)] = 1 116 | 117 | return sent_lens 118 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | from ignite.metrics.metric import Metric 2 | import re 3 | import torch 4 | from utils import covert_myids_to_mytokens, load_data, data_dump 5 | import numpy as np 6 | from collections import defaultdict, Counter 7 | 8 | ID2TOK = load_data("../datasets/ID2TOK.pkl") 9 | 10 | 11 | class FScore(Metric): 12 | """ 13 | Calculates the top-k categorical accuracy. 14 | 15 | - `update` must receive output of the form `(y_pred, y)`. 16 | """ 17 | 18 | def __init__(self, output_transform=lambda x: x, lbl_method="BIO"): 19 | self.EMOS_MAP = {"0": "OTHER", "1": "POS", "2": "NEG", "3": "NORM"} 20 | self.lbl_method = lbl_method 21 | self.ents = defaultdict(list) 22 | self.ents_pred = defaultdict(list) 23 | self.f1s_ent = [] 24 | self.f1s_emo = [] 25 | self.ones_pred = 0 26 | self.ones = 0 27 | if self.lbl_method == "BIO": 28 | self.pattern = re.compile("1[2]*") # 贪婪匹配 29 | else: 30 | self.pattern = re.compile("(1[2]*3)|(4+)") # 贪婪匹配 31 | print(f"lbl_method: {self.lbl_method} pattern: {self.pattern}") 32 | # put the super in end! 33 | super(FScore, self).__init__(output_transform) 34 | 35 | def reset(self): 36 | self.ents.clear() 37 | self.ents_pred.clear() 38 | self.f1s_emo.clear() 39 | self.f1s_ent.clear() 40 | self.ones_pred = 0 41 | self.ones = 0 42 | 43 | def update(self, output): 44 | y_pred_ent, y_ent, y_pred_emo, y_emo, myinput_ids = output 45 | tokens = covert_myids_to_mytokens(ID2TOK, myinput_ids.tolist()) 46 | y_pred_ent = torch.argmax(torch.softmax(y_pred_ent, dim=-1), dim=-1) # [L, 1] 47 | y_pred_emo = torch.argmax(torch.softmax(y_pred_emo, dim=-1), dim=-1) # [L, 1] 48 | self._count(y_pred_ent, y_ent, y_pred_emo, y_emo, tokens) 49 | 50 | def _count(self, y_pred_ent, y_ent, y_pred_emo, y_emo, tokens): 51 | '''become str of nums the use re to match''' 52 | y_pred_ent = "".join([str(i.item()) for i in y_pred_ent]) 53 | y_ent = "".join([str(i.item()) for i in y_ent]) 54 | y_pred_emo = "".join([str(i.item()) for i in y_pred_emo]) 55 | y_emo = "".join([str(i.item()) for i in y_emo]) 56 | 57 | self._find_ents(y_pred_ent, y_pred_emo, self.pattern, tokens, self.ents_pred, "pred") 58 | self._find_ents(y_ent, y_emo, self.pattern, tokens, self.ents) 59 | ENTS_PRED = {ent for ent in self.ents_pred} 60 | ENTS = {ent for ent in self.ents} 61 | 62 | # 实体的分数 63 | f1_ent = self.cal_f1(ENTS_PRED, ENTS) 64 | self.f1s_ent.append(f1_ent) 65 | 66 | # 情感的分数 67 | EMOS_S_PRED = set() 68 | EMOS_S = set() 69 | for ent, emos in self.ents_pred.items(): 70 | c = Counter(emos).most_common(1) 71 | EMOS_S_PRED.add("{}_{}".format(ent, self.EMOS_MAP[c[0][0]])) 72 | for ent, emos in self.ents.items(): 73 | c = Counter(emos).most_common(1) 74 | EMOS_S.add("{}_{}".format(ent, self.EMOS_MAP[c[0][0]])) 75 | f1_emo = self.cal_f1(EMOS_S_PRED, EMOS_S) 76 | self.f1s_emo.append(f1_emo) 77 | self.ents_pred.clear() 78 | self.ents.clear() 79 | 80 | def compute(self): 81 | f1_ent_all = np.average(self.f1s_ent) if len(self.f1s_ent) > 0 else 0 82 | f1_emo_all = np.average(self.f1s_emo) if len(self.f1s_emo) > 0 else 0 83 | print("单字数: {}/{}".format(self.ones_pred, self.ones)) 84 | print(f"F1_ent: {f1_ent_all}\tF1_emo: {f1_emo_all}") 85 | return 0.5 * (f1_ent_all + f1_emo_all) 86 | 87 | def _find_ents(self, y_pred_ent, y_pred_emo, p, tokens, S, mode="lbl"): 88 | # 使用与取result一致的逻辑 89 | for r in p.finditer(y_pred_ent): 90 | i, j = r.span()[0], r.span()[1] 91 | res = "".join(tokens[i:j]) 92 | if len(res) == 1: 93 | if mode == "pred": 94 | self.ones_pred += 1 95 | else: 96 | self.ones += 1 97 | continue 98 | emos = y_pred_emo[i:j] 99 | emo = Counter(emos).most_common(1)[0][0] 100 | S[res].append(emo) 101 | 102 | def cal_f1(self, s1, s2): 103 | nb_correct = len(s1 & s2) 104 | nb_pred = len(s1) 105 | nb_true = len(s2) 106 | p = nb_correct / nb_pred if nb_pred > 0 else 0 107 | r = nb_correct / nb_true if nb_true > 0 else 0 108 | f1 = (2 * p * r) / (p + r) if p + r > 0 else 0 109 | return f1 110 | -------------------------------------------------------------------------------- /src/get_result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy as np 4 | import re 5 | from collections import defaultdict, Counter 6 | from utils import load_data, covert_myids_to_mytokens 7 | import time 8 | 9 | EMOS_MAP = {"0": "OTHER", "1": "POS", "2": "NEG", "3": "NORM"} 10 | ID2TOK = load_data("../datasets/ID2TOK.pkl") 11 | ###################################################################################### 12 | 13 | 14 | def _get_res(pattern ,cur_input_ids, cur_myinput_ids, cur_pred_ent, cur_pred_ent_conf, cur_pred_emo, cur_pred_emo_conf, S): 15 | ''' 16 | :param cur_myinput_ids: 17 | :param cur_pred_ent: 18 | :param pattern_ent: 19 | :param cur_pred_emo: 20 | :param ENTS: 21 | :return: 22 | ''' 23 | 24 | for r in pattern.finditer(cur_pred_ent): 25 | i, j = r.span()[0], r.span()[1] 26 | res = "".join(covert_myids_to_mytokens(ID2TOK, cur_myinput_ids[i:j])) 27 | res = res.replace("##", "") 28 | if "[PAD]" in res or "X" in res or "[CLS]" in res or "[SEP]" in res: 29 | continue 30 | if "[UNK]" in res: 31 | print("UNK: ", res) 32 | continue 33 | if len(res) == 1: 34 | continue 35 | emos = cur_pred_emo[i:j] 36 | # conf = cur_pred_ent_conf[i:j] 37 | # emo = Counter(emos).most_common(1)[0][0] 38 | S[res].extend(emos) 39 | 40 | 41 | def _get_ent(S): 42 | R = defaultdict() 43 | for ent, emos in S.items(): 44 | c = Counter(emos).most_common() 45 | if not c[0][0] == '0': 46 | R[ent] = c[0][0] 47 | elif len(c) > 1: 48 | R[ent] = c[1][0] 49 | else: 50 | # not certain 51 | pass 52 | return R 53 | 54 | 55 | #################################### Every News ########################################### 56 | def main(): 57 | # ------------------------------------- args ---------------------------------------- 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--res", 60 | default="result_new.txt", 61 | type=str, 62 | required=False, 63 | help="result file") 64 | parser.add_argument("--pred", 65 | default="pred_new.h5", 66 | type=str, required=False) 67 | parser.add_argument("--lbl_method", 68 | default="BIO", 69 | type=str, required=False) 70 | args = parser.parse_args() 71 | # ------------------------------------------------------------------------------ 72 | # ------------------------------- output file ---------------------------------- 73 | if args.lbl_method == "BIO": 74 | pattern = re.compile("1[2]*") # 贪婪匹配 75 | else: 76 | pattern = re.compile("(1[2]*3)|(4+)") # 贪婪匹配 77 | 78 | result_file = f"../results/{args.res}" 79 | pred_file = f"../preds/{args.pred}" 80 | f_result = open(result_file, "wt") 81 | f_pred = h5py.File(pred_file, "r") 82 | # ------------------------------------------------------------------------------ 83 | # -------------------------------- test data ----------------------------------- 84 | NEWS_MAP = load_data("../datasets/news_map.pkl") 85 | test_file = "../datasets/full.h5" 86 | 87 | f_test = h5py.File(test_file, "r") 88 | IDs = f_test.get("test/IDs")[()] 89 | # input_ids = f_test.get("test/input_ids")[()] 90 | myinput_ids = f_test.get("test/myinput_ids")[()] 91 | input_mask = f_test.get("test/input_mask")[()] 92 | segement_ids = f_test.get("test/segment_ids")[()] 93 | unique_IDs = np.unique(IDs) 94 | assert np.max(unique_IDs) == 79999 95 | # ------------------------------------------------------------------------------ 96 | # -------------------------------- pred data ----------------------------------- 97 | pred_ent = f_pred.get("ent")[()] 98 | pred_emo = f_pred.get("emo")[()] 99 | assert IDs.shape[0] == pred_ent.shape[0] == pred_emo.shape[0] 100 | # pred_ent_conf = f_pred.get("ent_raw") 101 | # pred_emo_conf = f_pred.get("emo_raw") 102 | # ------------------------------------------------------------------------------ 103 | 104 | idx1 = 0 105 | time0 = time.time() 106 | time1 = time.time() 107 | for id in range(80000): 108 | num = np.sum(IDs == id) 109 | idx2 = idx1 + num 110 | # [bs, 128, 3] 111 | cur_pred_ent = pred_ent[idx1:idx2, :].reshape(1, -1).squeeze().astype(np.int) # (6912,) 112 | # cur_pred_ent_conf = pred_ent_conf[idx1:idx2, :, :] 113 | # cur_pred_ent_conf = np.max(torch.softmax(torch.from_numpy(cur_pred_ent_conf), dim=-1).numpy(), axis=-1).reshape(1, 114 | # -1).squeeze() 115 | 116 | cur_pred_emo = pred_emo[idx1:idx2, :].reshape(1, -1).squeeze().astype(np.int) # (6912,) 117 | # cur_pred_emo_conf = pred_emo_conf[idx1:idx2, :, :] 118 | # cur_pred_emo_conf = np.max(torch.softmax(torch.from_numpy(cur_pred_emo_conf), dim=-1).numpy(), axis=-1).reshape(1, 119 | # -1).squeeze() 120 | 121 | # mask 122 | # cur_pred_emo = cur_pred_emo[cur_pred_ent == 1] 123 | # 原文 124 | # cur_input_ids = input_ids[idx1:idx2, :].reshape(1, -1).squeeze() 125 | cur_myinput_ids = myinput_ids[idx1:idx2, :].reshape(1, -1).squeeze() 126 | cur_input_mask = input_mask[idx1:idx2, :].reshape(1, -1).squeeze() 127 | cur_segment_ids = segement_ids[idx1:idx2, :].reshape(1, -1).squeeze() 128 | ################################# 2 mask: input mask + segment mask ##################### 129 | # ''' 130 | active_mask = cur_input_mask == 1 131 | active_seg = cur_segment_ids[active_mask] 132 | active_seg = active_seg == 0 133 | # cur_input_ids = cur_input_ids[active_mask][active_seg] 134 | cur_myinput_ids = cur_myinput_ids[active_mask][active_seg] 135 | cur_pred_ent = cur_pred_ent[active_mask][active_seg] 136 | # cur_pred_ent_conf = cur_pred_ent_conf[active_mask][active_seg] 137 | cur_pred_emo = cur_pred_emo[active_mask][active_seg] 138 | # cur_pred_emo_conf = cur_pred_emo_conf[active_mask][active_seg] 139 | # ''' 140 | # cur_pred_emo = cur_pred_emo[cur_pred_ent == 2] 141 | # cur_pred_emo_conf = cur_pred_emo_conf[cure_pred_ent == 2] 142 | ######################################################################################### 143 | # 10 -> 0 144 | # cur_pred_emo[cur_pred_emo == 10] = 0 145 | idx1 = idx2 146 | cur_pred_ent = "".join([str(cur_pred_ent[i]) for i in range(cur_pred_ent.shape[-1])]) 147 | cur_pred_emo = "".join([str(cur_pred_emo[i]) for i in range(cur_pred_emo.shape[-1])]) 148 | 149 | ENTS_LIST = defaultdict(list) 150 | # _get_res(cur_input_ids, cur_myinput_ids, cur_pred_ent, cur_pred_ent_conf, cur_pred_emo, cur_pred_emo_conf, ENTS_LIST) 151 | # _get_res(cur_input_ids, cur_myinput_ids, cur_pred_ent, "", cur_pred_emo, "", ENTS_LIST) 152 | _get_res(pattern, "", cur_myinput_ids, cur_pred_ent, "", cur_pred_emo, "", ENTS_LIST) 153 | 154 | ################################################################# 155 | 156 | ################### !!综合!! ######################################## 157 | R = _get_ent(ENTS_LIST) 158 | # result1: 舍弃单个汉字 取交集作为提交版本 TODO 空集策略 标点问题{青山 青山.} 159 | 160 | ######################### write to file ############################ 161 | newsID = NEWS_MAP[id] 162 | line = "{}\t{}\t{}" 163 | ents = [] 164 | emos = [] 165 | for ent, emo in R.items(): 166 | ent = ent.replace(",", "") 167 | ents.append(ent) 168 | emos.append(EMOS_MAP[emo]) 169 | 170 | assert len(ents) == len(emos) 171 | ents = ",".join(ents) 172 | emos = ",".join(emos) 173 | answer = line.format(newsID, ents, emos) 174 | answer = answer.replace("\r", "") 175 | answer = answer.replace("\n", "") 176 | answer = answer.replace("\u200B", "") 177 | 178 | # print(answer) 179 | f_result.write(answer + "\n") 180 | # if (id + 1) % 100 == 0: 181 | # print("=" * 10, id, "=" * 10) 182 | # break 183 | if id % 1000 == 0: 184 | time2 = time.time() 185 | print("=" * 10, "time used: [{}] {}".format(time2 - time1, id), "=" * 10) 186 | time1 = time2 187 | 188 | ################################################################################# 189 | 190 | f_result.close() 191 | f_pred.close() 192 | f_test.close() 193 | 194 | print("*" * 10, "time used: [{} min]".format((time.time() - time0) / 60), "*" * 10) 195 | print("over!!!") 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | from argparse import ArgumentParser 4 | import torch 5 | from ignite.engine import Engine, Events 6 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 7 | from torch.utils.data import DataLoader, SequentialSampler, TensorDataset 8 | from models import NetY3, NetY1, NetY2, NetY3_fz, NetY4 9 | 10 | 11 | def get_test_dataloader(): 12 | print("get test dataloader...........") 13 | # -------------------------------read from h5------------------------- 14 | f = h5py.File("../datasets/full.h5", "r") 15 | input_ids = torch.from_numpy(f["test/input_ids"][()]) 16 | input_mask = torch.from_numpy(f["test/input_mask"][()]) 17 | segment_ids = torch.from_numpy(f["test/segment_ids"][()]) 18 | assert input_ids.size() == segment_ids.size() == input_mask.size() 19 | print("test dataset num: ", input_ids.size(0)) 20 | test_dataset = TensorDataset(input_ids, input_mask, segment_ids) 21 | f.close() 22 | print("read h5 over!") 23 | test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.test_batch_size, 24 | num_workers=4) 25 | print("get date loader over!") 26 | return test_dataloader, len(test_dataset) 27 | 28 | 29 | def run(): 30 | ################################ Model Config ################################### 31 | if args.lbl_method == "BIO": 32 | num_labels_emo = 4 # O POS NEG NORM 33 | num_labels_ent = 3 # O B I 34 | else: 35 | num_labels_emo = 4 # O POS NEG NORM 36 | num_labels_ent = 5 # O B I E S 37 | 38 | if args.net == "3": 39 | model = NetY3.from_pretrained(args.bert_model, 40 | cache_dir="", 41 | num_labels_ent=num_labels_ent, 42 | num_labels_emo=num_labels_emo, 43 | dp=args.dp) 44 | elif args.net == "2": 45 | model = NetY2.from_pretrained(args.bert_model, 46 | cache_dir="", 47 | num_labels_ent=num_labels_ent, 48 | num_labels_emo=num_labels_emo, 49 | dp=args.dp) 50 | elif args.net == "1": 51 | model = NetY1.from_pretrained(args.bert_model, 52 | cache_dir="", 53 | num_labels_ent=num_labels_ent, 54 | num_labels_emo=num_labels_emo, 55 | dp=args.dp) 56 | elif args.net == "4": 57 | model = NetY4.from_pretrained(args.bert_model, 58 | cache_dir="", 59 | num_labels_ent=num_labels_ent, 60 | num_labels_emo=num_labels_emo, 61 | dp=args.dp) 62 | 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | model.to(device) 65 | model = torch.nn.DataParallel(model) 66 | 67 | # ------------------------------ load model from file ------------------------- 68 | model_file = os.path.join("../ckps", args.best_model) 69 | if os.path.exists(model_file): 70 | model.load_state_dict(torch.load(model_file)) 71 | print("load checkpoint: {} successfully!".format(model_file)) 72 | # ----------------------------------------------------------------------------- 73 | 74 | test_dataloader, test_size = get_test_dataloader() 75 | 76 | def test(engine, batch): 77 | model.eval() 78 | batch = tuple(t.to(device) for t in batch) 79 | input_ids, input_mask, segment_ids = batch 80 | 81 | with torch.no_grad(): 82 | logits_ent, logits_emo = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) 83 | return logits_ent, logits_emo 84 | 85 | tester = Engine(test) 86 | 87 | pbar_test = ProgressBar(persist=True) 88 | pbar_test.attach(tester) 89 | 90 | # ++++++++++++++++++++++++++++++++++ Test +++++++++++++++++++++++++++++++++ 91 | f = h5py.File(f"../preds/{args.pred}", "w") 92 | if args.lbl_method == "BIO": 93 | if args.raw: 94 | ent_raw = f.create_dataset("ent_raw", shape=(0, 128, 3), maxshape=(None, 128, 3), compression="gzip") 95 | emo_raw = f.create_dataset("emo_raw", shape=(0, 128, 4), maxshape=(None, 128, 4), compression="gzip") 96 | ent = f.create_dataset("ent", shape=(0, 128), maxshape=(None, 128), compression="gzip") 97 | emo = f.create_dataset("emo", shape=(0, 128), maxshape=(None, 128), compression="gzip") 98 | else: 99 | if args.raw: 100 | ent_raw = f.create_dataset("ent_raw", shape=(0, 256, 5), maxshape=(None, 256, 5), compression="gzip") 101 | emo_raw = f.create_dataset("emo_raw", shape=(0, 256, 4), maxshape=(None, 256, 4), compression="gzip") 102 | ent = f.create_dataset("ent", shape=(0, 256), maxshape=(None, 256), compression="gzip") 103 | emo = f.create_dataset("emo", shape=(0, 256), maxshape=(None, 256), compression="gzip") 104 | 105 | @tester.on(Events.ITERATION_COMPLETED) 106 | def get_test_pred(engine): 107 | # cur_iter = engine.state.iteration 108 | batch_size = engine.state.batch[0].size(0) 109 | pred_ent_raw, pred_emo_raw = engine.state.output 110 | pred_ent = torch.argmax(torch.softmax(pred_ent_raw, dim=-1), dim=-1) # [-1, 128] 111 | pred_emo = torch.argmax(torch.softmax(pred_emo_raw, dim=-1), dim=-1) # [-1, 128] 112 | 113 | def add_io(): 114 | old_size = ent.shape[0] 115 | ent.resize(old_size + batch_size, axis=0) 116 | emo.resize(old_size + batch_size, axis=0) 117 | ent[old_size: old_size + batch_size] = pred_ent.cpu() 118 | emo[old_size: old_size + batch_size] = pred_emo.cpu() 119 | if args.raw: 120 | ent_raw.resize(old_size + batch_size, axis=0) 121 | emo_raw.resize(old_size + batch_size, axis=0) 122 | ent_raw[old_size: old_size + batch_size] = pred_ent_raw.cpu() 123 | emo_raw[old_size: old_size + batch_size] = pred_emo_raw.cpu() 124 | 125 | def add_mem(): 126 | if engine.state.metrics.get("preds_ent") is None: 127 | engine.state.metrics["preds_ent"] = [] 128 | engine.state.metrics["preds_emo"] = [] 129 | engine.state.metrics["preds_ent"].append(pred_ent.cpu()) 130 | engine.state.metrics["preds_emo"].append(pred_emo.cpu()) 131 | if args.raw: 132 | engine.state.metrics["preds_ent_raw"] = [] 133 | engine.state.metrics["preds_emo_raw"] = [] 134 | engine.state.metrics["preds_ent_raw"].append(pred_ent_raw.cpu()) 135 | engine.state.metrics["preds_emo_raw"].append(pred_emo_raw.cpu()) 136 | 137 | else: 138 | engine.state.metrics["preds_ent"].append(pred_ent.cpu()) 139 | engine.state.metrics["preds_emo"].append(pred_emo.cpu()) 140 | if args.raw: 141 | engine.state.metrics["preds_ent_raw"].append(pred_ent_raw.cpu()) 142 | engine.state.metrics["preds_emo_raw"].append(pred_emo_raw.cpu()) 143 | 144 | add_mem() 145 | 146 | pbar_test = ProgressBar(persist=True) 147 | pbar_test.attach(tester) 148 | 149 | @tester.on(Events.EPOCH_COMPLETED) 150 | def save_and_close(engine): 151 | if engine.state.metrics.get("preds_ent") is not None: 152 | preds_ent = torch.cat(engine.state.metrics["preds_ent"], dim=0) 153 | preds_emo = torch.cat(engine.state.metrics["preds_emo"], dim=0) 154 | assert preds_ent.size(0) == preds_emo.size(0) 155 | ent.resize(preds_ent.size(0), axis=0) 156 | emo.resize(preds_emo.size(0), axis=0) 157 | ent[...] = preds_ent 158 | emo[...] = preds_emo 159 | if args.raw: 160 | preds_ent_raw = torch.cat(engine.state.metrics["preds_ent_raw"], dim=0) 161 | preds_emo_raw = torch.cat(engine.state.metrics["preds_emo_raw"], dim=0) 162 | ent_raw.resize(preds_ent_raw.size(0), axis=0) 163 | emo_raw.resize(preds_emo_raw.size(0), axis=0) 164 | ent_raw[...] = preds_ent_raw 165 | emo_raw[...] = preds_emo_raw 166 | print("pred size: ", ent.shape) 167 | f.close() 168 | print("test over") 169 | 170 | tester.run(test_dataloader) 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = ArgumentParser() 175 | parser.add_argument("--bert_model", default="../bert_pretrained/bert-base-chinese", 176 | type=str, required=False, 177 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 178 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 179 | "bert-base-multilingual-cased, bert-base-chinese.") 180 | parser.add_argument('--test_batch_size', type=int, default=1000, 181 | help='input batch size for test (default: 1000)') 182 | parser.add_argument("--best_model", 183 | default="best_model.pt", 184 | type=str, required=True) 185 | parser.add_argument("--pred", 186 | default="pred_new.h5", 187 | type=str, required=False) 188 | parser.add_argument("--dp", 189 | default=0.1, 190 | type=float, 191 | help="") 192 | parser.add_argument("--raw", 193 | action="store_true", 194 | help="是否存储置信度") 195 | parser.add_argument("--lbl_method", 196 | type=str, 197 | default="BIO", 198 | help="BIO / BIEO") 199 | parser.add_argument("--net", 200 | type=str, 201 | default="3", 202 | help="model") 203 | args = parser.parse_args() 204 | 205 | run() 206 | -------------------------------------------------------------------------------- /src/kfold/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.chdir("../.") 4 | import h5py 5 | from argparse import ArgumentParser 6 | import torch 7 | from ignite.engine import Engine, Events 8 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 9 | from torch.utils.data import DataLoader, SequentialSampler, TensorDataset 10 | from models import NetY3 11 | from utils import load_data 12 | 13 | 14 | def get_all_data(): 15 | print("get all data...........") 16 | # -------------------------------read from h5------------------------- 17 | if not args.lite: 18 | f = h5py.File("../datasets/full.h5") 19 | else: 20 | f = h5py.File("../datasets/lite.h5") 21 | input_ids_trn = torch.from_numpy(f["train/input_ids"][()]) 22 | myinput_ids_trn = torch.from_numpy(f["train/myinput_ids"][()]) 23 | input_mask_trn = torch.from_numpy(f["train/input_mask"][()]) 24 | segment_ids_trn = torch.from_numpy(f["train/segment_ids"][()]) 25 | label_ent_ids_trn = torch.from_numpy(f["train/label_ent_ids"][()]) 26 | label_emo_ids_trn = torch.from_numpy(f["train/label_emo_ids"][()]) 27 | assert input_ids_trn.size() == segment_ids_trn.size() == label_ent_ids_trn.size() == label_emo_ids_trn.size() == myinput_ids_trn.size() 28 | 29 | input_ids_val = torch.from_numpy(f["val/input_ids"][()]) 30 | myinput_ids_val = torch.from_numpy(f["val/myinput_ids"][()]) 31 | input_mask_val = torch.from_numpy(f["val/input_mask"][()]) 32 | segment_ids_val = torch.from_numpy(f["val/segment_ids"][()]) 33 | label_ent_ids_val = torch.from_numpy(f["val/label_ent_ids"][()]) 34 | label_emo_ids_val = torch.from_numpy(f["val/label_emo_ids"][()]) 35 | assert input_ids_val.size() == segment_ids_val.size() == label_ent_ids_val.size() == label_emo_ids_val.size() == myinput_ids_val.size() 36 | f.close() 37 | print("read h5 over!") 38 | input_ids = torch.cat([input_ids_trn, input_ids_val], dim=0) 39 | myinput_ids = torch.cat([myinput_ids_trn, myinput_ids_val], dim=0) 40 | input_mask = torch.cat([input_mask_trn, input_mask_val], dim=0) 41 | segment_ids = torch.cat([segment_ids_trn, segment_ids_val], dim=0) 42 | label_ent_ids = torch.cat([label_ent_ids_trn, label_ent_ids_val], dim=0) 43 | label_emo_ids = torch.cat([label_emo_ids_trn, label_emo_ids_val], dim=0) 44 | dataset = TensorDataset(input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids) 45 | 46 | return dataset 47 | 48 | 49 | def get_val_loader(dataset, cv): 50 | print(f"get dataloader {cv}") 51 | # 从 index 中取出 trn_dataset val_dataset 52 | index_file = "kfold/5cv_indexs_{}".format(cv) 53 | if os.path.exists(index_file): 54 | _, val_index = load_data(index_file) 55 | val_dataset = [dataset[idx] for idx in val_index] 56 | else: 57 | print("Not find index file!") 58 | exit(0) 59 | # --------------------------------------------------------------------- 60 | val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.val_batch_size, 61 | num_workers=args.nw, pin_memory=True) 62 | # val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.val_batch_size, 63 | # pin_memory=True) 64 | print("get date loader over!") 65 | return val_dataloader, len(val_dataset) 66 | 67 | 68 | def get_test_dataloader(): 69 | print("get test dataloader...........") 70 | # -------------------------------read from h5------------------------- 71 | f = h5py.File("../datasets/full.h5", "r") 72 | input_ids = torch.from_numpy(f["test/input_ids"][()]) 73 | input_mask = torch.from_numpy(f["test/input_mask"][()]) 74 | segment_ids = torch.from_numpy(f["test/segment_ids"][()]) 75 | assert input_ids.size() == segment_ids.size() == input_mask.size() 76 | print("test dataset num: ", input_ids.size(0)) 77 | test_dataset = TensorDataset(input_ids, input_mask, segment_ids) 78 | f.close() 79 | print("read h5 over!") 80 | test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.test_batch_size, 81 | num_workers=4) 82 | print("get date loader over!") 83 | return test_dataloader, len(test_dataset) 84 | 85 | 86 | def run(test_dataloader, cv): 87 | ################################ Model Config ################################### 88 | if args.lbl_method == "BIO": 89 | num_labels_emo = 4 # O POS NEG NORM 90 | num_labels_ent = 3 # O B I 91 | else: 92 | num_labels_emo = 4 # O POS NEG NORM 93 | num_labels_ent = 5 # O B I E S 94 | model = NetY3.from_pretrained(args.bert_model, 95 | cache_dir="", 96 | num_labels_ent=num_labels_ent, 97 | num_labels_emo=num_labels_emo, 98 | dp=0.2) 99 | 100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 101 | model.to(device) 102 | model = torch.nn.DataParallel(model) 103 | 104 | # ------------------------------ load model from file ------------------------- 105 | model_file = f"../ckps/cv/cv{cv}.pth" 106 | if os.path.exists(model_file): 107 | model.load_state_dict(torch.load(model_file)) 108 | print("load checkpoint: {} successfully!".format(model_file)) 109 | 110 | # ----------------------------------------------------------------------------- 111 | 112 | def test(engine, batch): 113 | model.eval() 114 | batch = tuple(t.to(device) for t in batch) 115 | input_ids, input_mask, segment_ids = batch 116 | 117 | with torch.no_grad(): 118 | logits_ent, logits_emo = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) 119 | return logits_ent, logits_emo 120 | 121 | tester = Engine(test) 122 | 123 | pbar_test = ProgressBar(persist=True) 124 | pbar_test.attach(tester) 125 | 126 | # ++++++++++++++++++++++++++++++++++ Test +++++++++++++++++++++++++++++++++ 127 | if cv == 1: 128 | f = h5py.File(f"../preds/{args.pred}", "w") 129 | else: 130 | f = h5py.File(f"../preds/{args.pred}", "r+") 131 | if args.lbl_method == "BIO": 132 | ent_raw = f.create_dataset(f"cv{cv}/ent_raw", shape=(0, 128, 3), maxshape=(None, 128, 3), compression="gzip") 133 | emo_raw = f.create_dataset(f"cv{cv}/emo_raw", shape=(0, 128, 4), maxshape=(None, 128, 4), compression="gzip") 134 | else: 135 | ent_raw = f.create_dataset(f"cv{cv}/ent_raw", shape=(0, 256, 5), maxshape=(None, 256, 5), compression="gzip") 136 | emo_raw = f.create_dataset(f"cv{cv}/emo_raw", shape=(0, 256, 4), maxshape=(None, 256, 4), compression="gzip") 137 | 138 | 139 | # ent = f.create_dataset(f"cv{cv}/ent", shape=(0, 128), maxshape=(None, 128), compression="gzip") 140 | # emo = f.create_dataset(f"cv{cv}/emo", shape=(0, 128), maxshape=(None, 128), compression="gzip") 141 | # if cv == 1: 142 | 143 | @tester.on(Events.ITERATION_COMPLETED) 144 | def get_test_pred(engine): 145 | # cur_iter = engine.state.iteration 146 | batch_size = engine.state.batch[0].size(0) 147 | pred_ent_raw, pred_emo_raw = engine.state.output 148 | 149 | # pred_ent = torch.argmax(torch.softmax(pred_ent_raw, dim=-1), dim=-1) # [-1, 128] 150 | def add_io(): 151 | # pred_emo = torch.argmax(torch.softmax(pred_emo_raw, dim=-1), dim=-1) # [-1, 128] 152 | old_size = ent_raw.shape[0] 153 | # ent.resize(old_size + batch_size, axis=0) 154 | # emo.resize(old_size + batch_size, axis=0) 155 | # ent[old_size: old_size + batch_size] = pred_ent.cpu() 156 | # emo[old_size: old_size + batch_size] = pred_emo.cpu() 157 | ent_raw.resize(old_size + batch_size, axis=0) 158 | emo_raw.resize(old_size + batch_size, axis=0) 159 | ent_raw[old_size: old_size + batch_size] = pred_ent_raw.cpu() 160 | emo_raw[old_size: old_size + batch_size] = pred_emo_raw.cpu() 161 | # if cv == 1: 162 | 163 | def add_mem(): 164 | if engine.state.metrics.get("preds_ent") is None: 165 | engine.state.metrics["preds_ent"] = [] 166 | engine.state.metrics["preds_emo"] = [] 167 | engine.state.metrics["preds_ent"].append(pred_ent_raw.cpu()) 168 | engine.state.metrics["preds_emo"].append(pred_emo_raw.cpu()) 169 | else: 170 | engine.state.metrics["preds_ent"].append(pred_ent_raw.cpu()) 171 | engine.state.metrics["preds_emo"].append(pred_emo_raw.cpu()) 172 | 173 | add_mem() 174 | 175 | @tester.on(Events.EPOCH_COMPLETED) 176 | def save_and_close(engine): 177 | if engine.state.metrics.get("preds_ent") is not None: 178 | preds_ent = torch.cat(engine.state.metrics["preds_ent"], dim=0) 179 | preds_emo = torch.cat(engine.state.metrics["preds_emo"], dim=0) 180 | assert preds_ent.size(0) == preds_emo.size(0) 181 | ent_raw.resize(preds_ent.size(0), axis=0) 182 | emo_raw.resize(preds_emo.size(0), axis=0) 183 | ent_raw[...] = preds_ent 184 | emo_raw[...] = preds_emo 185 | print("pred size: ", ent_raw.shape) 186 | f.close() 187 | print("test over") 188 | 189 | tester.run(test_dataloader) 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = ArgumentParser() 194 | parser.add_argument("--bert_model", default="../bert_pretrained/bert-base-chinese", 195 | type=str, required=False, 196 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 197 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 198 | "bert-base-multilingual-cased, bert-base-chinese.") 199 | parser.add_argument('--test_batch_size', type=int, default=1000, 200 | help='input batch size for test (default: 1000)') 201 | parser.add_argument("--best_model", 202 | default="best_model.pt", 203 | type=str, required=False) 204 | parser.add_argument("--pred", 205 | default="pred_new.h5", 206 | type=str, required=False) 207 | parser.add_argument("--raw", 208 | action="store_true", 209 | help="是否存储置信度") 210 | parser.add_argument("--lbl_method", 211 | type=str, 212 | default="BIO", 213 | help="BIO / BIEO") 214 | args = parser.parse_args() 215 | 216 | # 5 fold 217 | # dataset = get_all_data() 218 | test_dataloader, test_size = get_test_dataloader() 219 | for cv in range(1, 6): 220 | run(test_dataloader, cv) 221 | -------------------------------------------------------------------------------- /src/get_sents.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/15 19:25 4 | # @Author : 邵明岩 5 | # @File : get_sents.py 6 | # @Software: PyCharm 7 | 8 | import re 9 | import pickle 10 | from zhon import hanzi 11 | import string 12 | import json 13 | import random 14 | 15 | 16 | def Unicode(): 17 | val = random.randint(0x4e00, 0x9fbf) 18 | return chr(val) 19 | 20 | 21 | def hash_ch(len_text): 22 | return ''.join([Unicode() for _ in range(len_text)]) 23 | 24 | 25 | def data_dump(data, file): 26 | with open(file, "wb") as f: 27 | pickle.dump(data, f) 28 | print("store data successfully!") 29 | 30 | 31 | def get_label(sent, entity): 32 | def get_entity_len(d): 33 | return len(d['entity']) 34 | 35 | ner = ['O' for _ in range(len(sent))] 36 | 37 | entity = sorted(entity, key=get_entity_len, reverse=True) 38 | for d in entity: 39 | emotion = d['emotion'].strip() 40 | e_l = d['entity'] 41 | name = ''.join(e_l) 42 | e_n = len(e_l) 43 | for i in range(len(sent) - e_n + 1): 44 | a = ''.join(sent[i:i + e_n]) 45 | if a == name: 46 | flag = False 47 | for r in range(e_n): 48 | if ner[r + i].startswith('B') or ner[r + i].startswith('I'): 49 | flag = True 50 | break 51 | 52 | if flag is True: 53 | pass 54 | else: 55 | for r in range(e_n): 56 | if r == 0: 57 | ner[r + i] = 'B-{}'.format(emotion) 58 | else: 59 | ner[r + i] = 'I-{}'.format(emotion) 60 | return ner 61 | 62 | 63 | def get_label_no_emotion(sent, entity): 64 | def get_entity_len(d): 65 | return len(d['entity']) 66 | 67 | ner = ['O' for _ in range(len(sent))] 68 | 69 | entity = sorted(entity, key=get_entity_len, reverse=True) 70 | for d in entity: 71 | e_l = d['entity'] 72 | name = ''.join(e_l) 73 | e_n = len(e_l) 74 | for i in range(len(sent) - e_n + 1): 75 | a = ''.join(sent[i:i + e_n]) 76 | if a == name: 77 | flag = False 78 | for r in range(e_n): 79 | if ner[r + i].startswith('B') or ner[r + i].startswith('I'): 80 | flag = True 81 | break 82 | 83 | if flag is True: 84 | pass 85 | else: 86 | for r in range(e_n): 87 | if r == 0: 88 | ner[r + i] = 'B' 89 | else: 90 | ner[r + i] = 'I' 91 | 92 | return ner 93 | 94 | 95 | symbol_list = [('【', '】'), ('(', ')'), ('“', '”'), ('「', '」'), ('《', '》')] 96 | 97 | 98 | def get_entity_mask(text): 99 | pattern = r'{}[^{}\n\r]+{}' 100 | pattern2 = r'{}[^{}。。!!??\n\r]+{}' 101 | 102 | mask_texts_set = set() 103 | for symbol in symbol_list: 104 | symbol1 = symbol[0] 105 | symbol2 = symbol[1] 106 | if symbol1 == '“': 107 | res = re.findall(pattern2.format(symbol1, symbol2, symbol2), text) 108 | else: 109 | res = re.findall(pattern.format(symbol1, symbol2, symbol2), text) 110 | for r in res: 111 | mask_texts_set.add(r) 112 | 113 | mask_texts_set = list(mask_texts_set) 114 | mask_texts_set.sort(key=lambda x: len(x), reverse=True) 115 | mask_texts = [] 116 | for t in mask_texts_set: 117 | lent = len(t) 118 | while True: 119 | h = str(hash_ch(lent)) 120 | if h not in text: 121 | break 122 | assert len(t) == len(h) 123 | mask_texts.append((t, h)) 124 | 125 | for t, h in mask_texts: 126 | text = text.replace(t, h) 127 | 128 | return mask_texts, text 129 | 130 | 131 | def get_real_text(mask_texts, text): 132 | for t, h in mask_texts: 133 | text = text.replace(h, t) 134 | return text 135 | 136 | 137 | # 虽然是可变长的batch要的数据,但是过长的句子也是不可行的,这里 138 | # 最长用600,仅占2912458个句子中的938条 139 | 140 | def get_sentences(content): 141 | # mask一写书名号等 142 | mask_texts, content = get_entity_mask(content) 143 | 144 | # 基本分句,。|。|!|\!|?|\?,用这6个符号分 145 | sentences = re.split(r'(。|。|!|\!|?|\?)', content) # 保留分割符 146 | new_sents = [] 147 | for i in range(int(len(sentences) / 2)): 148 | sent = sentences[2 * i] + sentences[2 * i + 1] 149 | sent = sent.strip() 150 | new_sents.append(sent) 151 | res = [] 152 | max_len = 600 153 | for sent in new_sents: 154 | temp_sents = [] 155 | if len(sent) > max_len: # 大于max_len长度继续分,;|、|,|,|﹔|、,6个次级分句 156 | sents = re.split(r'(;|、|,|,|﹔|、)', sent) # 保留分割符 157 | for s in sents: 158 | if len(s) > max_len: # 子分句也大于max_len,采用截断式分句 159 | s = get_real_text(mask_texts, s) # 还原mask 160 | ss = [] 161 | j = max_len 162 | while j < len(s): 163 | ss.append(s[j - max_len:j]) 164 | j = j + max_len 165 | if len(s[j - max_len:len(s)]) < 20: 166 | temp = ss[-1] + s[j - max_len:len(s)] 167 | ss[-1] = temp[0:int(len(temp) / 2)] 168 | ss.append(temp[int(len(temp) / 2):]) 169 | else: 170 | ss.append(s[j - max_len:len(s)]) 171 | temp_sents.extend(ss) 172 | 173 | else: 174 | temp_sents.append(s) 175 | else: 176 | temp_sents.append(sent) 177 | 178 | res.extend(temp_sents) 179 | 180 | result = [] 181 | for r in res: 182 | r = get_real_text(mask_texts, r) 183 | r = r.replace('\n', '') 184 | r = r.replace('\r', '') 185 | result.append(r) 186 | 187 | return result 188 | 189 | 190 | def seg_char(sent): 191 | """ 192 | 把句子按字分开,不破坏英文结构 193 | """ 194 | sent = sent.replace('\n', '') 195 | sent = sent.replace('\r', '') 196 | pattern = re.compile(r'([\u4e00-\u9fa5])') 197 | chars = pattern.split(sent) 198 | chars = [w.strip() for w in chars if len(w.strip()) > 0] 199 | new_chars = [] 200 | for c in chars: 201 | if len(c) > 1: 202 | punctuation = hanzi.punctuation 203 | punctuation = '|'.join([p for p in punctuation]) 204 | pattern = re.compile(' |({})'.format(punctuation)) 205 | cs = pattern.split(c) 206 | for w in cs: 207 | if w and len(w.strip()) > 0: 208 | new_chars.append(w.strip()) 209 | else: 210 | new_chars.append(c) 211 | return new_chars 212 | 213 | 214 | def seg_char_sents(sentences): 215 | results = [] 216 | for sent in sentences: 217 | res = seg_char(sent) 218 | if len(res) > 0: 219 | results.append(res) 220 | return results 221 | 222 | 223 | def get_core_entityemotions(entityemotions): 224 | results = [] 225 | for ee in entityemotions: 226 | result = {} 227 | result['entity'] = seg_char(clean_text(ee['entity'])) 228 | result['emotion'] = ee['emotion'] 229 | results.append(result) 230 | 231 | return results 232 | 233 | 234 | def ishan(char): 235 | # for python 3.x 236 | # sample: ishan('一') == True, ishan('我&&你') == False 237 | return '\u4e00' <= char <= '\u9fff' 238 | 239 | 240 | def clean_text(text): 241 | new_text = [] 242 | for char in text: 243 | if ishan(char) or char in string.digits or char in string.ascii_letters or char in ( 244 | hanzi.punctuation + string.punctuation): 245 | new_text.append(char) 246 | elif char == '\t' or char == ' ': 247 | new_text.append(' ') 248 | elif char == '\r' or char == '\n': 249 | new_text.append('\n') 250 | else: 251 | continue 252 | 253 | new_text = ''.join(new_text) 254 | # html转移字符 255 | new_text = re.sub(r'"', '"', new_text) 256 | new_text = re.sub(r'&', '&', new_text) 257 | new_text = re.sub(r'<', '<', new_text) 258 | new_text = re.sub(r'>', '>', new_text) 259 | new_text = re.sub(r' ', ' ', new_text) 260 | new_text = re.sub(r'·', '·', new_text) 261 | # 去除多余空格 262 | new_text = re.sub(r' +', ' ', new_text) 263 | # 去除html链接 264 | new_text = re.sub( 265 | r'(http|ftp)s?://([^\u4e00-\u9fa5"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。])*', 266 | '', new_text) 267 | # 去除多余连续符号,比如颜文字表情 268 | new_text = re.sub(r'[#%&\'()*+-./:;<=>?@[\]^_`{|}~]{2,}', '', new_text) 269 | 270 | return new_text 271 | 272 | 273 | if __name__ == '__main__': 274 | 275 | f = open('../data/coreEntityEmotion_example.txt', 'r') 276 | datas = [] 277 | datas_em = [] 278 | all_index = len(f.readlines()) 279 | f.seek(0) 280 | for index, line in enumerate(f.readlines()): 281 | data = json.loads(line) 282 | print('{}/{}'.format(index, all_index)) 283 | new_data = {} 284 | new_data_em = {} 285 | new_data['newsId'] = data['newsId'] 286 | new_data_em['newsId'] = data['newsId'] 287 | 288 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 289 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 290 | 291 | title = clean_text(data['title'].strip()) 292 | 293 | if len(title) > 125: 294 | title = title[:125] 295 | print('warning:标题被截断!!') 296 | title = seg_char(title) 297 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 298 | assert len(title) == len(title_labels) 299 | new_data['title'] = (title, title_labels) 300 | title_labels = get_label(title, new_data['coreEntityEmotions']) 301 | assert len(title) == len(title_labels) 302 | new_data_em['title'] = (title, title_labels) 303 | data['content'] = clean_text(data['content'].strip()) 304 | if len(data['content'].strip()) == 0: 305 | new_data['content'] = [] 306 | new_data_em['content'] = [] 307 | pass 308 | else: 309 | if data['content'][-1] not in '。。!!??': 310 | data['content'] = data['content'] + '。' 311 | sentences = get_sentences(data['content']) 312 | sentences = seg_char_sents(sentences) 313 | content = [] 314 | content_em = [] 315 | for sent in sentences: 316 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 317 | assert len(sent) == len(sent_labels) 318 | content.append((sent, sent_labels)) 319 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 320 | assert len(sent) == len(sent_labels) 321 | content_em.append((sent, sent_labels)) 322 | new_data['content'] = content 323 | new_data_em['content'] = content_em 324 | 325 | datas.append(new_data) 326 | datas_em.append(new_data_em) 327 | 328 | f.close() 329 | data_dump(datas, '../datasets/variable_data/example_ner_no_emotion.pkl') 330 | data_dump(datas_em, '../datasets/variable_data/example_ner_has_emotion.pkl') 331 | 332 | f = open('../data/coreEntityEmotion_train.txt', 'r') 333 | datas = [] 334 | datas_em = [] 335 | all_index = len(f.readlines()) 336 | f.seek(0) 337 | for index, line in enumerate(f.readlines()): 338 | data = json.loads(line) 339 | print('{}/{}'.format(index, all_index)) 340 | new_data = {} 341 | new_data_em = {} 342 | new_data['newsId'] = data['newsId'] 343 | new_data_em['newsId'] = data['newsId'] 344 | 345 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 346 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 347 | 348 | title = clean_text(data['title'].strip()) 349 | 350 | if len(title) > 125: 351 | title = title[:125] 352 | print('warning:标题被截断!!') 353 | title = seg_char(title) 354 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 355 | assert len(title) == len(title_labels) 356 | new_data['title'] = (title, title_labels) 357 | title_labels = get_label(title, new_data['coreEntityEmotions']) 358 | assert len(title) == len(title_labels) 359 | new_data_em['title'] = (title, title_labels) 360 | data['content'] = clean_text(data['content'].strip()) 361 | if len(data['content'].strip()) == 0: 362 | new_data['content'] = [] 363 | new_data_em['content'] = [] 364 | pass 365 | else: 366 | if data['content'][-1] not in '。。!!??': 367 | data['content'] = data['content'] + '。' 368 | sentences = get_sentences(data['content']) 369 | sentences = seg_char_sents(sentences) 370 | content = [] 371 | content_em = [] 372 | for sent in sentences: 373 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 374 | assert len(sent) == len(sent_labels) 375 | content.append((sent, sent_labels)) 376 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 377 | assert len(sent) == len(sent_labels) 378 | content_em.append((sent, sent_labels)) 379 | new_data['content'] = content 380 | new_data_em['content'] = content_em 381 | 382 | datas.append(new_data) 383 | datas_em.append(new_data_em) 384 | 385 | f.close() 386 | data_dump(datas, '../datasets/variable_data/train_ner_no_emotion.pkl') 387 | data_dump(datas_em, '../datasets/variable_data/train_ner_has_emotion.pkl') 388 | 389 | f = open('../data/coreEntityEmotion_test_stage2.txt', 'r') 390 | datas = [] 391 | all_index = len(f.readlines()) 392 | f.seek(0) 393 | for index, line in enumerate(f.readlines()): 394 | data = json.loads(line) 395 | print('{}/{}'.format(index, all_index)) 396 | new_data = {} 397 | new_data['newsId'] = data['newsId'] 398 | 399 | data['title'] = clean_text(data['title'].strip()) 400 | if len(data['title']) > 125: 401 | data['title'] = data['title'][:125] 402 | print('warning:标题被截断!!') 403 | title = seg_char(data['title']) 404 | new_data['title'] = title 405 | data['content'] = clean_text(data['content'].strip()) 406 | if len(data['content'].strip()) == 0: 407 | new_data['content'] = [] 408 | else: 409 | if data['content'][-1] not in '。。!!??': 410 | data['content'] = data['content'] + '。' 411 | sentences = get_sentences(data['content']) 412 | sentences = seg_char_sents(sentences) 413 | content = [] 414 | for sent in sentences: 415 | content.append(sent) 416 | new_data['content'] = content 417 | 418 | datas.append(new_data) 419 | 420 | f.close() 421 | data_dump(datas, '../datasets/variable_data/test_ner2.pkl') 422 | -------------------------------------------------------------------------------- /src/get_sents_fix.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/15 19:20 4 | # @Author : 邵明岩 5 | # @File : get_sents_fix.py 6 | # @Software: PyCharm 7 | 8 | import re 9 | import pickle 10 | from zhon import hanzi 11 | import string 12 | import json 13 | import random 14 | 15 | cut_num = 0 16 | 17 | 18 | def Unicode(): 19 | val = random.randint(0x4e00, 0x9fbf) 20 | return chr(val) 21 | 22 | 23 | def hash_ch(len_text): 24 | return ''.join([Unicode() for _ in range(len_text)]) 25 | 26 | 27 | def data_dump(data, file): 28 | with open(file, "wb") as f: 29 | pickle.dump(data, f) 30 | print("store data successfully!") 31 | 32 | 33 | def get_label(sent, entity): 34 | def get_entity_len(d): 35 | return len(d['entity']) 36 | 37 | ner = ['O' for _ in range(len(sent))] 38 | 39 | entity = sorted(entity, key=get_entity_len, reverse=True) 40 | for d in entity: 41 | emotion = d['emotion'].strip() 42 | e_l = d['entity'] 43 | name = ''.join(e_l) 44 | e_n = len(e_l) 45 | for i in range(len(sent) - e_n + 1): 46 | a = ''.join(sent[i:i + e_n]) 47 | if a == name: 48 | flag = False 49 | for r in range(e_n): 50 | if ner[r + i].startswith('B') or ner[r + i].startswith('I'): 51 | flag = True 52 | break 53 | 54 | if flag is True: 55 | pass 56 | else: 57 | for r in range(e_n): 58 | if r == 0: 59 | ner[r + i] = 'B-{}'.format(emotion) 60 | else: 61 | ner[r + i] = 'I-{}'.format(emotion) 62 | return ner 63 | 64 | 65 | def get_label_no_emotion(sent, entity): 66 | def get_entity_len(d): 67 | return len(d['entity']) 68 | 69 | ner = ['O' for _ in range(len(sent))] 70 | 71 | entity = sorted(entity, key=get_entity_len, reverse=True) 72 | for d in entity: 73 | e_l = d['entity'] 74 | name = ''.join(e_l) 75 | e_n = len(e_l) 76 | for i in range(len(sent) - e_n + 1): 77 | a = ''.join(sent[i:i + e_n]) 78 | if a == name: 79 | flag = False 80 | for r in range(e_n): 81 | if ner[r + i].startswith('B') or ner[r + i].startswith('I'): 82 | flag = True 83 | break 84 | 85 | if flag is True: 86 | pass 87 | else: 88 | for r in range(e_n): 89 | if r == 0: 90 | ner[r + i] = 'B' 91 | else: 92 | ner[r + i] = 'I' 93 | 94 | return ner 95 | 96 | 97 | symbol_list = [('【', '】'), ('(', ')'), ('“', '”'), ('「', '」'), ('《', '》')] 98 | 99 | 100 | def get_entity_mask(text): 101 | pattern = r'{}[^{}\n\r]+{}' 102 | pattern2 = r'{}[^{}。。!!??\n\r]+{}' 103 | 104 | mask_texts_set = set() 105 | for symbol in symbol_list: 106 | symbol1 = symbol[0] 107 | symbol2 = symbol[1] 108 | if symbol1 == '“': 109 | res = re.findall(pattern2.format(symbol1, symbol2, symbol2), text) 110 | else: 111 | res = re.findall(pattern.format(symbol1, symbol2, symbol2), text) 112 | for r in res: 113 | mask_texts_set.add(r) 114 | 115 | mask_texts_set = list(mask_texts_set) 116 | mask_texts_set.sort(key=lambda x: len(x), reverse=True) 117 | mask_texts = [] 118 | for t in mask_texts_set: 119 | lent = len(t) 120 | while True: 121 | h = str(hash_ch(lent)) 122 | if h not in text: 123 | break 124 | assert len(t) == len(h) 125 | mask_texts.append((t, h)) 126 | 127 | for t, h in mask_texts: 128 | text = text.replace(t, h) 129 | 130 | return mask_texts, text 131 | 132 | 133 | def get_real_text(mask_texts, text): 134 | for t, h in mask_texts: 135 | text = text.replace(h, t) 136 | return text 137 | 138 | 139 | def get_sentences(content): 140 | global cut_num 141 | # mask一些书名号等 142 | mask_texts, content = get_entity_mask(content) 143 | 144 | # 基本分句,。|。|!|\!|?|\?,用这6个符号分 145 | sentences = re.split(r'(。|。|!|\!|?|\?)', content) # 保留分割符 146 | new_sents = [] 147 | for i in range(int(len(sentences) / 2)): 148 | sent = sentences[2 * i] + sentences[2 * i + 1] 149 | sent = sent.strip() 150 | new_sents.append(sent) 151 | res = [] 152 | sentence = '' 153 | max_len = 100 154 | for sent in new_sents: 155 | temp_sents = [] 156 | if len(sent) > max_len: # 大于max_len长度继续分,;|、|,|,|﹔|、,6个次级分句 157 | sents = re.split(r'(;|、|,|,|﹔|、)', sent) # 保留分割符 158 | for s in sents: 159 | if len(s) > max_len: # 子分句也大于max_len,采用截断式分句 160 | cut_num = cut_num + 1 161 | s = get_real_text(mask_texts, s) # 还原mask 162 | ss = [] 163 | j = max_len 164 | while j < len(s): 165 | ss.append(s[j - max_len:j]) 166 | j = j + max_len 167 | if len(s[j - max_len:len(s)]) < 20: 168 | temp = ss[-1] + s[j - max_len:len(s)] 169 | ss[-1] = temp[0:int(len(temp) / 2)] 170 | ss.append(temp[int(len(temp) / 2):]) 171 | else: 172 | ss.append(s[j - max_len:len(s)]) 173 | temp_sents.extend(ss) 174 | 175 | else: 176 | temp_sents.append(s) 177 | else: 178 | temp_sents.append(sent) 179 | 180 | # temp_sents获得了所有子句,将子句尽可能组成max_len长度的长句,减少训练时间 181 | for temp in temp_sents: 182 | if len(sentence + temp) <= max_len: 183 | sentence = sentence + temp 184 | else: 185 | res.append(sentence) 186 | sentence = temp 187 | 188 | if sentence != '': 189 | res.append(sentence) 190 | 191 | result = [] 192 | for r in res: 193 | r = get_real_text(mask_texts, r) 194 | r = r.replace('\n', '') 195 | r = r.replace('\r', '') 196 | result.append(r) 197 | 198 | return result 199 | 200 | 201 | def seg_char(sent): 202 | """ 203 | 把句子按字分开,不破坏英文结构 204 | """ 205 | sent = sent.replace('\n', '') 206 | sent = sent.replace('\r', '') 207 | pattern = re.compile(r'([\u4e00-\u9fa5])') 208 | chars = pattern.split(sent) 209 | chars = [w.strip() for w in chars if len(w.strip()) > 0] 210 | new_chars = [] 211 | for c in chars: 212 | if len(c) > 1: 213 | punctuation = hanzi.punctuation 214 | punctuation = '|'.join([p for p in punctuation]) 215 | pattern = re.compile(' |({})'.format(punctuation)) 216 | cs = pattern.split(c) 217 | for w in cs: 218 | if w and len(w.strip()) > 0: 219 | new_chars.append(w.strip()) 220 | else: 221 | new_chars.append(c) 222 | return new_chars 223 | 224 | 225 | def seg_char_sents(sentences): 226 | results = [] 227 | for sent in sentences: 228 | res = seg_char(sent) 229 | if len(res) > 0: 230 | results.append(res) 231 | return results 232 | 233 | 234 | def get_core_entityemotions(entityemotions): 235 | results = [] 236 | for ee in entityemotions: 237 | result = {} 238 | result['entity'] = seg_char(clean_text(ee['entity'])) 239 | result['emotion'] = ee['emotion'] 240 | results.append(result) 241 | 242 | return results 243 | 244 | 245 | def ishan(char): 246 | # for python 3.x 247 | # sample: ishan('一') == True, ishan('我&&你') == False 248 | return '\u4e00' <= char <= '\u9fff' 249 | 250 | 251 | def clean_text(text): 252 | new_text = [] 253 | for char in text: 254 | if ishan(char) or char in string.digits or char in string.ascii_letters or char in ( 255 | hanzi.punctuation + string.punctuation): 256 | new_text.append(char) 257 | elif char == '\t' or char == ' ': 258 | new_text.append(' ') 259 | elif char == '\r' or char == '\n': 260 | new_text.append('\n') 261 | else: 262 | continue 263 | 264 | new_text = ''.join(new_text) 265 | # html转移字符 266 | new_text = re.sub(r'"', '"', new_text) 267 | new_text = re.sub(r'&', '&', new_text) 268 | new_text = re.sub(r'<', '<', new_text) 269 | new_text = re.sub(r'>', '>', new_text) 270 | new_text = re.sub(r' ', ' ', new_text) 271 | new_text = re.sub(r'·', '·', new_text) 272 | # 去除多余空格 273 | new_text = re.sub(r' +', ' ', new_text) 274 | # 去除html链接 275 | new_text = re.sub( 276 | r'(http|ftp)s?://([^\u4e00-\u9fa5"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。])*', 277 | '', new_text) 278 | # 去除多余连续符号,比如颜文字表情 279 | new_text = re.sub(r'[#%&\'()*+-./:;<=>?@[\]^_`{|}~]{2,}', '', new_text) 280 | 281 | return new_text 282 | 283 | 284 | if __name__ == '__main__': 285 | f = open('../data/coreEntityEmotion_example.txt', 'r') 286 | datas = [] 287 | datas_em = [] 288 | all_index = len(f.readlines()) 289 | f.seek(0) 290 | for index, line in enumerate(f.readlines()): 291 | data = json.loads(line) 292 | print('{}/{}'.format(index, all_index)) 293 | new_data = {} 294 | new_data_em = {} 295 | new_data['newsId'] = data['newsId'] 296 | new_data_em['newsId'] = data['newsId'] 297 | 298 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 299 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 300 | 301 | title = clean_text(data['title'].strip()) 302 | 303 | if len(title) > 125: 304 | title = title[:125] 305 | print('warning:标题被截断!!') 306 | title = seg_char(title) 307 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 308 | assert len(title) == len(title_labels) 309 | new_data['title'] = (title, title_labels) 310 | title_labels = get_label(title, new_data['coreEntityEmotions']) 311 | assert len(title) == len(title_labels) 312 | new_data_em['title'] = (title, title_labels) 313 | data['content'] = clean_text(data['content'].strip()) 314 | if len(data['content'].strip()) == 0: 315 | new_data['content'] = [] 316 | new_data_em['content'] = [] 317 | pass 318 | else: 319 | if data['content'][-1] not in '。。!!??': 320 | data['content'] = data['content'] + '。' 321 | sentences = get_sentences(data['content']) 322 | sentences = seg_char_sents(sentences) 323 | content = [] 324 | content_em = [] 325 | for sent in sentences: 326 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 327 | assert len(sent) == len(sent_labels) 328 | content.append((sent, sent_labels)) 329 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 330 | assert len(sent) == len(sent_labels) 331 | content_em.append((sent, sent_labels)) 332 | new_data['content'] = content 333 | new_data_em['content'] = content_em 334 | 335 | datas.append(new_data) 336 | datas_em.append(new_data_em) 337 | 338 | f.close() 339 | data_dump(datas, '../datasets/example_ner_no_emotion.pkl') 340 | data_dump(datas_em, '../datasets/example_ner_has_emotion.pkl') 341 | 342 | f = open('../data/coreEntityEmotion_train.txt', 'r') 343 | datas = [] 344 | datas_em = [] 345 | all_index = len(f.readlines()) 346 | f.seek(0) 347 | for index, line in enumerate(f.readlines()): 348 | data = json.loads(line) 349 | print('{}/{}'.format(index, all_index)) 350 | new_data = {} 351 | new_data_em = {} 352 | new_data['newsId'] = data['newsId'] 353 | new_data_em['newsId'] = data['newsId'] 354 | 355 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 356 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 357 | 358 | title = clean_text(data['title'].strip()) 359 | 360 | if len(title) > 125: 361 | title = title[:125] 362 | print('warning:标题被截断!!') 363 | title = seg_char(title) 364 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 365 | assert len(title) == len(title_labels) 366 | new_data['title'] = (title, title_labels) 367 | title_labels = get_label(title, new_data['coreEntityEmotions']) 368 | assert len(title) == len(title_labels) 369 | new_data_em['title'] = (title, title_labels) 370 | data['content'] = clean_text(data['content'].strip()) 371 | if len(data['content'].strip()) == 0: 372 | new_data['content'] = [] 373 | new_data_em['content'] = [] 374 | pass 375 | else: 376 | if data['content'][-1] not in '。。!!??': 377 | data['content'] = data['content'] + '。' 378 | sentences = get_sentences(data['content']) 379 | sentences = seg_char_sents(sentences) 380 | content = [] 381 | content_em = [] 382 | for sent in sentences: 383 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 384 | assert len(sent) == len(sent_labels) 385 | content.append((sent, sent_labels)) 386 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 387 | assert len(sent) == len(sent_labels) 388 | content_em.append((sent, sent_labels)) 389 | new_data['content'] = content 390 | new_data_em['content'] = content_em 391 | 392 | datas.append(new_data) 393 | datas_em.append(new_data_em) 394 | 395 | f.close() 396 | data_dump(datas, '../datasets/train_ner_no_emotion.pkl') 397 | data_dump(datas_em, '../datasets/train_ner_has_emotion.pkl') 398 | 399 | f = open('../data/coreEntityEmotion_test_stage2.txt', 'r') 400 | datas = [] 401 | all_index = len(f.readlines()) 402 | f.seek(0) 403 | for index, line in enumerate(f.readlines()): 404 | data = json.loads(line) 405 | print('{}/{}'.format(index, all_index)) 406 | new_data = {} 407 | new_data['newsId'] = data['newsId'] 408 | 409 | data['title'] = clean_text(data['title'].strip()) 410 | if len(data['title']) > 125: 411 | data['title'] = data['title'][:125] 412 | print('warning:标题被截断!!') 413 | title = seg_char(data['title']) 414 | new_data['title'] = title 415 | data['content'] = clean_text(data['content'].strip()) 416 | if len(data['content'].strip()) == 0: 417 | new_data['content'] = [] 418 | else: 419 | if data['content'][-1] not in '。。!!??': 420 | data['content'] = data['content'] + '。' 421 | sentences = get_sentences(data['content']) 422 | sentences = seg_char_sents(sentences) 423 | content = [] 424 | for sent in sentences: 425 | content.append(sent) 426 | new_data['content'] = content 427 | 428 | datas.append(new_data) 429 | 430 | f.close() 431 | data_dump(datas, '../datasets/test_ner2.pkl') 432 | 433 | print('截断总数{}'.format(cut_num)) 434 | -------------------------------------------------------------------------------- /src/get_sents_fix_more.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/15 19:20 4 | # @Author : 邵明岩 5 | # @File : get_sents_fix.py 6 | # @Software: PyCharm 7 | 8 | import re 9 | import pickle 10 | from zhon import hanzi 11 | import string 12 | import json 13 | import random 14 | 15 | cut_num = 0 16 | 17 | 18 | def Unicode(): 19 | val = random.randint(0x4e00, 0x9fbf) 20 | return chr(val) 21 | 22 | 23 | def hash_ch(len_text): 24 | return ''.join([Unicode() for _ in range(len_text)]) 25 | 26 | 27 | def data_dump(data, file): 28 | with open(file, "wb") as f: 29 | pickle.dump(data, f) 30 | print("store data successfully!") 31 | 32 | 33 | def get_label(sent, entity): 34 | def get_entity_len(d): 35 | return len(d['entity']) 36 | 37 | ner = ['O' for _ in range(len(sent))] 38 | 39 | entity = sorted(entity, key=get_entity_len, reverse=True) 40 | for d in entity: 41 | emotion = d['emotion'].strip() 42 | e_l = d['entity'] 43 | name = ''.join(e_l) 44 | e_n = len(e_l) 45 | for i in range(len(sent) - e_n + 1): 46 | a = ''.join(sent[i:i + e_n]) 47 | if a == name: 48 | flag = False 49 | for r in range(e_n): 50 | if ner[r + i].startswith('B') or ner[r + i].startswith('I') or ner[r + i].startswith('E'): 51 | flag = True 52 | break 53 | 54 | if flag is True: 55 | pass 56 | else: 57 | for r in range(e_n): 58 | if r == 0: 59 | ner[r + i] = 'B-{}'.format(emotion) 60 | elif r == e_n - 1: 61 | ner[r + i] = 'E-{}'.format(emotion) 62 | else: 63 | ner[r + i] = 'I-{}'.format(emotion) 64 | return ner 65 | 66 | 67 | def get_label_no_emotion(sent, entity): 68 | def get_entity_len(d): 69 | return len(d['entity']) 70 | 71 | ner = ['O' for _ in range(len(sent))] 72 | 73 | entity = sorted(entity, key=get_entity_len, reverse=True) 74 | for d in entity: 75 | e_l = d['entity'] 76 | name = ''.join(e_l) 77 | e_n = len(e_l) 78 | for i in range(len(sent) - e_n + 1): 79 | a = ''.join(sent[i:i + e_n]) 80 | if a == name: 81 | flag = False 82 | for r in range(e_n): 83 | if ner[r + i].startswith('B') or ner[r + i].startswith('I') or ner[r + i].startswith('E'): 84 | flag = True 85 | break 86 | 87 | if flag is True: 88 | pass 89 | else: 90 | for r in range(e_n): 91 | if r == 0: 92 | ner[r + i] = 'B' 93 | elif r == e_n - 1: 94 | ner[r + i] = 'E' 95 | else: 96 | ner[r + i] = 'I' 97 | 98 | return ner 99 | 100 | 101 | symbol_list = [('【', '】'), ('(', ')'), ('“', '”'), ('「', '」'), ('《', '》')] 102 | 103 | 104 | def get_entity_mask(text): 105 | pattern = r'{}[^{}\n\r]+{}' 106 | pattern2 = r'{}[^{}。。!!??\n\r]+{}' 107 | 108 | mask_texts_set = set() 109 | for symbol in symbol_list: 110 | symbol1 = symbol[0] 111 | symbol2 = symbol[1] 112 | if symbol1 == '“': 113 | res = re.findall(pattern2.format(symbol1, symbol2, symbol2), text) 114 | else: 115 | res = re.findall(pattern.format(symbol1, symbol2, symbol2), text) 116 | for r in res: 117 | mask_texts_set.add(r) 118 | 119 | mask_texts_set = list(mask_texts_set) 120 | mask_texts_set.sort(key=lambda x: len(x), reverse=True) 121 | mask_texts = [] 122 | for t in mask_texts_set: 123 | lent = len(t) 124 | while True: 125 | h = str(hash_ch(lent)) 126 | if h not in text: 127 | break 128 | assert len(t) == len(h) 129 | mask_texts.append((t, h)) 130 | 131 | for t, h in mask_texts: 132 | text = text.replace(t, h) 133 | 134 | return mask_texts, text 135 | 136 | 137 | def get_real_text(mask_texts, text): 138 | for t, h in mask_texts: 139 | text = text.replace(h, t) 140 | return text 141 | 142 | 143 | def get_sentences(content, title_len): 144 | global cut_num 145 | # mask一些书名号等 146 | mask_texts, content = get_entity_mask(content) 147 | 148 | # 基本分句,。|。|!|\!|?|\?,用这6个符号分 149 | sentences = re.split(r'(。|。|!|\!|?|\?)', content) # 保留分割符 150 | new_sents = [] 151 | for i in range(int(len(sentences) / 2)): 152 | sent = sentences[2 * i] + sentences[2 * i + 1] 153 | sent = sent.strip() 154 | new_sents.append(sent) 155 | res = [] 156 | sentence = '' 157 | max_len = 250 - title_len 158 | for sent in new_sents: 159 | temp_sents = [] 160 | if len(sent) > max_len: # 大于max_len长度继续分,;|、|,|,|﹔|、,6个次级分句 161 | sents = re.split(r'(;|、|,|,|﹔|、)', sent) # 保留分割符 162 | for s in sents: 163 | if len(s) > max_len: # 子分句也大于max_len,采用截断式分句 164 | cut_num = cut_num + 1 165 | s = get_real_text(mask_texts, s) # 还原mask 166 | ss = [] 167 | j = max_len 168 | while j < len(s): 169 | ss.append(s[j - max_len:j]) 170 | j = j + max_len 171 | if len(s[j - max_len:len(s)]) < 20: 172 | temp = ss[-1] + s[j - max_len:len(s)] 173 | ss[-1] = temp[0:int(len(temp) / 2)] 174 | ss.append(temp[int(len(temp) / 2):]) 175 | else: 176 | ss.append(s[j - max_len:len(s)]) 177 | temp_sents.extend(ss) 178 | 179 | else: 180 | temp_sents.append(s) 181 | else: 182 | temp_sents.append(sent) 183 | 184 | # temp_sents获得了所有子句,将子句尽可能组成max_len长度的长句,减少训练时间 185 | for temp in temp_sents: 186 | if len(sentence + temp) <= max_len: 187 | sentence = sentence + temp 188 | else: 189 | res.append(sentence) 190 | sentence = temp 191 | 192 | if sentence != '': 193 | res.append(sentence) 194 | 195 | result = [] 196 | for r in res: 197 | r = get_real_text(mask_texts, r) 198 | r = r.replace('\n', '') 199 | r = r.replace('\r', '') 200 | result.append(r) 201 | 202 | return result 203 | 204 | 205 | def seg_char(sent): 206 | """ 207 | 把句子按字分开,不破坏英文结构 208 | """ 209 | sent = sent.replace('\n', '') 210 | sent = sent.replace('\r', '') 211 | pattern = re.compile(r'([\u4e00-\u9fa5])') 212 | chars = pattern.split(sent) 213 | chars = [w.strip() for w in chars if len(w.strip()) > 0] 214 | new_chars = [] 215 | for c in chars: 216 | if len(c) > 1: 217 | punctuation = hanzi.punctuation 218 | punctuation = '|'.join([p for p in punctuation]) 219 | pattern = re.compile(' |({})'.format(punctuation)) 220 | cs = pattern.split(c) 221 | for w in cs: 222 | if w and len(w.strip()) > 0: 223 | new_chars.append(w.strip()) 224 | else: 225 | new_chars.append(c) 226 | return new_chars 227 | 228 | 229 | def seg_char_sents(sentences): 230 | results = [] 231 | for sent in sentences: 232 | res = seg_char(sent) 233 | if len(res) > 0: 234 | results.append(res) 235 | return results 236 | 237 | 238 | def get_core_entityemotions(entityemotions): 239 | results = [] 240 | for ee in entityemotions: 241 | result = {} 242 | result['entity'] = seg_char(clean_text(ee['entity'])) 243 | result['emotion'] = ee['emotion'] 244 | results.append(result) 245 | 246 | return results 247 | 248 | 249 | def ishan(char): 250 | # for python 3.x 251 | # sample: ishan('一') == True, ishan('我&&你') == False 252 | return '\u4e00' <= char <= '\u9fff' 253 | 254 | 255 | def clean_text(text): 256 | new_text = [] 257 | for char in text: 258 | if ishan(char) or char in string.digits or char in string.ascii_letters or char in ( 259 | hanzi.punctuation + string.punctuation): 260 | new_text.append(char) 261 | elif char == '\t' or char == ' ': 262 | new_text.append(' ') 263 | elif char == '\r' or char == '\n': 264 | new_text.append('\n') 265 | else: 266 | continue 267 | 268 | new_text = ''.join(new_text) 269 | # html转移字符 270 | new_text = re.sub(r'"', '"', new_text) 271 | new_text = re.sub(r'&', '&', new_text) 272 | new_text = re.sub(r'<', '<', new_text) 273 | new_text = re.sub(r'>', '>', new_text) 274 | new_text = re.sub(r' ', ' ', new_text) 275 | new_text = re.sub(r'·', '·', new_text) 276 | # 去除多余空格 277 | new_text = re.sub(r' +', ' ', new_text) 278 | # 去除html链接 279 | new_text = re.sub( 280 | r'(http|ftp)s?://([^\u4e00-\u9fa5"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。])*', 281 | '', new_text) 282 | # 去除多余连续符号,比如颜文字表情 283 | new_text = re.sub(r'[#%&\'()*+-./:;<=>?@[\]^_`{|}~]{2,}', '', new_text) 284 | 285 | return new_text 286 | 287 | 288 | if __name__ == '__main__': 289 | f = open('../data/coreEntityEmotion_example.txt', 'r') 290 | datas = [] 291 | datas_em = [] 292 | all_index = len(f.readlines()) 293 | f.seek(0) 294 | for index, line in enumerate(f.readlines()): 295 | data = json.loads(line) 296 | if index % 100 == 0: 297 | print('{}/{}'.format(index, all_index)) 298 | new_data = {} 299 | new_data_em = {} 300 | new_data['newsId'] = data['newsId'] 301 | new_data_em['newsId'] = data['newsId'] 302 | 303 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 304 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 305 | 306 | title = clean_text(data['title'].strip()) 307 | 308 | if len(title) > 125: 309 | title = title[:125] 310 | print('warning:标题被截断!!') 311 | title = seg_char(title) 312 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 313 | assert len(title) == len(title_labels) 314 | new_data['title'] = (title, title_labels) 315 | title_labels = get_label(title, new_data['coreEntityEmotions']) 316 | assert len(title) == len(title_labels) 317 | new_data_em['title'] = (title, title_labels) 318 | data['content'] = clean_text(data['content'].strip()) 319 | if len(data['content'].strip()) == 0: 320 | new_data['content'] = [] 321 | new_data_em['content'] = [] 322 | pass 323 | else: 324 | if data['content'][-1] not in '。。!!??': 325 | data['content'] = data['content'] + '。' 326 | sentences = get_sentences(data['content'], len(''.join(title))) 327 | sentences = seg_char_sents(sentences) 328 | content = [] 329 | content_em = [] 330 | for sent in sentences: 331 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 332 | assert len(sent) == len(sent_labels) 333 | content.append((sent, sent_labels)) 334 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 335 | assert len(sent) == len(sent_labels) 336 | content_em.append((sent, sent_labels)) 337 | new_data['content'] = content 338 | new_data_em['content'] = content_em 339 | 340 | datas.append(new_data) 341 | datas_em.append(new_data_em) 342 | 343 | f.close() 344 | data_dump(datas, '../datasets/256/example_ner_no_emotion.pkl') 345 | data_dump(datas_em, '../datasets/256/example_ner_has_emotion.pkl') 346 | 347 | f = open('../data/coreEntityEmotion_train.txt', 'r') 348 | datas = [] 349 | datas_em = [] 350 | all_index = len(f.readlines()) 351 | f.seek(0) 352 | for index, line in enumerate(f.readlines()): 353 | data = json.loads(line) 354 | if index % 100 == 0: 355 | print('{}/{}'.format(index, all_index)) 356 | new_data = {} 357 | new_data_em = {} 358 | new_data['newsId'] = data['newsId'] 359 | new_data_em['newsId'] = data['newsId'] 360 | 361 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 362 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 363 | 364 | title = clean_text(data['title'].strip()) 365 | 366 | if len(title) > 125: 367 | title = title[:125] 368 | print('warning:标题被截断!!') 369 | title = seg_char(title) 370 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 371 | assert len(title) == len(title_labels) 372 | new_data['title'] = (title, title_labels) 373 | title_labels = get_label(title, new_data['coreEntityEmotions']) 374 | assert len(title) == len(title_labels) 375 | new_data_em['title'] = (title, title_labels) 376 | data['content'] = clean_text(data['content'].strip()) 377 | if len(data['content'].strip()) == 0: 378 | new_data['content'] = [] 379 | new_data_em['content'] = [] 380 | pass 381 | else: 382 | if data['content'][-1] not in '。。!!??': 383 | data['content'] = data['content'] + '。' 384 | sentences = get_sentences(data['content'], len(''.join(title))) 385 | sentences = seg_char_sents(sentences) 386 | content = [] 387 | content_em = [] 388 | for sent in sentences: 389 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 390 | assert len(sent) == len(sent_labels) 391 | content.append((sent, sent_labels)) 392 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 393 | assert len(sent) == len(sent_labels) 394 | content_em.append((sent, sent_labels)) 395 | new_data['content'] = content 396 | new_data_em['content'] = content_em 397 | 398 | datas.append(new_data) 399 | datas_em.append(new_data_em) 400 | 401 | f.close() 402 | data_dump(datas, '../datasets/256/train_ner_no_emotion.pkl') 403 | data_dump(datas_em, '../datasets/256/train_ner_has_emotion.pkl') 404 | 405 | f = open('../data/coreEntityEmotion_test_stage2.txt', 'r') 406 | datas = [] 407 | all_index = len(f.readlines()) 408 | f.seek(0) 409 | for index, line in enumerate(f.readlines()): 410 | data = json.loads(line) 411 | if index % 100 == 0: 412 | print('{}/{}'.format(index, all_index)) 413 | new_data = {} 414 | new_data['newsId'] = data['newsId'] 415 | 416 | data['title'] = clean_text(data['title'].strip()) 417 | if len(data['title']) > 125: 418 | data['title'] = data['title'][:125] 419 | print('warning:标题被截断!!') 420 | title = seg_char(data['title']) 421 | new_data['title'] = title 422 | data['content'] = clean_text(data['content'].strip()) 423 | if len(data['content'].strip()) == 0: 424 | new_data['content'] = [] 425 | else: 426 | if data['content'][-1] not in '。。!!??': 427 | data['content'] = data['content'] + '。' 428 | sentences = get_sentences(data['content'], len(''.join(title))) 429 | sentences = seg_char_sents(sentences) 430 | content = [] 431 | for sent in sentences: 432 | content.append(sent) 433 | new_data['content'] = content 434 | 435 | datas.append(new_data) 436 | 437 | f.close() 438 | data_dump(datas, '../datasets/256/test_ner2.pkl') 439 | 440 | print('截断总数{}'.format(cut_num)) 441 | -------------------------------------------------------------------------------- /src/get_sents_fix_more_s.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/5/15 19:20 4 | # @Author : 邵明岩 5 | # @File : get_sents_fix.py 6 | # @Software: PyCharm 7 | 8 | import re 9 | import pickle 10 | from zhon import hanzi 11 | import string 12 | import json 13 | import random 14 | 15 | cut_num = 0 16 | 17 | 18 | def Unicode(): 19 | val = random.randint(0x4e00, 0x9fbf) 20 | return chr(val) 21 | 22 | 23 | def hash_ch(len_text): 24 | return ''.join([Unicode() for _ in range(len_text)]) 25 | 26 | 27 | def data_dump(data, file): 28 | with open(file, "wb") as f: 29 | pickle.dump(data, f) 30 | print("store data successfully!") 31 | 32 | 33 | def get_label(sent, entity): 34 | def get_entity_len(d): 35 | return len(d['entity']) 36 | 37 | ner = ['O' for _ in range(len(sent))] 38 | 39 | entity = sorted(entity, key=get_entity_len, reverse=True) 40 | for d in entity: 41 | emotion = d['emotion'].strip() 42 | e_l = d['entity'] 43 | name = ''.join(e_l) 44 | e_n = len(e_l) 45 | for i in range(len(sent) - e_n + 1): 46 | a = ''.join(sent[i:i + e_n]) 47 | if a == name: 48 | flag = False 49 | for r in range(e_n): 50 | if ner[r + i].startswith('B') or ner[r + i].startswith('I') \ 51 | or ner[r + i].startswith('E') or ner[r + i].startswith('S'): 52 | flag = True 53 | break 54 | 55 | if flag is True: 56 | pass 57 | else: 58 | if e_n == 1: 59 | ner[i] = 'S-{}'.format(emotion) 60 | else: 61 | for r in range(e_n): 62 | if r == 0: 63 | ner[r + i] = 'B-{}'.format(emotion) 64 | elif r == e_n - 1: 65 | ner[r + i] = 'E-{}'.format(emotion) 66 | else: 67 | ner[r + i] = 'I-{}'.format(emotion) 68 | return ner 69 | 70 | 71 | def get_label_no_emotion(sent, entity): 72 | def get_entity_len(d): 73 | return len(d['entity']) 74 | 75 | ner = ['O' for _ in range(len(sent))] 76 | 77 | entity = sorted(entity, key=get_entity_len, reverse=True) 78 | for d in entity: 79 | e_l = d['entity'] 80 | name = ''.join(e_l) 81 | e_n = len(e_l) 82 | for i in range(len(sent) - e_n + 1): 83 | a = ''.join(sent[i:i + e_n]) 84 | if a == name: 85 | flag = False 86 | for r in range(e_n): 87 | if ner[r + i].startswith('B') or ner[r + i].startswith('I') \ 88 | or ner[r + i].startswith('E') or ner[r + i].startswith('S'): 89 | flag = True 90 | break 91 | 92 | if flag is True: 93 | pass 94 | else: 95 | if e_n == 1: 96 | ner[i] = 'S' 97 | else: 98 | for r in range(e_n): 99 | if r == 0: 100 | ner[r + i] = 'B' 101 | elif r == e_n - 1: 102 | ner[r + i] = 'E' 103 | else: 104 | ner[r + i] = 'I' 105 | 106 | return ner 107 | 108 | 109 | symbol_list = [('【', '】'), ('(', ')'), ('“', '”'), ('「', '」'), ('《', '》')] 110 | 111 | 112 | def get_entity_mask(text): 113 | pattern = r'{}[^{}\n\r]+{}' 114 | pattern2 = r'{}[^{}。。!!??\n\r]+{}' 115 | 116 | mask_texts_set = set() 117 | for symbol in symbol_list: 118 | symbol1 = symbol[0] 119 | symbol2 = symbol[1] 120 | if symbol1 == '“': 121 | res = re.findall(pattern2.format(symbol1, symbol2, symbol2), text) 122 | else: 123 | res = re.findall(pattern.format(symbol1, symbol2, symbol2), text) 124 | for r in res: 125 | mask_texts_set.add(r) 126 | 127 | mask_texts_set = list(mask_texts_set) 128 | mask_texts_set.sort(key=lambda x: len(x), reverse=True) 129 | mask_texts = [] 130 | for t in mask_texts_set: 131 | lent = len(t) 132 | while True: 133 | h = str(hash_ch(lent)) 134 | if h not in text: 135 | break 136 | assert len(t) == len(h) 137 | mask_texts.append((t, h)) 138 | 139 | for t, h in mask_texts: 140 | text = text.replace(t, h) 141 | 142 | return mask_texts, text 143 | 144 | 145 | def get_real_text(mask_texts, text): 146 | for t, h in mask_texts: 147 | text = text.replace(h, t) 148 | return text 149 | 150 | 151 | def get_sentences(content, title_len): 152 | global cut_num 153 | # mask一些书名号等 154 | mask_texts, content = get_entity_mask(content) 155 | 156 | # 基本分句,。|。|!|\!|?|\?,用这6个符号分 157 | sentences = re.split(r'(。|。|!|\!|?|\?)', content) # 保留分割符 158 | new_sents = [] 159 | for i in range(int(len(sentences) / 2)): 160 | sent = sentences[2 * i] + sentences[2 * i + 1] 161 | sent = sent.strip() 162 | new_sents.append(sent) 163 | res = [] 164 | sentence = '' 165 | max_len = 250 - title_len 166 | for sent in new_sents: 167 | temp_sents = [] 168 | if len(sent) > max_len: # 大于max_len长度继续分,;|、|,|,|﹔|、,6个次级分句 169 | sents = re.split(r'(;|、|,|,|﹔|、)', sent) # 保留分割符 170 | for s in sents: 171 | if len(s) > max_len: # 子分句也大于max_len,采用截断式分句 172 | cut_num = cut_num + 1 173 | s = get_real_text(mask_texts, s) # 还原mask 174 | ss = [] 175 | j = max_len 176 | while j < len(s): 177 | ss.append(s[j - max_len:j]) 178 | j = j + max_len 179 | if len(s[j - max_len:len(s)]) < 20: 180 | temp = ss[-1] + s[j - max_len:len(s)] 181 | ss[-1] = temp[0:int(len(temp) / 2)] 182 | ss.append(temp[int(len(temp) / 2):]) 183 | else: 184 | ss.append(s[j - max_len:len(s)]) 185 | temp_sents.extend(ss) 186 | 187 | else: 188 | temp_sents.append(s) 189 | else: 190 | temp_sents.append(sent) 191 | 192 | # temp_sents获得了所有子句,将子句尽可能组成max_len长度的长句,减少训练时间 193 | for temp in temp_sents: 194 | if len(sentence + temp) <= max_len: 195 | sentence = sentence + temp 196 | else: 197 | res.append(sentence) 198 | sentence = temp 199 | 200 | if sentence != '': 201 | res.append(sentence) 202 | 203 | result = [] 204 | for r in res: 205 | r = get_real_text(mask_texts, r) 206 | r = r.replace('\n', '') 207 | r = r.replace('\r', '') 208 | result.append(r) 209 | 210 | return result 211 | 212 | 213 | def seg_char(sent): 214 | """ 215 | 把句子按字分开,不破坏英文结构 216 | """ 217 | sent = sent.replace('\n', '') 218 | sent = sent.replace('\r', '') 219 | pattern = re.compile(r'([\u4e00-\u9fa5])') 220 | chars = pattern.split(sent) 221 | chars = [w.strip() for w in chars if len(w.strip()) > 0] 222 | new_chars = [] 223 | for c in chars: 224 | if len(c) > 1: 225 | punctuation = hanzi.punctuation 226 | punctuation = '|'.join([p for p in punctuation]) 227 | pattern = re.compile(' |({})'.format(punctuation)) 228 | cs = pattern.split(c) 229 | for w in cs: 230 | if w and len(w.strip()) > 0: 231 | new_chars.append(w.strip()) 232 | else: 233 | new_chars.append(c) 234 | return new_chars 235 | 236 | 237 | def seg_char_sents(sentences): 238 | results = [] 239 | for sent in sentences: 240 | res = seg_char(sent) 241 | if len(res) > 0: 242 | results.append(res) 243 | return results 244 | 245 | 246 | def get_core_entityemotions(entityemotions): 247 | results = [] 248 | for ee in entityemotions: 249 | result = {} 250 | result['entity'] = seg_char(clean_text(ee['entity'])) 251 | result['emotion'] = ee['emotion'] 252 | results.append(result) 253 | 254 | return results 255 | 256 | 257 | def ishan(char): 258 | # for python 3.x 259 | # sample: ishan('一') == True, ishan('我&&你') == False 260 | return '\u4e00' <= char <= '\u9fff' 261 | 262 | 263 | def clean_text(text): 264 | new_text = [] 265 | for char in text: 266 | if ishan(char) or char in string.digits or char in string.ascii_letters or char in ( 267 | hanzi.punctuation + string.punctuation): 268 | new_text.append(char) 269 | elif char == '\t' or char == ' ': 270 | new_text.append(' ') 271 | elif char == '\r' or char == '\n': 272 | new_text.append('\n') 273 | else: 274 | continue 275 | 276 | new_text = ''.join(new_text) 277 | # html转移字符 278 | new_text = re.sub(r'"', '"', new_text) 279 | new_text = re.sub(r'&', '&', new_text) 280 | new_text = re.sub(r'<', '<', new_text) 281 | new_text = re.sub(r'>', '>', new_text) 282 | new_text = re.sub(r' ', ' ', new_text) 283 | new_text = re.sub(r'·', '·', new_text) 284 | # 去除多余空格 285 | new_text = re.sub(r' +', ' ', new_text) 286 | # 去除html链接 287 | new_text = re.sub( 288 | r'(http|ftp)s?://([^\u4e00-\u9fa5"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。])*', 289 | '', new_text) 290 | # 去除多余连续符号,比如颜文字表情 291 | new_text = re.sub(r'[#%&\'()*+-./:;<=>?@[\]^_`{|}~]{2,}', '', new_text) 292 | 293 | return new_text 294 | 295 | 296 | if __name__ == '__main__': 297 | f = open('../data/coreEntityEmotion_example.txt', 'r') 298 | datas = [] 299 | datas_em = [] 300 | all_index = len(f.readlines()) 301 | f.seek(0) 302 | for index, line in enumerate(f.readlines()): 303 | data = json.loads(line) 304 | if index % 100 == 0: 305 | print('{}/{}'.format(index, all_index)) 306 | new_data = {} 307 | new_data_em = {} 308 | new_data['newsId'] = data['newsId'] 309 | new_data_em['newsId'] = data['newsId'] 310 | 311 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 312 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 313 | 314 | title = clean_text(data['title'].strip()) 315 | 316 | if len(title) > 125: 317 | title = title[:125] 318 | print('warning:标题被截断!!') 319 | title = seg_char(title) 320 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 321 | assert len(title) == len(title_labels) 322 | new_data['title'] = (title, title_labels) 323 | title_labels = get_label(title, new_data['coreEntityEmotions']) 324 | assert len(title) == len(title_labels) 325 | new_data_em['title'] = (title, title_labels) 326 | data['content'] = clean_text(data['content'].strip()) 327 | if len(data['content'].strip()) == 0: 328 | new_data['content'] = [] 329 | new_data_em['content'] = [] 330 | pass 331 | else: 332 | if data['content'][-1] not in '。。!!??': 333 | data['content'] = data['content'] + '。' 334 | sentences = get_sentences(data['content'], len(''.join(title))) 335 | sentences = seg_char_sents(sentences) 336 | content = [] 337 | content_em = [] 338 | for sent in sentences: 339 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 340 | assert len(sent) == len(sent_labels) 341 | content.append((sent, sent_labels)) 342 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 343 | assert len(sent) == len(sent_labels) 344 | content_em.append((sent, sent_labels)) 345 | new_data['content'] = content 346 | new_data_em['content'] = content_em 347 | 348 | datas.append(new_data) 349 | datas_em.append(new_data_em) 350 | 351 | f.close() 352 | data_dump(datas, '../datasets/256_s/example_ner_no_emotion.pkl') 353 | data_dump(datas_em, '../datasets/256_s/example_ner_has_emotion.pkl') 354 | 355 | f = open('../data/coreEntityEmotion_train.txt', 'r') 356 | datas = [] 357 | datas_em = [] 358 | all_index = len(f.readlines()) 359 | f.seek(0) 360 | for index, line in enumerate(f.readlines()): 361 | data = json.loads(line) 362 | if index % 100 == 0: 363 | print('{}/{}'.format(index, all_index)) 364 | new_data = {} 365 | new_data_em = {} 366 | new_data['newsId'] = data['newsId'] 367 | new_data_em['newsId'] = data['newsId'] 368 | 369 | new_data['coreEntityEmotions'] = get_core_entityemotions(data['coreEntityEmotions']) 370 | new_data_em['coreEntityEmotions'] = new_data['coreEntityEmotions'] 371 | 372 | title = clean_text(data['title'].strip()) 373 | 374 | if len(title) > 125: 375 | title = title[:125] 376 | print('warning:标题被截断!!') 377 | title = seg_char(title) 378 | title_labels = get_label_no_emotion(title, new_data['coreEntityEmotions']) 379 | assert len(title) == len(title_labels) 380 | new_data['title'] = (title, title_labels) 381 | title_labels = get_label(title, new_data['coreEntityEmotions']) 382 | assert len(title) == len(title_labels) 383 | new_data_em['title'] = (title, title_labels) 384 | data['content'] = clean_text(data['content'].strip()) 385 | if len(data['content'].strip()) == 0: 386 | new_data['content'] = [] 387 | new_data_em['content'] = [] 388 | pass 389 | else: 390 | if data['content'][-1] not in '。。!!??': 391 | data['content'] = data['content'] + '。' 392 | sentences = get_sentences(data['content'], len(''.join(title))) 393 | sentences = seg_char_sents(sentences) 394 | content = [] 395 | content_em = [] 396 | for sent in sentences: 397 | sent_labels = get_label_no_emotion(sent, new_data['coreEntityEmotions']) 398 | assert len(sent) == len(sent_labels) 399 | content.append((sent, sent_labels)) 400 | sent_labels = get_label(sent, new_data['coreEntityEmotions']) 401 | assert len(sent) == len(sent_labels) 402 | content_em.append((sent, sent_labels)) 403 | new_data['content'] = content 404 | new_data_em['content'] = content_em 405 | 406 | datas.append(new_data) 407 | datas_em.append(new_data_em) 408 | 409 | f.close() 410 | data_dump(datas, '../datasets/256_s/train_ner_no_emotion.pkl') 411 | data_dump(datas_em, '../datasets/256_s/train_ner_has_emotion.pkl') 412 | 413 | f = open('../data/coreEntityEmotion_test_stage2.txt', 'r') 414 | datas = [] 415 | all_index = len(f.readlines()) 416 | f.seek(0) 417 | for index, line in enumerate(f.readlines()): 418 | data = json.loads(line) 419 | if index % 100 == 0: 420 | print('{}/{}'.format(index, all_index)) 421 | new_data = {} 422 | new_data['newsId'] = data['newsId'] 423 | 424 | data['title'] = clean_text(data['title'].strip()) 425 | if len(data['title']) > 125: 426 | data['title'] = data['title'][:125] 427 | print('warning:标题被截断!!') 428 | title = seg_char(data['title']) 429 | new_data['title'] = title 430 | data['content'] = clean_text(data['content'].strip()) 431 | if len(data['content'].strip()) == 0: 432 | new_data['content'] = [] 433 | else: 434 | if data['content'][-1] not in '。。!!??': 435 | data['content'] = data['content'] + '。' 436 | sentences = get_sentences(data['content'], len(''.join(title))) 437 | sentences = seg_char_sents(sentences) 438 | content = [] 439 | for sent in sentences: 440 | content.append(sent) 441 | new_data['content'] = content 442 | 443 | datas.append(new_data) 444 | 445 | f.close() 446 | data_dump(datas, '../datasets/256_s/test_ner2.pkl') 447 | 448 | print('截断总数{}'.format(cut_num)) 449 | -------------------------------------------------------------------------------- /src/data_title_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | from pytorch_pretrained_bert.tokenization import BertTokenizer 11 | import h5py 12 | import gc 13 | from collections import OrderedDict 14 | import re 15 | from utils import covert_mytokens_to_myids, data_dump, load_data 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt='%m/%d/%Y %H:%M:%S', 19 | level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class InputExample(object): 24 | """A single training/test example for simple sequence classification.""" 25 | 26 | def __init__(self, guid, text, text_title=None, label_sent=None, label_title=None): 27 | ''' 28 | :param guid: example id, may be is newId 29 | :param text: cur sentence text 30 | :param text_title: the news title text 31 | :parm label: label text [B-POS, I-POS, I-POS] 32 | 在 fetures 中 label => 33 | label_ent: entity label [0,1,2...10] 34 | label_emo: emotion label [0,1,2...7] can transform from the label_ent 35 | 情感的标记到底应该时什么 ? [pos pos pos] or [B-pos, I-pos, I-pos] 36 | ''' 37 | self.guid = guid 38 | self.text = text 39 | self.text_title = text_title 40 | self.label_text = label_sent 41 | self.label_title = label_title 42 | 43 | 44 | class InputFeatures(object): 45 | """A single set of features of data.""" 46 | 47 | def __init__(self, ID, input_ids, myinput_ids, input_mask, segment_ids): 48 | self.ID = ID 49 | self.input_ids = input_ids 50 | self.myinput_ids = myinput_ids 51 | self.input_mask = input_mask 52 | self.segment_ids = segment_ids 53 | 54 | 55 | def readfile(filename): 56 | logger.info("read file:{}....".format(filename)) 57 | data = [] 58 | 59 | def _read(): 60 | f = open(filename, encoding="utf-8") 61 | title = None 62 | sentence = [] 63 | label_sent = [] 64 | label_title = [] 65 | count = 0 66 | ID = 0 67 | has_title = False 68 | for line in f: 69 | if len(line) == 0 or line[0] == "\n": 70 | if len(sentence) > 0: 71 | if not has_title: 72 | # the title 73 | title = sentence 74 | has_title = True 75 | data.append((ID, sentence, title, label_sent, label_title)) 76 | # refresh 77 | sentence = [] 78 | label_sent = [] 79 | else: 80 | # 第二个空行 sentence 已空 81 | has_title = False 82 | ID += 1 83 | count += 1 84 | if count % 50000 == 0: 85 | print(count) 86 | continue 87 | sentence.append(line[:-1]) 88 | 89 | 90 | # 防止因最后一行非空行而没加入最后一个 91 | if len(sentence) > 0: 92 | if not has_title: 93 | # the title 94 | title = sentence 95 | label_title = label_sent 96 | has_title = True 97 | data.append((ID, sentence, title, label_sent, label_title)) 98 | # refresh 99 | # sentence = [] 100 | # label_sent = [] 101 | f.close() 102 | 103 | _read() 104 | # 最后一句查看 105 | print("The Last Sentence....") 106 | print(data[-1][0]) 107 | print(data[-1][1]) 108 | print("sentence num: ", len(data)) 109 | return data 110 | 111 | 112 | class DataProcessor(object): 113 | """Base class for data converters for sequence classification data sets.""" 114 | 115 | def get_train_examples(self, data_dir): 116 | """Gets a collection of `InputExample`s for the train set.""" 117 | raise NotImplementedError() 118 | 119 | def get_dev_examples(self, data_dir): 120 | """Gets a collection of `InputExample`s for the dev set.""" 121 | raise NotImplementedError() 122 | 123 | def get_labels(self): 124 | """Gets the list of labels for this data set.""" 125 | raise NotImplementedError() 126 | 127 | @classmethod 128 | def _read_tsv(cls, input_file, quotechar=None): 129 | """Reads a tab separated value file.""" 130 | return readfile(input_file) 131 | 132 | 133 | class NerProcessor(DataProcessor): 134 | """Processor for the CoNLL-2003 data set.""" 135 | 136 | def get_test_examples(self, data_dir): 137 | """See base class.""" 138 | return self._create_examples( 139 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test") 140 | 141 | def get_labels(self): 142 | return ["O", "B-POS", "I-POS", "B-NEG", "I-NEG", "B-NORM", "I-NORM", "X", "[CLS]", "[SEP]"] 143 | 144 | def _create_examples(self, lines, set_type): 145 | examples = [] 146 | for i, (ID, sentence, title, label_sent, label_title) in enumerate(lines): 147 | guid = ID 148 | text = ' '.join(sentence) 149 | text_title = ' '.join(title) 150 | examples.append(InputExample(guid=guid, text=text, text_title=text_title, label_sent=label_sent, 151 | label_title=label_title)) 152 | return examples 153 | 154 | 155 | def _truncate_seq_pair(tokens_a, tokens_b, mytokens_a, mytokens_b, labels_A, labels_B, max_length): 156 | """Truncates a sequence pair in place to the maximum length.""" 157 | 158 | # This is a simple heuristic which will always truncate the longer sequence 159 | # one token at a time. This makes more sense than truncating an equal percent 160 | # of tokens from each, since if one sequence is very short then each token 161 | # that's truncated likely contains more information than a longer sequence. 162 | while True: 163 | total_length = len(tokens_a) + len(tokens_b) 164 | if total_length <= max_length: 165 | break 166 | if len(tokens_a) > len(tokens_b): 167 | tokens_a.pop() 168 | mytokens_a.pop() 169 | else: 170 | tokens_b.pop() 171 | mytokens_b.pop() 172 | 173 | 174 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, TOK2ID): 175 | """Loads a data file into a list of `InputBatch`s.""" 176 | 177 | logger.info("gen features...") 178 | label_map = {label: i for i, label in enumerate(label_list, 1)} 179 | label_map["PAD"] = 0 180 | 181 | features = [] 182 | logger.info("prepare features.....") 183 | count = 0 184 | for (ex_index, example) in enumerate(examples): 185 | textlist_A = example.text.split(' ') 186 | labellist_A = example.label_text 187 | textlist_B = example.text_title.split(' ') 188 | labellist_B = example.label_title 189 | tokens_A = [] 190 | mytokens_A = [] 191 | labels_A = [] 192 | tokens_B = [] 193 | mytokens_B = [] 194 | labels_B = [] 195 | 196 | # if count == 529320: 197 | # pdb.set_trace() 198 | ####################### tokenize ########################## 199 | def _tokenize(textlist, labellist, tokens, mytokens, labels): 200 | ''' 201 | # 重新检查textlist 除去多个汉字并联 202 | new_textlist = [] 203 | for word in textlist: 204 | if len(word) == 1: 205 | new_textlist.append(word) 206 | else: 207 | new_word = seg_char(word) 208 | if len(new_word) > 1: 209 | print("!wrong seg:", new_word) 210 | new_textlist.extend(new_word) 211 | else: 212 | new_textlist.append(word) 213 | textlist = new_textlist 214 | ''' 215 | SAMESPLIT = False 216 | for i, word in enumerate(textlist): 217 | SAMESPLIT = False 218 | token = tokenizer.tokenize(word) 219 | mytoken = token.copy() 220 | # debug [UNK] 221 | # 检查 [UNK], 编入 自定义 字典 222 | if "[UNK]" in token: 223 | assert "##" not in word 224 | # ## 225 | sharp_index = [i for i, tok in enumerate(token) if tok.startswith("##")] 226 | for si in sharp_index: 227 | no_sharp_tok = mytoken[si][2:] 228 | if no_sharp_tok not in TOK2ID: 229 | TOK2ID[no_sharp_tok] = len(TOK2ID) 230 | mytoken[si] = no_sharp_tok 231 | 232 | unks_index = [i for i, tok in enumerate(mytoken) if tok == "[UNK]"] 233 | unks_index.reverse() 234 | not_unks = [tok for tok in mytoken if tok != "[UNK]"] 235 | if not_unks: 236 | not_unks = [re.escape(nu) for nu in not_unks] 237 | if not (len(set(not_unks)) == 1 and len(not_unks) > 1): # 全是同一个字符,会无限循环 238 | pattern = "(.*)".join(not_unks) 239 | pattern = re.compile("(.*)" + pattern + "(.*)") 240 | else: 241 | SAMESPLIT = True 242 | print(word) 243 | f = "([^{}]*)".format(not_unks[0]) 244 | pattern = f.join(not_unks) 245 | pattern = re.compile(f + pattern + f) 246 | for res in pattern.findall(word): 247 | for r in res: 248 | if len(r) > 0 and r != "\u202a": # whitespace!! 249 | if r not in TOK2ID: 250 | TOK2ID[r] = len(TOK2ID) 251 | mytoken[unks_index[-1]] = r 252 | unks_index.pop() 253 | 254 | else: 255 | # 理论上应该是单个 如 4G 但是很奇怪分字有问题这里 如不了 256 | # assert len(token) == 1 257 | if len(token) == 1: 258 | if word not in TOK2ID: 259 | TOK2ID[word] = len(TOK2ID) 260 | mytoken[0] = word 261 | else: 262 | # ?不处理 263 | # '不了' 264 | mytoken = list(word) 265 | print("BUG!:", word, token, mytoken) 266 | for mytok in mytoken: 267 | if mytok not in TOK2ID: 268 | TOK2ID[mytok] = len(TOK2ID) 269 | if SAMESPLIT: 270 | print(word, token, mytoken, sep="\t") 271 | tokens.extend(token) 272 | mytokens.extend(mytoken) 273 | assert len(tokens) == len(mytokens) 274 | ''' 275 | label_1 = labellist[i] 276 | for m in range(len(token)): 277 | if m == 0: 278 | labels.append(label_1) 279 | else: 280 | labels.append("X") 281 | ''' 282 | 283 | _tokenize(textlist_A, labellist_A, tokens_A, mytokens_A, labels_A) 284 | _tokenize(textlist_B, labellist_B, tokens_B, mytokens_B, labels_B) 285 | ###################################################### 286 | 287 | ################### 处理tokenize后新增的位置 以及 合并 截断 ##################### 288 | # [CLS] A [SEP] B [SEP] 289 | _truncate_seq_pair(tokens_A, tokens_B, mytokens_A, mytokens_B, labels_A, labels_B, max_seq_length - 3) 290 | tokens = ["[CLS]"] + tokens_A + ["[SEP]"] + tokens_B + ["[SEP]"] 291 | mytokens = ["[CLS]"] + mytokens_A + ["[SEP]"] + mytokens_B + ["[SEP]"] 292 | assert len(tokens) == len(mytokens) 293 | segment_ids = [0] * (len(tokens_A) + 2) + [1] * (len(tokens_B) + 1) 294 | 295 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 296 | myinput_ids = covert_mytokens_to_myids(TOK2ID, mytokens) 297 | input_mask = [1] * len(input_ids) 298 | # ------------------------------- PAD ----------------------------------- 299 | padding = [0] * (max_seq_length - len(input_ids)) 300 | input_ids += padding 301 | myinput_ids += padding 302 | input_mask += padding 303 | segment_ids += padding 304 | 305 | ####################################################################### 306 | assert len(input_ids) == len(myinput_ids) 307 | assert len(input_ids) == max_seq_length 308 | assert len(input_mask) == max_seq_length 309 | assert len(segment_ids) == max_seq_length 310 | 311 | if ex_index < 5: 312 | logger.info("*** Example ***") 313 | logger.info("guid: %s" % (example.guid)) 314 | logger.info("tokens: %s" % " ".join( 315 | [str(x) for x in tokens])) 316 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 317 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 318 | logger.info( 319 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 320 | features.append( 321 | InputFeatures(ID=example.guid, 322 | input_ids=input_ids, 323 | myinput_ids=myinput_ids, 324 | input_mask=input_mask, 325 | segment_ids=segment_ids)) 326 | count += 1 327 | if count % 20000 == 0: 328 | print("gen example {} feature".format(count)) 329 | 330 | logger.info("finish features gen") 331 | 332 | return features 333 | 334 | 335 | def main(): 336 | parser = argparse.ArgumentParser() 337 | parser.add_argument("--data_dir", 338 | default="../datasets", 339 | type=str, 340 | required=False, 341 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 342 | parser.add_argument("--bert_token_model", 343 | default="../bert_pretrained/bert_token_model", 344 | type=str, required=False) 345 | parser.add_argument("--task_name", 346 | default="ner", 347 | type=str, 348 | required=False, 349 | help="The name of the task to train.") 350 | 351 | ## Other parameters 352 | parser.add_argument("--cache_dir", 353 | default="", 354 | type=str, 355 | help="Where do you want to store the pre-trained models downloaded from s3") 356 | parser.add_argument("--max_seq_length", 357 | default=128, 358 | type=int, 359 | help="The maximum total input sequence length after WordPiece tokenization. \n" 360 | "Sequences longer than this will be truncated, and sequences shorter \n" 361 | "than this will be padded.") 362 | parser.add_argument("--do_lower_case", 363 | action='store_true', 364 | help="Set this flag if you are using an uncased model.") 365 | parser.add_argument('--seed', 366 | type=int, 367 | default=42, 368 | help="random seed for initialization") 369 | args = parser.parse_args() 370 | logger.info("data_dir: {}".format(args.data_dir)) 371 | 372 | processors = {"ner": NerProcessor} 373 | 374 | random.seed(args.seed) 375 | np.random.seed(args.seed) 376 | torch.manual_seed(args.seed) 377 | 378 | task_name = args.task_name.lower() 379 | if task_name not in processors: 380 | raise ValueError("Task not found: %s" % (task_name)) 381 | 382 | processor = processors[task_name]() 383 | 384 | label_list = processor.get_labels() 385 | tokenizer = BertTokenizer.from_pretrained(args.bert_token_model, do_lower_case=args.do_lower_case) 386 | 387 | # ID2TOK = tokenizer.ids_to_tokens 388 | ID2TOK = load_data("../datasets/ID2TOK.pkl") 389 | TOK2ID = OrderedDict((tok, id) for id, tok in ID2TOK.items()) 390 | 391 | logger.info("in do_test now") 392 | test_examples = processor.get_test_examples(args.data_dir) 393 | test_features = convert_examples_to_features( 394 | test_examples, label_list, args.max_seq_length, tokenizer, TOK2ID) 395 | logger.info("***** Running on Test *****") 396 | logger.info(" Num examples = %d", len(test_examples)) 397 | # logger.info(" Batch size = %d", args.test_batch_size) 398 | all_IDs = torch.tensor([f.ID for f in test_features], dtype=torch.long) 399 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 400 | all_myinput_ids = torch.tensor([f.myinput_ids for f in test_features], dtype=torch.long) 401 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 402 | 403 | f = h5py.File("../datasets/full.h5", "r+", libver="latest") 404 | # for add again del first 405 | if "test" in f.keys(): 406 | del f["test"] 407 | f.create_dataset("test/IDs", data=all_IDs, compression="gzip") 408 | f.create_dataset("test/input_ids", data=all_input_ids, compression="gzip") 409 | f.create_dataset("test/myinput_ids", data=all_myinput_ids, compression="gzip") 410 | f.create_dataset("test/input_mask", data=all_input_mask, compression="gzip") 411 | del all_input_ids, all_input_mask 412 | gc.collect() 413 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 414 | f.create_dataset("test/segment_ids", data=all_segment_ids, compression="gzip") 415 | f.close() 416 | # data_dump(TOK2ID, "../datasets/TOK2ID_test.pkl") 417 | ID2TOK = OrderedDict((id, tok) for tok, id in TOK2ID.items()) 418 | data_dump(ID2TOK, "../datasets/ID2TOK.pkl") 419 | print("save to h5 over!") 420 | 421 | 422 | if __name__ == "__main__": 423 | main() 424 | -------------------------------------------------------------------------------- /src/kfold/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.chdir("../.") 4 | import h5py 5 | from argparse import ArgumentParser 6 | import torch 7 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 8 | from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator 9 | from ignite.handlers import ModelCheckpoint, EarlyStopping 10 | from ignite.contrib.handlers import CustomPeriodicEvent 11 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 12 | # from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | # from ignite.contrib.handlers.param_scheduler import (ConcatScheduler, 14 | # CosineAnnealingScheduler, 15 | # LinearCyclicalScheduler, 16 | # CyclicalScheduler, 17 | # LRScheduler) 18 | # from torch.optim.lr_scheduler import ExponentialLR 19 | 20 | from ignite.metrics import RunningAverage 21 | from ignite.contrib.handlers.tensorboard_logger import * 22 | from metric import FScore 23 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 24 | from models import NetY3 25 | from loss import FocalLoss 26 | from utils import load_data 27 | 28 | 29 | # torch.random.manual_seed(42) 30 | 31 | 32 | def get_all_data(): 33 | print("get all data...........") 34 | # -------------------------------read from h5------------------------- 35 | if not args.lite: 36 | f = h5py.File("../datasets/full.h5") 37 | else: 38 | f = h5py.File("../datasets/lite.h5") 39 | input_ids_trn = torch.from_numpy(f["train/input_ids"][()]) 40 | myinput_ids_trn = torch.from_numpy(f["train/myinput_ids"][()]) 41 | input_mask_trn = torch.from_numpy(f["train/input_mask"][()]) 42 | segment_ids_trn = torch.from_numpy(f["train/segment_ids"][()]) 43 | label_ent_ids_trn = torch.from_numpy(f["train/label_ent_ids"][()]) 44 | label_emo_ids_trn = torch.from_numpy(f["train/label_emo_ids"][()]) 45 | assert input_ids_trn.size() == segment_ids_trn.size() == label_ent_ids_trn.size() == label_emo_ids_trn.size() == myinput_ids_trn.size() 46 | 47 | input_ids_val = torch.from_numpy(f["val/input_ids"][()]) 48 | myinput_ids_val = torch.from_numpy(f["val/myinput_ids"][()]) 49 | input_mask_val = torch.from_numpy(f["val/input_mask"][()]) 50 | segment_ids_val = torch.from_numpy(f["val/segment_ids"][()]) 51 | label_ent_ids_val = torch.from_numpy(f["val/label_ent_ids"][()]) 52 | label_emo_ids_val = torch.from_numpy(f["val/label_emo_ids"][()]) 53 | assert input_ids_val.size() == segment_ids_val.size() == label_ent_ids_val.size() == label_emo_ids_val.size() == myinput_ids_val.size() 54 | f.close() 55 | print("read h5 over!") 56 | input_ids = torch.cat([input_ids_trn, input_ids_val], dim=0) 57 | myinput_ids = torch.cat([myinput_ids_trn, myinput_ids_val], dim=0) 58 | input_mask = torch.cat([input_mask_trn, input_mask_val], dim=0) 59 | segment_ids = torch.cat([segment_ids_trn, segment_ids_val], dim=0) 60 | label_ent_ids = torch.cat([label_ent_ids_trn, label_ent_ids_val], dim=0) 61 | label_emo_ids = torch.cat([label_emo_ids_trn, label_emo_ids_val], dim=0) 62 | dataset = TensorDataset(input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids) 63 | 64 | return dataset 65 | 66 | 67 | def get_data_loader(dataset, cv): 68 | print(f"get dataloader {cv}") 69 | # 从 index 中取出 trn_dataset val_dataset 70 | if not args.lite: 71 | index_file = "kfold/5cv_indexs_{}".format(cv) 72 | else: 73 | index_file = "kfold/5cv_indexs_{}_lite".format(cv) 74 | 75 | if os.path.exists(index_file): 76 | trn_index, val_index = load_data(index_file) 77 | trn_dataset = [dataset[idx] for idx in trn_index] 78 | val_dataset = [dataset[idx] for idx in val_index] 79 | else: 80 | print("Not find index file!") 81 | os._exit(-1) 82 | # --------------------------------------------------------------------- 83 | trn_dataloader = DataLoader(trn_dataset, sampler=RandomSampler(trn_dataset), batch_size=args.batch_size, 84 | num_workers=args.nw, pin_memory=True) 85 | # trn_dataloader = DataLoader(trn_dataset, sampler=RandomSampler(trn_dataset), batch_size=args.batch_size, 86 | # pin_memory=True) 87 | val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.val_batch_size, 88 | num_workers=args.nw, pin_memory=True) 89 | # val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.val_batch_size, 90 | # pin_memory=True) 91 | print("get date loader over!") 92 | return trn_dataloader, val_dataloader, len(trn_dataset) 93 | 94 | 95 | def train(dataset, cv): 96 | print(f"training cv: {cv}") 97 | 98 | ################################ Model Config ################################### 99 | if args.lbl_method == "BIO": 100 | num_labels_emo = 4 # O POS NEG NORM 101 | num_labels_ent = 3 # O B I 102 | else: 103 | num_labels_emo = 4 # O POS NEG NORM 104 | num_labels_ent = 5 # O B I E S 105 | 106 | model = NetY3.from_pretrained(args.bert_model, 107 | cache_dir="", 108 | num_labels_ent=num_labels_ent, 109 | num_labels_emo=num_labels_emo, 110 | dp=args.dp) 111 | 112 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 113 | model.to(device) 114 | model = torch.nn.DataParallel(model) 115 | 116 | ################################# hyper parameters ########################### 117 | # alpha = 0.5 # 0.44 118 | # alpha = 0.6 # 0.42 119 | # alpha = 0.7 120 | # alpha = 1.2 121 | # alpha = 0.8 122 | # alpha = 0.7 123 | # alphas = [1, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8] 124 | alpha = args.alpha 125 | # alphas = [2,1,1,0.8,0.8,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5] 126 | # ------------------------------ load model from file ------------------------- 127 | model_file = os.path.join(args.checkpoint_model_dir, args.ckp) 128 | if os.path.exists(model_file): 129 | model.load_state_dict(torch.load(model_file)) 130 | print("load checkpoint: {} successfully!".format(model_file)) 131 | # ----------------------------------------------------------------------------- 132 | 133 | trn_dataloader, val_dataloader, trn_size = get_data_loader(dataset, cv) 134 | 135 | ############################## Optimizer ################################### 136 | param_optimizer = list(model.named_parameters()) 137 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 138 | optimizer_grouped_parameters = [ 139 | {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and p.requires_grad], 140 | 'weight_decay': args.wd}, 141 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad], 142 | 'weight_decay': 0.0} 143 | ] 144 | # num_train_optimization_steps = int( trn_size / args.batch_size / args.gradient_accumulation_steps) * args.epochs 145 | num_train_optimization_steps = int(trn_size / args.batch_size) * args.epochs + 5 146 | optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, 147 | warmup=args.warmup_proportion, 148 | t_total=num_train_optimization_steps) 149 | # optimizer = Adam(filter(lambda p:p.requires_grad, model.parameters()), args.lr, weight_decay=5e-3) 150 | ###################################################################### 151 | if not args.focal: 152 | criterion = torch.nn.CrossEntropyLoss() 153 | else: 154 | criterion = FocalLoss(args.gamma) 155 | 156 | def step(engine, batch): 157 | model.train() 158 | batch = tuple(t.to(device) for t in batch) 159 | input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch 160 | 161 | optimizer.zero_grad() 162 | act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model( 163 | input_ids, myinput_ids, segment_ids, input_mask, 164 | label_ent_ids, label_emo_ids) 165 | # Only keep active parts of the loss 166 | loss_ent = criterion(act_logits_ent, act_y_ent) 167 | 168 | loss_emo = criterion(act_logits_emo, act_y_emo) 169 | # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo 170 | if not args.multi: 171 | loss = alpha * loss_ent + loss_emo 172 | else: 173 | alphas = [1e-5, 1e-2, 1e-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 174 | print("alphas: ", " ".join(map(str, alpha))) 175 | loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo 176 | 177 | if engine.state.metrics.get("total_loss") is None: 178 | engine.state.metrics["total_loss"] = 0 179 | engine.state.metrics["ent_loss"] = 0 180 | engine.state.metrics["emo_loss"] = 0 181 | else: 182 | engine.state.metrics["total_loss"] += loss.item() 183 | engine.state.metrics["ent_loss"] += loss_ent.item() 184 | engine.state.metrics["emo_loss"] += loss_emo.item() if act_logits_emo.size(0) > 0 else 0 185 | 186 | loss.backward() 187 | optimizer.step() 188 | return loss.item(), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids # [-1, 11] 189 | 190 | def infer(engine, batch): 191 | model.eval() 192 | batch = tuple(t.to(device) for t in batch) 193 | input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch 194 | 195 | with torch.no_grad(): 196 | act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model( 197 | input_ids, myinput_ids, segment_ids, 198 | input_mask, 199 | label_ent_ids, label_emo_ids) 200 | # Only keep active parts of the loss 201 | loss_ent = criterion(act_logits_ent, act_y_ent) 202 | # loss_emo = criterion_fl(act_logits_emo, act_y_emo) 203 | loss_emo = criterion(act_logits_emo, act_y_emo) 204 | # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo 205 | loss = alpha * loss_ent + loss_emo 206 | 207 | if engine.state.metrics.get("total_loss") is None: 208 | engine.state.metrics["total_loss"] = 0 209 | engine.state.metrics["ent_loss"] = 0 210 | engine.state.metrics["emo_loss"] = 0 if loss_emo else 0 211 | else: 212 | engine.state.metrics["total_loss"] += loss.item() 213 | engine.state.metrics["ent_loss"] += loss_ent.item() 214 | engine.state.metrics["emo_loss"] += loss_emo.item() if act_logits_emo.size(0) > 0 else 0 215 | # act_logits = torch.argmax(torch.softmax(act_logits, dim=-1), dim=-1) # [-1, 1] 216 | # loss = loss.mean() 217 | return loss.item(), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids # [-1, 11] 218 | 219 | trainer = Engine(step) 220 | trn_evaluator = Engine(infer) 221 | val_evaluator = Engine(infer) 222 | 223 | ############################## Custom Period Event ################################### 224 | cpe1 = CustomPeriodicEvent(n_epochs=1) 225 | cpe1.attach(trainer) 226 | cpe2 = CustomPeriodicEvent(n_epochs=2) 227 | cpe2.attach(trainer) 228 | cpe3 = CustomPeriodicEvent(n_epochs=3) 229 | cpe3.attach(trainer) 230 | cpe5 = CustomPeriodicEvent(n_epochs=5) 231 | cpe5.attach(trainer) 232 | 233 | ############################## My F1 ################################### 234 | F1 = FScore(output_transform=lambda x: [x[1], x[2], x[3], x[4], x[-1]], lbl_method="BIEOS") 235 | F1.attach(val_evaluator, "F1") 236 | 237 | ##################################### progress bar ######################### 238 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'batch_loss') 239 | pbar = ProgressBar(persist=True) 240 | pbar.attach(trainer, metric_names=["batch_loss"]) 241 | 242 | ##################################### Evaluate ######################### 243 | 244 | @trainer.on(Events.EPOCH_COMPLETED) 245 | def compute_val_metric(engine): 246 | # trainer engine 247 | engine.state.metrics["total_loss"] /= engine.state.iteration 248 | engine.state.metrics["ent_loss"] /= engine.state.iteration 249 | engine.state.metrics["emo_loss"] /= engine.state.iteration 250 | pbar.log_message( 251 | "Training - total_loss: {:.4f} ent_loss: {:.4f} emo_loss: {:.4f}".format(engine.state.metrics["total_loss"], 252 | engine.state.metrics["ent_loss"], 253 | engine.state.metrics["emo_loss"])) 254 | 255 | val_evaluator.run(val_dataloader) 256 | 257 | metrics = val_evaluator.state.metrics 258 | ent_loss = metrics["ent_loss"] 259 | emo_loss = metrics["emo_loss"] 260 | f1 = metrics['F1'] 261 | pbar.log_message( 262 | "Validation Results - Epoch: {} Ent_loss: {:.4f}, Emo_loss: {:.4f}, F1: {:.4f}" 263 | .format(engine.state.epoch, ent_loss, emo_loss, f1)) 264 | 265 | pbar.n = pbar.last_print_n = 0 266 | 267 | @val_evaluator.on(Events.EPOCH_COMPLETED) 268 | def reduct_step(engine): 269 | engine.state.metrics["total_loss"] /= engine.state.iteration 270 | engine.state.metrics["ent_loss"] /= engine.state.iteration 271 | engine.state.metrics["emo_loss"] /= engine.state.iteration 272 | pbar.log_message( 273 | "Validation - total_loss: {:.4f} ent_loss: {:.4f} emo_loss: {:.4f}".format( 274 | engine.state.metrics["total_loss"], 275 | engine.state.metrics["ent_loss"], 276 | engine.state.metrics["emo_loss"])) 277 | # Save a trained model and the associated configuration 278 | # model_to_save = model.module if hasattr(model, 279 | # 'module') else model # Only save the model it-self 280 | # output_model_file = f"../ckps/cv/cv{cv}.pth" 281 | # torch.save(model.state_dict(), output_model_file) 282 | # print(f"save {output_model_file} successfully!") 283 | 284 | ###################################################################### 285 | 286 | ############################## checkpoint ################################### 287 | def best_f1(engine): 288 | f1 = engine.state.metrics["F1"] 289 | # loss = engine.state.metrics["loss"] 290 | return f1 291 | 292 | if not args.lite: 293 | ckp_dir = os.path.join(args.checkpoint_model_dir, "full", "cv", str(cv), args.hyper_cfg) 294 | else: 295 | ckp_dir = os.path.join(args.checkpoint_model_dir, "lite", "cv", str(cv), args.hyper_cfg) 296 | 297 | checkpoint_handler = ModelCheckpoint(ckp_dir, 298 | 'ckp', 299 | # save_interval=args.checkpoint_interval, 300 | score_function=best_f1, 301 | score_name="F1", 302 | n_saved=5, 303 | require_empty=False, create_dir=True) 304 | 305 | # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, 306 | # to_save={'model_3FC': model}) 307 | 308 | val_evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, 309 | to_save={'model_title': model}) 310 | 311 | ###################################################################### 312 | 313 | ############################## earlystopping ################################### 314 | stopping_handler = EarlyStopping(patience=2, score_function=best_f1, trainer=trainer) 315 | val_evaluator.add_event_handler(Events.COMPLETED, stopping_handler) 316 | 317 | ###################################################################### 318 | 319 | #################################### tb logger ################################## 320 | # 在已经在对应基础上计算了 metric 的值 (compute_metric) 后 取值 log 321 | if not args.lite: 322 | tb_logger = TensorboardLogger(log_dir=os.path.join(args.log_dir, "full", "cv", str(cv), args.hyper_cfg)) 323 | else: 324 | tb_logger = TensorboardLogger(log_dir=os.path.join(args.log_dir, "lite", "cv", str(cv), args.hyper_cfg)) 325 | 326 | tb_logger.attach(trainer, 327 | log_handler=OutputHandler(tag="training", output_transform=lambda x: {'batchloss': x[0]}), 328 | event_name=Events.ITERATION_COMPLETED) 329 | 330 | tb_logger.attach(val_evaluator, 331 | log_handler=OutputHandler(tag="validation", output_transform=lambda x: {'batchloss': x[0]}), 332 | event_name=Events.ITERATION_COMPLETED) 333 | 334 | tb_logger.attach(trainer, 335 | log_handler=OutputHandler(tag="training", metric_names=["total_loss", "ent_loss", "emo_loss"]), 336 | event_name=Events.EPOCH_COMPLETED) 337 | # tb_logger.attach(trainer, 338 | # log_handler=OutputHandler(tag="training", output_transform=lambda x: {'loss': x[0]}), 339 | # event_name=Events.EPOCH_COMPLETED) 340 | 341 | ''' 342 | tb_logger.attach(trn_evaluator, 343 | log_handler=OutputHandler(tag="training", 344 | metric_names=["F1"], 345 | another_engine=trainer), 346 | event_name=Events.EPOCH_COMPLETED) 347 | 348 | ''' 349 | tb_logger.attach(val_evaluator, 350 | log_handler=OutputHandler(tag="validation", 351 | metric_names=["total_loss", "ent_loss", "emo_loss", "F1"], 352 | another_engine=trainer), 353 | event_name=Events.EPOCH_COMPLETED) 354 | 355 | tb_logger.attach(trainer, 356 | log_handler=OptimizerParamsHandler(optimizer, "lr"), 357 | event_name=Events.EPOCH_COMPLETED) 358 | ''' 359 | 360 | tb_logger.attach(trainer, 361 | log_handler=WeightsScalarHandler(model), 362 | event_name=Events.ITERATION_COMPLETED) 363 | 364 | tb_logger.attach(trainer, 365 | log_handler=WeightsHistHandler(model), 366 | event_name=Events.EPOCH_COMPLETED) 367 | 368 | # tb_logger.attach(trainer, 369 | # log_handler=GradsScalarHandler(model), 370 | # event_name=Events.ITERATION_COMPLETED) 371 | 372 | tb_logger.attach(trainer, 373 | log_handler=GradsHistHandler(model), 374 | event_name=Events.EPOCH_COMPLETED) 375 | ''' 376 | 377 | # lr_find() 378 | trainer.run(trn_dataloader, max_epochs=args.epochs) 379 | tb_logger.close() 380 | 381 | 382 | if __name__ == '__main__': 383 | parser = ArgumentParser() 384 | parser.add_argument("--bert_model", default="../bert_pretrained/bert-base-chinese", type=str, 385 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 386 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 387 | "bert-base-multilingual-cased, bert-base-chinese.") 388 | parser.add_argument('--batch_size', type=int, default=48, 389 | help='input batch size for training8 (default: 64)') 390 | parser.add_argument('--val_batch_size', type=int, default=48, 391 | help='input batch size for validation (default: 1000)') 392 | parser.add_argument('--epochs', type=int, default=1, 393 | help='number of epochs to train (default: 10)') 394 | parser.add_argument('--nw', type=int, default=4, 395 | help='number of nw') 396 | parser.add_argument('--lr', type=float, default=3e-5, 397 | help='learning rate (default: 0.01)') 398 | parser.add_argument("--alpha", type=float, default=3, help="alpha") 399 | parser.add_argument("--wd", type=float, default=0.1, help="weight decay") 400 | parser.add_argument('--momentum', type=float, default=0.5, 401 | help='SGD momentum (default: 0.5)') 402 | parser.add_argument('--log_interval', type=int, default=10, 403 | help='how many batches to wait before logging training status') 404 | parser.add_argument("--log_dir", type=str, default="../tbs", 405 | help="log directory for Tensorboard log output") 406 | parser.add_argument("--checkpoint_model_dir", type=str, default='../ckps', 407 | help="path to folder where checkpoints of trained models will be saved") 408 | parser.add_argument("--ckp", type=str, default='None', 409 | help="ckp file") 410 | parser.add_argument("--hyper_cfg", type=str, default='default', 411 | help="config path to folder where checkpoints of trained models will be saved") 412 | parser.add_argument("--checkpoint_interval", type=int, default=1, 413 | help="number of batches after which a checkpoint of trained model will be created") 414 | parser.add_argument("--warmup_proportion", 415 | default=0.1, 416 | type=float, 417 | help="Proportion of training to perform linear learning rate warmup for. " 418 | "E.g., 0.1 = 10%% of training.") 419 | parser.add_argument("--dp", 420 | default=0.5, 421 | type=float, 422 | help="") 423 | parser.add_argument("--gamma", 424 | default=2, 425 | type=float, 426 | help="") 427 | parser.add_argument("--focal", 428 | action="store_true", 429 | help="") 430 | parser.add_argument("--lite", 431 | action="store_true", 432 | help="") 433 | parser.add_argument("--multi", 434 | action="store_true", 435 | help="multi alpha or not") 436 | parser.add_argument("--lbl_method", 437 | type=str, 438 | default="BIO", 439 | help="BIO / BIEO") 440 | 441 | args = parser.parse_args() 442 | 443 | # 5 fold 444 | dataset = get_all_data() 445 | for cv in range(1, 6): 446 | train(dataset, cv) 447 | -------------------------------------------------------------------------------- /src/trainx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | from argparse import ArgumentParser 4 | import torch 5 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 6 | from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator 7 | from ignite.handlers import ModelCheckpoint, EarlyStopping 8 | from ignite.contrib.handlers import CustomPeriodicEvent 9 | from ignite.contrib.handlers.tqdm_logger import ProgressBar 10 | # from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | # from ignite.contrib.handlers.param_scheduler import (ConcatScheduler, 12 | # CosineAnnealingScheduler, 13 | # LinearCyclicalScheduler, 14 | # CyclicalScheduler, 15 | # LRScheduler) 16 | # from torch.optim.lr_scheduler import ExponentialLR 17 | 18 | from ignite.metrics import RunningAverage 19 | from ignite.contrib.handlers.tensorboard_logger import * 20 | from metric import FScore 21 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 22 | from models import NetX3 23 | from loss import FocalLoss 24 | 25 | 26 | # torch.random.manual_seed(42) 27 | 28 | 29 | def get_data_loader(): 30 | print("get data loader...........") 31 | # -------------------------------read from h5------------------------- 32 | if not args.lite: 33 | f = h5py.File("../datasets/full.h5") 34 | else: 35 | f = h5py.File("../datasets/lite.h5") 36 | input_ids = torch.from_numpy(f["train/input_ids"][()]) 37 | myinput_ids = torch.from_numpy(f["train/myinput_ids"][()]) 38 | input_mask = torch.from_numpy(f["train/input_mask"][()]) 39 | segment_ids = torch.from_numpy(f["train/segment_ids"][()]) 40 | label_ent_ids = torch.from_numpy(f["train/label_ent_ids"][()]) 41 | label_emo_ids = torch.from_numpy(f["train/label_emo_ids"][()]) 42 | assert input_ids.size() == segment_ids.size() == label_ent_ids.size() == label_emo_ids.size() == myinput_ids.size() 43 | print("train dataset num: ", input_ids.size(0)) 44 | trn_dataset = TensorDataset(input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids) 45 | 46 | input_ids = torch.from_numpy(f["val/input_ids"][()]) 47 | myinput_ids = torch.from_numpy(f["val/myinput_ids"][()]) 48 | input_mask = torch.from_numpy(f["val/input_mask"][()]) 49 | segment_ids = torch.from_numpy(f["val/segment_ids"][()]) 50 | label_ent_ids = torch.from_numpy(f["val/label_ent_ids"][()]) 51 | label_emo_ids = torch.from_numpy(f["val/label_emo_ids"][()]) 52 | assert input_ids.size() == segment_ids.size() == label_ent_ids.size() == label_emo_ids.size() == myinput_ids.size() 53 | val_dataset = TensorDataset(input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids) 54 | print("val dataset num: ", input_ids.size(0)) 55 | f.close() 56 | print("read h5 over!") 57 | print(f"bs: {args.batch_size}\nval_bs: {args.val_batch_size}") 58 | # ------------------------------- 可以从 H5Dataset2 继承优化 ------------------------- 59 | # trn_dataset = H5Dataset3("../datasets/data_full_title.h5", "train") 60 | # val_dataset = H5Dataset3("../datasets/data_full_title.h5", "val") 61 | 62 | # --------------------------------------------------------------------- 63 | trn_dataloader = DataLoader(trn_dataset, sampler=RandomSampler(trn_dataset), batch_size=args.batch_size, 64 | num_workers=args.nw, pin_memory=True) 65 | # trn_dataloader = DataLoader(trn_dataset, sampler=RandomSampler(trn_dataset), batch_size=args.batch_size, 66 | # pin_memory=True) 67 | val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.val_batch_size, 68 | num_workers=args.nw, pin_memory=True) 69 | # val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.val_batch_size, 70 | # pin_memory=True) 71 | print("get date loader over!") 72 | return trn_dataloader, val_dataloader, len(trn_dataset) 73 | 74 | 75 | def train(): 76 | ################################ Model Config ################################### 77 | num_labels_emo = 4 # O POS NEG NORM 78 | num_labels_ent = 3 # O B I 79 | 80 | model = NetX3.from_pretrained(args.bert_model, 81 | cache_dir="", 82 | num_labels_ent=num_labels_ent, 83 | num_labels_emo=num_labels_emo, 84 | dp=args.dp) 85 | 86 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 87 | model.to(device) 88 | model = torch.nn.DataParallel(model) 89 | 90 | ################################# hyper parameters ########################### 91 | # alpha = 0.5 # 0.44 92 | # alpha = 0.6 # 0.42 93 | # alpha = 0.7 94 | # alpha = 1.2 95 | # alpha = 0.8 96 | # alpha = 0.7 97 | # alphas = [1, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8] 98 | if not args.multi: 99 | alpha = args.alpha 100 | else: 101 | #alphas = [1e-5, 1e-2, 1e-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 102 | alphas = [1e-1, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 103 | print("alphas: ", " ".join(map(str, alphas))) 104 | # alphas = [2,1,1,0.8,0.8,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5] 105 | # ------------------------------ load model from file ------------------------- 106 | model_file = os.path.join(args.checkpoint_model_dir, args.ckp) 107 | if os.path.exists(model_file): 108 | model.load_state_dict(torch.load(model_file)) 109 | print("load checkpoint: {} successfully!".format(model_file)) 110 | # ----------------------------------------------------------------------------- 111 | 112 | trn_dataloader, val_dataloader, trn_size = get_data_loader() 113 | 114 | ############################## Optimizer ################################### 115 | param_optimizer = list(model.named_parameters()) 116 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 117 | optimizer_grouped_parameters = [ 118 | {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and p.requires_grad], 119 | 'weight_decay': args.wd}, 120 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad], 121 | 'weight_decay': 0.0} 122 | ] 123 | # num_train_optimization_steps = int( trn_size / args.batch_size / args.gradient_accumulation_steps) * args.epochs 124 | num_train_optimization_steps = int(trn_size / args.batch_size) * args.epochs + 5 125 | # num_train_optimization_steps = int(trn_size / args.batch_size) * args.epochs 126 | optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, 127 | warmup=args.warmup_proportion, 128 | t_total=num_train_optimization_steps) 129 | # optimizer = Adam(filter(lambda p:p.requires_grad, model.parameters()), args.lr, weight_decay=5e-3) 130 | ###################################################################### 131 | if not args.focal: 132 | criterion = torch.nn.CrossEntropyLoss() 133 | else: 134 | criterion = FocalLoss(args.gamma) 135 | 136 | 137 | def step_ent(engine, batch): 138 | model.train() 139 | batch = tuple(t.to(device) for t in batch) 140 | input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch 141 | 142 | optimizer.zero_grad() 143 | act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, mask_logits_emo, mask_y_emo, act_myinput_ids = model( 144 | input_ids, myinput_ids, segment_ids, input_mask, 145 | label_ent_ids, label_emo_ids) 146 | # Only keep active parts of the loss 147 | loss_ent = criterion(act_logits_ent, act_y_ent) 148 | 149 | loss = loss_ent 150 | 151 | if engine.state.metrics.get("ent_loss") is None: 152 | engine.state.metrics["ent_loss"] = loss_ent.item() 153 | else: 154 | engine.state.metrics["ent_loss"] += loss_ent.item() 155 | 156 | loss.backward() 157 | optimizer.step() 158 | return loss.item(), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids # [-1, 11] 159 | 160 | def step_emo(engine, batch): 161 | model.train() 162 | batch = tuple(t.to(device) for t in batch) 163 | input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch 164 | 165 | optimizer.zero_grad() 166 | act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, mask_logits_emo, mask_y_emo, act_myinput_ids = model( 167 | input_ids, myinput_ids, segment_ids, input_mask, 168 | label_ent_ids, label_emo_ids) 169 | # Only keep active parts of the loss 170 | loss_emo = criterion(act_logits_emo, act_y_emo) 171 | loss = loss_emo 172 | 173 | if engine.state.metrics.get("emo_loss") is None: 174 | engine.state.metrics["emo_loss"] = loss_emo.item() 175 | else: 176 | engine.state.metrics["emo_loss"] += loss_emo.item() 177 | 178 | loss.backward() 179 | optimizer.step() 180 | return loss.item(), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids # [-1, 11] 181 | 182 | 183 | 184 | def infer_ent(engine, batch): 185 | model.eval() 186 | batch = tuple(t.to(device) for t in batch) 187 | input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch 188 | 189 | with torch.no_grad(): 190 | act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, mask_logits_emo, mask_y_emo, act_myinput_ids = model( 191 | input_ids, myinput_ids, segment_ids, 192 | input_mask, 193 | label_ent_ids, label_emo_ids) 194 | # Only keep active parts of the loss 195 | loss_ent = criterion(act_logits_ent, act_y_ent) 196 | loss = loss_ent 197 | 198 | if engine.state.metrics.get("ent_loss") is None: 199 | engine.state.metrics["ent_loss"] = loss.item() 200 | else: 201 | engine.state.metrics["ent_loss"] += loss_ent.item() 202 | return loss.item(), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, mask_logits_emo, mask_y_emo, act_myinput_ids # [-1, 11] 203 | 204 | def infer_emo(engine, batch): 205 | model.eval() 206 | batch = tuple(t.to(device) for t in batch) 207 | input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch 208 | 209 | with torch.no_grad(): 210 | act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, mask_logits_emo, mask_y_emo, act_myinput_ids = model( 211 | input_ids, myinput_ids, segment_ids, 212 | input_mask, 213 | label_ent_ids, label_emo_ids) 214 | # Only keep active parts of the loss 215 | loss_emo = criterion(act_logits_emo, act_y_emo) 216 | loss = loss_emo 217 | 218 | if engine.state.metrics.get("emo_loss") is None: 219 | engine.state.metrics["emo_loss"] = loss_emo.item() 220 | else: 221 | engine.state.metrics["emo_loss"] += loss_emo.item() 222 | return loss.item(), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, mask_logits_emo, mask_y_emo, act_myinput_ids # [-1, 11] 223 | 224 | 225 | trainer_ent = Engine(step_ent) 226 | trainer_emo = Engine(step_emo) 227 | 228 | # trn_evaluator = Engine(infer) 229 | val_evaluator_ent = Engine(infer_ent) 230 | val_evaluator_emo = Engine(infer_emo) 231 | 232 | 233 | ############################## My F1 ################################### 234 | F1 = FScore(output_transform=lambda x: [x[1], x[2], x[3], x[4], x[-1]]) 235 | F1.attach(val_evaluator_ent, "F1") 236 | F1.attach(val_evaluator_emo, "F1") 237 | 238 | ##################################### progress bar ######################### 239 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer_ent, 'batch_loss') 240 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer_emo, 'batch_loss') 241 | pbar = ProgressBar(persist=True) 242 | pbar.attach(trainer_ent, metric_names=["batch_loss"]) 243 | pbar.attach(trainer_emo, metric_names=["batch_loss"]) 244 | 245 | ##################################### Evaluate ######################### 246 | 247 | @trainer_ent.on(Events.EPOCH_COMPLETED) 248 | def compute_val_metric_ent(engine): 249 | # trainer engine 250 | engine.state.metrics["ent_loss"] /= engine.state.iteration 251 | pbar.log_message( 252 | "Training - ent_loss: {:.4f}".format(engine.state.metrics["ent_loss"] )) 253 | 254 | val_evaluator_ent.run(val_dataloader) 255 | 256 | metrics = val_evaluator_ent.state.metrics 257 | ent_loss = metrics["ent_loss"] 258 | f1 = metrics['F1'] 259 | pbar.log_message( 260 | "Validation Results - Epoch: {} Ent_loss: {:.4f}, F1: {:.4f}" 261 | .format(engine.state.epoch, ent_loss, f1)) 262 | 263 | pbar.n = pbar.last_print_n = 0 264 | 265 | @trainer_emo.on(Events.EPOCH_COMPLETED) 266 | def compute_val_metric_emo(engine): 267 | # trainer engine 268 | engine.state.metrics["emo_loss"] /= engine.state.iteration 269 | pbar.log_message( 270 | "Training - emo_loss: {:.4f}".format(engine.state.metrics["emo_loss"] )) 271 | 272 | val_evaluator_emo.run(val_dataloader) 273 | 274 | metrics = val_evaluator_emo.state.metrics 275 | emo_loss = metrics["emo_loss"] 276 | f1 = metrics['F1'] 277 | pbar.log_message( 278 | "Validation Results - Epoch: {} , Emo_loss: {:.4f}, F1: {:.4f}" 279 | .format(engine.state.epoch, emo_loss, f1)) 280 | 281 | pbar.n = pbar.last_print_n = 0 282 | 283 | @val_evaluator_ent.on(Events.EPOCH_COMPLETED) 284 | def reduct_step_ent(engine): 285 | engine.state.metrics["ent_loss"] /= engine.state.iteration 286 | pbar.log_message( 287 | "Validation - ent_loss: {:.4f}".format( 288 | engine.state.metrics["ent_loss"])) 289 | 290 | @val_evaluator_emo.on(Events.EPOCH_COMPLETED) 291 | def reduct_step_emo(engine): 292 | engine.state.metrics["emo_loss"] /= engine.state.iteration 293 | pbar.log_message( 294 | "Validation - emo_loss: {:.4f}".format( 295 | engine.state.metrics["emo_loss"])) 296 | 297 | 298 | ###################################################################### 299 | 300 | ############################## checkpoint ################################### 301 | def best_f1(engine): 302 | f1 = engine.state.metrics["F1"] 303 | # loss = engine.state.metrics["loss"] 304 | return f1 305 | 306 | if not args.lite: 307 | ckp_dir = os.path.join(args.checkpoint_model_dir, "fullx", args.hyper_cfg) 308 | else: 309 | ckp_dir = os.path.join(args.checkpoint_model_dir, "litex", args.hyper_cfg) 310 | checkpoint_handler = ModelCheckpoint(ckp_dir, 311 | 'ckp', 312 | # save_interval=args.checkpoint_interval, 313 | score_function=best_f1, 314 | score_name="F1", 315 | n_saved=5, 316 | require_empty=False, create_dir=True) 317 | 318 | # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, 319 | # to_save={'model_3FC': model}) 320 | # val_evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, 321 | # to_save={'model_title': model}) 322 | 323 | ###################################################################### 324 | 325 | ############################## earlystopping ################################### 326 | # stopping_handler = EarlyStopping(patience=2, score_function=best_f1, trainer=trainer) 327 | # val_evaluator.add_event_handler(Events.COMPLETED, stopping_handler) 328 | 329 | ###################################################################### 330 | 331 | #################################### tb logger ################################## 332 | # 在已经在对应基础上计算了 metric 的值 (compute_metric) 后 取值 log 333 | if not args.lite: 334 | tb_logger = TensorboardLogger(log_dir=os.path.join(args.log_dir, "fullx", args.hyper_cfg)) 335 | else: 336 | tb_logger = TensorboardLogger(log_dir=os.path.join(args.log_dir, "litex", args.hyper_cfg)) 337 | 338 | tb_logger.attach(trainer_ent, 339 | log_handler=OutputHandler(tag="training_ent", output_transform=lambda x: {'batchloss': x[0]}), 340 | event_name=Events.ITERATION_COMPLETED) 341 | 342 | tb_logger.attach(trainer_emo, 343 | log_handler=OutputHandler(tag="training_emo", output_transform=lambda x: {'batchloss': x[0]}), 344 | event_name=Events.ITERATION_COMPLETED) 345 | 346 | 347 | tb_logger.attach(val_evaluator_ent, 348 | log_handler=OutputHandler(tag="validation_ent", output_transform=lambda x: {'batchloss': x[0]}), 349 | event_name=Events.ITERATION_COMPLETED) 350 | 351 | tb_logger.attach(val_evaluator_emo, 352 | log_handler=OutputHandler(tag="validation_emo", output_transform=lambda x: {'batchloss': x[0]}), 353 | event_name=Events.ITERATION_COMPLETED) 354 | # tb_logger.attach(trainer, 355 | # log_handler=OutputHandler(tag="training", output_transform=lambda x: {'loss': x[0]}), 356 | # event_name=Events.EPOCH_COMPLETED) 357 | 358 | ''' 359 | tb_logger.attach(trn_evaluator, 360 | log_handler=OutputHandler(tag="training", 361 | metric_names=["F1"], 362 | another_engine=trainer), 363 | event_name=Events.EPOCH_COMPLETED) 364 | 365 | 366 | tb_logger.attach(trainer, 367 | log_handler=OutputHandler(tag="training", output_transform=lambda x: {'batchloss_ent': x[-2]}), 368 | event_name=Events.ITERATION_COMPLETED) 369 | tb_logger.attach(trainer, 370 | log_handler=OutputHandler(tag="training", output_transform=lambda x: {'batchloss_emo': x[-1]}), 371 | event_name=Events.ITERATION_COMPLETED) 372 | 373 | tb_logger.attach(val_evaluator, 374 | log_handler=OutputHandler(tag="validation", output_transform=lambda x: {'batchloss': x[0]}), 375 | event_name=Events.ITERATION_COMPLETED) 376 | 377 | tb_logger.attach(val_evaluator, 378 | log_handler=OutputHandler(tag="validation", output_transform=lambda x: {'batchloss_ent': x[-2]}), 379 | event_name=Events.ITERATION_COMPLETED) 380 | tb_logger.attach(val_evaluator, 381 | log_handler=OutputHandler(tag="validation", output_transform=lambda x: {'batchloss': x[-1]}), 382 | event_name=Events.ITERATION_COMPLETED) 383 | tb_logger.attach(trainer, 384 | log_handler=OutputHandler(tag="training", metric_names=["total_loss", "ent_loss", "emo_loss"]), 385 | event_name=Events.EPOCH_COMPLETED) 386 | # tb_logger.attach(trainer, 387 | # log_handler=OutputHandler(tag="training", output_transform=lambda x: {'loss': x[0]}), 388 | # event_name=Events.EPOCH_COMPLETED) 389 | 390 | ''' 391 | 392 | tb_logger.attach(val_evaluator_ent, 393 | log_handler=OutputHandler(tag="validation_ent", 394 | metric_names=["ent_loss", "F1"], 395 | another_engine=trainer_ent), 396 | event_name=Events.EPOCH_COMPLETED) 397 | 398 | tb_logger.attach(val_evaluator_emo, 399 | log_handler=OutputHandler(tag="validation_emo", 400 | metric_names=["emo_loss", "F1"], 401 | another_engine=trainer_emo), 402 | event_name=Events.EPOCH_COMPLETED) 403 | 404 | tb_logger.attach(trainer_ent, 405 | log_handler=OptimizerParamsHandler(optimizer, "lr"), 406 | event_name=Events.EPOCH_COMPLETED) 407 | tb_logger.attach(trainer_emo, 408 | log_handler=OptimizerParamsHandler(optimizer, "lr"), 409 | event_name=Events.EPOCH_COMPLETED) 410 | ''' 411 | 412 | 413 | # tb_logger.attach(trainer, 414 | # log_handler=GradsScalarHandler(model), 415 | # event_name=Events.ITERATION_COMPLETED) 416 | 417 | tb_logger.attach(trainer, 418 | log_handler=GradsHistHandler(model), 419 | event_name=Events.EPOCH_COMPLETED) 420 | ''' 421 | 422 | # lr_find() 423 | trainer_ent.run(trn_dataloader, max_epochs=args.epochs) 424 | trainer_emo.run(trn_dataloader, max_epochs=args.epochs) 425 | tb_logger.close() 426 | 427 | 428 | if __name__ == '__main__': 429 | parser = ArgumentParser() 430 | parser.add_argument("--bert_model", default="../bert_pretrained/bert-base-chinese", type=str, 431 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 432 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 433 | "bert-base-multilingual-cased, bert-base-chinese.") 434 | parser.add_argument('--batch_size', type=int, default=48, 435 | help='input batch size for training8 (default: 64)') 436 | parser.add_argument('--val_batch_size', type=int, default=48, 437 | help='input batch size for validation (default: 1000)') 438 | parser.add_argument('--epochs', type=int, default=3, 439 | help='number of epochs to train (default: 10)') 440 | parser.add_argument('--nw', type=int, default=4, 441 | help='number of nw') 442 | parser.add_argument('--lr', type=float, default=3e-5, 443 | help='learning rate (default: 0.01)') 444 | parser.add_argument("--alpha", type=float, default=1, help="alpha") 445 | parser.add_argument("--wd", type=float, default=0.01, help="weight decay") 446 | parser.add_argument('--momentum', type=float, default=0.5, 447 | help='SGD momentum (default: 0.5)') 448 | parser.add_argument('--log_interval', type=int, default=10, 449 | help='how many batches to wait before logging training status') 450 | parser.add_argument("--log_dir", type=str, default="../tbs", 451 | help="log directory for Tensorboard log output") 452 | parser.add_argument("--checkpoint_model_dir", type=str, default='../ckps', 453 | help="path to folder where checkpoints of trained models will be saved") 454 | parser.add_argument("--ckp", type=str, default='None', 455 | help="ckp file") 456 | parser.add_argument("--hyper_cfg", type=str, default='default', 457 | help="config path to folder where checkpoints of trained models will be saved") 458 | parser.add_argument("--checkpoint_interval", type=int, default=1, 459 | help="number of batches after which a checkpoint of trained model will be created") 460 | parser.add_argument("--warmup_proportion", 461 | default=0.1, 462 | type=float, 463 | help="Proportion of training to perform linear learning rate warmup for. " 464 | "E.g., 0.1 = 10%% of training.") 465 | parser.add_argument("--dp", 466 | default=0.1, 467 | type=float, 468 | help="") 469 | parser.add_argument("--gamma", 470 | default=2, 471 | type=float, 472 | help="") 473 | parser.add_argument("--focal", 474 | action="store_true", 475 | help="") 476 | parser.add_argument("--lite", 477 | action="store_true", 478 | help="") 479 | parser.add_argument("--multi", 480 | action="store_true", 481 | help="multi alpha or not") 482 | 483 | args = parser.parse_args() 484 | 485 | train() 486 | -------------------------------------------------------------------------------- /src/data_title_trnval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | from pytorch_pretrained_bert.tokenization import BertTokenizer 11 | import h5py 12 | import gc 13 | from collections import OrderedDict 14 | import re 15 | from utils import covert_mytokens_to_myids, load_data, data_dump 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt='%m/%d/%Y %H:%M:%S', 19 | level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class InputExample(object): 24 | """A single training/test example for simple sequence classification.""" 25 | 26 | def __init__(self, guid, text, text_title=None, label_sent=None, label_title=None): 27 | ''' 28 | :param guid: example id, may be is newId 29 | :param text: cur sentence text 30 | :param text_title: the news title text 31 | :parm label: label text [B-POS, I-POS, I-POS] 32 | 在 fetures 中 label => 33 | label_ent: entity label [0,1,2...10] 34 | label_emo: emotion label [0,1,2...7] can transform from the label_ent 35 | 情感的标记到底应该时什么 ? [pos pos pos] or [B-pos, I-pos, I-pos] 36 | ''' 37 | self.guid = guid 38 | self.text = text 39 | self.text_title = text_title 40 | self.label_text = label_sent 41 | self.label_title = label_title 42 | 43 | 44 | class InputFeatures(object): 45 | """A single set of features of data.""" 46 | 47 | def __init__(self, input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids): 48 | self.input_ids = input_ids 49 | self.myinput_ids = myinput_ids 50 | self.input_mask = input_mask 51 | self.segment_ids = segment_ids 52 | self.label_ent_ids = label_ent_ids 53 | self.label_emo_ids = label_emo_ids 54 | 55 | 56 | def readfile(filename): 57 | logger.info("read file:{}....".format(filename)) 58 | data = [] 59 | 60 | def _read(): 61 | f = open(filename, encoding="utf-8") 62 | title = None 63 | label_title = None 64 | sentence = [] 65 | label_sent = [] 66 | count = 0 67 | has_title = False 68 | for line in f: 69 | if len(line) == 0 or line[0] == "\n": 70 | if len(sentence) > 0: 71 | if not has_title: 72 | # the title 73 | title = sentence 74 | label_title = label_sent 75 | has_title = True 76 | data.append((sentence, title, label_sent, label_title)) 77 | # refresh 78 | sentence = [] 79 | label_sent = [] 80 | else: 81 | # 第二个空行 sentence 已空 82 | has_title = False 83 | count += 1 84 | if count % 50000 == 0: 85 | print(count) 86 | continue 87 | splits = line.split(' ') 88 | sentence.append(splits[0]) 89 | label_sent.append(splits[-1][:-1]) 90 | 91 | # 防止因最后一行非空行而没加入最后一个 92 | if len(sentence) > 0: 93 | if not has_title: 94 | # the title 95 | title = sentence 96 | label_title = label_sent 97 | has_title = True 98 | data.append((sentence, title, label_sent, label_title)) 99 | # refresh 100 | # sentence = [] 101 | # label_sent = [] 102 | f.close() 103 | 104 | _read() 105 | # 最后一句查看 106 | print("The Last Sentence....") 107 | print(data[-1][0]) 108 | print(data[-1][1]) 109 | print("sentence num: ", len(data)) 110 | return data 111 | 112 | 113 | class DataProcessor(object): 114 | """Base class for data converters for sequence classification data sets.""" 115 | 116 | def get_train_examples(self, data_dir, islite): 117 | """Gets a collection of `InputExample`s for the train set.""" 118 | raise NotImplementedError() 119 | 120 | def get_dev_examples(self, data_dir, islite): 121 | """Gets a collection of `InputExample`s for the dev set.""" 122 | raise NotImplementedError() 123 | 124 | def get_labels(self, label_method): 125 | """Gets the list of labels for this data set.""" 126 | raise NotImplementedError() 127 | 128 | @classmethod 129 | def _read_tsv(cls, input_file, quotechar=None): 130 | """Reads a tab separated value file.""" 131 | return readfile(input_file) 132 | 133 | 134 | class NerProcessor(DataProcessor): 135 | """Processor for the CoNLL-2003 data set.""" 136 | 137 | def get_train_examples(self, data_dir, islite): 138 | """See base class.""" 139 | if not islite: 140 | return self._create_examples( 141 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train") 142 | else: 143 | return self._create_examples( 144 | self._read_tsv(os.path.join(data_dir, "lite_trn.txt")), "train") 145 | 146 | def get_dev_examples(self, data_dir, islite): 147 | """See base class.""" 148 | if not islite: 149 | return self._create_examples( 150 | self._read_tsv(os.path.join(data_dir, "val.txt")), "train") 151 | else: 152 | return self._create_examples( 153 | self._read_tsv(os.path.join(data_dir, "lite_val.txt")), "train") 154 | 155 | def get_labels(self, label_method): 156 | if label_method == "BIO": 157 | return ["O", "B-POS", "I-POS", "B-NEG", "I-NEG", "B-NORM", "I-NORM", "[CLS]", "[SEP]"] 158 | else: 159 | # BIESO 160 | return ["O", "B-POS", "I-POS", "E-POS", "B-NEG", "I-NEG", "E-NEG", "B-NORM", "I-NORM", "E-NORM", 161 | "S-POS", "S-NEG", "S-NORM", 162 | "[CLS]", "[SEP]"] 163 | 164 | def _create_examples(self, lines, set_type): 165 | examples = [] 166 | for i, (sentence, title, label_sent, label_title) in enumerate(lines): 167 | guid = "%s-%s" % (set_type, i) 168 | text = ' '.join(sentence) 169 | text_title = ' '.join(title) 170 | examples.append(InputExample(guid=guid, text=text, text_title=text_title, label_sent=label_sent, 171 | label_title=label_title)) 172 | return examples 173 | 174 | 175 | def _truncate_seq_pair(tokens_a, tokens_b, mytokens_a, mytokens_b, labels_A, labels_B, max_length): 176 | """Truncates a sequence pair in place to the maximum length.""" 177 | 178 | # This is a simple heuristic which will always truncate the longer sequence 179 | # one token at a time. This makes more sense than truncating an equal percent 180 | # of tokens from each, since if one sequence is very short then each token 181 | # that's truncated likely contains more information than a longer sequence. 182 | while True: 183 | total_length = len(tokens_a) + len(tokens_b) 184 | if total_length <= max_length: 185 | break 186 | if len(tokens_a) > len(tokens_b): 187 | tokens_a.pop() 188 | mytokens_a.pop() 189 | labels_A.pop() 190 | else: 191 | tokens_b.pop() 192 | mytokens_b.pop() 193 | labels_B.pop() 194 | 195 | 196 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, TOK2ID, lbl_method): 197 | """Loads a data file into a list of `InputBatch`s.""" 198 | 199 | logger.info("gen features...") 200 | label_map = {label: i for i, label in enumerate(label_list, 1)} 201 | label_map["PAD"] = 0 202 | label_map_reversed = {v: k for k, v in label_map.items()} 203 | 204 | features = [] 205 | logger.info("prepare features.....") 206 | count = 0 207 | for (ex_index, example) in enumerate(examples): 208 | textlist_A = example.text.split(' ') 209 | labellist_A = example.label_text 210 | textlist_B = example.text_title.split(' ') 211 | labellist_B = example.label_title 212 | tokens_A = [] 213 | mytokens_A = [] 214 | labels_A = [] 215 | tokens_B = [] 216 | mytokens_B = [] 217 | labels_B = [] 218 | 219 | # if count == 529320: 220 | # pdb.set_trace() 221 | ####################### tokenize ########################## 222 | def _tokenize(textlist, labellist, tokens, mytokens, labels): 223 | ''' 224 | # 重新检查textlist 除去多个汉字并联 225 | new_textlist = [] 226 | for word in textlist: 227 | if len(word) == 1: 228 | new_textlist.append(word) 229 | else: 230 | new_word = seg_char(word) 231 | if len(new_word) > 1: 232 | print("!wrong seg:", new_word) 233 | new_textlist.extend(new_word) 234 | else: 235 | new_textlist.append(word) 236 | textlist = new_textlist 237 | ''' 238 | SAMESPLIT = False 239 | for i, word in enumerate(textlist): 240 | SAMESPLIT = False 241 | token = tokenizer.tokenize(word) 242 | mytoken = token.copy() 243 | # debug [UNK] 244 | # 检查 [UNK], 编入 自定义 字典 245 | 246 | sharp_index = [i for i, tok in enumerate(token) if tok.startswith("##")] 247 | for si in sharp_index: 248 | no_sharp_tok = mytoken[si][2:] 249 | if no_sharp_tok not in TOK2ID: 250 | TOK2ID[no_sharp_tok] = len(TOK2ID) 251 | mytoken[si] = no_sharp_tok 252 | if "[UNK]" in token: 253 | assert "##" not in word 254 | # ## 255 | unks_index = [i for i, tok in enumerate(mytoken) if tok == "[UNK]"] 256 | unks_index.reverse() 257 | not_unks = [tok for tok in mytoken if tok != "[UNK]"] 258 | if not_unks: 259 | not_unks = [re.escape(nu) for nu in not_unks] 260 | if not (len(set(not_unks)) == 1 and len(not_unks) > 1): # 全是同一个字符,会无限循环 261 | pattern = "(.*)".join(not_unks) 262 | pattern = re.compile("(.*)" + pattern + "(.*)") 263 | else: 264 | SAMESPLIT = True 265 | print(word) 266 | f = "([^{}]*)".format(not_unks[0]) 267 | pattern = f.join(not_unks) 268 | pattern = re.compile(f + pattern + f) 269 | for res in pattern.findall(word): 270 | for r in res: 271 | if len(r) > 0 and r != "\u202a": # whitespace!! 272 | if r not in TOK2ID: 273 | TOK2ID[r] = len(TOK2ID) 274 | mytoken[unks_index[-1]] = r 275 | unks_index.pop() 276 | 277 | else: 278 | # 理论上应该是单个 如 4G 但是很奇怪分字有问题这里 如不了 279 | # assert len(token) == 1 280 | if len(token) == 1: 281 | if word not in TOK2ID: 282 | TOK2ID[word] = len(TOK2ID) 283 | mytoken[0] = word 284 | else: 285 | # ?不处理 286 | # '不了' 287 | mytoken = list(word) 288 | print("BUG!:", word, token, mytoken) 289 | for mytok in mytoken: 290 | if mytok not in TOK2ID: 291 | TOK2ID[mytok] = len(TOK2ID) 292 | if SAMESPLIT: 293 | print(f"guid: {example.guid}, title: {example.text_title}", word, token, mytoken, sep="\t") 294 | tokens.extend(token) 295 | mytokens.extend(mytoken) 296 | assert len(tokens) == len(mytokens) 297 | label_1 = labellist[i] 298 | for m in range(len(token)): 299 | # if m == 0: 300 | # else: 301 | # labels.append("X") 302 | if label_1 == "B-POS": 303 | if m == 0: 304 | labels.append(label_1) 305 | else: 306 | labels.append("I-POS") 307 | elif label_1 == "B-NEG": 308 | if m == 0: 309 | labels.append(label_1) 310 | else: 311 | labels.append("I-NEG") 312 | elif label_1 == "B-NORM": 313 | if m == 0: 314 | labels.append(label_1) 315 | else: 316 | labels.append("I-NORM") 317 | else: 318 | # 其它 I E 319 | if lbl_method == "BIO": 320 | labels.append(label_1) 321 | else: 322 | if len(token) > 1 and label_1[0] == "E": 323 | if label_1 == "E-POS": 324 | if m < len(token) - 1: 325 | labels.append("I-POS") 326 | else: 327 | labels.append("E-POS") 328 | if label_1 == "E-NEG": 329 | if m < len(token) - 1: 330 | labels.append("I-NEG") 331 | else: 332 | labels.append("E-NEG") 333 | if label_1 == "E-NORM": 334 | if m < len(token) - 1: 335 | labels.append("I-NORM") 336 | else: 337 | labels.append("E-NORM") 338 | else: 339 | labels.append(label_1) 340 | 341 | _tokenize(textlist_A, labellist_A, tokens_A, mytokens_A, labels_A) 342 | _tokenize(textlist_B, labellist_B, tokens_B, mytokens_B, labels_B) 343 | ###################################################### 344 | 345 | ################### 处理tokenize后新增的位置 以及 合并 截断 ##################### 346 | # [CLS] A [SEP] B [SEP] 347 | _truncate_seq_pair(tokens_A, tokens_B, mytokens_A, mytokens_B, labels_A, labels_B, max_seq_length - 3) 348 | tokens = ["[CLS]"] + tokens_A + ["[SEP]"] + tokens_B + ["[SEP]"] 349 | mytokens = ["[CLS]"] + mytokens_A + ["[SEP]"] + mytokens_B + ["[SEP]"] 350 | assert len(tokens) == len(mytokens) 351 | labels = ["[CLS]"] + labels_A + ["[SEP]"] + labels_B + ["[SEP]"] 352 | segment_ids = [0] * (len(tokens_A) + 2) + [1] * (len(tokens_B) + 1) 353 | 354 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 355 | myinput_ids = covert_mytokens_to_myids(TOK2ID, mytokens) 356 | input_mask = [1] * len(input_ids) 357 | label_ids = [label_map[lbl] for lbl in labels] 358 | # ------------------------------- PAD ----------------------------------- 359 | padding = [0] * (max_seq_length - len(input_ids)) 360 | input_ids += padding 361 | myinput_ids += padding 362 | input_mask += padding 363 | segment_ids += padding 364 | label_ids += padding 365 | # ----------------------------------------------------------------------- 366 | #############################分为两个label_ids########################### 367 | # emo label id 还有待实验 B-POS I-POS or POS POS 368 | # PAD O B-POS I-POS B-NEG I-NEG B-NORM I-NORM X [CLS] [SEP] 369 | # 0 1 2 3 4 5 6 7 8 9 10 370 | # PAD O B I X [CLS] [SEP] 371 | 372 | # O O POS POS NEG NEG NORM NORM O O O 373 | # 0 1 2 3 374 | 375 | # 类最少的模式 376 | # O B I 377 | # 0 1 2 3 378 | # 0 POS NEG NORM 379 | 380 | # label_emo_ids = label_ids 381 | # 这里可能有两种emo标法 只标 B 或者 BI都标, 这里使用后者,即上面所示 382 | L = np.array(label_ids) 383 | 384 | if lbl_method == "BIO": 385 | L[L == 10] = 0 386 | L[L == 9] = 0 387 | L[L == 8] = 0 388 | L[L == 1] = 0 389 | 390 | L[L == 2] = 1 391 | L[L == 3] = 1 392 | L[L == 4] = 2 393 | L[L == 5] = 2 394 | L[L == 6] = 3 395 | L[L == 7] = 3 396 | label_emo_ids = L.tolist() 397 | # ------------------ 398 | L = np.array(label_ids) 399 | '''' 400 | 带 PAD X [CLS] 标法 401 | L[L==4]=2 402 | L[L==6]=2 403 | L[L==5]=3 404 | L[L==7]=3 405 | L[L==8]=4 406 | L[L==9]=5 407 | L[L==10]=6 408 | ''' 409 | L[L == 1] = 0 410 | L[L == 2] = 1 411 | L[L == 4] = 1 412 | L[L == 6] = 1 413 | L[L == 3] = 2 414 | L[L == 5] = 2 415 | L[L == 7] = 2 416 | L[L == 8] = 0 417 | L[L == 9] = 0 418 | L[L == 10] = 0 419 | label_ent_ids = L.tolist() 420 | elif lbl_method == "BIEOS": 421 | L[L == 14] = 0 422 | L[L == 15] = 0 423 | L[L == 1] = 0 424 | 425 | L[L == 2] = 1 426 | L[L == 3] = 1 427 | L[L == 4] = 1 428 | L[L == 11] = 1 429 | 430 | L[L == 5] = 2 431 | L[L == 6] = 2 432 | L[L == 7] = 2 433 | L[L == 12] = 2 434 | 435 | L[L == 8] = 3 436 | L[L == 9] = 3 437 | L[L == 10] = 3 438 | L[L == 13] = 3 439 | label_emo_ids = L.tolist() 440 | # ------------------ 441 | L = np.array(label_ids) 442 | L[L == 14] = 0 443 | L[L == 15] = 0 444 | L[L == 1] = 0 445 | 446 | L[L == 2] = 1 # B 447 | L[L == 5] = 1 # B 448 | L[L == 8] = 1 # B 449 | 450 | L[L == 3] = 2 # I 451 | L[L == 6] = 2 # I 452 | L[L == 9] = 2 # I 453 | 454 | L[L == 4] = 3 # E 455 | L[L == 7] = 3 # E 456 | L[L == 10] = 3 # E 457 | 458 | L[L==11] = 4 459 | L[L==12] = 4 460 | L[L==13] = 4 461 | label_ent_ids = L.tolist() 462 | else: 463 | pass 464 | 465 | ####################################################################### 466 | 467 | assert len(input_ids) == len(myinput_ids) 468 | assert len(input_ids) == max_seq_length 469 | assert len(input_mask) == max_seq_length 470 | assert len(segment_ids) == max_seq_length 471 | assert len(label_ent_ids) == max_seq_length 472 | assert len(label_emo_ids) == max_seq_length 473 | 474 | if ex_index < 5: 475 | logger.info("*** Example ***") 476 | logger.info("guid: %s" % (example.guid)) 477 | logger.info("tokens: %s" % " ".join( 478 | [str(x) for x in tokens])) 479 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 480 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 481 | logger.info( 482 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 483 | logger.info( 484 | "label: %s" % " ".join([label_map_reversed[x] for x in label_ids])) 485 | features.append( 486 | InputFeatures(input_ids=input_ids, 487 | myinput_ids=myinput_ids, 488 | input_mask=input_mask, 489 | segment_ids=segment_ids, 490 | label_ent_ids=label_ent_ids, 491 | label_emo_ids=label_emo_ids)) 492 | count += 1 493 | if count % 20000 == 0: 494 | print("gen example {} feature".format(count)) 495 | 496 | logger.info("finish features gen") 497 | 498 | return features 499 | 500 | 501 | def main(): 502 | parser = argparse.ArgumentParser() 503 | parser.add_argument("--data_dir", 504 | default="../datasets", 505 | type=str, 506 | required=False, 507 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 508 | 509 | parser.add_argument("--bert_token_model", 510 | default="../bert_pretrained/bert_token_model", 511 | type=str, required=False) 512 | parser.add_argument("--task_name", 513 | default="ner", 514 | type=str, 515 | required=False, 516 | help="The name of the task to train.") 517 | 518 | ## Other parameters 519 | parser.add_argument("--cache_dir", 520 | default="", 521 | type=str, 522 | help="Where do you want to store the pre-trained models downloaded from s3") 523 | parser.add_argument("--max_seq_length", 524 | default=128, 525 | type=int, 526 | help="The maximum total input sequence length after WordPiece tokenization. \n" 527 | "Sequences longer than this will be truncated, and sequences shorter \n" 528 | "than this will be padded.") 529 | parser.add_argument("--do_lower_case", 530 | action='store_true', 531 | help="Set this flag if you are using an uncased model.") 532 | parser.add_argument('--seed', 533 | type=int, 534 | default=42, 535 | help="random seed for initialization") 536 | parser.add_argument("--lite", 537 | action='store_true', 538 | help="is lite") 539 | parser.add_argument("--label_method", 540 | default="BIO", 541 | type=str, 542 | required=False, 543 | help="标注格式 BIO BIESO 等") 544 | args = parser.parse_args() 545 | logger.info("data_dir: {}".format(args.data_dir)) 546 | 547 | processors = {"ner": NerProcessor} 548 | 549 | random.seed(args.seed) 550 | np.random.seed(args.seed) 551 | torch.manual_seed(args.seed) 552 | 553 | task_name = args.task_name.lower() 554 | if task_name not in processors: 555 | raise ValueError("Task not found: %s" % (task_name)) 556 | 557 | processor = processors[task_name]() 558 | 559 | label_list = processor.get_labels(args.label_method) 560 | tokenizer = BertTokenizer.from_pretrained(args.bert_token_model, do_lower_case=args.do_lower_case) 561 | 562 | if not os.path.exists("../datasets/ID2TOK.pkl"): 563 | ID2TOK = tokenizer.ids_to_tokens 564 | else: 565 | ID2TOK = load_data("../datasets/ID2TOK.pkl") 566 | TOK2ID = OrderedDict((tok, id) for id, tok in ID2TOK.items()) 567 | 568 | ################################# train data ################################# 569 | train_examples = processor.get_train_examples(args.data_dir, args.lite) 570 | train_features = convert_examples_to_features( 571 | train_examples, label_list, args.max_seq_length, tokenizer, TOK2ID, args.label_method) 572 | logger.info("***** Running training *****") 573 | logger.info(" Num examples = %d", len(train_features)) 574 | ################################ save2h5 ######################################## 575 | if not args.lite: 576 | f = h5py.File("../datasets/full.h5", "w", libver="latest") 577 | else: 578 | f = h5py.File("../datasets/lite.h5", "w", libver="latest") 579 | # free memory!! 580 | gc.collect() 581 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 582 | all_myinput_ids = torch.tensor([f.myinput_ids for f in train_features], dtype=torch.long) 583 | f.create_dataset("train/input_ids", data=all_input_ids, compression="gzip") 584 | f.create_dataset("train/myinput_ids", data=all_myinput_ids, compression="gzip") 585 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 586 | f.create_dataset("train/input_mask", data=all_input_mask, compression="gzip") 587 | del all_input_ids, all_input_mask 588 | gc.collect() 589 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 590 | f.create_dataset("train/segment_ids", data=all_segment_ids, compression="gzip") 591 | all_label_ent_ids = torch.tensor([f.label_ent_ids for f in train_features], dtype=torch.long) 592 | f.create_dataset("train/label_ent_ids", data=all_label_ent_ids, compression="gzip") 593 | all_label_emo_ids = torch.tensor([f.label_emo_ids for f in train_features], dtype=torch.long) 594 | f.create_dataset("train/label_emo_ids", data=all_label_emo_ids, compression="gzip") 595 | del all_segment_ids, all_label_ent_ids, all_label_emo_ids 596 | gc.collect() 597 | 598 | ################################# val data ################################# 599 | eval_examples = processor.get_dev_examples(args.data_dir, args.lite) 600 | eval_features = convert_examples_to_features( 601 | eval_examples, label_list, args.max_seq_length, tokenizer, TOK2ID, args.label_method) 602 | 603 | logger.info("***** Running evaluation *****") 604 | logger.info(" Num examples = %d", len(eval_features)) 605 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 606 | all_myinput_ids = torch.tensor([f.myinput_ids for f in eval_features], dtype=torch.long) 607 | f.create_dataset("val/input_ids", data=all_input_ids, compression="gzip") 608 | f.create_dataset("val/myinput_ids", data=all_myinput_ids, compression="gzip") 609 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 610 | f.create_dataset("val/input_mask", data=all_input_mask, compression="gzip") 611 | del all_input_ids, all_input_mask 612 | gc.collect() 613 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 614 | f.create_dataset("val/segment_ids", data=all_segment_ids, compression="gzip") 615 | all_label_ent_ids = torch.tensor([f.label_ent_ids for f in eval_features], dtype=torch.long) 616 | f.create_dataset("val/label_ent_ids", data=all_label_ent_ids, compression="gzip") 617 | all_label_emo_ids = torch.tensor([f.label_emo_ids for f in eval_features], dtype=torch.long) 618 | f.create_dataset("val/label_emo_ids", data=all_label_emo_ids, compression="gzip") 619 | del all_segment_ids, all_label_ent_ids, all_label_emo_ids 620 | gc.collect() 621 | f.close() 622 | ID2TOK = OrderedDict((id, tok) for tok, id in TOK2ID.items()) 623 | data_dump(ID2TOK, "../datasets/ID2TOK.pkl") 624 | 625 | print("save to h5 over!") 626 | 627 | 628 | if __name__ == "__main__": 629 | main() 630 | --------------------------------------------------------------------------------