├── README.md ├── __init__.py ├── bert_config.py ├── checkpoints ├── cnews │ └── 占位.txt └── cpws │ └── 占位.txt ├── data ├── cnews │ ├── labels.txt │ └── process.py └── cpws │ ├── labels.txt │ └── process.py ├── data_loader.py ├── logs ├── main.log └── preprocess.log ├── main.py ├── main_dataparallel.py ├── main_distributed.py ├── main_mp_distributed.py ├── models.py ├── nvidia.bat ├── requirements.txt ├── test_ddp.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc └── utils.cpython-37.pyc └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # 基于pytorch+bert的中文文本分类 2 | 3 | 本项目是基于pytorch+bert的中文文本分类。 4 | 5 | ## 使用依赖 6 | ```python 7 | torch==1.6.0 8 | transformers==4.5.1 9 | ``` 10 | ## 相关说明 11 | ``` 12 | --logs:存放日志 13 | --checkpoints:存放保存的模型 14 | --data:存放数据 15 | --utils:存放辅助函数 16 | --bert_config.py:相关配置 17 | --data_loader.py:制作数据为torch所需的格式 18 | --models.py:存放模型代码 19 | --main.py:主运行程序,包含训练、验证、测试、预测以及相关评价指标的计算 20 | ``` 21 | 22 | 在hugging face上预先下载好预训练的bert模型,放在和该项目同级下的model_hub文件夹下。 23 | 24 | ## 裁判文书分类 25 | 26 | 数据集地址:[裁判文书网NLP文本分类数据集 - Heywhale.com](https://www.heywhale.com/mw/dataset/625869115fe0ad0017c6a7f7/file) 27 | 28 | #### 一般步骤 29 | 30 | - 1、在data下新建一个存放该数据集的文件夹,这里是cpws,然后将数据放在该文件夹下。在该文件夹下新建一个process.py,主要是获取标签并存储在labels.txt中。 31 | - 2、在data_loader.py里面新建一个类,类可参考CPWSDataset,主要是返回一个列表:[(文本,标签)]。 32 | - 3、在main.py里面的datasets、train_files、test_files里面添加上属于该数据集的一些信息,最后运行main.py即可。 33 | - 4、在运行main.py时,可通过指定--do_train、--do_test、--do_predict来选择训练、测试或预测。 34 | 35 | #### 运行指令 36 | 37 | ```python 38 | python main.py \ 39 | --bert_dir="../model_hub/chinese-bert-wwm-ext/" \ 40 | --data_dir="./data/cpws/" \ 41 | --data_name="cpws" \ 42 | --log_dir="./logs/" \ 43 | --output_dir="./checkpoints/" \ 44 | --num_tags=5 \ 45 | --seed=123 \ 46 | --gpu_ids="0" \ 47 | --max_seq_len=256 \ 48 | --lr=3e-5 \ 49 | --train_batch_size=16 \ 50 | --train_epochs=5 \ 51 | --eval_batch_size=16 \ 52 | --do_train \ 53 | --do_test \ 54 | --do_predict 55 | ``` 56 | 57 | #### 结果 58 | 59 | 这里运行了300步之后手动停止了。 60 | 61 | ```python 62 | {'盗窃罪': 0, '交通肇事罪': 1, '诈骗罪': 2, '故意伤害罪': 3, '危险驾驶罪': 4} 63 | ========进行测试======== 64 | 【test】 loss:22.040314 accuracy:0.9823 micro_f1:0.9823 macro_f1:0.9822 65 | precision recall f1-score support 66 | 67 | 盗窃罪 0.97 0.99 0.98 1998 68 | 交通肇事罪 0.97 0.99 0.98 1996 69 | 诈骗罪 0.99 0.98 0.99 1998 70 | 故意伤害罪 0.99 0.99 0.99 1999 71 | 危险驾驶罪 0.99 0.95 0.97 2000 72 | 73 | accuracy 0.98 9991 74 | macro avg 0.98 0.98 0.98 9991 75 | weighted avg 0.98 0.98 0.98 9991 76 | 公诉机关指控:1、2015年3月18日18时许,被告人余某窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内,趁工作人员不注意盗走工地内的脚手架扣件70个(价值人民币252元)。2、2015年3月19日13时和17时,被告人余某分两次窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内一楼房一层的中间配电室内,利用随身携带的铁钳盗走该配电室内的电缆线(共计574米,价值人民币4707元)。3、2015年3月21日7时30分许,被告人余某窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内一楼房一层靠东边的配电室内,利用随身携带的铁钳要将该配电室内的电缆线(共156米,价值人民币1279元)盗走时被工地负责人洪某某发现,后被工地保安吴某某抓获并扭送公安机关。公诉机关认为被告人余某的行为已构成××,本案第三起盗窃系犯罪未遂,建议对被告人余某在××至一年六个月的幅度内处以刑罚,并处罚金。 77 | 预测标签: 盗窃罪 78 | 真实标签: 盗窃罪 79 | ========================== 80 | ``` 81 | 82 | # 新闻文本分类 83 | 84 | 使用的数据集是THUCNews,数据地址:THUCNews 85 | 86 | #### 一般步骤 87 | 88 | - 1、在data下新建一个存放该数据集的文件夹,这里是cnews,然后将数据放在该文件夹下。在该文件夹下新建一个process.py,主要是获取标签并存储在labels.txt中。 89 | - 2、在data_loader.py里面新建一个类,类可参考CNEWSDataset,主要是返回一个列表:[(文本,标签)]。 90 | - 3、在main.py里面的datasets、train_files、test_files里面添加上属于该数据集的一些信息,最后运行main.py即可。 91 | - 4、在运行main.py时,可通过指定--do_train、--do_test、--do_predict来选择训练、测试或预测。 92 | 93 | #### 运行 94 | 95 | ```python 96 | python main.py \ 97 | --bert_dir="../model_hub/chinese-bert-wwm-ext/" \ 98 | --data_dir="./data/cnews/" \ 99 | --data_name="cnews" \ 100 | --log_dir="./logs/" \ 101 | --output_dir="./checkpoints/" \ 102 | --num_tags=10 \ 103 | --seed=123 \ 104 | --gpu_ids="0" \ 105 | --max_seq_len=512 \ 106 | --lr=3e-5 \ 107 | --train_batch_size=16 \ 108 | --train_epochs=5 \ 109 | --eval_batch_size=16 \ 110 | --do_predict 111 | ``` 112 | #### 结果 113 | 114 | 这里运行了800步手动停止了。 115 | 116 | ```python 117 | {'房产': 0, '娱乐': 1, '教育': 2, '体育': 3, '家居': 4, '时政': 5, '财经': 6, '时尚': 7, '游戏': 8, '科技': 9} 118 | ========进行测试======== 119 | 【test】 loss:76.024950 accuracy:0.9697 micro_f1:0.9697 macro_f1:0.9696 120 | precision recall f1-score support 121 | 122 | 房产 0.91 0.92 0.92 1000 123 | 娱乐 0.99 0.99 0.99 1000 124 | 教育 0.97 0.96 0.97 1000 125 | 体育 1.00 1.00 1.00 1000 126 | 家居 0.98 0.91 0.94 1000 127 | 时政 0.98 0.94 0.96 1000 128 | 财经 0.96 0.99 0.97 1000 129 | 时尚 0.94 1.00 0.97 1000 130 | 游戏 0.99 0.99 0.99 1000 131 | 科技 0.98 0.99 0.99 1000 132 | 133 | accuracy 0.97 10000 134 | macro avg 0.97 0.97 0.97 10000 135 | weighted avg 0.97 0.97 0.97 10000 136 | 137 | 鲍勃库西奖归谁属? NCAA最强控卫是坎巴还是弗神新浪体育讯如今,本赛季的NCAA进入到了末段,各项奖项的评选结果也即将出炉,其中评选最佳控卫的鲍勃-库西奖就将在下周最终四强战时公布,鲍勃-库西奖是由奈史密斯篮球名人堂提供,旨在奖励年度最佳大学控卫。最终获奖的球员也即将在以下几名热门人选中产生。〈〈〈 NCAA疯狂三月专题主页上线,点击链接查看精彩内容吉梅尔-弗雷戴特,杨百翰大学“弗神”吉梅尔-弗雷戴特一直都备受关注,他不仅仅是一名射手,他会用“终结对手脚踝”一样的变向过掉面前的防守者,并且他可以用任意一支手完成得分,如果他被犯规了,可以提前把这两份划入他的帐下了,因为他是一名命中率高达90%的罚球手。弗雷戴特具有所有伟大控卫都具备的一点特质,他是一位赢家也是一位领导者。“他整个赛季至始至终的稳定领导着球队前进,这是无可比拟的。”杨百翰大学主教练戴夫-罗斯称赞道,“他的得分能力毋庸置疑,但是我认为他带领球队获胜的能力才是他最重要的控卫职责。我们在主场之外的比赛(客场或中立场)共取胜19场,他都表现的很棒。”弗雷戴特能否在NBA取得成功?当然,但是有很多专业人士比我们更有资格去做出这样的判断。“我喜爱他。”凯尔特人主教练多克-里弗斯说道,“他很棒,我看过ESPN的片段剪辑,从剪辑来看,他是个超级巨星,我认为他很成为一名优秀的NBA球员。”诺兰-史密斯,杜克大学当赛季初,球队宣布大一天才控卫凯瑞-厄尔文因脚趾的伤病缺席赛季大部分比赛后,诺兰-史密斯便开始接管球权,他在进攻端上足发条,在ACC联盟(杜克大学所在分区)的得分榜上名列前茅,但同时他在分区助攻榜上也占据头名,这在众强林立的ACC联盟前无古人。“我不认为全美有其他的球员能在凯瑞-厄尔文受伤后,如此好的接管球队,并且之前毫无准备。”杜克主教练迈克-沙舍夫斯基赞扬道,“他会将比赛带入自己的节奏,得分,组织,领导球队,无所不能。而且他现在是攻防俱佳,对持球人的防守很有提高。总之他拥有了辉煌的赛季。”坎巴-沃克,康涅狄格大学坎巴-沃克带领康涅狄格在赛季初的毛伊岛邀请赛一路力克密歇根州大和肯塔基等队夺冠,他场均30分4助攻得到最佳球员。在大东赛区锦标赛和全国锦标赛中,他场均27.1分,6.1个篮板,5.1次助攻,依旧如此给力。他以疯狂的表现开始这个赛季,也将以疯狂的表现结束这个赛季。“我们在全国锦标赛中前进着,并且之前曾经5天连赢5场,赢得了大东赛区锦标赛的冠军,这些都归功于坎巴-沃克。”康涅狄格大学主教练吉姆-卡洪称赞道,“他是一名纯正的控卫而且能为我们得分,他有过单场42分,有过单场17助攻,也有过单场15篮板。这些都是一名6英尺175镑的球员所完成的啊!我们有很多好球员,但他才是最好的领导者,为球队所做的贡献也是最大。”乔丹-泰勒,威斯康辛大学全美没有一个持球者能像乔丹-泰勒一样很少失误,他4.26的助攻失误在全美遥遥领先,在大十赛区的比赛中,他平均35.8分钟才会有一次失误。他还是名很出色的得分手,全场砍下39分击败印第安纳大学的比赛就是最好的证明,其中下半场他曾经连拿18分。“那个夜晚他证明自己值得首轮顺位。”当时的见证者印第安纳大学主教练汤姆-克雷恩说道。“对一名控卫的所有要求不过是领导球队、使球队变的更好、带领球队成功,乔丹-泰勒全做到了。”威斯康辛教练博-莱恩说道。诺里斯-科尔,克利夫兰州大诺里斯-科尔的草根传奇正在上演,默默无闻的他被克利夫兰州大招募后便开始刻苦地训练,去年夏天他曾加练上千次跳投,来提高这个可能的弱点。他在本赛季与杨斯顿州大的比赛中得到40分20篮板和9次助攻,在他之前,过去15年只有一位球员曾经在NCAA一级联盟做到过40+20,他的名字是布雷克-格里芬。“他可以很轻松地防下对方王牌。”克利夫兰州大主教练加里-沃特斯如此称赞自己的弟子,“同时他还能得分,并为球队助攻,他几乎能做到一个成功的团队所有需要的事。”这其中四名球员都带领自己的球队进入到了甜蜜16强,虽然有3个球员和他们各自的球队被挡在8强的大门之外,但是他们已经表现的足够出色,不远的将来他们很可能出现在一所你熟悉的NBA球馆里。(clay) 138 | 预测标签: 体育 139 | 真实标签: 体育 140 | ========================== 141 | ``` 142 | 143 | # Dataparallel分布式训练 144 | 145 | ```python 146 | python main_dataparallel.py \ 147 | --bert_dir="../model_hub/chinese-bert-wwm-ext/" \ 148 | --data_dir="./data/cnews/" \ 149 | --data_name="cnews" \ 150 | --log_dir="./logs/" \ 151 | --output_dir="./checkpoints/" \ 152 | --num_tags=10 \ 153 | --seed=123 \ 154 | --gpu_ids="0,1,3" \ 155 | --max_seq_len=512 \ 156 | --lr=3e-5 \ 157 | --train_batch_size=64 \ 158 | --train_epochs=1 \ 159 | --eval_batch_size=64 \ 160 | --do_train \ 161 | --do_predict \ 162 | --do_test 163 | ``` 164 | 165 | # Distributed单机多卡分布式训练(windows下) 166 | 167 | linux下没有测试过。运行需要在powershell里面运行,右键点击开始菜单,选择powershell。nvidia.bat用于监控运行之后GPU的使用情况。需要pytorch版本至少大于1.7,这里使用的是pytorch==1.12。 168 | 169 | ### 使用torch.distributed.launch启动 170 | 171 | ```python 172 | python -m torch.distributed.launch --nnode=1 --node_rank=0 --nproc_per_node=4 main_distributed.py --local_world_size=4 --bert_dir="../model_hub/chinese-bert-wwm-ext/" --data_dir="./data/cnews/" --data_name="cnews" --log_dir="./logs/" --output_dir="./checkpoints/" --num_tags=10 --seed=123 --max_seq_len=512 --lr=3e-5 --train_batch_size=64 --train_epochs=1 --eval_batch_size=64 --do_train --do_predict --do_test 173 | ``` 174 | 175 | **说明**:文件里面通过```os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'```来选择使用的GPU。nproc_per_node为使用的GPU的数目,local_world_size为使用的GPU的数目。 176 | 177 | ### 使用torch.multiprocessing启动 178 | 179 | ```python 180 | python main_mp_distributed.py --local_world_size=4 --bert_dir="../model_hub/chinese-bert-wwm-ext/" --data_dir="./data/cnews/" --data_name="cnews" --log_dir="./logs/" --output_dir="./checkpoints/" --num_tags=10 --seed=123 --max_seq_len=512 --lr=3e-5 --train_batch_size=64 --train_epochs=1 --eval_batch_size=64 --do_train --do_predict --do_test 181 | ``` 182 | 183 | **说明**:文件里面通过```os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'```来选择使用的GPU。local_world_size为使用的GPU的数目。 184 | 185 | # 补充 186 | 187 | Q:怎么训练自己的数据集?
188 | 189 | A:按照样例的一般步骤里面进行即可。
190 | 191 | # 更新日志 192 | 193 | - 2022-08-08:重构了代码,使得总体结构更加简单,更易于用于不同的数据集上。 194 | - 2022-08-09:新增是否加载模型继续训练,运行参数加上--retrain。 195 | 196 | - 2023-03-30:新增基于dataparallel的分布式训练。 197 | 198 | - 2023-04-02:新增基于distributed的分布式训练。 199 | 200 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/__init__.py -------------------------------------------------------------------------------- /bert_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Args: 5 | @staticmethod 6 | def parse(): 7 | parser = argparse.ArgumentParser() 8 | return parser 9 | 10 | @staticmethod 11 | def initialize(parser): 12 | # args for path 13 | parser.add_argument('--output_dir', default='../checkpoints/', 14 | help='the output dir for model checkpoints') 15 | 16 | parser.add_argument('--bert_dir', default='../../model_hub/bert-base-case/', 17 | help='bert dir for uer') 18 | parser.add_argument('--data_dir', default='../data/cnews/', 19 | help='data dir for uer') 20 | parser.add_argument('--log_dir', default='../logs/', 21 | help='log dir for uer') 22 | parser.add_argument('--data_name', default='cnews', 23 | help='数据集的名称') 24 | 25 | # other args 26 | parser.add_argument('--num_tags', default=65, type=int, 27 | help='number of tags') 28 | parser.add_argument('--seed', type=int, default=123, help='random seed') 29 | 30 | parser.add_argument('--gpu_ids', type=str, default='0', 31 | help='gpu ids to use, -1 for cpu, "0,1" for multi gpu') 32 | 33 | parser.add_argument('--max_seq_len', default=256, type=int) 34 | 35 | parser.add_argument('--eval_batch_size', default=12, type=int) 36 | 37 | parser.add_argument('--swa_start', default=3, type=int, 38 | help='the epoch when swa start') 39 | 40 | # train args 41 | # This is passed in via launch.py 42 | parser.add_argument("--local_rank", type=int, default=0) 43 | # This needs to be explicitly passed in 44 | parser.add_argument("--local_world_size", type=int, default=1) 45 | parser.add_argument('--train_epochs', default=15, type=int, 46 | help='Max training epoch') 47 | 48 | parser.add_argument('--dropout_prob', default=0.1, type=float, 49 | help='drop out probability') 50 | 51 | # 2e-5 52 | parser.add_argument('--lr', default=3e-5, type=float, 53 | help='learning rate for the bert module') 54 | # 2e-3 55 | parser.add_argument('--other_lr', default=3e-4, type=float, 56 | help='learning rate for the module except bert') 57 | # 0.5 58 | parser.add_argument('--max_grad_norm', default=1, type=float, 59 | help='max grad clip') 60 | 61 | parser.add_argument('--warmup_proportion', default=0.1, type=float) 62 | 63 | parser.add_argument('--weight_decay', default=0.01, type=float) 64 | 65 | parser.add_argument('--adam_epsilon', default=1e-8, type=float) 66 | 67 | parser.add_argument('--train_batch_size', default=32, type=int) 68 | 69 | # parser.add_argument('--eval_model', default=True, action='store_true', 70 | # help='whether to eval model after training') 71 | parser.add_argument('--do_train', action='store_true', 72 | help='是否训练') 73 | parser.add_argument('--do_test', action='store_true', 74 | help='是否测试') 75 | parser.add_argument('--do_predict', action='store_true', 76 | help='是否预测') 77 | parser.add_argument('--retrain', action='store_true', 78 | help='是否加载模型继续训练') 79 | 80 | return parser 81 | 82 | def get_parser(self): 83 | parser = self.parse() 84 | parser = self.initialize(parser) 85 | return parser.parse_args() 86 | -------------------------------------------------------------------------------- /checkpoints/cnews/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/checkpoints/cnews/占位.txt -------------------------------------------------------------------------------- /checkpoints/cpws/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/checkpoints/cpws/占位.txt -------------------------------------------------------------------------------- /data/cnews/labels.txt: -------------------------------------------------------------------------------- 1 | 教育 2 | 娱乐 3 | 家居 4 | 房产 5 | 科技 6 | 时尚 7 | 体育 8 | 财经 9 | 时政 10 | 游戏 11 | -------------------------------------------------------------------------------- /data/cnews/process.py: -------------------------------------------------------------------------------- 1 | labels = [] 2 | 3 | with open('cnews.train.txt','r') as fp: 4 | lines = fp.read().strip().split('\n') 5 | for line in lines: 6 | line = line.split('\t') 7 | labels.append(line[0]) 8 | 9 | labels = set(labels) 10 | with open('./labels.txt','w') as fp: 11 | fp.write("\n".join(labels)) -------------------------------------------------------------------------------- /data/cpws/labels.txt: -------------------------------------------------------------------------------- 1 | 盗窃罪 2 | 交通肇事罪 3 | 诈骗罪 4 | 故意伤害罪 5 | 危险驾驶罪 -------------------------------------------------------------------------------- /data/cpws/process.py: -------------------------------------------------------------------------------- 1 | filename = 'train_data.txt' 2 | labels = set() 3 | with open(filename, 'r', encoding='utf-8') as f: 4 | raw_data = f.readlines() 5 | for d in raw_data: 6 | d = d.strip() 7 | d = d.split("\t") 8 | if len(d) == 2: 9 | labels.add(d[0]) 10 | 11 | with open('../final_data/labels.txt', 'w', encoding='utf-8') as f: 12 | f.write("\n".join(list(labels))) -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | class ListDataset(Dataset): 9 | def __init__(self, file_path=None, data=None, tokenizer=None, max_len=None, **kwargs): 10 | self.kwargs = kwargs 11 | if isinstance(file_path, (str, list)): 12 | self.data = self.load_data(file_path) 13 | elif isinstance(data, list): 14 | self.data = data 15 | else: 16 | raise ValueError('The input args shall be str format file_path / list format dataset') 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, index): 22 | return self.data[index] 23 | 24 | @staticmethod 25 | def load_data(file_path): 26 | return file_path 27 | 28 | 29 | # 加载数据集 30 | class CNEWSDataset(ListDataset): 31 | @staticmethod 32 | def load_data(filename): 33 | data = [] 34 | with open(filename, encoding='utf-8') as f: 35 | raw_data = f.readlines() 36 | for d in raw_data: 37 | d = d.strip().split('\t') 38 | text = d[1] 39 | label = d[0] 40 | data.append((text, label)) 41 | return data 42 | 43 | 44 | class CPWSDataset(ListDataset): 45 | @staticmethod 46 | def load_data(filename): 47 | data = [] 48 | with open(filename, encoding='utf-8') as f: 49 | raw_data = f.readlines() 50 | for d in raw_data: 51 | d = d.strip() 52 | d = d.split("\t") 53 | if len(d) == 2: 54 | data.append((d[1], d[0])) 55 | return data 56 | 57 | 58 | 59 | 60 | class Collate: 61 | def __init__(self, tokenizer, max_len, tag2id): 62 | self.tokenizer = tokenizer 63 | self.maxlen = max_len 64 | self.tag2id = tag2id 65 | 66 | def collate_fn(self, batch): 67 | batch_labels = [] 68 | batch_token_ids = [] 69 | batch_attention_mask = [] 70 | batch_token_type_ids = [] 71 | for i, (text, label) in enumerate(batch): 72 | output = self.tokenizer.encode_plus( 73 | text=text, 74 | max_length=self.maxlen, 75 | padding="max_length", 76 | truncation='longest_first', 77 | return_token_type_ids=True, 78 | return_attention_mask=True 79 | ) 80 | token_ids = output["input_ids"] 81 | token_type_ids = output["token_type_ids"] 82 | attention_mask = output["attention_mask"] 83 | batch_token_ids.append(token_ids) # 前面已经限制了长度 84 | batch_attention_mask.append(attention_mask) 85 | batch_token_type_ids.append(token_type_ids) 86 | batch_labels.append(self.tag2id[label]) 87 | batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long) 88 | attention_mask = torch.tensor(batch_attention_mask, dtype=torch.long) 89 | token_type_ids = torch.tensor(batch_token_type_ids, dtype=torch.long) 90 | batch_labels = torch.tensor(batch_labels, dtype=torch.long) 91 | batch_data = { 92 | "token_ids":batch_token_ids, 93 | "attention_masks":attention_mask, 94 | "token_type_ids":token_type_ids, 95 | "labels":batch_labels 96 | } 97 | return batch_data 98 | 99 | 100 | if __name__ == "__main__": 101 | from transformers import BertTokenizer 102 | 103 | max_len = 512 104 | tokenizer = BertTokenizer.from_pretrained('../model_hub/chinese-bert-wwm-ext') 105 | train_dataset = CNEWSDataset(file_path='data/cnews/cnews.train.txt') 106 | print(train_dataset[0]) 107 | 108 | with open('data/cnews/labels.txt', 'r', encoding="utf-8") as fp: 109 | labels = fp.read().strip().split("\n") 110 | id2tag = {} 111 | tag2id = {} 112 | for i, label in enumerate(labels): 113 | id2tag[i] = label 114 | tag2id[label] = i 115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 | collate = Collate(tokenizer=tokenizer, max_len=max_len, tag2id=tag2id, device=device) 117 | batch_size = 2 118 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate.collate_fn) 119 | 120 | for i, batch in enumerate(train_dataloader): 121 | print(batch["token_ids"].shape) 122 | print(batch["attention_masks"].shape) 123 | print(batch["token_type_ids"].shape) 124 | print(batch["labels"].shape) 125 | break 126 | -------------------------------------------------------------------------------- /logs/preprocess.log: -------------------------------------------------------------------------------- 1 | 2021-07-16 16:48:43,343 - INFO - preprocess.py - - 150 - {'output_dir': '../checkpoints/', 'bert_dir': '../model_hub/bert-base-chinese/', 'data_dir': '../data/cnews/', 'log_dir': './logs/', 'num_tags': 65, 'seed': 123, 'gpu_ids': '0', 'max_seq_len': 256, 'eval_batch_size': 12, 'swa_start': 3, 'train_epochs': 15, 'dropout_prob': 0.1, 'lr': 3e-05, 'other_lr': 0.0003, 'max_grad_norm': 1, 'warmup_proportion': 0.1, 'weight_decay': 0.01, 'adam_epsilon': 1e-08, 'train_batch_size': 32, 'eval_model': True} 2 | 2021-07-16 16:48:43,815 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 50000 examples to features 3 | 2021-07-16 16:48:43,822 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-0 *** 4 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 马 晓 旭 意 外 受 伤 让 国 奥 警 惕 无 奈 大 雨 格 外 青 睐 殷 家 军 记 者 傅 亚 雨 沈 阳 报 道 来 到 沈 阳 , 国 奥 队 依 然 没 有 摆 脱 雨 水 的 困 扰 。 7 月 31 日 下 午 6 点 , 国 奥 队 的 日 常 训 练 再 度 受 到 大 雨 的 干 扰 , 无 奈 之 下 队 员 们 只 慢 跑 了 25 分 钟 就 草 草 收 场 。 31 日 上 午 10 点 , 国 奥 队 在 奥 体 中 心 外 场 训 练 的 时 候 , 天 就 是 阴 沉 沉 的 , 气 象 预 报 显 示 当 天 下 午 沈 阳 就 有 大 雨 , 但 幸 好 队 伍 上 午 的 训 练 并 没 有 受 到 任 何 干 扰 。 下 午 6 点 , 当 球 队 抵 达 训 练 场 时 , 大 雨 已 经 下 了 几 个 小 时 , 而 且 丝 毫 没 有 停 下 来 的 意 思 。 抱 着 试 一 试 的 态 度 , 球 队 开 始 了 当 天 下 午 的 例 行 训 练 , 25 分 钟 过 去 了 , 天 气 没 有 任 何 转 好 的 迹 象 , 为 了 保 护 球 [SEP] 5 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7716, 3236, 3195, 2692, 1912, 1358, 839, 6375, 1744, 1952, 6356, 2664, 3187, 1937, 1920, 7433, 3419, 1912, 7471, 4712, 3668, 2157, 1092, 6381, 5442, 987, 762, 7433, 3755, 7345, 2845, 6887, 3341, 1168, 3755, 7345, 8024, 1744, 1952, 7339, 898, 4197, 3766, 3300, 3030, 5564, 7433, 3717, 4638, 1737, 2817, 511, 128, 3299, 8176, 3189, 678, 1286, 127, 4157, 8024, 1744, 1952, 7339, 4638, 3189, 2382, 6378, 5298, 1086, 2428, 1358, 1168, 1920, 7433, 4638, 2397, 2817, 8024, 3187, 1937, 722, 678, 7339, 1447, 812, 1372, 2714, 6651, 749, 8132, 1146, 7164, 2218, 5770, 5770, 3119, 1767, 511, 8176, 3189, 677, 1286, 8108, 4157, 8024, 1744, 1952, 7339, 1762, 1952, 860, 704, 2552, 1912, 1767, 6378, 5298, 4638, 3198, 952, 8024, 1921, 2218, 3221, 7346, 3756, 3756, 4638, 8024, 3698, 6496, 7564, 2845, 3227, 4850, 2496, 1921, 678, 1286, 3755, 7345, 2218, 3300, 1920, 7433, 8024, 852, 2401, 1962, 7339, 824, 677, 1286, 4638, 6378, 5298, 2400, 3766, 3300, 1358, 1168, 818, 862, 2397, 2817, 511, 678, 1286, 127, 4157, 8024, 2496, 4413, 7339, 2850, 6809, 6378, 5298, 1767, 3198, 8024, 1920, 7433, 2347, 5307, 678, 749, 1126, 702, 2207, 3198, 8024, 5445, 684, 692, 3690, 3766, 3300, 977, 678, 3341, 4638, 2692, 2590, 511, 2849, 4708, 6407, 671, 6407, 4638, 2578, 2428, 8024, 4413, 7339, 2458, 1993, 749, 2496, 1921, 678, 1286, 4638, 891, 6121, 6378, 5298, 8024, 8132, 1146, 7164, 6814, 1343, 749, 8024, 1921, 3698, 3766, 3300, 818, 862, 6760, 1962, 4638, 6839, 6496, 8024, 711, 749, 924, 2844, 4413, 102] 6 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 7 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 8 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 9 | 2021-07-16 16:48:43,838 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-1 *** 10 | 2021-07-16 16:48:43,838 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 商 瑞 华 首 战 复 仇 心 切 中 国 玫 瑰 要 用 美 国 方 式 攻 克 瑞 典 多 曼 来 了 , 瑞 典 来 了 , 商 瑞 华 首 战 求 3 分 的 信 心 也 来 了 。 距 离 首 战 72 小 时 当 口 , 中 国 女 足 彻 底 从 [UNK] 恐 瑞 症 [UNK] 当 中 获 得 解 脱 , 因 为 商 瑞 华 已 经 找 到 了 瑞 典 人 的 软 肋 。 找 到 软 肋 , 保 密 4 月 20 日 奥 运 会 分 组 抽 签 结 果 出 来 后 , 中 国 姑 娘 就 把 瑞 典 锁 定 为 关 乎 奥 运 成 败 的 头 号 劲 敌 , 因 为 除 了 浦 玮 等 个 别 老 将 之 外 , 现 役 女 足 将 士 竟 然 没 有 人 尝 过 击 败 瑞 典 的 滋 味 。 在 中 瑞 两 队 共 计 15 次 交 锋 的 历 史 上 , 中 国 队 6 胜 3 平 6 负 与 瑞 典 队 平 分 秋 色 , 但 从 2001 年 起 至 今 近 8 年 时 间 , 中 国 在 同 瑞 典 连 续 5 次 交 锋 中 均 未 尝 胜 绩 , 战 绩 为 2 平 3 负 。 尽 管 八 年 [SEP] 11 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1555, 4448, 1290, 7674, 2773, 1908, 790, 2552, 1147, 704, 1744, 4382, 4456, 6206, 4500, 5401, 1744, 3175, 2466, 3122, 1046, 4448, 1073, 1914, 3294, 3341, 749, 8024, 4448, 1073, 3341, 749, 8024, 1555, 4448, 1290, 7674, 2773, 3724, 124, 1146, 4638, 928, 2552, 738, 3341, 749, 511, 6655, 4895, 7674, 2773, 8325, 2207, 3198, 2496, 1366, 8024, 704, 1744, 1957, 6639, 2515, 2419, 794, 100, 2607, 4448, 4568, 100, 2496, 704, 5815, 2533, 6237, 5564, 8024, 1728, 711, 1555, 4448, 1290, 2347, 5307, 2823, 1168, 749, 4448, 1073, 782, 4638, 6763, 5490, 511, 2823, 1168, 6763, 5490, 8024, 924, 2166, 125, 3299, 8113, 3189, 1952, 6817, 833, 1146, 5299, 2853, 5041, 5310, 3362, 1139, 3341, 1400, 8024, 704, 1744, 1996, 2023, 2218, 2828, 4448, 1073, 7219, 2137, 711, 1068, 725, 1952, 6817, 2768, 6571, 4638, 1928, 1384, 1226, 3127, 8024, 1728, 711, 7370, 749, 3855, 4383, 5023, 702, 1166, 5439, 2199, 722, 1912, 8024, 4385, 2514, 1957, 6639, 2199, 1894, 4994, 4197, 3766, 3300, 782, 2214, 6814, 1140, 6571, 4448, 1073, 4638, 3996, 1456, 511, 1762, 704, 4448, 697, 7339, 1066, 6369, 8115, 3613, 769, 7226, 4638, 1325, 1380, 677, 8024, 704, 1744, 7339, 127, 5526, 124, 2398, 127, 6566, 680, 4448, 1073, 7339, 2398, 1146, 4904, 5682, 8024, 852, 794, 8285, 2399, 6629, 5635, 791, 6818, 129, 2399, 3198, 7313, 8024, 704, 1744, 1762, 1398, 4448, 1073, 6825, 5330, 126, 3613, 769, 7226, 704, 1772, 3313, 2214, 5526, 5327, 8024, 2773, 5327, 711, 123, 2398, 124, 6566, 511, 2226, 5052, 1061, 2399, 102] 12 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 13 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 14 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 15 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-2 *** 16 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 冠 军 球 队 迎 新 欢 乐 派 对 黄 旭 获 大 奖 张 军 赢 下 pk 赛 新 浪 体 育 讯 12 月 27 日 晚 , [UNK] 冠 军 高 尔 夫 球 队 迎 新 高 球 欢 乐 派 对 [UNK] 活 动 在 北 京 都 市 名 人 高 尔 夫 俱 乐 部 举 行 。 邢 傲 伟 、 黄 旭 、 邹 凯 、 滕 海 滨 、 杨 凌 、 马 燕 红 、 张 军 、 王 丽 萍 、 杨 影 、 阎 森 、 毕 文 静 、 夏 煊 泽 、 张 璇 、 叶 乔 波 、 莫 慧 兰 、 周 雅 菲 、 胡 妮 、 戴 菲 菲 、 马 健 等 奥 运 冠 军 、 世 界 冠 军 参 加 了 当 晚 的 活 动 。 [UNK] 以 球 会 友 、 联 谊 交 流 [UNK] 为 宗 旨 的 冠 军 高 尔 夫 球 队 , 从 创 立 之 初 就 秉 承 关 爱 社 会 、 回 馈 社 会 的 价 值 观 。 此 次 迎 新 派 对 活 动 的 举 行 , 标 志 着 这 个 为 中 国 金 牌 运 动 员 提 供 的 放 松 身 心 、 以 球 会 友 的 平 台 正 式 全 面 启 动 。 球 队 已 经 与 都 市 名 人 高 [SEP] 17 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1094, 1092, 4413, 7339, 6816, 3173, 3614, 727, 3836, 2190, 7942, 3195, 5815, 1920, 1946, 2476, 1092, 6617, 678, 8465, 6612, 3173, 3857, 860, 5509, 6380, 8110, 3299, 8149, 3189, 3241, 8024, 100, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 6816, 3173, 7770, 4413, 3614, 727, 3836, 2190, 100, 3833, 1220, 1762, 1266, 776, 6963, 2356, 1399, 782, 7770, 2209, 1923, 936, 727, 6956, 715, 6121, 511, 6928, 1000, 836, 510, 7942, 3195, 510, 6941, 1132, 510, 4001, 3862, 4012, 510, 3342, 1119, 510, 7716, 4242, 5273, 510, 2476, 1092, 510, 4374, 714, 5847, 510, 3342, 2512, 510, 7330, 3481, 510, 3684, 3152, 7474, 510, 1909, 4201, 3813, 510, 2476, 4462, 510, 1383, 730, 3797, 510, 5811, 2716, 1065, 510, 1453, 7414, 5838, 510, 5529, 1984, 510, 2785, 5838, 5838, 510, 7716, 978, 5023, 1952, 6817, 1094, 1092, 510, 686, 4518, 1094, 1092, 1346, 1217, 749, 2496, 3241, 4638, 3833, 1220, 511, 100, 809, 4413, 833, 1351, 510, 5468, 6449, 769, 3837, 100, 711, 2134, 3192, 4638, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 8024, 794, 1158, 4989, 722, 1159, 2218, 4903, 2824, 1068, 4263, 4852, 833, 510, 1726, 7668, 4852, 833, 4638, 817, 966, 6225, 511, 3634, 3613, 6816, 3173, 3836, 2190, 3833, 1220, 4638, 715, 6121, 8024, 3403, 2562, 4708, 6821, 702, 711, 704, 1744, 7032, 4277, 6817, 1220, 1447, 2990, 897, 4638, 3123, 3351, 6716, 2552, 510, 809, 4413, 833, 1351, 4638, 2398, 1378, 3633, 2466, 1059, 7481, 1423, 1220, 511, 4413, 7339, 2347, 5307, 680, 6963, 2356, 1399, 782, 7770, 102] 18 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 19 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 20 | 2021-07-16 16:48:43,858 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 21 | 2021-07-16 16:55:52,209 - INFO - preprocess.py - convert_examples_to_features - 121 - Build 50000 features 22 | 2021-07-16 16:55:52,275 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 5000 examples to features 23 | 2021-07-16 16:55:52,277 - INFO - preprocess.py - convert_bert_example - 84 - *** dev_example-0 *** 24 | 2021-07-16 16:55:52,277 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 黄 蜂 vs 湖 人 首 发 : 科 比 带 伤 战 保 罗 加 索 尔 救 赎 之 战 新 浪 体 育 讯 北 京 时 间 4 月 27 日 , nba 季 后 赛 首 轮 洛 杉 矶 湖 人 主 场 迎 战 新 奥 尔 良 黄 蜂 , 此 前 的 比 赛 中 , 双 方 战 成 2 - 2 平 , 因 此 本 场 比 赛 对 于 两 支 球 队 来 说 都 非 常 重 要 , 赛 前 双 方 也 公 布 了 首 发 阵 容 : 湖 人 队 : 费 舍 尔 、 科 比 、 阿 泰 斯 特 、 加 索 尔 、 拜 纳 姆 黄 蜂 队 : 保 罗 、 贝 里 内 利 、 阿 里 扎 、 兰 德 里 、 奥 卡 福 [ 新 浪 nba 官 方 微 博 ] [ 新 浪 nba 湖 人 新 闻 动 态 微 博 ] [ 新 浪 nba 专 题 ] [ 黄 蜂 vs 湖 人 图 文 直 播 室 ] ( 新 浪 体 育 ) [SEP] 25 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7942, 6044, 8349, 3959, 782, 7674, 1355, 8038, 4906, 3683, 2372, 839, 2773, 924, 5384, 1217, 5164, 2209, 3131, 6604, 722, 2773, 3173, 3857, 860, 5509, 6380, 1266, 776, 3198, 7313, 125, 3299, 8149, 3189, 8024, 8391, 2108, 1400, 6612, 7674, 6762, 3821, 3329, 4768, 3959, 782, 712, 1767, 6816, 2773, 3173, 1952, 2209, 5679, 7942, 6044, 8024, 3634, 1184, 4638, 3683, 6612, 704, 8024, 1352, 3175, 2773, 2768, 123, 118, 123, 2398, 8024, 1728, 3634, 3315, 1767, 3683, 6612, 2190, 754, 697, 3118, 4413, 7339, 3341, 6432, 6963, 7478, 2382, 7028, 6206, 8024, 6612, 1184, 1352, 3175, 738, 1062, 2357, 749, 7674, 1355, 7347, 2159, 8038, 3959, 782, 7339, 8038, 6589, 5650, 2209, 510, 4906, 3683, 510, 7350, 3805, 3172, 4294, 510, 1217, 5164, 2209, 510, 2876, 5287, 1990, 7942, 6044, 7339, 8038, 924, 5384, 510, 6564, 7027, 1079, 1164, 510, 7350, 7027, 2799, 510, 1065, 2548, 7027, 510, 1952, 1305, 4886, 138, 3173, 3857, 8391, 2135, 3175, 2544, 1300, 140, 138, 3173, 3857, 8391, 3959, 782, 3173, 7319, 1220, 2578, 2544, 1300, 140, 138, 3173, 3857, 8391, 683, 7579, 140, 138, 7942, 6044, 8349, 3959, 782, 1745, 3152, 4684, 3064, 2147, 140, 113, 3173, 3857, 860, 5509, 114, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 26 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 27 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 28 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 29 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 84 - *** dev_example-1 *** 30 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 1. 7 秒 神 之 一 击 救 马 刺 王 朝 于 危 难 这 个 新 秀 有 点 牛 ! 新 浪 体 育 讯 在 刚 刚 结 束 的 比 赛 中 , 回 到 主 场 的 马 刺 通 过 加 时 以 110 - 103 惊 险 地 战 胜 了 灰 熊 , 避 免 了 让 主 场 观 众 见 证 黑 八 的 尴 尬 。 在 常 规 时 间 的 最 后 关 头 , 加 里 - 尼 尔 命 中 一 记 大 号 三 分 , 这 个 进 球 也 帮 助 马 刺 把 比 赛 带 进 了 加 时 , 并 最 终 翻 盘 成 功 。 [UNK] 波 波 维 奇 教 练 安 排 得 很 详 细 , 他 告 诉 我 有 机 会 就 要 出 手 。 邓 肯 的 掩 护 做 得 很 好 , 他 帮 我 完 全 挡 住 了 对 方 的 防 守 球 员 , 我 投 篮 的 视 野 非 常 好 , 于 是 我 就 出 手 了 。 [UNK] 即 使 命 中 了 这 记 价 值 连 城 的 球 , 尼 尔 依 然 保 持 着 低 调 。 在 被 问 及 这 是 不 是 他 职 业 生 涯 最 重 要 的 投 篮 时 , 他 说 : [UNK] 没 错 , 到 [SEP] 31 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 122, 119, 128, 4907, 4868, 722, 671, 1140, 3131, 7716, 1173, 4374, 3308, 754, 1314, 7410, 6821, 702, 3173, 4899, 3300, 4157, 4281, 8013, 3173, 3857, 860, 5509, 6380, 1762, 1157, 1157, 5310, 3338, 4638, 3683, 6612, 704, 8024, 1726, 1168, 712, 1767, 4638, 7716, 1173, 6858, 6814, 1217, 3198, 809, 8406, 118, 8615, 2661, 7372, 1765, 2773, 5526, 749, 4129, 4220, 8024, 6912, 1048, 749, 6375, 712, 1767, 6225, 830, 6224, 6395, 7946, 1061, 4638, 2219, 2217, 511, 1762, 2382, 6226, 3198, 7313, 4638, 3297, 1400, 1068, 1928, 8024, 1217, 7027, 118, 2225, 2209, 1462, 704, 671, 6381, 1920, 1384, 676, 1146, 8024, 6821, 702, 6822, 4413, 738, 2376, 1221, 7716, 1173, 2828, 3683, 6612, 2372, 6822, 749, 1217, 3198, 8024, 2400, 3297, 5303, 5436, 4669, 2768, 1216, 511, 100, 3797, 3797, 5335, 1936, 3136, 5298, 2128, 2961, 2533, 2523, 6422, 5301, 8024, 800, 1440, 6401, 2769, 3300, 3322, 833, 2218, 6206, 1139, 2797, 511, 6924, 5507, 4638, 2973, 2844, 976, 2533, 2523, 1962, 8024, 800, 2376, 2769, 2130, 1059, 2913, 857, 749, 2190, 3175, 4638, 7344, 2127, 4413, 1447, 8024, 2769, 2832, 5074, 4638, 6228, 7029, 7478, 2382, 1962, 8024, 754, 3221, 2769, 2218, 1139, 2797, 749, 511, 100, 1315, 886, 1462, 704, 749, 6821, 6381, 817, 966, 6825, 1814, 4638, 4413, 8024, 2225, 2209, 898, 4197, 924, 2898, 4708, 856, 6444, 511, 1762, 6158, 7309, 1350, 6821, 3221, 679, 3221, 800, 5466, 689, 4495, 3889, 3297, 7028, 6206, 4638, 2832, 5074, 3198, 8024, 800, 6432, 8038, 100, 3766, 7231, 8024, 1168, 102] 32 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 33 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 34 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 35 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 84 - *** dev_example-2 *** 36 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 1 人 灭 掘 金 ! 神 般 杜 兰 特 ! 他 想 要 分 的 时 候 没 人 能 挡 新 浪 体 育 讯 在 nba 的 世 界 里 , 真 的 猛 男 , 敢 于 直 面 惨 淡 的 手 感 , 敢 于 正 视 落 后 的 局 面 , 然 后 用 一 己 之 力 , 力 挽 狂 澜 , 点 燃 球 迷 激 情 , 最 后 微 微 一 笑 , 带 领 球 队 在 季 后 赛 的 战 场 上 赢 下 比 赛 , 并 进 入 下 一 轮 , 今 日 雷 霆 凯 文 - 杜 兰 特 所 做 的 , 无 非 就 是 这 样 的 事 情 。 巨 星 这 个 东 西 很 难 定 义 , 有 时 候 你 就 是 30 分 30 板 也 未 必 能 得 到 一 个 巨 星 名 头 , 反 而 会 有 可 能 会 被 称 为 刷 子 , 而 巨 星 不 仅 是 数 据 上 能 够 出 类 拔 萃 , 也 不 仅 仅 是 能 够 帮 助 球 队 赢 球 , 从 意 志 力 层 面 上 来 讲 , 母 队 比 赛 快 输 了 , 队 友 的 腿 都 开 始 抖 了 , 所 有 人 都 在 看 着 你 , 球 都 到 了 你 [SEP] 37 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 122, 782, 4127, 2963, 7032, 8013, 4868, 5663, 3336, 1065, 4294, 8013, 800, 2682, 6206, 1146, 4638, 3198, 952, 3766, 782, 5543, 2913, 3173, 3857, 860, 5509, 6380, 1762, 8391, 4638, 686, 4518, 7027, 8024, 4696, 4638, 4338, 4511, 8024, 3140, 754, 4684, 7481, 2673, 3909, 4638, 2797, 2697, 8024, 3140, 754, 3633, 6228, 5862, 1400, 4638, 2229, 7481, 8024, 4197, 1400, 4500, 671, 2346, 722, 1213, 8024, 1213, 2924, 4312, 4073, 8024, 4157, 4234, 4413, 6837, 4080, 2658, 8024, 3297, 1400, 2544, 2544, 671, 5010, 8024, 2372, 7566, 4413, 7339, 1762, 2108, 1400, 6612, 4638, 2773, 1767, 677, 6617, 678, 3683, 6612, 8024, 2400, 6822, 1057, 678, 671, 6762, 8024, 791, 3189, 7440, 7447, 1132, 3152, 118, 3336, 1065, 4294, 2792, 976, 4638, 8024, 3187, 7478, 2218, 3221, 6821, 3416, 4638, 752, 2658, 511, 2342, 3215, 6821, 702, 691, 6205, 2523, 7410, 2137, 721, 8024, 3300, 3198, 952, 872, 2218, 3221, 8114, 1146, 8114, 3352, 738, 3313, 2553, 5543, 2533, 1168, 671, 702, 2342, 3215, 1399, 1928, 8024, 1353, 5445, 833, 3300, 1377, 5543, 833, 6158, 4917, 711, 1170, 2094, 8024, 5445, 2342, 3215, 679, 788, 3221, 3144, 2945, 677, 5543, 1916, 1139, 5102, 2869, 5842, 8024, 738, 679, 788, 788, 3221, 5543, 1916, 2376, 1221, 4413, 7339, 6617, 4413, 8024, 794, 2692, 2562, 1213, 2231, 7481, 677, 3341, 6382, 8024, 3678, 7339, 3683, 6612, 2571, 6783, 749, 8024, 7339, 1351, 4638, 5597, 6963, 2458, 1993, 2833, 749, 8024, 2792, 3300, 782, 6963, 1762, 4692, 4708, 872, 8024, 4413, 6963, 1168, 749, 872, 102] 38 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 39 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 40 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 41 | 2021-07-16 16:56:29,964 - INFO - preprocess.py - convert_examples_to_features - 121 - Build 5000 features 42 | 2021-07-16 16:56:30,087 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 10000 examples to features 43 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 84 - *** test_example-0 *** 44 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 鲍 勃 库 西 奖 归 谁 属 ? ncaa 最 强 控 卫 是 坎 巴 还 是 弗 神 新 浪 体 育 讯 如 今 , 本 赛 季 的 ncaa 进 入 到 了 末 段 , 各 项 奖 项 的 评 选 结 果 也 即 将 出 炉 , 其 中 评 选 最 佳 控 卫 的 鲍 勃 - 库 西 奖 就 将 在 下 周 最 终 四 强 战 时 公 布 , 鲍 勃 - 库 西 奖 是 由 奈 史 密 斯 篮 球 名 人 堂 提 供 , 旨 在 奖 励 年 度 最 佳 大 学 控 卫 。 最 终 获 奖 的 球 员 也 即 将 在 以 下 几 名 热 门 人 选 中 产 生 。 〈 〈 〈 ncaa 疯 狂 三 月 专 题 主 页 上 线 , 点 击 链 接 查 看 精 彩 内 容 吉 梅 尔 - 弗 雷 戴 特 , 杨 百 翰 大 学 [UNK] 弗 神 [UNK] 吉 梅 尔 - 弗 雷 戴 特 一 直 都 备 受 关 注 , 他 不 仅 仅 是 一 名 射 手 , 他 会 用 [UNK] 终 结 对 手 脚 踝 [UNK] 一 样 的 变 向 过 掉 面 前 的 防 守 者 , 并 且 他 可 以 用 任 意 一 支 手 完 成 得 分 , [SEP] 45 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7828, 1234, 2417, 6205, 1946, 2495, 6443, 2247, 8043, 12394, 10226, 3297, 2487, 2971, 1310, 3221, 1775, 2349, 6820, 3221, 2472, 4868, 3173, 3857, 860, 5509, 6380, 1963, 791, 8024, 3315, 6612, 2108, 4638, 12394, 10226, 6822, 1057, 1168, 749, 3314, 3667, 8024, 1392, 7555, 1946, 7555, 4638, 6397, 6848, 5310, 3362, 738, 1315, 2199, 1139, 4140, 8024, 1071, 704, 6397, 6848, 3297, 881, 2971, 1310, 4638, 7828, 1234, 118, 2417, 6205, 1946, 2218, 2199, 1762, 678, 1453, 3297, 5303, 1724, 2487, 2773, 3198, 1062, 2357, 8024, 7828, 1234, 118, 2417, 6205, 1946, 3221, 4507, 1937, 1380, 2166, 3172, 5074, 4413, 1399, 782, 1828, 2990, 897, 8024, 3192, 1762, 1946, 1225, 2399, 2428, 3297, 881, 1920, 2110, 2971, 1310, 511, 3297, 5303, 5815, 1946, 4638, 4413, 1447, 738, 1315, 2199, 1762, 809, 678, 1126, 1399, 4178, 7305, 782, 6848, 704, 772, 4495, 511, 515, 515, 515, 12394, 10226, 4556, 4312, 676, 3299, 683, 7579, 712, 7552, 677, 5296, 8024, 4157, 1140, 7216, 2970, 3389, 4692, 5125, 2506, 1079, 2159, 1395, 3449, 2209, 118, 2472, 7440, 2785, 4294, 8024, 3342, 4636, 5432, 1920, 2110, 100, 2472, 4868, 100, 1395, 3449, 2209, 118, 2472, 7440, 2785, 4294, 671, 4684, 6963, 1906, 1358, 1068, 3800, 8024, 800, 679, 788, 788, 3221, 671, 1399, 2198, 2797, 8024, 800, 833, 4500, 100, 5303, 5310, 2190, 2797, 5558, 6674, 100, 671, 3416, 4638, 1359, 1403, 6814, 2957, 7481, 1184, 4638, 7344, 2127, 5442, 8024, 2400, 684, 800, 1377, 809, 4500, 818, 2692, 671, 3118, 2797, 2130, 2768, 2533, 1146, 8024, 102] 46 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 47 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 48 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 49 | 2021-07-16 16:56:30,134 - INFO - preprocess.py - convert_bert_example - 84 - *** test_example-1 *** 50 | 2021-07-16 16:56:30,134 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 麦 基 砍 28 + 18 + 5 却 充 满 寂 寞 纪 录 之 夜 他 的 痛 阿 联 最 懂 新 浪 体 育 讯 上 天 对 每 个 人 都 是 公 平 的 , 贾 维 尔 - 麦 基 也 不 例 外 。 今 天 华 盛 顿 奇 才 客 场 104 - 114 负 于 金 州 勇 士 , 麦 基 就 好 不 容 易 等 到 [UNK] 捏 软 柿 子 [UNK] 的 机 会 , 上 半 场 打 出 现 象 级 表 现 , 只 可 惜 无 法 一 以 贯 之 。 最 终 , 麦 基 12 投 9 中 , 得 到 生 涯 最 高 的 28 分 , 以 及 平 生 涯 最 佳 的 18 个 篮 板 , 另 有 5 次 封 盖 。 此 外 , 他 11 次 罚 球 命 中 10 个 , 这 两 项 也 均 为 生 涯 最 高 。 如 果 在 赛 前 搞 个 竞 猜 , 上 半 场 谁 会 是 奇 才 阵 中 罚 球 次 数 最 多 的 球 员 ? 若 有 人 答 曰 [UNK] 麦 基 [UNK] , 不 是 恶 搞 就 是 脑 残 。 但 半 场 结 束 , 麦 基 竟 砍 下 22 分 ( 第 二 节 砍 下 14 分 ) 。 更 罕 见 的 , 则 是 [SEP] 51 | 2021-07-16 16:56:30,134 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7931, 1825, 4775, 8143, 116, 8123, 116, 126, 1316, 1041, 4007, 2163, 2174, 5279, 2497, 722, 1915, 800, 4638, 4578, 7350, 5468, 3297, 2743, 3173, 3857, 860, 5509, 6380, 677, 1921, 2190, 3680, 702, 782, 6963, 3221, 1062, 2398, 4638, 8024, 6593, 5335, 2209, 118, 7931, 1825, 738, 679, 891, 1912, 511, 791, 1921, 1290, 4670, 7561, 1936, 2798, 2145, 1767, 8503, 118, 8866, 6566, 754, 7032, 2336, 1235, 1894, 8024, 7931, 1825, 2218, 1962, 679, 2159, 3211, 5023, 1168, 100, 2934, 6763, 3398, 2094, 100, 4638, 3322, 833, 8024, 677, 1288, 1767, 2802, 1139, 4385, 6496, 5277, 6134, 4385, 8024, 1372, 1377, 2667, 3187, 3791, 671, 809, 6581, 722, 511, 3297, 5303, 8024, 7931, 1825, 8110, 2832, 130, 704, 8024, 2533, 1168, 4495, 3889, 3297, 7770, 4638, 8143, 1146, 8024, 809, 1350, 2398, 4495, 3889, 3297, 881, 4638, 8123, 702, 5074, 3352, 8024, 1369, 3300, 126, 3613, 2196, 4667, 511, 3634, 1912, 8024, 800, 8111, 3613, 5385, 4413, 1462, 704, 8108, 702, 8024, 6821, 697, 7555, 738, 1772, 711, 4495, 3889, 3297, 7770, 511, 1963, 3362, 1762, 6612, 1184, 3018, 702, 4993, 4339, 8024, 677, 1288, 1767, 6443, 833, 3221, 1936, 2798, 7347, 704, 5385, 4413, 3613, 3144, 3297, 1914, 4638, 4413, 1447, 8043, 5735, 3300, 782, 5031, 3288, 100, 7931, 1825, 100, 8024, 679, 3221, 2626, 3018, 2218, 3221, 5554, 3655, 511, 852, 1288, 1767, 5310, 3338, 8024, 7931, 1825, 4994, 4775, 678, 8130, 1146, 113, 5018, 753, 5688, 4775, 678, 8122, 1146, 114, 511, 3291, 5383, 6224, 4638, 8024, 1156, 3221, 102] 52 | 2021-07-16 16:56:30,135 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 53 | 2021-07-16 16:56:30,135 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 54 | 2021-07-16 16:56:30,135 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 55 | 2021-07-16 16:56:30,138 - INFO - preprocess.py - convert_bert_example - 84 - *** test_example-2 *** 56 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 黄 蜂 vs 湖 人 首 发 : 科 比 冲 击 七 连 胜 火 箭 两 旧 将 登 场 新 浪 体 育 讯 北 京 时 间 3 月 28 日 , nba 常 规 赛 洛 杉 矶 湖 人 主 场 迎 战 新 奥 尔 良 黄 蜂 , 赛 前 双 方 也 公 布 了 首 发 阵 容 : 点 击 进 入 新 浪 体 育 视 频 直 播 室 点 击 进 入 新 浪 体 育 图 文 直 播 室 点 击 进 入 新 浪 体 育 nba 专 题 点 击 进 入 新 浪 nba 官 方 微 博 双 方 首 发 阵 容 : 湖 人 队 : 德 里 克 - 费 舍 尔 、 科 比 - 布 莱 恩 特 、 罗 恩 - 阿 泰 斯 特 、 保 罗 - 加 索 尔 、 安 德 鲁 - 拜 纳 姆 黄 蜂 队 : 克 里 斯 - 保 罗 、 马 科 - 贝 里 内 利 、 特 雷 沃 - 阿 里 扎 、 卡 尔 - 兰 德 里 、 埃 梅 卡 - 奥 卡 福 ( 新 浪 体 育 ) [SEP] 57 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7942, 6044, 8349, 3959, 782, 7674, 1355, 8038, 4906, 3683, 1103, 1140, 673, 6825, 5526, 4125, 5055, 697, 3191, 2199, 4633, 1767, 3173, 3857, 860, 5509, 6380, 1266, 776, 3198, 7313, 124, 3299, 8143, 3189, 8024, 8391, 2382, 6226, 6612, 3821, 3329, 4768, 3959, 782, 712, 1767, 6816, 2773, 3173, 1952, 2209, 5679, 7942, 6044, 8024, 6612, 1184, 1352, 3175, 738, 1062, 2357, 749, 7674, 1355, 7347, 2159, 8038, 4157, 1140, 6822, 1057, 3173, 3857, 860, 5509, 6228, 7574, 4684, 3064, 2147, 4157, 1140, 6822, 1057, 3173, 3857, 860, 5509, 1745, 3152, 4684, 3064, 2147, 4157, 1140, 6822, 1057, 3173, 3857, 860, 5509, 8391, 683, 7579, 4157, 1140, 6822, 1057, 3173, 3857, 8391, 2135, 3175, 2544, 1300, 1352, 3175, 7674, 1355, 7347, 2159, 8038, 3959, 782, 7339, 8038, 2548, 7027, 1046, 118, 6589, 5650, 2209, 510, 4906, 3683, 118, 2357, 5812, 2617, 4294, 510, 5384, 2617, 118, 7350, 3805, 3172, 4294, 510, 924, 5384, 118, 1217, 5164, 2209, 510, 2128, 2548, 7826, 118, 2876, 5287, 1990, 7942, 6044, 7339, 8038, 1046, 7027, 3172, 118, 924, 5384, 510, 7716, 4906, 118, 6564, 7027, 1079, 1164, 510, 4294, 7440, 3753, 118, 7350, 7027, 2799, 510, 1305, 2209, 118, 1065, 2548, 7027, 510, 1812, 3449, 1305, 118, 1952, 1305, 4886, 113, 3173, 3857, 860, 5509, 114, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 58 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 59 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 60 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}] 61 | 2021-07-16 16:58:00,802 - INFO - preprocess.py - convert_examples_to_features - 121 - Build 10000 features 62 | 2021-07-16 17:01:09,034 - INFO - preprocess.py - - 150 - {'output_dir': '../checkpoints/', 'bert_dir': '../model_hub/bert-base-chinese/', 'data_dir': '../data/cnews/', 'log_dir': './logs/', 'num_tags': 65, 'seed': 123, 'gpu_ids': '0', 'max_seq_len': 256, 'eval_batch_size': 12, 'swa_start': 3, 'train_epochs': 15, 'dropout_prob': 0.1, 'lr': 3e-05, 'other_lr': 0.0003, 'max_grad_norm': 1, 'warmup_proportion': 0.1, 'weight_decay': 0.01, 'adam_epsilon': 1e-08, 'train_batch_size': 32, 'eval_model': True} 63 | 2021-07-16 17:01:09,505 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 50000 examples to features 64 | 2021-07-16 17:01:09,512 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-0 *** 65 | 2021-07-16 17:01:09,512 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 马 晓 旭 意 外 受 伤 让 国 奥 警 惕 无 奈 大 雨 格 外 青 睐 殷 家 军 记 者 傅 亚 雨 沈 阳 报 道 来 到 沈 阳 , 国 奥 队 依 然 没 有 摆 脱 雨 水 的 困 扰 。 7 月 31 日 下 午 6 点 , 国 奥 队 的 日 常 训 练 再 度 受 到 大 雨 的 干 扰 , 无 奈 之 下 队 员 们 只 慢 跑 了 25 分 钟 就 草 草 收 场 。 31 日 上 午 10 点 , 国 奥 队 在 奥 体 中 心 外 场 训 练 的 时 候 , 天 就 是 阴 沉 沉 的 , 气 象 预 报 显 示 当 天 下 午 沈 阳 就 有 大 雨 , 但 幸 好 队 伍 上 午 的 训 练 并 没 有 受 到 任 何 干 扰 。 下 午 6 点 , 当 球 队 抵 达 训 练 场 时 , 大 雨 已 经 下 了 几 个 小 时 , 而 且 丝 毫 没 有 停 下 来 的 意 思 。 抱 着 试 一 试 的 态 度 , 球 队 开 始 了 当 天 下 午 的 例 行 训 练 , 25 分 钟 过 去 了 , 天 气 没 有 任 何 转 好 的 迹 象 , 为 了 保 护 球 [SEP] 66 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7716, 3236, 3195, 2692, 1912, 1358, 839, 6375, 1744, 1952, 6356, 2664, 3187, 1937, 1920, 7433, 3419, 1912, 7471, 4712, 3668, 2157, 1092, 6381, 5442, 987, 762, 7433, 3755, 7345, 2845, 6887, 3341, 1168, 3755, 7345, 8024, 1744, 1952, 7339, 898, 4197, 3766, 3300, 3030, 5564, 7433, 3717, 4638, 1737, 2817, 511, 128, 3299, 8176, 3189, 678, 1286, 127, 4157, 8024, 1744, 1952, 7339, 4638, 3189, 2382, 6378, 5298, 1086, 2428, 1358, 1168, 1920, 7433, 4638, 2397, 2817, 8024, 3187, 1937, 722, 678, 7339, 1447, 812, 1372, 2714, 6651, 749, 8132, 1146, 7164, 2218, 5770, 5770, 3119, 1767, 511, 8176, 3189, 677, 1286, 8108, 4157, 8024, 1744, 1952, 7339, 1762, 1952, 860, 704, 2552, 1912, 1767, 6378, 5298, 4638, 3198, 952, 8024, 1921, 2218, 3221, 7346, 3756, 3756, 4638, 8024, 3698, 6496, 7564, 2845, 3227, 4850, 2496, 1921, 678, 1286, 3755, 7345, 2218, 3300, 1920, 7433, 8024, 852, 2401, 1962, 7339, 824, 677, 1286, 4638, 6378, 5298, 2400, 3766, 3300, 1358, 1168, 818, 862, 2397, 2817, 511, 678, 1286, 127, 4157, 8024, 2496, 4413, 7339, 2850, 6809, 6378, 5298, 1767, 3198, 8024, 1920, 7433, 2347, 5307, 678, 749, 1126, 702, 2207, 3198, 8024, 5445, 684, 692, 3690, 3766, 3300, 977, 678, 3341, 4638, 2692, 2590, 511, 2849, 4708, 6407, 671, 6407, 4638, 2578, 2428, 8024, 4413, 7339, 2458, 1993, 749, 2496, 1921, 678, 1286, 4638, 891, 6121, 6378, 5298, 8024, 8132, 1146, 7164, 6814, 1343, 749, 8024, 1921, 3698, 3766, 3300, 818, 862, 6760, 1962, 4638, 6839, 6496, 8024, 711, 749, 924, 2844, 4413, 102] 67 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 68 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 69 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 89 - labels: [6] 70 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-1 *** 71 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 商 瑞 华 首 战 复 仇 心 切 中 国 玫 瑰 要 用 美 国 方 式 攻 克 瑞 典 多 曼 来 了 , 瑞 典 来 了 , 商 瑞 华 首 战 求 3 分 的 信 心 也 来 了 。 距 离 首 战 72 小 时 当 口 , 中 国 女 足 彻 底 从 [UNK] 恐 瑞 症 [UNK] 当 中 获 得 解 脱 , 因 为 商 瑞 华 已 经 找 到 了 瑞 典 人 的 软 肋 。 找 到 软 肋 , 保 密 4 月 20 日 奥 运 会 分 组 抽 签 结 果 出 来 后 , 中 国 姑 娘 就 把 瑞 典 锁 定 为 关 乎 奥 运 成 败 的 头 号 劲 敌 , 因 为 除 了 浦 玮 等 个 别 老 将 之 外 , 现 役 女 足 将 士 竟 然 没 有 人 尝 过 击 败 瑞 典 的 滋 味 。 在 中 瑞 两 队 共 计 15 次 交 锋 的 历 史 上 , 中 国 队 6 胜 3 平 6 负 与 瑞 典 队 平 分 秋 色 , 但 从 2001 年 起 至 今 近 8 年 时 间 , 中 国 在 同 瑞 典 连 续 5 次 交 锋 中 均 未 尝 胜 绩 , 战 绩 为 2 平 3 负 。 尽 管 八 年 [SEP] 72 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1555, 4448, 1290, 7674, 2773, 1908, 790, 2552, 1147, 704, 1744, 4382, 4456, 6206, 4500, 5401, 1744, 3175, 2466, 3122, 1046, 4448, 1073, 1914, 3294, 3341, 749, 8024, 4448, 1073, 3341, 749, 8024, 1555, 4448, 1290, 7674, 2773, 3724, 124, 1146, 4638, 928, 2552, 738, 3341, 749, 511, 6655, 4895, 7674, 2773, 8325, 2207, 3198, 2496, 1366, 8024, 704, 1744, 1957, 6639, 2515, 2419, 794, 100, 2607, 4448, 4568, 100, 2496, 704, 5815, 2533, 6237, 5564, 8024, 1728, 711, 1555, 4448, 1290, 2347, 5307, 2823, 1168, 749, 4448, 1073, 782, 4638, 6763, 5490, 511, 2823, 1168, 6763, 5490, 8024, 924, 2166, 125, 3299, 8113, 3189, 1952, 6817, 833, 1146, 5299, 2853, 5041, 5310, 3362, 1139, 3341, 1400, 8024, 704, 1744, 1996, 2023, 2218, 2828, 4448, 1073, 7219, 2137, 711, 1068, 725, 1952, 6817, 2768, 6571, 4638, 1928, 1384, 1226, 3127, 8024, 1728, 711, 7370, 749, 3855, 4383, 5023, 702, 1166, 5439, 2199, 722, 1912, 8024, 4385, 2514, 1957, 6639, 2199, 1894, 4994, 4197, 3766, 3300, 782, 2214, 6814, 1140, 6571, 4448, 1073, 4638, 3996, 1456, 511, 1762, 704, 4448, 697, 7339, 1066, 6369, 8115, 3613, 769, 7226, 4638, 1325, 1380, 677, 8024, 704, 1744, 7339, 127, 5526, 124, 2398, 127, 6566, 680, 4448, 1073, 7339, 2398, 1146, 4904, 5682, 8024, 852, 794, 8285, 2399, 6629, 5635, 791, 6818, 129, 2399, 3198, 7313, 8024, 704, 1744, 1762, 1398, 4448, 1073, 6825, 5330, 126, 3613, 769, 7226, 704, 1772, 3313, 2214, 5526, 5327, 8024, 2773, 5327, 711, 123, 2398, 124, 6566, 511, 2226, 5052, 1061, 2399, 102] 73 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 74 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 75 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 89 - labels: [6] 76 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-2 *** 77 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 冠 军 球 队 迎 新 欢 乐 派 对 黄 旭 获 大 奖 张 军 赢 下 pk 赛 新 浪 体 育 讯 12 月 27 日 晚 , [UNK] 冠 军 高 尔 夫 球 队 迎 新 高 球 欢 乐 派 对 [UNK] 活 动 在 北 京 都 市 名 人 高 尔 夫 俱 乐 部 举 行 。 邢 傲 伟 、 黄 旭 、 邹 凯 、 滕 海 滨 、 杨 凌 、 马 燕 红 、 张 军 、 王 丽 萍 、 杨 影 、 阎 森 、 毕 文 静 、 夏 煊 泽 、 张 璇 、 叶 乔 波 、 莫 慧 兰 、 周 雅 菲 、 胡 妮 、 戴 菲 菲 、 马 健 等 奥 运 冠 军 、 世 界 冠 军 参 加 了 当 晚 的 活 动 。 [UNK] 以 球 会 友 、 联 谊 交 流 [UNK] 为 宗 旨 的 冠 军 高 尔 夫 球 队 , 从 创 立 之 初 就 秉 承 关 爱 社 会 、 回 馈 社 会 的 价 值 观 。 此 次 迎 新 派 对 活 动 的 举 行 , 标 志 着 这 个 为 中 国 金 牌 运 动 员 提 供 的 放 松 身 心 、 以 球 会 友 的 平 台 正 式 全 面 启 动 。 球 队 已 经 与 都 市 名 人 高 [SEP] 78 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1094, 1092, 4413, 7339, 6816, 3173, 3614, 727, 3836, 2190, 7942, 3195, 5815, 1920, 1946, 2476, 1092, 6617, 678, 8465, 6612, 3173, 3857, 860, 5509, 6380, 8110, 3299, 8149, 3189, 3241, 8024, 100, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 6816, 3173, 7770, 4413, 3614, 727, 3836, 2190, 100, 3833, 1220, 1762, 1266, 776, 6963, 2356, 1399, 782, 7770, 2209, 1923, 936, 727, 6956, 715, 6121, 511, 6928, 1000, 836, 510, 7942, 3195, 510, 6941, 1132, 510, 4001, 3862, 4012, 510, 3342, 1119, 510, 7716, 4242, 5273, 510, 2476, 1092, 510, 4374, 714, 5847, 510, 3342, 2512, 510, 7330, 3481, 510, 3684, 3152, 7474, 510, 1909, 4201, 3813, 510, 2476, 4462, 510, 1383, 730, 3797, 510, 5811, 2716, 1065, 510, 1453, 7414, 5838, 510, 5529, 1984, 510, 2785, 5838, 5838, 510, 7716, 978, 5023, 1952, 6817, 1094, 1092, 510, 686, 4518, 1094, 1092, 1346, 1217, 749, 2496, 3241, 4638, 3833, 1220, 511, 100, 809, 4413, 833, 1351, 510, 5468, 6449, 769, 3837, 100, 711, 2134, 3192, 4638, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 8024, 794, 1158, 4989, 722, 1159, 2218, 4903, 2824, 1068, 4263, 4852, 833, 510, 1726, 7668, 4852, 833, 4638, 817, 966, 6225, 511, 3634, 3613, 6816, 3173, 3836, 2190, 3833, 1220, 4638, 715, 6121, 8024, 3403, 2562, 4708, 6821, 702, 711, 704, 1744, 7032, 4277, 6817, 1220, 1447, 2990, 897, 4638, 3123, 3351, 6716, 2552, 510, 809, 4413, 833, 1351, 4638, 2398, 1378, 3633, 2466, 1059, 7481, 1423, 1220, 511, 4413, 7339, 2347, 5307, 680, 6963, 2356, 1399, 782, 7770, 102] 79 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 80 | 2021-07-16 17:01:09,548 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 81 | 2021-07-16 17:01:09,548 - INFO - preprocess.py - convert_bert_example - 89 - labels: [6] 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(r"./") 4 | """ 5 | 该文件使用的data_loader.py里面的数据加载方式。 6 | """ 7 | # coding=utf-8 8 | import json 9 | import random 10 | from pprint import pprint 11 | import os 12 | import logging 13 | import shutil 14 | from sklearn.metrics import accuracy_score, f1_score, classification_report 15 | import torch 16 | import torch.nn as nn 17 | import numpy as np 18 | import pickle 19 | from torch.utils.data import DataLoader, RandomSampler 20 | from transformers import BertTokenizer 21 | 22 | import bert_config 23 | import models 24 | from utils import utils 25 | from data_loader import CNEWSDataset, Collate, CPWSDataset 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class Trainer: 31 | def __init__(self, args, train_loader, dev_loader, test_loader, device, model, optimizer): 32 | self.args = args 33 | self.device = device 34 | self.model = model 35 | self.optimizer = optimizer 36 | self.criterion = nn.CrossEntropyLoss() 37 | self.train_loader = train_loader 38 | self.dev_loader = dev_loader 39 | self.test_loader = test_loader 40 | self.model.to(self.device) 41 | 42 | def load_ckp(self, model, optimizer, checkpoint_path): 43 | checkpoint = torch.load(checkpoint_path) 44 | model.load_state_dict(checkpoint['state_dict']) 45 | optimizer.load_state_dict(checkpoint['optimizer']) 46 | epoch = checkpoint['epoch'] 47 | loss = checkpoint['loss'] 48 | return model, optimizer, epoch, loss 49 | 50 | def load_model(self, model, checkpoint_path): 51 | checkpoint = torch.load(checkpoint_path) 52 | model.load_state_dict(checkpoint['state_dict']) 53 | return model 54 | 55 | def save_ckp(self, state, checkpoint_path): 56 | torch.save(state, checkpoint_path) 57 | 58 | """ 59 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path): 60 | tmp_checkpoint_path = checkpoint_path 61 | torch.save(state, tmp_checkpoint_path) 62 | if is_best: 63 | tmp_best_model_path = best_model_path 64 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path) 65 | """ 66 | 67 | def train(self): 68 | total_step = len(self.train_loader) * self.args.train_epochs 69 | global_step = 0 70 | eval_step = 100 71 | best_dev_micro_f1 = 0.0 72 | for epoch in range(args.train_epochs): 73 | for train_step, train_data in enumerate(self.train_loader): 74 | self.model.train() 75 | token_ids = train_data['token_ids'].to(self.device) 76 | attention_masks = train_data['attention_masks'].to(self.device) 77 | token_type_ids = train_data['token_type_ids'].to(self.device) 78 | labels = train_data['labels'].to(self.device) 79 | train_outputs = self.model(token_ids, attention_masks, token_type_ids) 80 | loss = self.criterion(train_outputs, labels) 81 | self.optimizer.zero_grad() 82 | loss.backward() 83 | self.optimizer.step() 84 | logger.info( 85 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss.item())) 86 | global_step += 1 87 | if global_step % eval_step == 0: 88 | dev_loss, dev_outputs, dev_targets = self.dev() 89 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets) 90 | logger.info( 91 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, accuracy, 92 | micro_f1, macro_f1)) 93 | if macro_f1 > best_dev_micro_f1: 94 | logger.info("------------>保存当前最好的模型") 95 | checkpoint = { 96 | 'epoch': epoch, 97 | 'loss': dev_loss, 98 | 'state_dict': self.model.state_dict(), 99 | 'optimizer': self.optimizer.state_dict(), 100 | } 101 | best_dev_micro_f1 = macro_f1 102 | save_path = os.path.join(self.args.output_dir, args.data_name) 103 | if not os.path.exists(save_path): 104 | os.makedirs(save_path) 105 | checkpoint_path = os.path.join(save_path, 'best.pt') 106 | self.save_ckp(checkpoint, checkpoint_path) 107 | 108 | def dev(self): 109 | self.model.eval() 110 | total_loss = 0.0 111 | dev_outputs = [] 112 | dev_targets = [] 113 | with torch.no_grad(): 114 | for dev_step, dev_data in enumerate(self.dev_loader): 115 | token_ids = dev_data['token_ids'].to(self.device) 116 | attention_masks = dev_data['attention_masks'].to(self.device) 117 | token_type_ids = dev_data['token_type_ids'].to(self.device) 118 | labels = dev_data['labels'].to(self.device) 119 | outputs = self.model(token_ids, attention_masks, token_type_ids) 120 | loss = self.criterion(outputs, labels) 121 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss) 122 | total_loss += loss.item() 123 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 124 | dev_outputs.extend(outputs.tolist()) 125 | dev_targets.extend(labels.cpu().detach().numpy().tolist()) 126 | 127 | return total_loss, dev_outputs, dev_targets 128 | 129 | def test(self, model): 130 | model.eval() 131 | model.to(self.device) 132 | total_loss = 0.0 133 | test_outputs = [] 134 | test_targets = [] 135 | with torch.no_grad(): 136 | for test_step, test_data in enumerate(self.test_loader): 137 | token_ids = test_data['token_ids'].to(self.device) 138 | attention_masks = test_data['attention_masks'].to(self.device) 139 | token_type_ids = test_data['token_type_ids'].to(self.device) 140 | labels = test_data['labels'].to(self.device) 141 | outputs = model(token_ids, attention_masks, token_type_ids) 142 | loss = self.criterion(outputs, labels) 143 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss) 144 | total_loss += loss.item() 145 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 146 | test_outputs.extend(outputs.tolist()) 147 | test_targets.extend(labels.cpu().detach().numpy().tolist()) 148 | 149 | return total_loss, test_outputs, test_targets 150 | 151 | def predict(self, tokenizer, text, id2label, args, model): 152 | model.eval() 153 | model.to(self.device) 154 | with torch.no_grad(): 155 | inputs = tokenizer.encode_plus(text=text, 156 | add_special_tokens=True, 157 | max_length=args.max_seq_len, 158 | truncation='longest_first', 159 | padding="max_length", 160 | return_token_type_ids=True, 161 | return_attention_mask=True, 162 | return_tensors='pt') 163 | token_ids = inputs['input_ids'].to(self.device) 164 | attention_masks = inputs['attention_mask'].to(self.device) 165 | token_type_ids = inputs['token_type_ids'].to(self.device) 166 | outputs = model(token_ids, attention_masks, token_type_ids) 167 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist() 168 | if len(outputs) != 0: 169 | outputs = [id2label[i] for i in outputs] 170 | return outputs 171 | else: 172 | return '不好意思,我没有识别出来' 173 | 174 | def get_metrics(self, outputs, targets): 175 | accuracy = accuracy_score(targets, outputs) 176 | micro_f1 = f1_score(targets, outputs, average='micro') 177 | macro_f1 = f1_score(targets, outputs, average='macro') 178 | return accuracy, micro_f1, macro_f1 179 | 180 | def get_classification_report(self, outputs, targets, labels): 181 | report = classification_report(targets, outputs, target_names=labels) 182 | return report 183 | 184 | 185 | datasets = { 186 | "cnews": CNEWSDataset, 187 | "cpws": CPWSDataset 188 | } 189 | 190 | train_files = { 191 | "cnews": "cnews.train.txt", 192 | "cpws": "train_data.txt" 193 | } 194 | 195 | test_files = { 196 | "cnews": "cnews.test.txt", 197 | "cpws": "test_data.txt" 198 | } 199 | 200 | 201 | def main(args, tokenizer, device): 202 | dataset = datasets.get(args.data_name, None) 203 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None) 204 | if dataset is None: 205 | raise Exception("请输入正确的数据集名称") 206 | label2id = {} 207 | id2label = {} 208 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp: 209 | labels = fp.read().strip().split('\n') 210 | for i, label in enumerate(labels): 211 | label2id[label] = i 212 | id2label[i] = label 213 | print(label2id) 214 | 215 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id) 216 | 217 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file)) 218 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, 219 | collate_fn=collate.collate_fn) 220 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file)) 221 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, 222 | collate_fn=collate.collate_fn) 223 | 224 | model = models.BertForSequenceClassification(args) 225 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 226 | 227 | if args.retrain: 228 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 229 | checkpoint = torch.load(checkpoint_path) 230 | model.load_state_dict(checkpoint['state_dict']) 231 | model.to(device) 232 | optimizer.load_state_dict(checkpoint['optimizer']) 233 | 234 | epoch = checkpoint['epoch'] 235 | loss = checkpoint['loss'] 236 | logger.info("加载模型继续训练,epoch:{} loss:{}".format(epoch, loss)) 237 | 238 | trainer = Trainer(args, train_loader, test_loader, test_loader, device, model, optimizer) 239 | 240 | if args.do_train: 241 | # 训练和验证 242 | trainer.train() 243 | 244 | # 测试 245 | if args.do_test: 246 | logger.info('========进行测试========') 247 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 248 | model = trainer.load_model(model, checkpoint_path) 249 | total_loss, test_outputs, test_targets = trainer.test(model) 250 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets) 251 | logger.info( 252 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, micro_f1, 253 | macro_f1)) 254 | report = trainer.get_classification_report(test_outputs, test_targets, labels) 255 | logger.info(report) 256 | 257 | # 预测 258 | if args.do_predict: 259 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 260 | model = trainer.load_model(model, checkpoint_path) 261 | line = test_dataset[0] 262 | text = line[0] 263 | print(text) 264 | result = trainer.predict(tokenizer, text, id2label, args, model) 265 | print("预测标签:", result[0]) 266 | print("真实标签:", line[1]) 267 | print("==========================") 268 | 269 | 270 | if __name__ == '__main__': 271 | args = bert_config.Args().get_parser() 272 | utils.set_seed(args.seed) 273 | utils.set_logger(os.path.join(args.log_dir, 'main.log')) 274 | 275 | # processor = preprocess.Processor() 276 | 277 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir) 278 | gpu_ids = args.gpu_ids.split(',') 279 | device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0]) 280 | 281 | main(args, tokenizer, device) 282 | -------------------------------------------------------------------------------- /main_dataparallel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(r"./") 4 | """ 5 | 该文件使用的data_loader.py里面的数据加载方式。 6 | """ 7 | # coding=utf-8 8 | import json 9 | import random 10 | from pprint import pprint 11 | import os 12 | import logging 13 | import shutil 14 | from sklearn.metrics import accuracy_score, f1_score, classification_report 15 | import torch 16 | import torch.nn as nn 17 | import numpy as np 18 | import pickle 19 | from torch.utils.data import DataLoader, RandomSampler 20 | from transformers import BertTokenizer 21 | 22 | import bert_config 23 | import models 24 | from utils import utils 25 | from data_loader import CNEWSDataset, Collate, CPWSDataset 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | args = bert_config.Args().get_parser() 30 | utils.set_seed(args.seed) 31 | utils.set_logger(os.path.join(args.log_dir, 'main.log')) 32 | 33 | gpu_ids = args.gpu_ids.split(',') 34 | gpu_ids = [int(i) for i in gpu_ids] 35 | torch.cuda.set_device('cuda:{}'.format(gpu_ids[0])) 36 | 37 | 38 | class Trainer: 39 | def __init__(self, args, optimizer): 40 | self.args = args 41 | self.optimizer = optimizer 42 | self.criterion = nn.CrossEntropyLoss() 43 | 44 | def load_ckp(self, model, optimizer, checkpoint_path): 45 | checkpoint = torch.load(checkpoint_path) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | optimizer.load_state_dict(checkpoint['optimizer']) 48 | epoch = checkpoint['epoch'] 49 | loss = checkpoint['loss'] 50 | return model, optimizer, epoch, loss 51 | 52 | def load_model(self, model, checkpoint_path): 53 | checkpoint = torch.load(checkpoint_path) 54 | # new_start_dict = {} 55 | # for k, v in checkpoint['state_dict'].items(): 56 | # new_start_dict["module." + k] = v 57 | # model.load_state_dict(new_start_dict) 58 | model.load_state_dict(checkpoint["state_dict"]) 59 | return model 60 | 61 | def save_ckp(self, state, checkpoint_path): 62 | torch.save(state, checkpoint_path) 63 | 64 | """ 65 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path): 66 | tmp_checkpoint_path = checkpoint_path 67 | torch.save(state, tmp_checkpoint_path) 68 | if is_best: 69 | tmp_best_model_path = best_model_path 70 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path) 71 | """ 72 | 73 | def train(self, model, train_loader, dev_loader=None): 74 | self.dev_loader = dev_loader 75 | self.model = model 76 | total_step = len(train_loader) * self.args.train_epochs 77 | global_step = 0 78 | eval_step = 100 79 | best_dev_micro_f1 = 0.0 80 | for epoch in range(self.args.train_epochs): 81 | for train_step, train_data in enumerate(train_loader): 82 | self.model.train() 83 | token_ids = train_data['token_ids'].cuda() 84 | attention_masks = train_data['attention_masks'].cuda() 85 | token_type_ids = train_data['token_type_ids'].cuda() 86 | labels = train_data['labels'].cuda() 87 | train_outputs = self.model(token_ids, attention_masks, token_type_ids) 88 | loss = self.criterion(train_outputs, labels) 89 | self.optimizer.zero_grad() 90 | loss.backward() 91 | self.optimizer.step() 92 | logger.info( 93 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss.item())) 94 | global_step += 1 95 | if global_step % eval_step == 0: 96 | dev_loss, dev_outputs, dev_targets = self.dev() 97 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets) 98 | logger.info( 99 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, accuracy, 100 | micro_f1, macro_f1)) 101 | if macro_f1 > best_dev_micro_f1: 102 | logger.info("------------>保存当前最好的模型") 103 | checkpoint = { 104 | 'epoch': epoch, 105 | 'loss': dev_loss, 106 | 'state_dict': self.model.state_dict(), 107 | 'optimizer': self.optimizer.state_dict(), 108 | } 109 | best_dev_micro_f1 = macro_f1 110 | save_path = os.path.join(self.args.output_dir, args.data_name) 111 | if not os.path.exists(save_path): 112 | os.makedirs(save_path) 113 | checkpoint_path = os.path.join(save_path, 'best.pt') 114 | self.save_ckp(checkpoint, checkpoint_path) 115 | 116 | def dev(self): 117 | self.model.eval() 118 | total_loss = 0.0 119 | dev_outputs = [] 120 | dev_targets = [] 121 | with torch.no_grad(): 122 | for dev_step, dev_data in enumerate(self.dev_loader): 123 | token_ids = dev_data['token_ids'].cuda() 124 | attention_masks = dev_data['attention_masks'].cuda() 125 | token_type_ids = dev_data['token_type_ids'].cuda() 126 | labels = dev_data['labels'].cuda() 127 | outputs = self.model(token_ids, attention_masks, token_type_ids) 128 | loss = self.criterion(outputs, labels) 129 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss) 130 | total_loss += loss.item() 131 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 132 | dev_outputs.extend(outputs.tolist()) 133 | dev_targets.extend(labels.cpu().detach().numpy().tolist()) 134 | 135 | return total_loss, dev_outputs, dev_targets 136 | 137 | def test(self, model, test_loader): 138 | model.eval() 139 | total_loss = 0.0 140 | test_outputs = [] 141 | test_targets = [] 142 | with torch.no_grad(): 143 | for test_step, test_data in enumerate(test_loader): 144 | token_ids = test_data['token_ids'].cuda() 145 | attention_masks = test_data['attention_masks'].cuda() 146 | token_type_ids = test_data['token_type_ids'].cuda() 147 | labels = test_data['labels'].cuda() 148 | outputs = model(token_ids, attention_masks, token_type_ids) 149 | 150 | loss = self.criterion(outputs, labels) 151 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss) 152 | total_loss += loss.item() 153 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 154 | test_outputs.extend(outputs.tolist()) 155 | test_targets.extend(labels.cpu().detach().numpy().tolist()) 156 | 157 | return total_loss, test_outputs, test_targets 158 | 159 | def predict(self, tokenizer, text, id2label, args, model): 160 | model.eval() 161 | with torch.no_grad(): 162 | inputs = tokenizer.encode_plus(text=text, 163 | add_special_tokens=True, 164 | max_length=args.max_seq_len, 165 | truncation='longest_first', 166 | padding="max_length", 167 | return_token_type_ids=True, 168 | return_attention_mask=True, 169 | return_tensors='pt') 170 | token_ids = inputs['input_ids'].cuda() 171 | attention_masks = inputs['attention_mask'].cuda() 172 | token_type_ids = inputs['token_type_ids'].cuda() 173 | outputs = model(token_ids, attention_masks, token_type_ids) 174 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist() 175 | if len(outputs) != 0: 176 | outputs = [id2label[i] for i in outputs] 177 | return outputs 178 | else: 179 | return '不好意思,我没有识别出来' 180 | 181 | def get_metrics(self, outputs, targets): 182 | accuracy = accuracy_score(targets, outputs) 183 | micro_f1 = f1_score(targets, outputs, average='micro') 184 | macro_f1 = f1_score(targets, outputs, average='macro') 185 | return accuracy, micro_f1, macro_f1 186 | 187 | def get_classification_report(self, outputs, targets, labels): 188 | report = classification_report(targets, outputs, target_names=labels) 189 | return report 190 | 191 | 192 | datasets = { 193 | "cnews": CNEWSDataset, 194 | "cpws": CPWSDataset 195 | } 196 | 197 | train_files = { 198 | "cnews": "cnews.train.txt", 199 | "cpws": "train_data.txt" 200 | } 201 | 202 | test_files = { 203 | "cnews": "cnews.test.txt", 204 | "cpws": "test_data.txt" 205 | } 206 | 207 | 208 | def main(args, tokenizer): 209 | dataset = datasets.get(args.data_name, None) 210 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None) 211 | if dataset is None: 212 | raise Exception("请输入正确的数据集名称") 213 | label2id = {} 214 | id2label = {} 215 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp: 216 | labels = fp.read().strip().split('\n') 217 | for i, label in enumerate(labels): 218 | label2id[label] = i 219 | id2label[i] = label 220 | print(label2id) 221 | 222 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id) 223 | 224 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file)) 225 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, 226 | collate_fn=collate.collate_fn, num_workers=4) 227 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file)) 228 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, 229 | collate_fn=collate.collate_fn, num_workers=4) 230 | 231 | model = models.BertForSequenceClassification(args) 232 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 233 | 234 | trainer = Trainer(args, optimizer) 235 | 236 | if args.do_train: 237 | if args.retrain: 238 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 239 | checkpoint = torch.load(checkpoint_path) 240 | model.cuda() 241 | r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0]) 242 | r_model.load_state_dict(checkpoint['state_dict']) 243 | optimizer.load_state_dict(checkpoint['optimizer']) 244 | 245 | epoch = checkpoint['epoch'] 246 | loss = checkpoint['loss'] 247 | logger.info("加载模型继续训练,epoch:{} loss:{}".format(epoch, loss)) 248 | # 训练和验证 249 | trainer.train(r_model, train_loader, dev_loader=test_loader) 250 | else: 251 | model.cuda() 252 | r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0]) 253 | # 训练和验证 254 | trainer.train(r_model, train_loader, dev_loader=test_loader) 255 | 256 | # 测试 257 | if args.do_test: 258 | logger.info('========进行测试========') 259 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 260 | # 多卡预测要先模型并行 261 | model.cuda() 262 | model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0]) 263 | # 再加载模型 264 | model = trainer.load_model(model, checkpoint_path) 265 | 266 | total_loss, test_outputs, test_targets = trainer.test(model, test_loader) 267 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets) 268 | logger.info( 269 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, micro_f1, 270 | macro_f1)) 271 | report = trainer.get_classification_report(test_outputs, test_targets, labels) 272 | logger.info(report) 273 | 274 | # 预测 275 | if args.do_predict: 276 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 277 | model = trainer.load_model(model, checkpoint_path) 278 | model.cuda() 279 | model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0]) 280 | line = test_dataset[0] 281 | text = line[0] 282 | print(text) 283 | result = trainer.predict(tokenizer, text, id2label, args, model) 284 | print("预测标签:", result[0]) 285 | print("真实标签:", line[1]) 286 | print("==========================") 287 | 288 | 289 | if __name__ == '__main__': 290 | # processor = preprocess.Processor() 291 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir) 292 | main(args, tokenizer) 293 | -------------------------------------------------------------------------------- /main_distributed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(r"./") 4 | """ 5 | 该文件使用的data_loader.py里面的数据加载方式。 6 | """ 7 | # coding=utf-8 8 | import os 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3' 11 | import tempfile 12 | import json 13 | import random 14 | from pprint import pprint 15 | import os 16 | import logging 17 | import shutil 18 | from sklearn.metrics import accuracy_score, f1_score, classification_report 19 | import torch 20 | import torch.nn as nn 21 | import numpy as np 22 | import pickle 23 | from torch.utils.data import DataLoader, RandomSampler 24 | from transformers import BertTokenizer 25 | import torch.distributed as dist 26 | 27 | import bert_config 28 | import models 29 | from utils import utils 30 | from data_loader import CNEWSDataset, Collate, CPWSDataset 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | # 单机多卡 36 | # word_size:机器一共有几张卡 37 | # rank:第几块GPU 38 | # local_rank:第几块GPU,和rank相同 39 | # print(torch.cuda.device_count()) 40 | 41 | # local_rank = torch.distributed.get_rank() 42 | # args.local_rank = local_rank 43 | # print(args.local_rank) 44 | # dist.init_process_group(backend='gloo', 45 | # init_method=r"file:///D://Code//project//pytorch-distributed training//tmp", 46 | # rank=0, 47 | # world_size=1) 48 | # torch.cuda.set_device(args.local_rank) 49 | 50 | 51 | class Trainer: 52 | def __init__(self, args, optimizer): 53 | self.args = args 54 | self.optimizer = optimizer 55 | self.criterion = nn.CrossEntropyLoss().cuda(self.args.local_rank) 56 | 57 | def load_ckp(self, model, optimizer, checkpoint_path): 58 | checkpoint = torch.load(checkpoint_path) 59 | model.load_state_dict(checkpoint['state_dict']) 60 | optimizer.load_state_dict(checkpoint['optimizer']) 61 | epoch = checkpoint['epoch'] 62 | loss = checkpoint['loss'] 63 | return model, optimizer, epoch, loss 64 | 65 | def load_model(self, model, checkpoint_path): 66 | checkpoint = torch.load(checkpoint_path) 67 | # new_start_dict = {} 68 | # for k, v in checkpoint['state_dict'].items(): 69 | # new_start_dict["module." + k] = v 70 | # model.load_state_dict(new_start_dict) 71 | model.load_state_dict(checkpoint["state_dict"]) 72 | return model 73 | 74 | def save_ckp(self, state, checkpoint_path): 75 | torch.save(state, checkpoint_path) 76 | 77 | """ 78 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path): 79 | tmp_checkpoint_path = checkpoint_path 80 | torch.save(state, tmp_checkpoint_path) 81 | if is_best: 82 | tmp_best_model_path = best_model_path 83 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path) 84 | """ 85 | 86 | def train(self, model, train_loader, train_sampler, dev_loader=None): 87 | self.dev_loader = dev_loader 88 | self.model = model 89 | total_step = len(train_loader) * self.args.train_epochs 90 | global_step = 0 91 | eval_step = 10 92 | best_dev_micro_f1 = 0.0 93 | for epoch in range(self.args.train_epochs): 94 | train_sampler.set_epoch(epoch) 95 | for train_step, train_data in enumerate(train_loader): 96 | self.model.train() 97 | token_ids = train_data['token_ids'].cuda(self.args.local_rank) 98 | attention_masks = train_data['attention_masks'].cuda(self.args.local_rank) 99 | token_type_ids = train_data['token_type_ids'].cuda(self.args.local_rank) 100 | labels = train_data['labels'].cuda(self.args.local_rank) 101 | train_outputs = self.model(token_ids, attention_masks, token_type_ids) 102 | 103 | loss = self.criterion(train_outputs, labels) 104 | 105 | torch.distributed.barrier() 106 | 107 | self.optimizer.zero_grad() 108 | loss.backward() 109 | self.optimizer.step() 110 | 111 | loss = self.loss_reduce(loss) 112 | if args.local_rank == 0: 113 | logger.info( 114 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss)) 115 | global_step += 1 116 | if dev_loader is not None and self.args.local_rank == 0: 117 | if global_step % eval_step == 0: 118 | dev_loss, dev_outputs, dev_targets = self.dev() 119 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets) 120 | logger.info( 121 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, 122 | accuracy, 123 | micro_f1, 124 | macro_f1)) 125 | if macro_f1 > best_dev_micro_f1: 126 | logger.info("------------>保存当前最好的模型") 127 | checkpoint = { 128 | 'epoch': epoch, 129 | 'loss': dev_loss, 130 | 'state_dict': self.model.state_dict(), 131 | 'optimizer': self.optimizer.state_dict(), 132 | } 133 | best_dev_micro_f1 = macro_f1 134 | save_path = os.path.join(self.args.output_dir, args.data_name) 135 | if not os.path.exists(save_path): 136 | os.makedirs(save_path) 137 | checkpoint_path = os.path.join(save_path, 'best.pt') 138 | self.save_ckp(checkpoint, checkpoint_path) 139 | if dev_loader is None and self.args.local_rank == 0: 140 | checkpoint = { 141 | 'epoch': epoch, 142 | 'state_dict': self.model.state_dict(), 143 | 'optimizer': self.optimizer.state_dict(), 144 | } 145 | save_path = os.path.join(self.args.output_dir, args.data_name) 146 | if not os.path.exists(save_path): 147 | os.makedirs(save_path) 148 | checkpoint_path = os.path.join(save_path, 'best.pt') 149 | self.save_ckp(checkpoint, checkpoint_path) 150 | 151 | def output_reduce(self, outputs, targets): 152 | output_gather_list = [torch.zeros_like(outputs) for _ in range(self.args.local_world_size)] 153 | # 把每一个GPU的输出聚合起来 154 | dist.all_gather(output_gather_list, outputs) 155 | 156 | outputs = torch.cat(output_gather_list, dim=0) 157 | target_gather_list = [torch.zeros_like(targets) for _ in range(self.args.local_world_size)] 158 | # 把每一个GPU的输出聚合起来 159 | dist.all_gather(target_gather_list, targets) 160 | targets = torch.cat(target_gather_list, dim=0) 161 | return outputs, targets 162 | 163 | def loss_reduce(self, loss): 164 | rt = loss.clone() 165 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 166 | rt /= self.args.local_world_size 167 | return rt 168 | 169 | def dev(self): 170 | self.model.eval() 171 | total_loss = 0.0 172 | dev_outputs = [] 173 | dev_targets = [] 174 | with torch.no_grad(): 175 | for dev_step, dev_data in enumerate(self.dev_loader): 176 | token_ids = dev_data['token_ids'].cuda(self.args.local_rank) 177 | attention_masks = dev_data['attention_masks'].cuda(self.args.local_rank) 178 | token_type_ids = dev_data['token_type_ids'].cuda(self.args.local_rank) 179 | labels = dev_data['labels'].cuda(self.args.local_rank) 180 | outputs = self.model(token_ids, attention_masks, token_type_ids) 181 | 182 | loss = self.criterion(outputs, labels) 183 | torch.distributed.barrier() 184 | # total_loss += loss.item() 185 | loss = self.loss_reduce(loss) 186 | total_loss += loss 187 | outputs, targets = self.output_reduce(outputs, labels) 188 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 189 | dev_outputs.extend(outputs.tolist()) 190 | dev_targets.extend(targets.cpu().detach().numpy().tolist()) 191 | print(len(dev_outputs), len(dev_targets)) 192 | return total_loss, dev_outputs, dev_targets 193 | 194 | def test(self, model, test_loader): 195 | model.eval() 196 | total_loss = 0.0 197 | test_outputs = [] 198 | test_targets = [] 199 | with torch.no_grad(): 200 | for test_step, test_data in enumerate(test_loader): 201 | token_ids = test_data['token_ids'].cuda(self.args.local_rank) 202 | attention_masks = test_data['attention_masks'].cuda(self.args.local_rank) 203 | token_type_ids = test_data['token_type_ids'].cuda(self.args.local_rank) 204 | labels = test_data['labels'].cuda(self.args.local_rank) 205 | outputs = model(token_ids, attention_masks, token_type_ids) 206 | 207 | loss = self.criterion(outputs, labels) 208 | torch.distributed.barrier() 209 | loss = self.loss_reduce(loss) 210 | total_loss += loss 211 | outputs, targets = self.output_reduce(outputs, labels) 212 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 213 | test_outputs.extend(outputs.tolist()) 214 | test_targets.extend(targets.cpu().detach().numpy().tolist()) 215 | # total_loss += loss.item() 216 | # outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 217 | # test_outputs.extend(outputs.tolist()) 218 | # test_targets.extend(labels.cpu().detach().numpy().tolist()) 219 | 220 | return total_loss, test_outputs, test_targets 221 | 222 | def predict(self, tokenizer, text, id2label, args, model): 223 | model.eval() 224 | with torch.no_grad(): 225 | inputs = tokenizer.encode_plus(text=text, 226 | add_special_tokens=True, 227 | max_length=args.max_seq_len, 228 | truncation='longest_first', 229 | padding="max_length", 230 | return_token_type_ids=True, 231 | return_attention_mask=True, 232 | return_tensors='pt') 233 | token_ids = inputs['input_ids'].cuda(self.args.local_rank) 234 | attention_masks = inputs['attention_mask'].cuda(self.args.local_rank) 235 | token_type_ids = inputs['token_type_ids'].cuda(self.args.local_rank) 236 | outputs = model(token_ids, attention_masks, token_type_ids) 237 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist() 238 | if len(outputs) != 0: 239 | outputs = [id2label[i] for i in outputs] 240 | return outputs 241 | else: 242 | return '不好意思,我没有识别出来' 243 | 244 | def get_metrics(self, outputs, targets): 245 | accuracy = accuracy_score(targets, outputs) 246 | micro_f1 = f1_score(targets, outputs, average='micro') 247 | macro_f1 = f1_score(targets, outputs, average='macro') 248 | return accuracy, micro_f1, macro_f1 249 | 250 | def get_classification_report(self, outputs, targets, labels): 251 | report = classification_report(targets, outputs, target_names=labels) 252 | return report 253 | 254 | 255 | datasets = { 256 | "cnews": CNEWSDataset, 257 | "cpws": CPWSDataset 258 | } 259 | 260 | train_files = { 261 | "cnews": "cnews.train.txt", 262 | "cpws": "train_data.txt" 263 | } 264 | 265 | test_files = { 266 | "cnews": "cnews.test.txt", 267 | "cpws": "test_data.txt" 268 | } 269 | 270 | 271 | def main(args, tokenizer, local_rank, local_world_size): 272 | n = torch.cuda.device_count() // local_world_size 273 | device_ids = list(range(local_rank * n, (local_rank + 1) * n)) 274 | 275 | print( 276 | f"[{os.getpid()}] rank = {dist.get_rank()}, " 277 | + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end='' 278 | ) 279 | dataset = datasets.get(args.data_name, None) 280 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None) 281 | if dataset is None: 282 | raise Exception("请输入正确的数据集名称") 283 | label2id = {} 284 | id2label = {} 285 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp: 286 | labels = fp.read().strip().split('\n') 287 | for i, label in enumerate(labels): 288 | label2id[label] = i 289 | id2label[i] = label 290 | if args.local_rank == 0: 291 | print(label2id) 292 | 293 | args.train_batch_size = int(args.train_batch_size / torch.cuda.device_count()) 294 | args.eval_batch_size = int(args.eval_batch_size / torch.cuda.device_count()) 295 | 296 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id) 297 | 298 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file)) 299 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 300 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, 301 | collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler) 302 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file)) 303 | 304 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 305 | 306 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, 307 | collate_fn=collate.collate_fn, num_workers=4, sampler=test_sampler) 308 | # test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, 309 | # collate_fn=collate.collate_fn, num_workers=4) 310 | 311 | model = models.BertForSequenceClassification(args) 312 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 313 | trainer = Trainer(args, optimizer) 314 | 315 | if args.do_train: 316 | if args.retrain: 317 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 318 | checkpoint = torch.load(checkpoint_path) 319 | # trainer.optimizer.load_state_dict(checkpoint['optimizer']) 320 | model.cuda(args.local_rank) 321 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 322 | r_model.load_state_dict(checkpoint['state_dict']) 323 | if args.local_rank == 0: 324 | logger.info("加载模型继续训练") 325 | # 训练和验证 326 | trainer.train(r_model, train_loader, train_sampler, dev_loader=test_loader) 327 | else: 328 | 329 | model.cuda(args.local_rank) 330 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 331 | # 训练和验证 332 | trainer.train(r_model, train_loader, train_sampler, dev_loader=test_loader) 333 | 334 | # 测试 335 | if args.do_test: 336 | if args.local_rank == 0: 337 | logger.info('========进行测试========') 338 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 339 | # 多卡预测要先模型并行 340 | model.cuda(args.local_rank) 341 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 342 | # 再加载模型 343 | model = trainer.load_model(model, checkpoint_path) 344 | 345 | total_loss, test_outputs, test_targets = trainer.test(model, test_loader) 346 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets) 347 | if args.local_rank == 0: 348 | logger.info( 349 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, 350 | micro_f1, 351 | macro_f1)) 352 | report = trainer.get_classification_report(test_outputs, test_targets, labels) 353 | logger.info(report) 354 | 355 | # 预测 356 | if args.do_predict: 357 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 358 | model.cuda(args.local_rank) 359 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 360 | model = trainer.load_model(model, checkpoint_path) 361 | line = test_dataset[0] 362 | text = line[0] 363 | result = trainer.predict(tokenizer, text, id2label, args, model) 364 | if args.local_rank == 0: 365 | print(text) 366 | print("预测标签:", result[0]) 367 | print("真实标签:", line[1]) 368 | print("==========================") 369 | 370 | 371 | def spmd_main(local_world_size, local_rank, init_method): 372 | env_dict = { 373 | key: os.environ[key] 374 | for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE") 375 | } 376 | if local_rank == 0: 377 | for k, v in env_dict.items(): 378 | print(k, v) 379 | if sys.platform == "win32": 380 | # Distributed package only covers collective communications with Gloo 381 | # backend and FileStore on Windows platform. Set init_method parameter 382 | # in init_process_group to a local file. 383 | if "INIT_METHOD" in os.environ.keys(): 384 | print(f"init_method is {os.environ['INIT_METHOD']}") 385 | url_obj = urlparse(os.environ["INIT_METHOD"]) 386 | if url_obj.scheme.lower() != "file": 387 | raise ValueError("Windows only supports FileStore") 388 | else: 389 | init_method = os.environ["INIT_METHOD"] 390 | else: 391 | # It is a example application, For convience, we create a file in temp dir. 392 | # current_work_dir = os.getcwd() 393 | # init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}" 394 | init_method = init_method 395 | print(init_method) 396 | dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]), 397 | world_size=int(env_dict["WORLD_SIZE"])) 398 | else: 399 | print(f"[{os.getpid()}] Initializing process group with: {env_dict}") 400 | dist.init_process_group(backend="nccl") 401 | 402 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir) 403 | main(args, tokenizer, local_rank, local_world_size) 404 | 405 | dist.destroy_process_group() 406 | 407 | 408 | if __name__ == '__main__': 409 | # processor = preprocess.Processor() 410 | args = bert_config.Args().get_parser() 411 | utils.set_seed(args.seed) 412 | utils.set_logger(os.path.join(args.log_dir, 'main.log')) 413 | # The main entry point is called directly without using subprocess 414 | current_work_dir = os.getcwd() 415 | init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}" 416 | spmd_main(args.local_world_size, args.local_rank, init_method) 417 | -------------------------------------------------------------------------------- /main_mp_distributed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(r"./") 4 | """ 5 | 该文件使用的data_loader.py里面的数据加载方式。 6 | """ 7 | # coding=utf-8 8 | import os 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3' 11 | import tempfile 12 | import json 13 | import random 14 | from pprint import pprint 15 | import os 16 | import logging 17 | import shutil 18 | from sklearn.metrics import accuracy_score, f1_score, classification_report 19 | import torch 20 | import torch.nn as nn 21 | import numpy as np 22 | import pickle 23 | from torch.utils.data import DataLoader, RandomSampler 24 | from transformers import BertTokenizer 25 | import torch.distributed as dist 26 | import torch.multiprocessing as mp 27 | 28 | import bert_config 29 | import models 30 | from utils import utils 31 | from data_loader import CNEWSDataset, Collate, CPWSDataset 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | # 单机多卡 37 | # word_size:机器一共有几张卡 38 | # rank:第几块GPU 39 | # local_rank:第几块GPU,和rank相同 40 | # print(torch.cuda.device_count()) 41 | 42 | # local_rank = torch.distributed.get_rank() 43 | # args.local_rank = local_rank 44 | # print(args.local_rank) 45 | # dist.init_process_group(backend='gloo', 46 | # init_method=r"file:///D://Code//project//pytorch-distributed training//tmp", 47 | # rank=0, 48 | # world_size=1) 49 | # torch.cuda.set_device(args.local_rank) 50 | 51 | 52 | class Trainer: 53 | def __init__(self, args, optimizer): 54 | self.args = args 55 | self.optimizer = optimizer 56 | self.criterion = nn.CrossEntropyLoss().cuda(self.args.local_rank) 57 | 58 | def load_ckp(self, model, optimizer, checkpoint_path): 59 | checkpoint = torch.load(checkpoint_path) 60 | model.load_state_dict(checkpoint['state_dict']) 61 | optimizer.load_state_dict(checkpoint['optimizer']) 62 | epoch = checkpoint['epoch'] 63 | loss = checkpoint['loss'] 64 | return model, optimizer, epoch, loss 65 | 66 | def load_model(self, model, checkpoint_path): 67 | checkpoint = torch.load(checkpoint_path) 68 | # new_start_dict = {} 69 | # for k, v in checkpoint['state_dict'].items(): 70 | # new_start_dict["module." + k] = v 71 | # model.load_state_dict(new_start_dict) 72 | model.load_state_dict(checkpoint["state_dict"]) 73 | return model 74 | 75 | def save_ckp(self, state, checkpoint_path): 76 | torch.save(state, checkpoint_path) 77 | 78 | """ 79 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path): 80 | tmp_checkpoint_path = checkpoint_path 81 | torch.save(state, tmp_checkpoint_path) 82 | if is_best: 83 | tmp_best_model_path = best_model_path 84 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path) 85 | """ 86 | 87 | def train(self, model, train_loader, train_sampler, dev_loader=None): 88 | self.dev_loader = dev_loader 89 | self.model = model 90 | total_step = len(train_loader) * self.args.train_epochs 91 | global_step = 0 92 | eval_step = 10 93 | best_dev_micro_f1 = 0.0 94 | for epoch in range(self.args.train_epochs): 95 | train_sampler.set_epoch(epoch) 96 | for train_step, train_data in enumerate(train_loader): 97 | self.model.train() 98 | token_ids = train_data['token_ids'].cuda(self.args.local_rank) 99 | attention_masks = train_data['attention_masks'].cuda(self.args.local_rank) 100 | token_type_ids = train_data['token_type_ids'].cuda(self.args.local_rank) 101 | labels = train_data['labels'].cuda(self.args.local_rank) 102 | train_outputs = self.model(token_ids, attention_masks, token_type_ids) 103 | 104 | loss = self.criterion(train_outputs, labels) 105 | 106 | torch.distributed.barrier() 107 | 108 | self.optimizer.zero_grad() 109 | loss.backward() 110 | self.optimizer.step() 111 | 112 | loss = self.loss_reduce(loss) 113 | if self.args.local_rank == 0: 114 | print( 115 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss)) 116 | global_step += 1 117 | if dev_loader is not None and self.args.local_rank == 0: 118 | if global_step % eval_step == 0: 119 | dev_loss, dev_outputs, dev_targets = self.dev() 120 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets) 121 | print( 122 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, 123 | accuracy, 124 | micro_f1, 125 | macro_f1)) 126 | if macro_f1 > best_dev_micro_f1: 127 | print("------------>保存当前最好的模型") 128 | checkpoint = { 129 | 'epoch': epoch, 130 | 'loss': dev_loss, 131 | 'state_dict': self.model.state_dict(), 132 | 'optimizer': self.optimizer.state_dict(), 133 | } 134 | best_dev_micro_f1 = macro_f1 135 | save_path = os.path.join(self.args.output_dir, self.args.data_name) 136 | if not os.path.exists(save_path): 137 | os.makedirs(save_path) 138 | checkpoint_path = os.path.join(save_path, 'best.pt') 139 | self.save_ckp(checkpoint, checkpoint_path) 140 | if dev_loader is None and self.args.local_rank == 0: 141 | checkpoint = { 142 | 'epoch': epoch, 143 | 'state_dict': self.model.state_dict(), 144 | 'optimizer': self.optimizer.state_dict(), 145 | } 146 | save_path = os.path.join(self.args.output_dir, self.args.data_name) 147 | if not os.path.exists(save_path): 148 | os.makedirs(save_path) 149 | checkpoint_path = os.path.join(save_path, 'best.pt') 150 | self.save_ckp(checkpoint, checkpoint_path) 151 | 152 | def output_reduce(self, outputs, targets): 153 | output_gather_list = [torch.zeros_like(outputs) for _ in range(self.args.local_world_size)] 154 | # 把每一个GPU的输出聚合起来 155 | dist.all_gather(output_gather_list, outputs) 156 | 157 | outputs = torch.cat(output_gather_list, dim=0) 158 | target_gather_list = [torch.zeros_like(targets) for _ in range(self.args.local_world_size)] 159 | # 把每一个GPU的输出聚合起来 160 | dist.all_gather(target_gather_list, targets) 161 | targets = torch.cat(target_gather_list, dim=0) 162 | return outputs, targets 163 | 164 | def loss_reduce(self, loss): 165 | rt = loss.clone() 166 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 167 | rt /= self.args.local_world_size 168 | return rt 169 | 170 | def dev(self): 171 | self.model.eval() 172 | total_loss = 0.0 173 | dev_outputs = [] 174 | dev_targets = [] 175 | with torch.no_grad(): 176 | for dev_step, dev_data in enumerate(self.dev_loader): 177 | token_ids = dev_data['token_ids'].cuda(self.args.local_rank) 178 | attention_masks = dev_data['attention_masks'].cuda(self.args.local_rank) 179 | token_type_ids = dev_data['token_type_ids'].cuda(self.args.local_rank) 180 | labels = dev_data['labels'].cuda(self.args.local_rank) 181 | outputs = self.model(token_ids, attention_masks, token_type_ids) 182 | 183 | loss = self.criterion(outputs, labels) 184 | torch.distributed.barrier() 185 | # total_loss += loss.item() 186 | loss = self.loss_reduce(loss) 187 | total_loss += loss 188 | outputs, targets = self.output_reduce(outputs, labels) 189 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 190 | dev_outputs.extend(outputs.tolist()) 191 | dev_targets.extend(targets.cpu().detach().numpy().tolist()) 192 | return total_loss, dev_outputs, dev_targets 193 | 194 | def test(self, model, test_loader): 195 | model.eval() 196 | total_loss = 0.0 197 | test_outputs = [] 198 | test_targets = [] 199 | with torch.no_grad(): 200 | for test_step, test_data in enumerate(test_loader): 201 | token_ids = test_data['token_ids'].cuda(self.args.local_rank) 202 | attention_masks = test_data['attention_masks'].cuda(self.args.local_rank) 203 | token_type_ids = test_data['token_type_ids'].cuda(self.args.local_rank) 204 | labels = test_data['labels'].cuda(self.args.local_rank) 205 | outputs = model(token_ids, attention_masks, token_type_ids) 206 | 207 | loss = self.criterion(outputs, labels) 208 | torch.distributed.barrier() 209 | loss = self.loss_reduce(loss) 210 | total_loss += loss 211 | outputs, targets = self.output_reduce(outputs, labels) 212 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 213 | test_outputs.extend(outputs.tolist()) 214 | test_targets.extend(targets.cpu().detach().numpy().tolist()) 215 | # total_loss += loss.item() 216 | # outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten() 217 | # test_outputs.extend(outputs.tolist()) 218 | # test_targets.extend(labels.cpu().detach().numpy().tolist()) 219 | 220 | return total_loss, test_outputs, test_targets 221 | 222 | def predict(self, tokenizer, text, id2label, args, model): 223 | model.eval() 224 | with torch.no_grad(): 225 | inputs = tokenizer.encode_plus(text=text, 226 | add_special_tokens=True, 227 | max_length=args.max_seq_len, 228 | truncation='longest_first', 229 | padding="max_length", 230 | return_token_type_ids=True, 231 | return_attention_mask=True, 232 | return_tensors='pt') 233 | token_ids = inputs['input_ids'].cuda(self.args.local_rank) 234 | attention_masks = inputs['attention_mask'].cuda(self.args.local_rank) 235 | token_type_ids = inputs['token_type_ids'].cuda(self.args.local_rank) 236 | outputs = model(token_ids, attention_masks, token_type_ids) 237 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist() 238 | if len(outputs) != 0: 239 | outputs = [id2label[i] for i in outputs] 240 | return outputs 241 | else: 242 | return '不好意思,我没有识别出来' 243 | 244 | def get_metrics(self, outputs, targets): 245 | accuracy = accuracy_score(targets, outputs) 246 | micro_f1 = f1_score(targets, outputs, average='micro') 247 | macro_f1 = f1_score(targets, outputs, average='macro') 248 | return accuracy, micro_f1, macro_f1 249 | 250 | def get_classification_report(self, outputs, targets, labels): 251 | report = classification_report(targets, outputs, target_names=labels) 252 | return report 253 | 254 | 255 | datasets = { 256 | "cnews": CNEWSDataset, 257 | "cpws": CPWSDataset 258 | } 259 | 260 | train_files = { 261 | "cnews": "cnews.train.txt", 262 | "cpws": "train_data.txt" 263 | } 264 | 265 | test_files = { 266 | "cnews": "cnews.test.txt", 267 | "cpws": "test_data.txt" 268 | } 269 | 270 | 271 | def main_worker(local_rank, args): 272 | args.local_rank = local_rank 273 | print(type(args.local_rank)) 274 | tokenizer = args.tokenizer 275 | local_world_size = args.local_world_size 276 | # The main entry point is called directly without using subprocess 277 | current_work_dir = os.getcwd() 278 | init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}" 279 | if sys.platform == "win32": 280 | # Distributed package only covers collective communications with Gloo 281 | # backend and FileStore on Windows platform. Set init_method parameter 282 | # in init_process_group to a local file. 283 | if "INIT_METHOD" in os.environ.keys(): 284 | print(f"init_method is {os.environ['INIT_METHOD']}") 285 | url_obj = urlparse(os.environ["INIT_METHOD"]) 286 | if url_obj.scheme.lower() != "file": 287 | raise ValueError("Windows only supports FileStore") 288 | else: 289 | init_method = os.environ["INIT_METHOD"] 290 | else: 291 | # It is a example application, For convience, we create a file in temp dir. 292 | # current_work_dir = os.getcwd() 293 | # init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}" 294 | init_method = init_method 295 | print(init_method) 296 | dist.init_process_group(backend="gloo", init_method=init_method, rank=local_rank, 297 | world_size=args.local_world_size) 298 | dist.barrier() 299 | else: 300 | # print(f"[{os.getpid()}] Initializing process group with: {env_dict}") 301 | dist.init_process_group(backend="nccl") 302 | n = torch.cuda.device_count() // local_world_size 303 | device_ids = list(range(local_rank * n, (local_rank + 1) * n)) 304 | 305 | print( 306 | f"[{os.getpid()}] rank = {dist.get_rank()}, " 307 | + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end='' 308 | ) 309 | # device_ids = [local_rank] 310 | # print(device_ids) 311 | 312 | dataset = datasets.get(args.data_name, None) 313 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None) 314 | if dataset is None: 315 | raise Exception("请输入正确的数据集名称") 316 | label2id = {} 317 | id2label = {} 318 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp: 319 | labels = fp.read().strip().split('\n') 320 | for i, label in enumerate(labels): 321 | label2id[label] = i 322 | id2label[i] = label 323 | if args.local_rank == 0: 324 | print(label2id) 325 | if args.local_rank == 0: 326 | print('========加载数据集========') 327 | args.train_batch_size = int(args.train_batch_size / torch.cuda.device_count()) 328 | args.eval_batch_size = int(args.eval_batch_size / torch.cuda.device_count()) 329 | 330 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id) 331 | 332 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file)) 333 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 334 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, 335 | collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler) 336 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file)) 337 | 338 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 339 | 340 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, 341 | collate_fn=collate.collate_fn, num_workers=4, sampler=test_sampler) 342 | # test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, 343 | # collate_fn=collate.collate_fn, num_workers=4) 344 | if args.local_rank == 0: 345 | print('========定义模型和优化器========') 346 | model = models.BertForSequenceClassification(args) 347 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 348 | trainer = Trainer(args, optimizer) 349 | 350 | if args.do_train: 351 | if args.retrain: 352 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 353 | checkpoint = torch.load(checkpoint_path) 354 | # trainer.optimizer.load_state_dict(checkpoint['optimizer']) 355 | model.cuda(args.local_rank) 356 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 357 | r_model.load_state_dict(checkpoint['state_dict']) 358 | if args.local_rank == 0: 359 | print("========加载模型继续训练========") 360 | # 训练和验证 361 | trainer.train(r_model, train_loader, train_sampler, dev_loader=None) 362 | else: 363 | if args.local_rank == 0: 364 | print('========进行训练========') 365 | model.cuda(args.local_rank) 366 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 367 | # 训练和验证 368 | trainer.train(r_model, train_loader, train_sampler, dev_loader=None) 369 | 370 | # 测试 371 | if args.do_test: 372 | if args.local_rank == 0: 373 | print('========进行测试========') 374 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 375 | # 多卡预测要先模型并行 376 | model.cuda(args.local_rank) 377 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 378 | # 再加载模型 379 | model = trainer.load_model(model, checkpoint_path) 380 | 381 | total_loss, test_outputs, test_targets = trainer.test(model, test_loader) 382 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets) 383 | if args.local_rank == 0: 384 | print( 385 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, 386 | micro_f1, 387 | macro_f1)) 388 | report = trainer.get_classification_report(test_outputs, test_targets, labels) 389 | print(report) 390 | 391 | # 预测 392 | if args.do_predict: 393 | if args.local_rank == 0: 394 | print('========进行预测========') 395 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name) 396 | model.cuda(args.local_rank) 397 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids) 398 | model = trainer.load_model(model, checkpoint_path) 399 | line = test_dataset[0] 400 | text = line[0] 401 | result = trainer.predict(tokenizer, text, id2label, args, model) 402 | if args.local_rank == 0: 403 | print(text) 404 | print("预测标签:", result[0]) 405 | print("真实标签:", line[1]) 406 | print("==========================") 407 | 408 | dist.destroy_process_group() 409 | 410 | 411 | if __name__ == '__main__': 412 | args = bert_config.Args().get_parser() 413 | utils.set_seed(args.seed) 414 | utils.set_logger(os.path.join(args.log_dir, 'main.log')) 415 | 416 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir) 417 | args.tokenizer = tokenizer 418 | args.nprocs = args.local_world_size 419 | mp.spawn(main_worker, nprocs=args.nprocs, args=(args,)) 420 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel 2 | import torch.nn as nn 3 | 4 | 5 | class BertForSequenceClassification(nn.Module): 6 | def __init__(self, args): 7 | super(BertForSequenceClassification, self).__init__() 8 | self.bert = BertModel.from_pretrained(args.bert_dir) 9 | self.bert_config = self.bert.config 10 | out_dims = self.bert_config.hidden_size 11 | self.dropout = nn.Dropout(0.3) 12 | self.linear = nn.Linear(out_dims, args.num_tags) 13 | 14 | def forward(self, token_ids, attention_masks, token_type_ids): 15 | bert_outputs = self.bert( 16 | input_ids = token_ids, 17 | attention_mask = attention_masks, 18 | token_type_ids = token_type_ids, 19 | ) 20 | seq_out = bert_outputs[1] 21 | seq_out = self.dropout(seq_out) 22 | seq_out = self.linear(seq_out) 23 | return seq_out -------------------------------------------------------------------------------- /nvidia.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | :: 執行的指令 ( 請注意若要使用 pipe 符號必須在之前加上一個 ^ 符號 ) 4 | SET ExecuteCommand=nvidia-smi 5 | 6 | :: 單位: 秒 7 | SET ExecutePeriod=1 8 | 9 | 10 | SETLOCAL EnableDelayedExpansion 11 | 12 | :loop 13 | 14 | cls 15 | 16 | echo !date! !time! 17 | echo 每 !ExecutePeriod! 秒執行一次,指令^: !ExecuteCommand! 18 | 19 | echo. 20 | 21 | %ExecuteCommand% 22 | 23 | timeout /t %ExecutePeriod% > nul 24 | 25 | goto loop -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tokenizers==0.10.2 2 | torch==1.6.0 3 | transformers==4.5.1 4 | pytorch-lightning==0.9.0 5 | tensorboard==2.2.0 6 | -------------------------------------------------------------------------------- /test_ddp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import torch 5 | import torchvision 6 | from torch import distributed as dist 7 | from torchvision.models import resnet18 8 | from torch.utils.data import DataLoader 9 | from torchvision.datasets import MNIST 10 | from torchvision.transforms import ToTensor 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | from torch.utils.data.distributed import DistributedSampler 13 | import numpy as np 14 | 15 | 16 | def reduce_loss(tensor, rank, world_size): 17 | with torch.no_grad(): 18 | dist.reduce(tensor, dst=0) 19 | if rank == 0: 20 | tensor /= world_size 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--local_rank', type=int, help="local gpu id") 25 | parser.add_argument('--world_size', type=int, help="total nodes") 26 | args = parser.parse_args() 27 | 28 | # world_size = os.environ["world_size"] 29 | 30 | batch_size = 128 31 | epochs = 5 32 | lr = 0.001 33 | n = torch.cuda.device_count() // args.world_size 34 | device_ids = list(range(args.local_rank * n, (args.local_rank + 1) * n)) 35 | print("初始化1") 36 | init_method_path = "D:\\Code\\project\\pytorch-distributed training\\pytorch_bert_chinese_text_classification\\tmp\\" 37 | if os.path.exists(init_method_path): 38 | os.remove(init_method_path) 39 | print("已删除init_method_path") 40 | dist.init_process_group(backend='gloo', 41 | init_method='file:///{}'.format(init_method_path), 42 | rank=int(args.local_rank), world_size=int(args.world_size)) 43 | torch.cuda.set_device(args.local_rank) 44 | global_rank = dist.get_rank() 45 | print("初始化2") 46 | 47 | print( 48 | f"[{os.getpid()}] rank = {dist.get_rank()}, " 49 | + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end='' 50 | ) 51 | 52 | from torchvision.models.resnet import ResNet, BasicBlock 53 | 54 | 55 | class MnistResNet(ResNet): 56 | def __init__(self): 57 | super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10) 58 | self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 59 | 60 | def forward(self, x): 61 | return torch.softmax(super(MnistResNet, self).forward(x), dim=-1) 62 | 63 | 64 | # net = resnet18() 65 | net = MnistResNet() 66 | net.cuda() 67 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 68 | net = DDP(net, device_ids=[args.local_rank], output_device=args.local_rank) 69 | 70 | 71 | class ToNumpy(object): 72 | def __call__(self, sample): 73 | return np.array(sample) 74 | 75 | 76 | data_root = 'dataset' 77 | trainset = MNIST(root=data_root, 78 | download=True, 79 | train=True, 80 | transform=torchvision.transforms.Compose( 81 | [ToNumpy(), torchvision.transforms.ToTensor()]) 82 | ) 83 | 84 | valset = MNIST(root=data_root, 85 | download=True, 86 | train=False, 87 | transform=torchvision.transforms.Compose( 88 | [ToNumpy(), torchvision.transforms.ToTensor()]) 89 | ) 90 | 91 | sampler = DistributedSampler(trainset) 92 | train_loader = DataLoader(trainset, 93 | batch_size=batch_size, 94 | shuffle=False, 95 | pin_memory=True, 96 | sampler=sampler, 97 | ) 98 | 99 | val_loader = DataLoader(valset, 100 | batch_size=batch_size, 101 | shuffle=False, 102 | pin_memory=True, 103 | ) 104 | 105 | criterion = torch.nn.CrossEntropyLoss() 106 | opt = torch.optim.Adam(net.parameters(), lr=lr) 107 | 108 | net.train() 109 | for e in range(epochs): 110 | # DistributedSampler deterministically shuffle data 111 | # by seting random seed be current number epoch 112 | # so if do not call set_epoch when start of one epoch 113 | # the order of shuffled data will be always same 114 | sampler.set_epoch(e) 115 | for idx, (imgs, labels) in enumerate(train_loader): 116 | imgs = imgs.cuda() 117 | labels = labels.cuda() 118 | output = net(imgs) 119 | loss = criterion(output, labels) 120 | opt.zero_grad() 121 | loss.backward() 122 | opt.step() 123 | reduce_loss(loss, global_rank, args.world_size) 124 | if idx % 10 == 0 and global_rank == 0: 125 | print('Epoch: {} step: {} loss: {}'.format(e, idx, loss.item())) 126 | net.eval() 127 | with torch.no_grad(): 128 | cnt = 0 129 | total = len(val_loader.dataset) 130 | for imgs, labels in val_loader: 131 | imgs, labels = imgs.cuda(), labels.cuda() 132 | output = net(imgs) 133 | predict = torch.argmax(output, dim=1) 134 | cnt += (predict == labels).sum().item() 135 | 136 | if global_rank == 0: 137 | print('eval accuracy: {}'.format(cnt / total)) 138 | 139 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import random 3 | import logging 4 | import time 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def timer(func): 10 | """ 11 | 函数计时器 12 | :param func: 13 | :return: 14 | """ 15 | 16 | @functools.wraps(func) 17 | def wrapper(*args, **kwargs): 18 | start = time.time() 19 | res = func(*args, **kwargs) 20 | end = time.time() 21 | print("{}共耗时约{:.4f}秒".format(func.__name__, end - start)) 22 | return res 23 | 24 | return wrapper 25 | 26 | 27 | def set_seed(seed=123): 28 | """ 29 | 设置随机数种子,保证实验可重现 30 | :param seed: 31 | :return: 32 | """ 33 | random.seed(seed) 34 | torch.manual_seed(seed) 35 | np.random.seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | 38 | 39 | def set_logger(log_path): 40 | """ 41 | 配置log 42 | :param log_path:s 43 | :return: 44 | """ 45 | logger = logging.getLogger() 46 | logger.setLevel(logging.INFO) 47 | 48 | # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler 49 | if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers): 50 | file_handler = logging.FileHandler(log_path) 51 | formatter = logging.Formatter( 52 | '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s') 53 | file_handler.setFormatter(formatter) 54 | logger.addHandler(file_handler) 55 | 56 | if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers): 57 | stream_handler = logging.StreamHandler() 58 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 59 | logger.addHandler(stream_handler) 60 | 61 | --------------------------------------------------------------------------------