├── README.md
├── __init__.py
├── bert_config.py
├── checkpoints
├── cnews
│ └── 占位.txt
└── cpws
│ └── 占位.txt
├── data
├── cnews
│ ├── labels.txt
│ └── process.py
└── cpws
│ ├── labels.txt
│ └── process.py
├── data_loader.py
├── logs
├── main.log
└── preprocess.log
├── main.py
├── main_dataparallel.py
├── main_distributed.py
├── main_mp_distributed.py
├── models.py
├── nvidia.bat
├── requirements.txt
├── test_ddp.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-37.pyc
└── utils.cpython-37.pyc
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # 基于pytorch+bert的中文文本分类
2 |
3 | 本项目是基于pytorch+bert的中文文本分类。
4 |
5 | ## 使用依赖
6 | ```python
7 | torch==1.6.0
8 | transformers==4.5.1
9 | ```
10 | ## 相关说明
11 | ```
12 | --logs:存放日志
13 | --checkpoints:存放保存的模型
14 | --data:存放数据
15 | --utils:存放辅助函数
16 | --bert_config.py:相关配置
17 | --data_loader.py:制作数据为torch所需的格式
18 | --models.py:存放模型代码
19 | --main.py:主运行程序,包含训练、验证、测试、预测以及相关评价指标的计算
20 | ```
21 |
22 | 在hugging face上预先下载好预训练的bert模型,放在和该项目同级下的model_hub文件夹下。
23 |
24 | ## 裁判文书分类
25 |
26 | 数据集地址:[裁判文书网NLP文本分类数据集 - Heywhale.com](https://www.heywhale.com/mw/dataset/625869115fe0ad0017c6a7f7/file)
27 |
28 | #### 一般步骤
29 |
30 | - 1、在data下新建一个存放该数据集的文件夹,这里是cpws,然后将数据放在该文件夹下。在该文件夹下新建一个process.py,主要是获取标签并存储在labels.txt中。
31 | - 2、在data_loader.py里面新建一个类,类可参考CPWSDataset,主要是返回一个列表:[(文本,标签)]。
32 | - 3、在main.py里面的datasets、train_files、test_files里面添加上属于该数据集的一些信息,最后运行main.py即可。
33 | - 4、在运行main.py时,可通过指定--do_train、--do_test、--do_predict来选择训练、测试或预测。
34 |
35 | #### 运行指令
36 |
37 | ```python
38 | python main.py \
39 | --bert_dir="../model_hub/chinese-bert-wwm-ext/" \
40 | --data_dir="./data/cpws/" \
41 | --data_name="cpws" \
42 | --log_dir="./logs/" \
43 | --output_dir="./checkpoints/" \
44 | --num_tags=5 \
45 | --seed=123 \
46 | --gpu_ids="0" \
47 | --max_seq_len=256 \
48 | --lr=3e-5 \
49 | --train_batch_size=16 \
50 | --train_epochs=5 \
51 | --eval_batch_size=16 \
52 | --do_train \
53 | --do_test \
54 | --do_predict
55 | ```
56 |
57 | #### 结果
58 |
59 | 这里运行了300步之后手动停止了。
60 |
61 | ```python
62 | {'盗窃罪': 0, '交通肇事罪': 1, '诈骗罪': 2, '故意伤害罪': 3, '危险驾驶罪': 4}
63 | ========进行测试========
64 | 【test】 loss:22.040314 accuracy:0.9823 micro_f1:0.9823 macro_f1:0.9822
65 | precision recall f1-score support
66 |
67 | 盗窃罪 0.97 0.99 0.98 1998
68 | 交通肇事罪 0.97 0.99 0.98 1996
69 | 诈骗罪 0.99 0.98 0.99 1998
70 | 故意伤害罪 0.99 0.99 0.99 1999
71 | 危险驾驶罪 0.99 0.95 0.97 2000
72 |
73 | accuracy 0.98 9991
74 | macro avg 0.98 0.98 0.98 9991
75 | weighted avg 0.98 0.98 0.98 9991
76 | 公诉机关指控:1、2015年3月18日18时许,被告人余某窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内,趁工作人员不注意盗走工地内的脚手架扣件70个(价值人民币252元)。2、2015年3月19日13时和17时,被告人余某分两次窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内一楼房一层的中间配电室内,利用随身携带的铁钳盗走该配电室内的电缆线(共计574米,价值人民币4707元)。3、2015年3月21日7时30分许,被告人余某窜至漳州市芗城区丹霞路欣隆盛世小区2期工地内一楼房一层靠东边的配电室内,利用随身携带的铁钳要将该配电室内的电缆线(共156米,价值人民币1279元)盗走时被工地负责人洪某某发现,后被工地保安吴某某抓获并扭送公安机关。公诉机关认为被告人余某的行为已构成××,本案第三起盗窃系犯罪未遂,建议对被告人余某在××至一年六个月的幅度内处以刑罚,并处罚金。
77 | 预测标签: 盗窃罪
78 | 真实标签: 盗窃罪
79 | ==========================
80 | ```
81 |
82 | # 新闻文本分类
83 |
84 | 使用的数据集是THUCNews,数据地址:THUCNews
85 |
86 | #### 一般步骤
87 |
88 | - 1、在data下新建一个存放该数据集的文件夹,这里是cnews,然后将数据放在该文件夹下。在该文件夹下新建一个process.py,主要是获取标签并存储在labels.txt中。
89 | - 2、在data_loader.py里面新建一个类,类可参考CNEWSDataset,主要是返回一个列表:[(文本,标签)]。
90 | - 3、在main.py里面的datasets、train_files、test_files里面添加上属于该数据集的一些信息,最后运行main.py即可。
91 | - 4、在运行main.py时,可通过指定--do_train、--do_test、--do_predict来选择训练、测试或预测。
92 |
93 | #### 运行
94 |
95 | ```python
96 | python main.py \
97 | --bert_dir="../model_hub/chinese-bert-wwm-ext/" \
98 | --data_dir="./data/cnews/" \
99 | --data_name="cnews" \
100 | --log_dir="./logs/" \
101 | --output_dir="./checkpoints/" \
102 | --num_tags=10 \
103 | --seed=123 \
104 | --gpu_ids="0" \
105 | --max_seq_len=512 \
106 | --lr=3e-5 \
107 | --train_batch_size=16 \
108 | --train_epochs=5 \
109 | --eval_batch_size=16 \
110 | --do_predict
111 | ```
112 | #### 结果
113 |
114 | 这里运行了800步手动停止了。
115 |
116 | ```python
117 | {'房产': 0, '娱乐': 1, '教育': 2, '体育': 3, '家居': 4, '时政': 5, '财经': 6, '时尚': 7, '游戏': 8, '科技': 9}
118 | ========进行测试========
119 | 【test】 loss:76.024950 accuracy:0.9697 micro_f1:0.9697 macro_f1:0.9696
120 | precision recall f1-score support
121 |
122 | 房产 0.91 0.92 0.92 1000
123 | 娱乐 0.99 0.99 0.99 1000
124 | 教育 0.97 0.96 0.97 1000
125 | 体育 1.00 1.00 1.00 1000
126 | 家居 0.98 0.91 0.94 1000
127 | 时政 0.98 0.94 0.96 1000
128 | 财经 0.96 0.99 0.97 1000
129 | 时尚 0.94 1.00 0.97 1000
130 | 游戏 0.99 0.99 0.99 1000
131 | 科技 0.98 0.99 0.99 1000
132 |
133 | accuracy 0.97 10000
134 | macro avg 0.97 0.97 0.97 10000
135 | weighted avg 0.97 0.97 0.97 10000
136 |
137 | 鲍勃库西奖归谁属? NCAA最强控卫是坎巴还是弗神新浪体育讯如今,本赛季的NCAA进入到了末段,各项奖项的评选结果也即将出炉,其中评选最佳控卫的鲍勃-库西奖就将在下周最终四强战时公布,鲍勃-库西奖是由奈史密斯篮球名人堂提供,旨在奖励年度最佳大学控卫。最终获奖的球员也即将在以下几名热门人选中产生。〈〈〈 NCAA疯狂三月专题主页上线,点击链接查看精彩内容吉梅尔-弗雷戴特,杨百翰大学“弗神”吉梅尔-弗雷戴特一直都备受关注,他不仅仅是一名射手,他会用“终结对手脚踝”一样的变向过掉面前的防守者,并且他可以用任意一支手完成得分,如果他被犯规了,可以提前把这两份划入他的帐下了,因为他是一名命中率高达90%的罚球手。弗雷戴特具有所有伟大控卫都具备的一点特质,他是一位赢家也是一位领导者。“他整个赛季至始至终的稳定领导着球队前进,这是无可比拟的。”杨百翰大学主教练戴夫-罗斯称赞道,“他的得分能力毋庸置疑,但是我认为他带领球队获胜的能力才是他最重要的控卫职责。我们在主场之外的比赛(客场或中立场)共取胜19场,他都表现的很棒。”弗雷戴特能否在NBA取得成功?当然,但是有很多专业人士比我们更有资格去做出这样的判断。“我喜爱他。”凯尔特人主教练多克-里弗斯说道,“他很棒,我看过ESPN的片段剪辑,从剪辑来看,他是个超级巨星,我认为他很成为一名优秀的NBA球员。”诺兰-史密斯,杜克大学当赛季初,球队宣布大一天才控卫凯瑞-厄尔文因脚趾的伤病缺席赛季大部分比赛后,诺兰-史密斯便开始接管球权,他在进攻端上足发条,在ACC联盟(杜克大学所在分区)的得分榜上名列前茅,但同时他在分区助攻榜上也占据头名,这在众强林立的ACC联盟前无古人。“我不认为全美有其他的球员能在凯瑞-厄尔文受伤后,如此好的接管球队,并且之前毫无准备。”杜克主教练迈克-沙舍夫斯基赞扬道,“他会将比赛带入自己的节奏,得分,组织,领导球队,无所不能。而且他现在是攻防俱佳,对持球人的防守很有提高。总之他拥有了辉煌的赛季。”坎巴-沃克,康涅狄格大学坎巴-沃克带领康涅狄格在赛季初的毛伊岛邀请赛一路力克密歇根州大和肯塔基等队夺冠,他场均30分4助攻得到最佳球员。在大东赛区锦标赛和全国锦标赛中,他场均27.1分,6.1个篮板,5.1次助攻,依旧如此给力。他以疯狂的表现开始这个赛季,也将以疯狂的表现结束这个赛季。“我们在全国锦标赛中前进着,并且之前曾经5天连赢5场,赢得了大东赛区锦标赛的冠军,这些都归功于坎巴-沃克。”康涅狄格大学主教练吉姆-卡洪称赞道,“他是一名纯正的控卫而且能为我们得分,他有过单场42分,有过单场17助攻,也有过单场15篮板。这些都是一名6英尺175镑的球员所完成的啊!我们有很多好球员,但他才是最好的领导者,为球队所做的贡献也是最大。”乔丹-泰勒,威斯康辛大学全美没有一个持球者能像乔丹-泰勒一样很少失误,他4.26的助攻失误在全美遥遥领先,在大十赛区的比赛中,他平均35.8分钟才会有一次失误。他还是名很出色的得分手,全场砍下39分击败印第安纳大学的比赛就是最好的证明,其中下半场他曾经连拿18分。“那个夜晚他证明自己值得首轮顺位。”当时的见证者印第安纳大学主教练汤姆-克雷恩说道。“对一名控卫的所有要求不过是领导球队、使球队变的更好、带领球队成功,乔丹-泰勒全做到了。”威斯康辛教练博-莱恩说道。诺里斯-科尔,克利夫兰州大诺里斯-科尔的草根传奇正在上演,默默无闻的他被克利夫兰州大招募后便开始刻苦地训练,去年夏天他曾加练上千次跳投,来提高这个可能的弱点。他在本赛季与杨斯顿州大的比赛中得到40分20篮板和9次助攻,在他之前,过去15年只有一位球员曾经在NCAA一级联盟做到过40+20,他的名字是布雷克-格里芬。“他可以很轻松地防下对方王牌。”克利夫兰州大主教练加里-沃特斯如此称赞自己的弟子,“同时他还能得分,并为球队助攻,他几乎能做到一个成功的团队所有需要的事。”这其中四名球员都带领自己的球队进入到了甜蜜16强,虽然有3个球员和他们各自的球队被挡在8强的大门之外,但是他们已经表现的足够出色,不远的将来他们很可能出现在一所你熟悉的NBA球馆里。(clay)
138 | 预测标签: 体育
139 | 真实标签: 体育
140 | ==========================
141 | ```
142 |
143 | # Dataparallel分布式训练
144 |
145 | ```python
146 | python main_dataparallel.py \
147 | --bert_dir="../model_hub/chinese-bert-wwm-ext/" \
148 | --data_dir="./data/cnews/" \
149 | --data_name="cnews" \
150 | --log_dir="./logs/" \
151 | --output_dir="./checkpoints/" \
152 | --num_tags=10 \
153 | --seed=123 \
154 | --gpu_ids="0,1,3" \
155 | --max_seq_len=512 \
156 | --lr=3e-5 \
157 | --train_batch_size=64 \
158 | --train_epochs=1 \
159 | --eval_batch_size=64 \
160 | --do_train \
161 | --do_predict \
162 | --do_test
163 | ```
164 |
165 | # Distributed单机多卡分布式训练(windows下)
166 |
167 | linux下没有测试过。运行需要在powershell里面运行,右键点击开始菜单,选择powershell。nvidia.bat用于监控运行之后GPU的使用情况。需要pytorch版本至少大于1.7,这里使用的是pytorch==1.12。
168 |
169 | ### 使用torch.distributed.launch启动
170 |
171 | ```python
172 | python -m torch.distributed.launch --nnode=1 --node_rank=0 --nproc_per_node=4 main_distributed.py --local_world_size=4 --bert_dir="../model_hub/chinese-bert-wwm-ext/" --data_dir="./data/cnews/" --data_name="cnews" --log_dir="./logs/" --output_dir="./checkpoints/" --num_tags=10 --seed=123 --max_seq_len=512 --lr=3e-5 --train_batch_size=64 --train_epochs=1 --eval_batch_size=64 --do_train --do_predict --do_test
173 | ```
174 |
175 | **说明**:文件里面通过```os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'```来选择使用的GPU。nproc_per_node为使用的GPU的数目,local_world_size为使用的GPU的数目。
176 |
177 | ### 使用torch.multiprocessing启动
178 |
179 | ```python
180 | python main_mp_distributed.py --local_world_size=4 --bert_dir="../model_hub/chinese-bert-wwm-ext/" --data_dir="./data/cnews/" --data_name="cnews" --log_dir="./logs/" --output_dir="./checkpoints/" --num_tags=10 --seed=123 --max_seq_len=512 --lr=3e-5 --train_batch_size=64 --train_epochs=1 --eval_batch_size=64 --do_train --do_predict --do_test
181 | ```
182 |
183 | **说明**:文件里面通过```os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'```来选择使用的GPU。local_world_size为使用的GPU的数目。
184 |
185 | # 补充
186 |
187 | Q:怎么训练自己的数据集?
188 |
189 | A:按照样例的一般步骤里面进行即可。
190 |
191 | # 更新日志
192 |
193 | - 2022-08-08:重构了代码,使得总体结构更加简单,更易于用于不同的数据集上。
194 | - 2022-08-09:新增是否加载模型继续训练,运行参数加上--retrain。
195 |
196 | - 2023-03-30:新增基于dataparallel的分布式训练。
197 |
198 | - 2023-04-02:新增基于distributed的分布式训练。
199 |
200 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/__init__.py
--------------------------------------------------------------------------------
/bert_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | class Args:
5 | @staticmethod
6 | def parse():
7 | parser = argparse.ArgumentParser()
8 | return parser
9 |
10 | @staticmethod
11 | def initialize(parser):
12 | # args for path
13 | parser.add_argument('--output_dir', default='../checkpoints/',
14 | help='the output dir for model checkpoints')
15 |
16 | parser.add_argument('--bert_dir', default='../../model_hub/bert-base-case/',
17 | help='bert dir for uer')
18 | parser.add_argument('--data_dir', default='../data/cnews/',
19 | help='data dir for uer')
20 | parser.add_argument('--log_dir', default='../logs/',
21 | help='log dir for uer')
22 | parser.add_argument('--data_name', default='cnews',
23 | help='数据集的名称')
24 |
25 | # other args
26 | parser.add_argument('--num_tags', default=65, type=int,
27 | help='number of tags')
28 | parser.add_argument('--seed', type=int, default=123, help='random seed')
29 |
30 | parser.add_argument('--gpu_ids', type=str, default='0',
31 | help='gpu ids to use, -1 for cpu, "0,1" for multi gpu')
32 |
33 | parser.add_argument('--max_seq_len', default=256, type=int)
34 |
35 | parser.add_argument('--eval_batch_size', default=12, type=int)
36 |
37 | parser.add_argument('--swa_start', default=3, type=int,
38 | help='the epoch when swa start')
39 |
40 | # train args
41 | # This is passed in via launch.py
42 | parser.add_argument("--local_rank", type=int, default=0)
43 | # This needs to be explicitly passed in
44 | parser.add_argument("--local_world_size", type=int, default=1)
45 | parser.add_argument('--train_epochs', default=15, type=int,
46 | help='Max training epoch')
47 |
48 | parser.add_argument('--dropout_prob', default=0.1, type=float,
49 | help='drop out probability')
50 |
51 | # 2e-5
52 | parser.add_argument('--lr', default=3e-5, type=float,
53 | help='learning rate for the bert module')
54 | # 2e-3
55 | parser.add_argument('--other_lr', default=3e-4, type=float,
56 | help='learning rate for the module except bert')
57 | # 0.5
58 | parser.add_argument('--max_grad_norm', default=1, type=float,
59 | help='max grad clip')
60 |
61 | parser.add_argument('--warmup_proportion', default=0.1, type=float)
62 |
63 | parser.add_argument('--weight_decay', default=0.01, type=float)
64 |
65 | parser.add_argument('--adam_epsilon', default=1e-8, type=float)
66 |
67 | parser.add_argument('--train_batch_size', default=32, type=int)
68 |
69 | # parser.add_argument('--eval_model', default=True, action='store_true',
70 | # help='whether to eval model after training')
71 | parser.add_argument('--do_train', action='store_true',
72 | help='是否训练')
73 | parser.add_argument('--do_test', action='store_true',
74 | help='是否测试')
75 | parser.add_argument('--do_predict', action='store_true',
76 | help='是否预测')
77 | parser.add_argument('--retrain', action='store_true',
78 | help='是否加载模型继续训练')
79 |
80 | return parser
81 |
82 | def get_parser(self):
83 | parser = self.parse()
84 | parser = self.initialize(parser)
85 | return parser.parse_args()
86 |
--------------------------------------------------------------------------------
/checkpoints/cnews/占位.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/checkpoints/cnews/占位.txt
--------------------------------------------------------------------------------
/checkpoints/cpws/占位.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/checkpoints/cpws/占位.txt
--------------------------------------------------------------------------------
/data/cnews/labels.txt:
--------------------------------------------------------------------------------
1 | 教育
2 | 娱乐
3 | 家居
4 | 房产
5 | 科技
6 | 时尚
7 | 体育
8 | 财经
9 | 时政
10 | 游戏
11 |
--------------------------------------------------------------------------------
/data/cnews/process.py:
--------------------------------------------------------------------------------
1 | labels = []
2 |
3 | with open('cnews.train.txt','r') as fp:
4 | lines = fp.read().strip().split('\n')
5 | for line in lines:
6 | line = line.split('\t')
7 | labels.append(line[0])
8 |
9 | labels = set(labels)
10 | with open('./labels.txt','w') as fp:
11 | fp.write("\n".join(labels))
--------------------------------------------------------------------------------
/data/cpws/labels.txt:
--------------------------------------------------------------------------------
1 | 盗窃罪
2 | 交通肇事罪
3 | 诈骗罪
4 | 故意伤害罪
5 | 危险驾驶罪
--------------------------------------------------------------------------------
/data/cpws/process.py:
--------------------------------------------------------------------------------
1 | filename = 'train_data.txt'
2 | labels = set()
3 | with open(filename, 'r', encoding='utf-8') as f:
4 | raw_data = f.readlines()
5 | for d in raw_data:
6 | d = d.strip()
7 | d = d.split("\t")
8 | if len(d) == 2:
9 | labels.add(d[0])
10 |
11 | with open('../final_data/labels.txt', 'w', encoding='utf-8') as f:
12 | f.write("\n".join(list(labels)))
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import json
3 | import torch
4 | import numpy as np
5 | from torch.utils.data import DataLoader, Dataset
6 |
7 |
8 | class ListDataset(Dataset):
9 | def __init__(self, file_path=None, data=None, tokenizer=None, max_len=None, **kwargs):
10 | self.kwargs = kwargs
11 | if isinstance(file_path, (str, list)):
12 | self.data = self.load_data(file_path)
13 | elif isinstance(data, list):
14 | self.data = data
15 | else:
16 | raise ValueError('The input args shall be str format file_path / list format dataset')
17 |
18 | def __len__(self):
19 | return len(self.data)
20 |
21 | def __getitem__(self, index):
22 | return self.data[index]
23 |
24 | @staticmethod
25 | def load_data(file_path):
26 | return file_path
27 |
28 |
29 | # 加载数据集
30 | class CNEWSDataset(ListDataset):
31 | @staticmethod
32 | def load_data(filename):
33 | data = []
34 | with open(filename, encoding='utf-8') as f:
35 | raw_data = f.readlines()
36 | for d in raw_data:
37 | d = d.strip().split('\t')
38 | text = d[1]
39 | label = d[0]
40 | data.append((text, label))
41 | return data
42 |
43 |
44 | class CPWSDataset(ListDataset):
45 | @staticmethod
46 | def load_data(filename):
47 | data = []
48 | with open(filename, encoding='utf-8') as f:
49 | raw_data = f.readlines()
50 | for d in raw_data:
51 | d = d.strip()
52 | d = d.split("\t")
53 | if len(d) == 2:
54 | data.append((d[1], d[0]))
55 | return data
56 |
57 |
58 |
59 |
60 | class Collate:
61 | def __init__(self, tokenizer, max_len, tag2id):
62 | self.tokenizer = tokenizer
63 | self.maxlen = max_len
64 | self.tag2id = tag2id
65 |
66 | def collate_fn(self, batch):
67 | batch_labels = []
68 | batch_token_ids = []
69 | batch_attention_mask = []
70 | batch_token_type_ids = []
71 | for i, (text, label) in enumerate(batch):
72 | output = self.tokenizer.encode_plus(
73 | text=text,
74 | max_length=self.maxlen,
75 | padding="max_length",
76 | truncation='longest_first',
77 | return_token_type_ids=True,
78 | return_attention_mask=True
79 | )
80 | token_ids = output["input_ids"]
81 | token_type_ids = output["token_type_ids"]
82 | attention_mask = output["attention_mask"]
83 | batch_token_ids.append(token_ids) # 前面已经限制了长度
84 | batch_attention_mask.append(attention_mask)
85 | batch_token_type_ids.append(token_type_ids)
86 | batch_labels.append(self.tag2id[label])
87 | batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long)
88 | attention_mask = torch.tensor(batch_attention_mask, dtype=torch.long)
89 | token_type_ids = torch.tensor(batch_token_type_ids, dtype=torch.long)
90 | batch_labels = torch.tensor(batch_labels, dtype=torch.long)
91 | batch_data = {
92 | "token_ids":batch_token_ids,
93 | "attention_masks":attention_mask,
94 | "token_type_ids":token_type_ids,
95 | "labels":batch_labels
96 | }
97 | return batch_data
98 |
99 |
100 | if __name__ == "__main__":
101 | from transformers import BertTokenizer
102 |
103 | max_len = 512
104 | tokenizer = BertTokenizer.from_pretrained('../model_hub/chinese-bert-wwm-ext')
105 | train_dataset = CNEWSDataset(file_path='data/cnews/cnews.train.txt')
106 | print(train_dataset[0])
107 |
108 | with open('data/cnews/labels.txt', 'r', encoding="utf-8") as fp:
109 | labels = fp.read().strip().split("\n")
110 | id2tag = {}
111 | tag2id = {}
112 | for i, label in enumerate(labels):
113 | id2tag[i] = label
114 | tag2id[label] = i
115 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116 | collate = Collate(tokenizer=tokenizer, max_len=max_len, tag2id=tag2id, device=device)
117 | batch_size = 2
118 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate.collate_fn)
119 |
120 | for i, batch in enumerate(train_dataloader):
121 | print(batch["token_ids"].shape)
122 | print(batch["attention_masks"].shape)
123 | print(batch["token_type_ids"].shape)
124 | print(batch["labels"].shape)
125 | break
126 |
--------------------------------------------------------------------------------
/logs/preprocess.log:
--------------------------------------------------------------------------------
1 | 2021-07-16 16:48:43,343 - INFO - preprocess.py - - 150 - {'output_dir': '../checkpoints/', 'bert_dir': '../model_hub/bert-base-chinese/', 'data_dir': '../data/cnews/', 'log_dir': './logs/', 'num_tags': 65, 'seed': 123, 'gpu_ids': '0', 'max_seq_len': 256, 'eval_batch_size': 12, 'swa_start': 3, 'train_epochs': 15, 'dropout_prob': 0.1, 'lr': 3e-05, 'other_lr': 0.0003, 'max_grad_norm': 1, 'warmup_proportion': 0.1, 'weight_decay': 0.01, 'adam_epsilon': 1e-08, 'train_batch_size': 32, 'eval_model': True}
2 | 2021-07-16 16:48:43,815 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 50000 examples to features
3 | 2021-07-16 16:48:43,822 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-0 ***
4 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 马 晓 旭 意 外 受 伤 让 国 奥 警 惕 无 奈 大 雨 格 外 青 睐 殷 家 军 记 者 傅 亚 雨 沈 阳 报 道 来 到 沈 阳 , 国 奥 队 依 然 没 有 摆 脱 雨 水 的 困 扰 。 7 月 31 日 下 午 6 点 , 国 奥 队 的 日 常 训 练 再 度 受 到 大 雨 的 干 扰 , 无 奈 之 下 队 员 们 只 慢 跑 了 25 分 钟 就 草 草 收 场 。 31 日 上 午 10 点 , 国 奥 队 在 奥 体 中 心 外 场 训 练 的 时 候 , 天 就 是 阴 沉 沉 的 , 气 象 预 报 显 示 当 天 下 午 沈 阳 就 有 大 雨 , 但 幸 好 队 伍 上 午 的 训 练 并 没 有 受 到 任 何 干 扰 。 下 午 6 点 , 当 球 队 抵 达 训 练 场 时 , 大 雨 已 经 下 了 几 个 小 时 , 而 且 丝 毫 没 有 停 下 来 的 意 思 。 抱 着 试 一 试 的 态 度 , 球 队 开 始 了 当 天 下 午 的 例 行 训 练 , 25 分 钟 过 去 了 , 天 气 没 有 任 何 转 好 的 迹 象 , 为 了 保 护 球 [SEP]
5 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7716, 3236, 3195, 2692, 1912, 1358, 839, 6375, 1744, 1952, 6356, 2664, 3187, 1937, 1920, 7433, 3419, 1912, 7471, 4712, 3668, 2157, 1092, 6381, 5442, 987, 762, 7433, 3755, 7345, 2845, 6887, 3341, 1168, 3755, 7345, 8024, 1744, 1952, 7339, 898, 4197, 3766, 3300, 3030, 5564, 7433, 3717, 4638, 1737, 2817, 511, 128, 3299, 8176, 3189, 678, 1286, 127, 4157, 8024, 1744, 1952, 7339, 4638, 3189, 2382, 6378, 5298, 1086, 2428, 1358, 1168, 1920, 7433, 4638, 2397, 2817, 8024, 3187, 1937, 722, 678, 7339, 1447, 812, 1372, 2714, 6651, 749, 8132, 1146, 7164, 2218, 5770, 5770, 3119, 1767, 511, 8176, 3189, 677, 1286, 8108, 4157, 8024, 1744, 1952, 7339, 1762, 1952, 860, 704, 2552, 1912, 1767, 6378, 5298, 4638, 3198, 952, 8024, 1921, 2218, 3221, 7346, 3756, 3756, 4638, 8024, 3698, 6496, 7564, 2845, 3227, 4850, 2496, 1921, 678, 1286, 3755, 7345, 2218, 3300, 1920, 7433, 8024, 852, 2401, 1962, 7339, 824, 677, 1286, 4638, 6378, 5298, 2400, 3766, 3300, 1358, 1168, 818, 862, 2397, 2817, 511, 678, 1286, 127, 4157, 8024, 2496, 4413, 7339, 2850, 6809, 6378, 5298, 1767, 3198, 8024, 1920, 7433, 2347, 5307, 678, 749, 1126, 702, 2207, 3198, 8024, 5445, 684, 692, 3690, 3766, 3300, 977, 678, 3341, 4638, 2692, 2590, 511, 2849, 4708, 6407, 671, 6407, 4638, 2578, 2428, 8024, 4413, 7339, 2458, 1993, 749, 2496, 1921, 678, 1286, 4638, 891, 6121, 6378, 5298, 8024, 8132, 1146, 7164, 6814, 1343, 749, 8024, 1921, 3698, 3766, 3300, 818, 862, 6760, 1962, 4638, 6839, 6496, 8024, 711, 749, 924, 2844, 4413, 102]
6 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
7 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
8 | 2021-07-16 16:48:43,823 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
9 | 2021-07-16 16:48:43,838 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-1 ***
10 | 2021-07-16 16:48:43,838 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 商 瑞 华 首 战 复 仇 心 切 中 国 玫 瑰 要 用 美 国 方 式 攻 克 瑞 典 多 曼 来 了 , 瑞 典 来 了 , 商 瑞 华 首 战 求 3 分 的 信 心 也 来 了 。 距 离 首 战 72 小 时 当 口 , 中 国 女 足 彻 底 从 [UNK] 恐 瑞 症 [UNK] 当 中 获 得 解 脱 , 因 为 商 瑞 华 已 经 找 到 了 瑞 典 人 的 软 肋 。 找 到 软 肋 , 保 密 4 月 20 日 奥 运 会 分 组 抽 签 结 果 出 来 后 , 中 国 姑 娘 就 把 瑞 典 锁 定 为 关 乎 奥 运 成 败 的 头 号 劲 敌 , 因 为 除 了 浦 玮 等 个 别 老 将 之 外 , 现 役 女 足 将 士 竟 然 没 有 人 尝 过 击 败 瑞 典 的 滋 味 。 在 中 瑞 两 队 共 计 15 次 交 锋 的 历 史 上 , 中 国 队 6 胜 3 平 6 负 与 瑞 典 队 平 分 秋 色 , 但 从 2001 年 起 至 今 近 8 年 时 间 , 中 国 在 同 瑞 典 连 续 5 次 交 锋 中 均 未 尝 胜 绩 , 战 绩 为 2 平 3 负 。 尽 管 八 年 [SEP]
11 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1555, 4448, 1290, 7674, 2773, 1908, 790, 2552, 1147, 704, 1744, 4382, 4456, 6206, 4500, 5401, 1744, 3175, 2466, 3122, 1046, 4448, 1073, 1914, 3294, 3341, 749, 8024, 4448, 1073, 3341, 749, 8024, 1555, 4448, 1290, 7674, 2773, 3724, 124, 1146, 4638, 928, 2552, 738, 3341, 749, 511, 6655, 4895, 7674, 2773, 8325, 2207, 3198, 2496, 1366, 8024, 704, 1744, 1957, 6639, 2515, 2419, 794, 100, 2607, 4448, 4568, 100, 2496, 704, 5815, 2533, 6237, 5564, 8024, 1728, 711, 1555, 4448, 1290, 2347, 5307, 2823, 1168, 749, 4448, 1073, 782, 4638, 6763, 5490, 511, 2823, 1168, 6763, 5490, 8024, 924, 2166, 125, 3299, 8113, 3189, 1952, 6817, 833, 1146, 5299, 2853, 5041, 5310, 3362, 1139, 3341, 1400, 8024, 704, 1744, 1996, 2023, 2218, 2828, 4448, 1073, 7219, 2137, 711, 1068, 725, 1952, 6817, 2768, 6571, 4638, 1928, 1384, 1226, 3127, 8024, 1728, 711, 7370, 749, 3855, 4383, 5023, 702, 1166, 5439, 2199, 722, 1912, 8024, 4385, 2514, 1957, 6639, 2199, 1894, 4994, 4197, 3766, 3300, 782, 2214, 6814, 1140, 6571, 4448, 1073, 4638, 3996, 1456, 511, 1762, 704, 4448, 697, 7339, 1066, 6369, 8115, 3613, 769, 7226, 4638, 1325, 1380, 677, 8024, 704, 1744, 7339, 127, 5526, 124, 2398, 127, 6566, 680, 4448, 1073, 7339, 2398, 1146, 4904, 5682, 8024, 852, 794, 8285, 2399, 6629, 5635, 791, 6818, 129, 2399, 3198, 7313, 8024, 704, 1744, 1762, 1398, 4448, 1073, 6825, 5330, 126, 3613, 769, 7226, 704, 1772, 3313, 2214, 5526, 5327, 8024, 2773, 5327, 711, 123, 2398, 124, 6566, 511, 2226, 5052, 1061, 2399, 102]
12 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
13 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
14 | 2021-07-16 16:48:43,839 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
15 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-2 ***
16 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 冠 军 球 队 迎 新 欢 乐 派 对 黄 旭 获 大 奖 张 军 赢 下 pk 赛 新 浪 体 育 讯 12 月 27 日 晚 , [UNK] 冠 军 高 尔 夫 球 队 迎 新 高 球 欢 乐 派 对 [UNK] 活 动 在 北 京 都 市 名 人 高 尔 夫 俱 乐 部 举 行 。 邢 傲 伟 、 黄 旭 、 邹 凯 、 滕 海 滨 、 杨 凌 、 马 燕 红 、 张 军 、 王 丽 萍 、 杨 影 、 阎 森 、 毕 文 静 、 夏 煊 泽 、 张 璇 、 叶 乔 波 、 莫 慧 兰 、 周 雅 菲 、 胡 妮 、 戴 菲 菲 、 马 健 等 奥 运 冠 军 、 世 界 冠 军 参 加 了 当 晚 的 活 动 。 [UNK] 以 球 会 友 、 联 谊 交 流 [UNK] 为 宗 旨 的 冠 军 高 尔 夫 球 队 , 从 创 立 之 初 就 秉 承 关 爱 社 会 、 回 馈 社 会 的 价 值 观 。 此 次 迎 新 派 对 活 动 的 举 行 , 标 志 着 这 个 为 中 国 金 牌 运 动 员 提 供 的 放 松 身 心 、 以 球 会 友 的 平 台 正 式 全 面 启 动 。 球 队 已 经 与 都 市 名 人 高 [SEP]
17 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1094, 1092, 4413, 7339, 6816, 3173, 3614, 727, 3836, 2190, 7942, 3195, 5815, 1920, 1946, 2476, 1092, 6617, 678, 8465, 6612, 3173, 3857, 860, 5509, 6380, 8110, 3299, 8149, 3189, 3241, 8024, 100, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 6816, 3173, 7770, 4413, 3614, 727, 3836, 2190, 100, 3833, 1220, 1762, 1266, 776, 6963, 2356, 1399, 782, 7770, 2209, 1923, 936, 727, 6956, 715, 6121, 511, 6928, 1000, 836, 510, 7942, 3195, 510, 6941, 1132, 510, 4001, 3862, 4012, 510, 3342, 1119, 510, 7716, 4242, 5273, 510, 2476, 1092, 510, 4374, 714, 5847, 510, 3342, 2512, 510, 7330, 3481, 510, 3684, 3152, 7474, 510, 1909, 4201, 3813, 510, 2476, 4462, 510, 1383, 730, 3797, 510, 5811, 2716, 1065, 510, 1453, 7414, 5838, 510, 5529, 1984, 510, 2785, 5838, 5838, 510, 7716, 978, 5023, 1952, 6817, 1094, 1092, 510, 686, 4518, 1094, 1092, 1346, 1217, 749, 2496, 3241, 4638, 3833, 1220, 511, 100, 809, 4413, 833, 1351, 510, 5468, 6449, 769, 3837, 100, 711, 2134, 3192, 4638, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 8024, 794, 1158, 4989, 722, 1159, 2218, 4903, 2824, 1068, 4263, 4852, 833, 510, 1726, 7668, 4852, 833, 4638, 817, 966, 6225, 511, 3634, 3613, 6816, 3173, 3836, 2190, 3833, 1220, 4638, 715, 6121, 8024, 3403, 2562, 4708, 6821, 702, 711, 704, 1744, 7032, 4277, 6817, 1220, 1447, 2990, 897, 4638, 3123, 3351, 6716, 2552, 510, 809, 4413, 833, 1351, 4638, 2398, 1378, 3633, 2466, 1059, 7481, 1423, 1220, 511, 4413, 7339, 2347, 5307, 680, 6963, 2356, 1399, 782, 7770, 102]
18 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
19 | 2021-07-16 16:48:43,857 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
20 | 2021-07-16 16:48:43,858 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
21 | 2021-07-16 16:55:52,209 - INFO - preprocess.py - convert_examples_to_features - 121 - Build 50000 features
22 | 2021-07-16 16:55:52,275 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 5000 examples to features
23 | 2021-07-16 16:55:52,277 - INFO - preprocess.py - convert_bert_example - 84 - *** dev_example-0 ***
24 | 2021-07-16 16:55:52,277 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 黄 蜂 vs 湖 人 首 发 : 科 比 带 伤 战 保 罗 加 索 尔 救 赎 之 战 新 浪 体 育 讯 北 京 时 间 4 月 27 日 , nba 季 后 赛 首 轮 洛 杉 矶 湖 人 主 场 迎 战 新 奥 尔 良 黄 蜂 , 此 前 的 比 赛 中 , 双 方 战 成 2 - 2 平 , 因 此 本 场 比 赛 对 于 两 支 球 队 来 说 都 非 常 重 要 , 赛 前 双 方 也 公 布 了 首 发 阵 容 : 湖 人 队 : 费 舍 尔 、 科 比 、 阿 泰 斯 特 、 加 索 尔 、 拜 纳 姆 黄 蜂 队 : 保 罗 、 贝 里 内 利 、 阿 里 扎 、 兰 德 里 、 奥 卡 福 [ 新 浪 nba 官 方 微 博 ] [ 新 浪 nba 湖 人 新 闻 动 态 微 博 ] [ 新 浪 nba 专 题 ] [ 黄 蜂 vs 湖 人 图 文 直 播 室 ] ( 新 浪 体 育 ) [SEP]
25 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7942, 6044, 8349, 3959, 782, 7674, 1355, 8038, 4906, 3683, 2372, 839, 2773, 924, 5384, 1217, 5164, 2209, 3131, 6604, 722, 2773, 3173, 3857, 860, 5509, 6380, 1266, 776, 3198, 7313, 125, 3299, 8149, 3189, 8024, 8391, 2108, 1400, 6612, 7674, 6762, 3821, 3329, 4768, 3959, 782, 712, 1767, 6816, 2773, 3173, 1952, 2209, 5679, 7942, 6044, 8024, 3634, 1184, 4638, 3683, 6612, 704, 8024, 1352, 3175, 2773, 2768, 123, 118, 123, 2398, 8024, 1728, 3634, 3315, 1767, 3683, 6612, 2190, 754, 697, 3118, 4413, 7339, 3341, 6432, 6963, 7478, 2382, 7028, 6206, 8024, 6612, 1184, 1352, 3175, 738, 1062, 2357, 749, 7674, 1355, 7347, 2159, 8038, 3959, 782, 7339, 8038, 6589, 5650, 2209, 510, 4906, 3683, 510, 7350, 3805, 3172, 4294, 510, 1217, 5164, 2209, 510, 2876, 5287, 1990, 7942, 6044, 7339, 8038, 924, 5384, 510, 6564, 7027, 1079, 1164, 510, 7350, 7027, 2799, 510, 1065, 2548, 7027, 510, 1952, 1305, 4886, 138, 3173, 3857, 8391, 2135, 3175, 2544, 1300, 140, 138, 3173, 3857, 8391, 3959, 782, 3173, 7319, 1220, 2578, 2544, 1300, 140, 138, 3173, 3857, 8391, 683, 7579, 140, 138, 7942, 6044, 8349, 3959, 782, 1745, 3152, 4684, 3064, 2147, 140, 113, 3173, 3857, 860, 5509, 114, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
26 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
27 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
28 | 2021-07-16 16:55:52,278 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
29 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 84 - *** dev_example-1 ***
30 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 1. 7 秒 神 之 一 击 救 马 刺 王 朝 于 危 难 这 个 新 秀 有 点 牛 ! 新 浪 体 育 讯 在 刚 刚 结 束 的 比 赛 中 , 回 到 主 场 的 马 刺 通 过 加 时 以 110 - 103 惊 险 地 战 胜 了 灰 熊 , 避 免 了 让 主 场 观 众 见 证 黑 八 的 尴 尬 。 在 常 规 时 间 的 最 后 关 头 , 加 里 - 尼 尔 命 中 一 记 大 号 三 分 , 这 个 进 球 也 帮 助 马 刺 把 比 赛 带 进 了 加 时 , 并 最 终 翻 盘 成 功 。 [UNK] 波 波 维 奇 教 练 安 排 得 很 详 细 , 他 告 诉 我 有 机 会 就 要 出 手 。 邓 肯 的 掩 护 做 得 很 好 , 他 帮 我 完 全 挡 住 了 对 方 的 防 守 球 员 , 我 投 篮 的 视 野 非 常 好 , 于 是 我 就 出 手 了 。 [UNK] 即 使 命 中 了 这 记 价 值 连 城 的 球 , 尼 尔 依 然 保 持 着 低 调 。 在 被 问 及 这 是 不 是 他 职 业 生 涯 最 重 要 的 投 篮 时 , 他 说 : [UNK] 没 错 , 到 [SEP]
31 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 122, 119, 128, 4907, 4868, 722, 671, 1140, 3131, 7716, 1173, 4374, 3308, 754, 1314, 7410, 6821, 702, 3173, 4899, 3300, 4157, 4281, 8013, 3173, 3857, 860, 5509, 6380, 1762, 1157, 1157, 5310, 3338, 4638, 3683, 6612, 704, 8024, 1726, 1168, 712, 1767, 4638, 7716, 1173, 6858, 6814, 1217, 3198, 809, 8406, 118, 8615, 2661, 7372, 1765, 2773, 5526, 749, 4129, 4220, 8024, 6912, 1048, 749, 6375, 712, 1767, 6225, 830, 6224, 6395, 7946, 1061, 4638, 2219, 2217, 511, 1762, 2382, 6226, 3198, 7313, 4638, 3297, 1400, 1068, 1928, 8024, 1217, 7027, 118, 2225, 2209, 1462, 704, 671, 6381, 1920, 1384, 676, 1146, 8024, 6821, 702, 6822, 4413, 738, 2376, 1221, 7716, 1173, 2828, 3683, 6612, 2372, 6822, 749, 1217, 3198, 8024, 2400, 3297, 5303, 5436, 4669, 2768, 1216, 511, 100, 3797, 3797, 5335, 1936, 3136, 5298, 2128, 2961, 2533, 2523, 6422, 5301, 8024, 800, 1440, 6401, 2769, 3300, 3322, 833, 2218, 6206, 1139, 2797, 511, 6924, 5507, 4638, 2973, 2844, 976, 2533, 2523, 1962, 8024, 800, 2376, 2769, 2130, 1059, 2913, 857, 749, 2190, 3175, 4638, 7344, 2127, 4413, 1447, 8024, 2769, 2832, 5074, 4638, 6228, 7029, 7478, 2382, 1962, 8024, 754, 3221, 2769, 2218, 1139, 2797, 749, 511, 100, 1315, 886, 1462, 704, 749, 6821, 6381, 817, 966, 6825, 1814, 4638, 4413, 8024, 2225, 2209, 898, 4197, 924, 2898, 4708, 856, 6444, 511, 1762, 6158, 7309, 1350, 6821, 3221, 679, 3221, 800, 5466, 689, 4495, 3889, 3297, 7028, 6206, 4638, 2832, 5074, 3198, 8024, 800, 6432, 8038, 100, 3766, 7231, 8024, 1168, 102]
32 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
33 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
34 | 2021-07-16 16:55:52,286 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
35 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 84 - *** dev_example-2 ***
36 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 1 人 灭 掘 金 ! 神 般 杜 兰 特 ! 他 想 要 分 的 时 候 没 人 能 挡 新 浪 体 育 讯 在 nba 的 世 界 里 , 真 的 猛 男 , 敢 于 直 面 惨 淡 的 手 感 , 敢 于 正 视 落 后 的 局 面 , 然 后 用 一 己 之 力 , 力 挽 狂 澜 , 点 燃 球 迷 激 情 , 最 后 微 微 一 笑 , 带 领 球 队 在 季 后 赛 的 战 场 上 赢 下 比 赛 , 并 进 入 下 一 轮 , 今 日 雷 霆 凯 文 - 杜 兰 特 所 做 的 , 无 非 就 是 这 样 的 事 情 。 巨 星 这 个 东 西 很 难 定 义 , 有 时 候 你 就 是 30 分 30 板 也 未 必 能 得 到 一 个 巨 星 名 头 , 反 而 会 有 可 能 会 被 称 为 刷 子 , 而 巨 星 不 仅 是 数 据 上 能 够 出 类 拔 萃 , 也 不 仅 仅 是 能 够 帮 助 球 队 赢 球 , 从 意 志 力 层 面 上 来 讲 , 母 队 比 赛 快 输 了 , 队 友 的 腿 都 开 始 抖 了 , 所 有 人 都 在 看 着 你 , 球 都 到 了 你 [SEP]
37 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 122, 782, 4127, 2963, 7032, 8013, 4868, 5663, 3336, 1065, 4294, 8013, 800, 2682, 6206, 1146, 4638, 3198, 952, 3766, 782, 5543, 2913, 3173, 3857, 860, 5509, 6380, 1762, 8391, 4638, 686, 4518, 7027, 8024, 4696, 4638, 4338, 4511, 8024, 3140, 754, 4684, 7481, 2673, 3909, 4638, 2797, 2697, 8024, 3140, 754, 3633, 6228, 5862, 1400, 4638, 2229, 7481, 8024, 4197, 1400, 4500, 671, 2346, 722, 1213, 8024, 1213, 2924, 4312, 4073, 8024, 4157, 4234, 4413, 6837, 4080, 2658, 8024, 3297, 1400, 2544, 2544, 671, 5010, 8024, 2372, 7566, 4413, 7339, 1762, 2108, 1400, 6612, 4638, 2773, 1767, 677, 6617, 678, 3683, 6612, 8024, 2400, 6822, 1057, 678, 671, 6762, 8024, 791, 3189, 7440, 7447, 1132, 3152, 118, 3336, 1065, 4294, 2792, 976, 4638, 8024, 3187, 7478, 2218, 3221, 6821, 3416, 4638, 752, 2658, 511, 2342, 3215, 6821, 702, 691, 6205, 2523, 7410, 2137, 721, 8024, 3300, 3198, 952, 872, 2218, 3221, 8114, 1146, 8114, 3352, 738, 3313, 2553, 5543, 2533, 1168, 671, 702, 2342, 3215, 1399, 1928, 8024, 1353, 5445, 833, 3300, 1377, 5543, 833, 6158, 4917, 711, 1170, 2094, 8024, 5445, 2342, 3215, 679, 788, 3221, 3144, 2945, 677, 5543, 1916, 1139, 5102, 2869, 5842, 8024, 738, 679, 788, 788, 3221, 5543, 1916, 2376, 1221, 4413, 7339, 6617, 4413, 8024, 794, 2692, 2562, 1213, 2231, 7481, 677, 3341, 6382, 8024, 3678, 7339, 3683, 6612, 2571, 6783, 749, 8024, 7339, 1351, 4638, 5597, 6963, 2458, 1993, 2833, 749, 8024, 2792, 3300, 782, 6963, 1762, 4692, 4708, 872, 8024, 4413, 6963, 1168, 749, 872, 102]
38 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
39 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
40 | 2021-07-16 16:55:52,299 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
41 | 2021-07-16 16:56:29,964 - INFO - preprocess.py - convert_examples_to_features - 121 - Build 5000 features
42 | 2021-07-16 16:56:30,087 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 10000 examples to features
43 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 84 - *** test_example-0 ***
44 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 鲍 勃 库 西 奖 归 谁 属 ? ncaa 最 强 控 卫 是 坎 巴 还 是 弗 神 新 浪 体 育 讯 如 今 , 本 赛 季 的 ncaa 进 入 到 了 末 段 , 各 项 奖 项 的 评 选 结 果 也 即 将 出 炉 , 其 中 评 选 最 佳 控 卫 的 鲍 勃 - 库 西 奖 就 将 在 下 周 最 终 四 强 战 时 公 布 , 鲍 勃 - 库 西 奖 是 由 奈 史 密 斯 篮 球 名 人 堂 提 供 , 旨 在 奖 励 年 度 最 佳 大 学 控 卫 。 最 终 获 奖 的 球 员 也 即 将 在 以 下 几 名 热 门 人 选 中 产 生 。 〈 〈 〈 ncaa 疯 狂 三 月 专 题 主 页 上 线 , 点 击 链 接 查 看 精 彩 内 容 吉 梅 尔 - 弗 雷 戴 特 , 杨 百 翰 大 学 [UNK] 弗 神 [UNK] 吉 梅 尔 - 弗 雷 戴 特 一 直 都 备 受 关 注 , 他 不 仅 仅 是 一 名 射 手 , 他 会 用 [UNK] 终 结 对 手 脚 踝 [UNK] 一 样 的 变 向 过 掉 面 前 的 防 守 者 , 并 且 他 可 以 用 任 意 一 支 手 完 成 得 分 , [SEP]
45 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7828, 1234, 2417, 6205, 1946, 2495, 6443, 2247, 8043, 12394, 10226, 3297, 2487, 2971, 1310, 3221, 1775, 2349, 6820, 3221, 2472, 4868, 3173, 3857, 860, 5509, 6380, 1963, 791, 8024, 3315, 6612, 2108, 4638, 12394, 10226, 6822, 1057, 1168, 749, 3314, 3667, 8024, 1392, 7555, 1946, 7555, 4638, 6397, 6848, 5310, 3362, 738, 1315, 2199, 1139, 4140, 8024, 1071, 704, 6397, 6848, 3297, 881, 2971, 1310, 4638, 7828, 1234, 118, 2417, 6205, 1946, 2218, 2199, 1762, 678, 1453, 3297, 5303, 1724, 2487, 2773, 3198, 1062, 2357, 8024, 7828, 1234, 118, 2417, 6205, 1946, 3221, 4507, 1937, 1380, 2166, 3172, 5074, 4413, 1399, 782, 1828, 2990, 897, 8024, 3192, 1762, 1946, 1225, 2399, 2428, 3297, 881, 1920, 2110, 2971, 1310, 511, 3297, 5303, 5815, 1946, 4638, 4413, 1447, 738, 1315, 2199, 1762, 809, 678, 1126, 1399, 4178, 7305, 782, 6848, 704, 772, 4495, 511, 515, 515, 515, 12394, 10226, 4556, 4312, 676, 3299, 683, 7579, 712, 7552, 677, 5296, 8024, 4157, 1140, 7216, 2970, 3389, 4692, 5125, 2506, 1079, 2159, 1395, 3449, 2209, 118, 2472, 7440, 2785, 4294, 8024, 3342, 4636, 5432, 1920, 2110, 100, 2472, 4868, 100, 1395, 3449, 2209, 118, 2472, 7440, 2785, 4294, 671, 4684, 6963, 1906, 1358, 1068, 3800, 8024, 800, 679, 788, 788, 3221, 671, 1399, 2198, 2797, 8024, 800, 833, 4500, 100, 5303, 5310, 2190, 2797, 5558, 6674, 100, 671, 3416, 4638, 1359, 1403, 6814, 2957, 7481, 1184, 4638, 7344, 2127, 5442, 8024, 2400, 684, 800, 1377, 809, 4500, 818, 2692, 671, 3118, 2797, 2130, 2768, 2533, 1146, 8024, 102]
46 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
47 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
48 | 2021-07-16 16:56:30,119 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
49 | 2021-07-16 16:56:30,134 - INFO - preprocess.py - convert_bert_example - 84 - *** test_example-1 ***
50 | 2021-07-16 16:56:30,134 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 麦 基 砍 28 + 18 + 5 却 充 满 寂 寞 纪 录 之 夜 他 的 痛 阿 联 最 懂 新 浪 体 育 讯 上 天 对 每 个 人 都 是 公 平 的 , 贾 维 尔 - 麦 基 也 不 例 外 。 今 天 华 盛 顿 奇 才 客 场 104 - 114 负 于 金 州 勇 士 , 麦 基 就 好 不 容 易 等 到 [UNK] 捏 软 柿 子 [UNK] 的 机 会 , 上 半 场 打 出 现 象 级 表 现 , 只 可 惜 无 法 一 以 贯 之 。 最 终 , 麦 基 12 投 9 中 , 得 到 生 涯 最 高 的 28 分 , 以 及 平 生 涯 最 佳 的 18 个 篮 板 , 另 有 5 次 封 盖 。 此 外 , 他 11 次 罚 球 命 中 10 个 , 这 两 项 也 均 为 生 涯 最 高 。 如 果 在 赛 前 搞 个 竞 猜 , 上 半 场 谁 会 是 奇 才 阵 中 罚 球 次 数 最 多 的 球 员 ? 若 有 人 答 曰 [UNK] 麦 基 [UNK] , 不 是 恶 搞 就 是 脑 残 。 但 半 场 结 束 , 麦 基 竟 砍 下 22 分 ( 第 二 节 砍 下 14 分 ) 。 更 罕 见 的 , 则 是 [SEP]
51 | 2021-07-16 16:56:30,134 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7931, 1825, 4775, 8143, 116, 8123, 116, 126, 1316, 1041, 4007, 2163, 2174, 5279, 2497, 722, 1915, 800, 4638, 4578, 7350, 5468, 3297, 2743, 3173, 3857, 860, 5509, 6380, 677, 1921, 2190, 3680, 702, 782, 6963, 3221, 1062, 2398, 4638, 8024, 6593, 5335, 2209, 118, 7931, 1825, 738, 679, 891, 1912, 511, 791, 1921, 1290, 4670, 7561, 1936, 2798, 2145, 1767, 8503, 118, 8866, 6566, 754, 7032, 2336, 1235, 1894, 8024, 7931, 1825, 2218, 1962, 679, 2159, 3211, 5023, 1168, 100, 2934, 6763, 3398, 2094, 100, 4638, 3322, 833, 8024, 677, 1288, 1767, 2802, 1139, 4385, 6496, 5277, 6134, 4385, 8024, 1372, 1377, 2667, 3187, 3791, 671, 809, 6581, 722, 511, 3297, 5303, 8024, 7931, 1825, 8110, 2832, 130, 704, 8024, 2533, 1168, 4495, 3889, 3297, 7770, 4638, 8143, 1146, 8024, 809, 1350, 2398, 4495, 3889, 3297, 881, 4638, 8123, 702, 5074, 3352, 8024, 1369, 3300, 126, 3613, 2196, 4667, 511, 3634, 1912, 8024, 800, 8111, 3613, 5385, 4413, 1462, 704, 8108, 702, 8024, 6821, 697, 7555, 738, 1772, 711, 4495, 3889, 3297, 7770, 511, 1963, 3362, 1762, 6612, 1184, 3018, 702, 4993, 4339, 8024, 677, 1288, 1767, 6443, 833, 3221, 1936, 2798, 7347, 704, 5385, 4413, 3613, 3144, 3297, 1914, 4638, 4413, 1447, 8043, 5735, 3300, 782, 5031, 3288, 100, 7931, 1825, 100, 8024, 679, 3221, 2626, 3018, 2218, 3221, 5554, 3655, 511, 852, 1288, 1767, 5310, 3338, 8024, 7931, 1825, 4994, 4775, 678, 8130, 1146, 113, 5018, 753, 5688, 4775, 678, 8122, 1146, 114, 511, 3291, 5383, 6224, 4638, 8024, 1156, 3221, 102]
52 | 2021-07-16 16:56:30,135 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
53 | 2021-07-16 16:56:30,135 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
54 | 2021-07-16 16:56:30,135 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
55 | 2021-07-16 16:56:30,138 - INFO - preprocess.py - convert_bert_example - 84 - *** test_example-2 ***
56 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 黄 蜂 vs 湖 人 首 发 : 科 比 冲 击 七 连 胜 火 箭 两 旧 将 登 场 新 浪 体 育 讯 北 京 时 间 3 月 28 日 , nba 常 规 赛 洛 杉 矶 湖 人 主 场 迎 战 新 奥 尔 良 黄 蜂 , 赛 前 双 方 也 公 布 了 首 发 阵 容 : 点 击 进 入 新 浪 体 育 视 频 直 播 室 点 击 进 入 新 浪 体 育 图 文 直 播 室 点 击 进 入 新 浪 体 育 nba 专 题 点 击 进 入 新 浪 nba 官 方 微 博 双 方 首 发 阵 容 : 湖 人 队 : 德 里 克 - 费 舍 尔 、 科 比 - 布 莱 恩 特 、 罗 恩 - 阿 泰 斯 特 、 保 罗 - 加 索 尔 、 安 德 鲁 - 拜 纳 姆 黄 蜂 队 : 克 里 斯 - 保 罗 、 马 科 - 贝 里 内 利 、 特 雷 沃 - 阿 里 扎 、 卡 尔 - 兰 德 里 、 埃 梅 卡 - 奥 卡 福 ( 新 浪 体 育 ) [SEP]
57 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7942, 6044, 8349, 3959, 782, 7674, 1355, 8038, 4906, 3683, 1103, 1140, 673, 6825, 5526, 4125, 5055, 697, 3191, 2199, 4633, 1767, 3173, 3857, 860, 5509, 6380, 1266, 776, 3198, 7313, 124, 3299, 8143, 3189, 8024, 8391, 2382, 6226, 6612, 3821, 3329, 4768, 3959, 782, 712, 1767, 6816, 2773, 3173, 1952, 2209, 5679, 7942, 6044, 8024, 6612, 1184, 1352, 3175, 738, 1062, 2357, 749, 7674, 1355, 7347, 2159, 8038, 4157, 1140, 6822, 1057, 3173, 3857, 860, 5509, 6228, 7574, 4684, 3064, 2147, 4157, 1140, 6822, 1057, 3173, 3857, 860, 5509, 1745, 3152, 4684, 3064, 2147, 4157, 1140, 6822, 1057, 3173, 3857, 860, 5509, 8391, 683, 7579, 4157, 1140, 6822, 1057, 3173, 3857, 8391, 2135, 3175, 2544, 1300, 1352, 3175, 7674, 1355, 7347, 2159, 8038, 3959, 782, 7339, 8038, 2548, 7027, 1046, 118, 6589, 5650, 2209, 510, 4906, 3683, 118, 2357, 5812, 2617, 4294, 510, 5384, 2617, 118, 7350, 3805, 3172, 4294, 510, 924, 5384, 118, 1217, 5164, 2209, 510, 2128, 2548, 7826, 118, 2876, 5287, 1990, 7942, 6044, 7339, 8038, 1046, 7027, 3172, 118, 924, 5384, 510, 7716, 4906, 118, 6564, 7027, 1079, 1164, 510, 4294, 7440, 3753, 118, 7350, 7027, 2799, 510, 1305, 2209, 118, 1065, 2548, 7027, 510, 1812, 3449, 1305, 118, 1952, 1305, 4886, 113, 3173, 3857, 860, 5509, 114, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
58 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
59 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
60 | 2021-07-16 16:56:30,139 - INFO - preprocess.py - convert_bert_example - 89 - labels: [{'教育': 0, '娱乐': 1, '家居': 2, '房产': 3, '科技': 4, '时尚': 5, '体育': 6, '财经': 7, '时政': 8, '游戏': 9}]
61 | 2021-07-16 16:58:00,802 - INFO - preprocess.py - convert_examples_to_features - 121 - Build 10000 features
62 | 2021-07-16 17:01:09,034 - INFO - preprocess.py - - 150 - {'output_dir': '../checkpoints/', 'bert_dir': '../model_hub/bert-base-chinese/', 'data_dir': '../data/cnews/', 'log_dir': './logs/', 'num_tags': 65, 'seed': 123, 'gpu_ids': '0', 'max_seq_len': 256, 'eval_batch_size': 12, 'swa_start': 3, 'train_epochs': 15, 'dropout_prob': 0.1, 'lr': 3e-05, 'other_lr': 0.0003, 'max_grad_norm': 1, 'warmup_proportion': 0.1, 'weight_decay': 0.01, 'adam_epsilon': 1e-08, 'train_batch_size': 32, 'eval_model': True}
63 | 2021-07-16 17:01:09,505 - INFO - preprocess.py - convert_examples_to_features - 106 - Convert 50000 examples to features
64 | 2021-07-16 17:01:09,512 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-0 ***
65 | 2021-07-16 17:01:09,512 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 马 晓 旭 意 外 受 伤 让 国 奥 警 惕 无 奈 大 雨 格 外 青 睐 殷 家 军 记 者 傅 亚 雨 沈 阳 报 道 来 到 沈 阳 , 国 奥 队 依 然 没 有 摆 脱 雨 水 的 困 扰 。 7 月 31 日 下 午 6 点 , 国 奥 队 的 日 常 训 练 再 度 受 到 大 雨 的 干 扰 , 无 奈 之 下 队 员 们 只 慢 跑 了 25 分 钟 就 草 草 收 场 。 31 日 上 午 10 点 , 国 奥 队 在 奥 体 中 心 外 场 训 练 的 时 候 , 天 就 是 阴 沉 沉 的 , 气 象 预 报 显 示 当 天 下 午 沈 阳 就 有 大 雨 , 但 幸 好 队 伍 上 午 的 训 练 并 没 有 受 到 任 何 干 扰 。 下 午 6 点 , 当 球 队 抵 达 训 练 场 时 , 大 雨 已 经 下 了 几 个 小 时 , 而 且 丝 毫 没 有 停 下 来 的 意 思 。 抱 着 试 一 试 的 态 度 , 球 队 开 始 了 当 天 下 午 的 例 行 训 练 , 25 分 钟 过 去 了 , 天 气 没 有 任 何 转 好 的 迹 象 , 为 了 保 护 球 [SEP]
66 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 7716, 3236, 3195, 2692, 1912, 1358, 839, 6375, 1744, 1952, 6356, 2664, 3187, 1937, 1920, 7433, 3419, 1912, 7471, 4712, 3668, 2157, 1092, 6381, 5442, 987, 762, 7433, 3755, 7345, 2845, 6887, 3341, 1168, 3755, 7345, 8024, 1744, 1952, 7339, 898, 4197, 3766, 3300, 3030, 5564, 7433, 3717, 4638, 1737, 2817, 511, 128, 3299, 8176, 3189, 678, 1286, 127, 4157, 8024, 1744, 1952, 7339, 4638, 3189, 2382, 6378, 5298, 1086, 2428, 1358, 1168, 1920, 7433, 4638, 2397, 2817, 8024, 3187, 1937, 722, 678, 7339, 1447, 812, 1372, 2714, 6651, 749, 8132, 1146, 7164, 2218, 5770, 5770, 3119, 1767, 511, 8176, 3189, 677, 1286, 8108, 4157, 8024, 1744, 1952, 7339, 1762, 1952, 860, 704, 2552, 1912, 1767, 6378, 5298, 4638, 3198, 952, 8024, 1921, 2218, 3221, 7346, 3756, 3756, 4638, 8024, 3698, 6496, 7564, 2845, 3227, 4850, 2496, 1921, 678, 1286, 3755, 7345, 2218, 3300, 1920, 7433, 8024, 852, 2401, 1962, 7339, 824, 677, 1286, 4638, 6378, 5298, 2400, 3766, 3300, 1358, 1168, 818, 862, 2397, 2817, 511, 678, 1286, 127, 4157, 8024, 2496, 4413, 7339, 2850, 6809, 6378, 5298, 1767, 3198, 8024, 1920, 7433, 2347, 5307, 678, 749, 1126, 702, 2207, 3198, 8024, 5445, 684, 692, 3690, 3766, 3300, 977, 678, 3341, 4638, 2692, 2590, 511, 2849, 4708, 6407, 671, 6407, 4638, 2578, 2428, 8024, 4413, 7339, 2458, 1993, 749, 2496, 1921, 678, 1286, 4638, 891, 6121, 6378, 5298, 8024, 8132, 1146, 7164, 6814, 1343, 749, 8024, 1921, 3698, 3766, 3300, 818, 862, 6760, 1962, 4638, 6839, 6496, 8024, 711, 749, 924, 2844, 4413, 102]
67 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
68 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
69 | 2021-07-16 17:01:09,513 - INFO - preprocess.py - convert_bert_example - 89 - labels: [6]
70 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-1 ***
71 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 商 瑞 华 首 战 复 仇 心 切 中 国 玫 瑰 要 用 美 国 方 式 攻 克 瑞 典 多 曼 来 了 , 瑞 典 来 了 , 商 瑞 华 首 战 求 3 分 的 信 心 也 来 了 。 距 离 首 战 72 小 时 当 口 , 中 国 女 足 彻 底 从 [UNK] 恐 瑞 症 [UNK] 当 中 获 得 解 脱 , 因 为 商 瑞 华 已 经 找 到 了 瑞 典 人 的 软 肋 。 找 到 软 肋 , 保 密 4 月 20 日 奥 运 会 分 组 抽 签 结 果 出 来 后 , 中 国 姑 娘 就 把 瑞 典 锁 定 为 关 乎 奥 运 成 败 的 头 号 劲 敌 , 因 为 除 了 浦 玮 等 个 别 老 将 之 外 , 现 役 女 足 将 士 竟 然 没 有 人 尝 过 击 败 瑞 典 的 滋 味 。 在 中 瑞 两 队 共 计 15 次 交 锋 的 历 史 上 , 中 国 队 6 胜 3 平 6 负 与 瑞 典 队 平 分 秋 色 , 但 从 2001 年 起 至 今 近 8 年 时 间 , 中 国 在 同 瑞 典 连 续 5 次 交 锋 中 均 未 尝 胜 绩 , 战 绩 为 2 平 3 负 。 尽 管 八 年 [SEP]
72 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1555, 4448, 1290, 7674, 2773, 1908, 790, 2552, 1147, 704, 1744, 4382, 4456, 6206, 4500, 5401, 1744, 3175, 2466, 3122, 1046, 4448, 1073, 1914, 3294, 3341, 749, 8024, 4448, 1073, 3341, 749, 8024, 1555, 4448, 1290, 7674, 2773, 3724, 124, 1146, 4638, 928, 2552, 738, 3341, 749, 511, 6655, 4895, 7674, 2773, 8325, 2207, 3198, 2496, 1366, 8024, 704, 1744, 1957, 6639, 2515, 2419, 794, 100, 2607, 4448, 4568, 100, 2496, 704, 5815, 2533, 6237, 5564, 8024, 1728, 711, 1555, 4448, 1290, 2347, 5307, 2823, 1168, 749, 4448, 1073, 782, 4638, 6763, 5490, 511, 2823, 1168, 6763, 5490, 8024, 924, 2166, 125, 3299, 8113, 3189, 1952, 6817, 833, 1146, 5299, 2853, 5041, 5310, 3362, 1139, 3341, 1400, 8024, 704, 1744, 1996, 2023, 2218, 2828, 4448, 1073, 7219, 2137, 711, 1068, 725, 1952, 6817, 2768, 6571, 4638, 1928, 1384, 1226, 3127, 8024, 1728, 711, 7370, 749, 3855, 4383, 5023, 702, 1166, 5439, 2199, 722, 1912, 8024, 4385, 2514, 1957, 6639, 2199, 1894, 4994, 4197, 3766, 3300, 782, 2214, 6814, 1140, 6571, 4448, 1073, 4638, 3996, 1456, 511, 1762, 704, 4448, 697, 7339, 1066, 6369, 8115, 3613, 769, 7226, 4638, 1325, 1380, 677, 8024, 704, 1744, 7339, 127, 5526, 124, 2398, 127, 6566, 680, 4448, 1073, 7339, 2398, 1146, 4904, 5682, 8024, 852, 794, 8285, 2399, 6629, 5635, 791, 6818, 129, 2399, 3198, 7313, 8024, 704, 1744, 1762, 1398, 4448, 1073, 6825, 5330, 126, 3613, 769, 7226, 704, 1772, 3313, 2214, 5526, 5327, 8024, 2773, 5327, 711, 123, 2398, 124, 6566, 511, 2226, 5052, 1061, 2399, 102]
73 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
74 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
75 | 2021-07-16 17:01:09,528 - INFO - preprocess.py - convert_bert_example - 89 - labels: [6]
76 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 84 - *** train_example-2 ***
77 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 85 - text: [CLS] 冠 军 球 队 迎 新 欢 乐 派 对 黄 旭 获 大 奖 张 军 赢 下 pk 赛 新 浪 体 育 讯 12 月 27 日 晚 , [UNK] 冠 军 高 尔 夫 球 队 迎 新 高 球 欢 乐 派 对 [UNK] 活 动 在 北 京 都 市 名 人 高 尔 夫 俱 乐 部 举 行 。 邢 傲 伟 、 黄 旭 、 邹 凯 、 滕 海 滨 、 杨 凌 、 马 燕 红 、 张 军 、 王 丽 萍 、 杨 影 、 阎 森 、 毕 文 静 、 夏 煊 泽 、 张 璇 、 叶 乔 波 、 莫 慧 兰 、 周 雅 菲 、 胡 妮 、 戴 菲 菲 、 马 健 等 奥 运 冠 军 、 世 界 冠 军 参 加 了 当 晚 的 活 动 。 [UNK] 以 球 会 友 、 联 谊 交 流 [UNK] 为 宗 旨 的 冠 军 高 尔 夫 球 队 , 从 创 立 之 初 就 秉 承 关 爱 社 会 、 回 馈 社 会 的 价 值 观 。 此 次 迎 新 派 对 活 动 的 举 行 , 标 志 着 这 个 为 中 国 金 牌 运 动 员 提 供 的 放 松 身 心 、 以 球 会 友 的 平 台 正 式 全 面 启 动 。 球 队 已 经 与 都 市 名 人 高 [SEP]
78 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 86 - token_ids: [101, 1094, 1092, 4413, 7339, 6816, 3173, 3614, 727, 3836, 2190, 7942, 3195, 5815, 1920, 1946, 2476, 1092, 6617, 678, 8465, 6612, 3173, 3857, 860, 5509, 6380, 8110, 3299, 8149, 3189, 3241, 8024, 100, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 6816, 3173, 7770, 4413, 3614, 727, 3836, 2190, 100, 3833, 1220, 1762, 1266, 776, 6963, 2356, 1399, 782, 7770, 2209, 1923, 936, 727, 6956, 715, 6121, 511, 6928, 1000, 836, 510, 7942, 3195, 510, 6941, 1132, 510, 4001, 3862, 4012, 510, 3342, 1119, 510, 7716, 4242, 5273, 510, 2476, 1092, 510, 4374, 714, 5847, 510, 3342, 2512, 510, 7330, 3481, 510, 3684, 3152, 7474, 510, 1909, 4201, 3813, 510, 2476, 4462, 510, 1383, 730, 3797, 510, 5811, 2716, 1065, 510, 1453, 7414, 5838, 510, 5529, 1984, 510, 2785, 5838, 5838, 510, 7716, 978, 5023, 1952, 6817, 1094, 1092, 510, 686, 4518, 1094, 1092, 1346, 1217, 749, 2496, 3241, 4638, 3833, 1220, 511, 100, 809, 4413, 833, 1351, 510, 5468, 6449, 769, 3837, 100, 711, 2134, 3192, 4638, 1094, 1092, 7770, 2209, 1923, 4413, 7339, 8024, 794, 1158, 4989, 722, 1159, 2218, 4903, 2824, 1068, 4263, 4852, 833, 510, 1726, 7668, 4852, 833, 4638, 817, 966, 6225, 511, 3634, 3613, 6816, 3173, 3836, 2190, 3833, 1220, 4638, 715, 6121, 8024, 3403, 2562, 4708, 6821, 702, 711, 704, 1744, 7032, 4277, 6817, 1220, 1447, 2990, 897, 4638, 3123, 3351, 6716, 2552, 510, 809, 4413, 833, 1351, 4638, 2398, 1378, 3633, 2466, 1059, 7481, 1423, 1220, 511, 4413, 7339, 2347, 5307, 680, 6963, 2356, 1399, 782, 7770, 102]
79 | 2021-07-16 17:01:09,547 - INFO - preprocess.py - convert_bert_example - 87 - attention_masks: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
80 | 2021-07-16 17:01:09,548 - INFO - preprocess.py - convert_bert_example - 88 - token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
81 | 2021-07-16 17:01:09,548 - INFO - preprocess.py - convert_bert_example - 89 - labels: [6]
82 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append(r"./")
4 | """
5 | 该文件使用的data_loader.py里面的数据加载方式。
6 | """
7 | # coding=utf-8
8 | import json
9 | import random
10 | from pprint import pprint
11 | import os
12 | import logging
13 | import shutil
14 | from sklearn.metrics import accuracy_score, f1_score, classification_report
15 | import torch
16 | import torch.nn as nn
17 | import numpy as np
18 | import pickle
19 | from torch.utils.data import DataLoader, RandomSampler
20 | from transformers import BertTokenizer
21 |
22 | import bert_config
23 | import models
24 | from utils import utils
25 | from data_loader import CNEWSDataset, Collate, CPWSDataset
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | class Trainer:
31 | def __init__(self, args, train_loader, dev_loader, test_loader, device, model, optimizer):
32 | self.args = args
33 | self.device = device
34 | self.model = model
35 | self.optimizer = optimizer
36 | self.criterion = nn.CrossEntropyLoss()
37 | self.train_loader = train_loader
38 | self.dev_loader = dev_loader
39 | self.test_loader = test_loader
40 | self.model.to(self.device)
41 |
42 | def load_ckp(self, model, optimizer, checkpoint_path):
43 | checkpoint = torch.load(checkpoint_path)
44 | model.load_state_dict(checkpoint['state_dict'])
45 | optimizer.load_state_dict(checkpoint['optimizer'])
46 | epoch = checkpoint['epoch']
47 | loss = checkpoint['loss']
48 | return model, optimizer, epoch, loss
49 |
50 | def load_model(self, model, checkpoint_path):
51 | checkpoint = torch.load(checkpoint_path)
52 | model.load_state_dict(checkpoint['state_dict'])
53 | return model
54 |
55 | def save_ckp(self, state, checkpoint_path):
56 | torch.save(state, checkpoint_path)
57 |
58 | """
59 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path):
60 | tmp_checkpoint_path = checkpoint_path
61 | torch.save(state, tmp_checkpoint_path)
62 | if is_best:
63 | tmp_best_model_path = best_model_path
64 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path)
65 | """
66 |
67 | def train(self):
68 | total_step = len(self.train_loader) * self.args.train_epochs
69 | global_step = 0
70 | eval_step = 100
71 | best_dev_micro_f1 = 0.0
72 | for epoch in range(args.train_epochs):
73 | for train_step, train_data in enumerate(self.train_loader):
74 | self.model.train()
75 | token_ids = train_data['token_ids'].to(self.device)
76 | attention_masks = train_data['attention_masks'].to(self.device)
77 | token_type_ids = train_data['token_type_ids'].to(self.device)
78 | labels = train_data['labels'].to(self.device)
79 | train_outputs = self.model(token_ids, attention_masks, token_type_ids)
80 | loss = self.criterion(train_outputs, labels)
81 | self.optimizer.zero_grad()
82 | loss.backward()
83 | self.optimizer.step()
84 | logger.info(
85 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss.item()))
86 | global_step += 1
87 | if global_step % eval_step == 0:
88 | dev_loss, dev_outputs, dev_targets = self.dev()
89 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets)
90 | logger.info(
91 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, accuracy,
92 | micro_f1, macro_f1))
93 | if macro_f1 > best_dev_micro_f1:
94 | logger.info("------------>保存当前最好的模型")
95 | checkpoint = {
96 | 'epoch': epoch,
97 | 'loss': dev_loss,
98 | 'state_dict': self.model.state_dict(),
99 | 'optimizer': self.optimizer.state_dict(),
100 | }
101 | best_dev_micro_f1 = macro_f1
102 | save_path = os.path.join(self.args.output_dir, args.data_name)
103 | if not os.path.exists(save_path):
104 | os.makedirs(save_path)
105 | checkpoint_path = os.path.join(save_path, 'best.pt')
106 | self.save_ckp(checkpoint, checkpoint_path)
107 |
108 | def dev(self):
109 | self.model.eval()
110 | total_loss = 0.0
111 | dev_outputs = []
112 | dev_targets = []
113 | with torch.no_grad():
114 | for dev_step, dev_data in enumerate(self.dev_loader):
115 | token_ids = dev_data['token_ids'].to(self.device)
116 | attention_masks = dev_data['attention_masks'].to(self.device)
117 | token_type_ids = dev_data['token_type_ids'].to(self.device)
118 | labels = dev_data['labels'].to(self.device)
119 | outputs = self.model(token_ids, attention_masks, token_type_ids)
120 | loss = self.criterion(outputs, labels)
121 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss)
122 | total_loss += loss.item()
123 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
124 | dev_outputs.extend(outputs.tolist())
125 | dev_targets.extend(labels.cpu().detach().numpy().tolist())
126 |
127 | return total_loss, dev_outputs, dev_targets
128 |
129 | def test(self, model):
130 | model.eval()
131 | model.to(self.device)
132 | total_loss = 0.0
133 | test_outputs = []
134 | test_targets = []
135 | with torch.no_grad():
136 | for test_step, test_data in enumerate(self.test_loader):
137 | token_ids = test_data['token_ids'].to(self.device)
138 | attention_masks = test_data['attention_masks'].to(self.device)
139 | token_type_ids = test_data['token_type_ids'].to(self.device)
140 | labels = test_data['labels'].to(self.device)
141 | outputs = model(token_ids, attention_masks, token_type_ids)
142 | loss = self.criterion(outputs, labels)
143 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss)
144 | total_loss += loss.item()
145 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
146 | test_outputs.extend(outputs.tolist())
147 | test_targets.extend(labels.cpu().detach().numpy().tolist())
148 |
149 | return total_loss, test_outputs, test_targets
150 |
151 | def predict(self, tokenizer, text, id2label, args, model):
152 | model.eval()
153 | model.to(self.device)
154 | with torch.no_grad():
155 | inputs = tokenizer.encode_plus(text=text,
156 | add_special_tokens=True,
157 | max_length=args.max_seq_len,
158 | truncation='longest_first',
159 | padding="max_length",
160 | return_token_type_ids=True,
161 | return_attention_mask=True,
162 | return_tensors='pt')
163 | token_ids = inputs['input_ids'].to(self.device)
164 | attention_masks = inputs['attention_mask'].to(self.device)
165 | token_type_ids = inputs['token_type_ids'].to(self.device)
166 | outputs = model(token_ids, attention_masks, token_type_ids)
167 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist()
168 | if len(outputs) != 0:
169 | outputs = [id2label[i] for i in outputs]
170 | return outputs
171 | else:
172 | return '不好意思,我没有识别出来'
173 |
174 | def get_metrics(self, outputs, targets):
175 | accuracy = accuracy_score(targets, outputs)
176 | micro_f1 = f1_score(targets, outputs, average='micro')
177 | macro_f1 = f1_score(targets, outputs, average='macro')
178 | return accuracy, micro_f1, macro_f1
179 |
180 | def get_classification_report(self, outputs, targets, labels):
181 | report = classification_report(targets, outputs, target_names=labels)
182 | return report
183 |
184 |
185 | datasets = {
186 | "cnews": CNEWSDataset,
187 | "cpws": CPWSDataset
188 | }
189 |
190 | train_files = {
191 | "cnews": "cnews.train.txt",
192 | "cpws": "train_data.txt"
193 | }
194 |
195 | test_files = {
196 | "cnews": "cnews.test.txt",
197 | "cpws": "test_data.txt"
198 | }
199 |
200 |
201 | def main(args, tokenizer, device):
202 | dataset = datasets.get(args.data_name, None)
203 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None)
204 | if dataset is None:
205 | raise Exception("请输入正确的数据集名称")
206 | label2id = {}
207 | id2label = {}
208 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp:
209 | labels = fp.read().strip().split('\n')
210 | for i, label in enumerate(labels):
211 | label2id[label] = i
212 | id2label[i] = label
213 | print(label2id)
214 |
215 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id)
216 |
217 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
218 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True,
219 | collate_fn=collate.collate_fn)
220 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file))
221 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False,
222 | collate_fn=collate.collate_fn)
223 |
224 | model = models.BertForSequenceClassification(args)
225 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
226 |
227 | if args.retrain:
228 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
229 | checkpoint = torch.load(checkpoint_path)
230 | model.load_state_dict(checkpoint['state_dict'])
231 | model.to(device)
232 | optimizer.load_state_dict(checkpoint['optimizer'])
233 |
234 | epoch = checkpoint['epoch']
235 | loss = checkpoint['loss']
236 | logger.info("加载模型继续训练,epoch:{} loss:{}".format(epoch, loss))
237 |
238 | trainer = Trainer(args, train_loader, test_loader, test_loader, device, model, optimizer)
239 |
240 | if args.do_train:
241 | # 训练和验证
242 | trainer.train()
243 |
244 | # 测试
245 | if args.do_test:
246 | logger.info('========进行测试========')
247 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
248 | model = trainer.load_model(model, checkpoint_path)
249 | total_loss, test_outputs, test_targets = trainer.test(model)
250 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets)
251 | logger.info(
252 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, micro_f1,
253 | macro_f1))
254 | report = trainer.get_classification_report(test_outputs, test_targets, labels)
255 | logger.info(report)
256 |
257 | # 预测
258 | if args.do_predict:
259 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
260 | model = trainer.load_model(model, checkpoint_path)
261 | line = test_dataset[0]
262 | text = line[0]
263 | print(text)
264 | result = trainer.predict(tokenizer, text, id2label, args, model)
265 | print("预测标签:", result[0])
266 | print("真实标签:", line[1])
267 | print("==========================")
268 |
269 |
270 | if __name__ == '__main__':
271 | args = bert_config.Args().get_parser()
272 | utils.set_seed(args.seed)
273 | utils.set_logger(os.path.join(args.log_dir, 'main.log'))
274 |
275 | # processor = preprocess.Processor()
276 |
277 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir)
278 | gpu_ids = args.gpu_ids.split(',')
279 | device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
280 |
281 | main(args, tokenizer, device)
282 |
--------------------------------------------------------------------------------
/main_dataparallel.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append(r"./")
4 | """
5 | 该文件使用的data_loader.py里面的数据加载方式。
6 | """
7 | # coding=utf-8
8 | import json
9 | import random
10 | from pprint import pprint
11 | import os
12 | import logging
13 | import shutil
14 | from sklearn.metrics import accuracy_score, f1_score, classification_report
15 | import torch
16 | import torch.nn as nn
17 | import numpy as np
18 | import pickle
19 | from torch.utils.data import DataLoader, RandomSampler
20 | from transformers import BertTokenizer
21 |
22 | import bert_config
23 | import models
24 | from utils import utils
25 | from data_loader import CNEWSDataset, Collate, CPWSDataset
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | args = bert_config.Args().get_parser()
30 | utils.set_seed(args.seed)
31 | utils.set_logger(os.path.join(args.log_dir, 'main.log'))
32 |
33 | gpu_ids = args.gpu_ids.split(',')
34 | gpu_ids = [int(i) for i in gpu_ids]
35 | torch.cuda.set_device('cuda:{}'.format(gpu_ids[0]))
36 |
37 |
38 | class Trainer:
39 | def __init__(self, args, optimizer):
40 | self.args = args
41 | self.optimizer = optimizer
42 | self.criterion = nn.CrossEntropyLoss()
43 |
44 | def load_ckp(self, model, optimizer, checkpoint_path):
45 | checkpoint = torch.load(checkpoint_path)
46 | model.load_state_dict(checkpoint['state_dict'])
47 | optimizer.load_state_dict(checkpoint['optimizer'])
48 | epoch = checkpoint['epoch']
49 | loss = checkpoint['loss']
50 | return model, optimizer, epoch, loss
51 |
52 | def load_model(self, model, checkpoint_path):
53 | checkpoint = torch.load(checkpoint_path)
54 | # new_start_dict = {}
55 | # for k, v in checkpoint['state_dict'].items():
56 | # new_start_dict["module." + k] = v
57 | # model.load_state_dict(new_start_dict)
58 | model.load_state_dict(checkpoint["state_dict"])
59 | return model
60 |
61 | def save_ckp(self, state, checkpoint_path):
62 | torch.save(state, checkpoint_path)
63 |
64 | """
65 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path):
66 | tmp_checkpoint_path = checkpoint_path
67 | torch.save(state, tmp_checkpoint_path)
68 | if is_best:
69 | tmp_best_model_path = best_model_path
70 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path)
71 | """
72 |
73 | def train(self, model, train_loader, dev_loader=None):
74 | self.dev_loader = dev_loader
75 | self.model = model
76 | total_step = len(train_loader) * self.args.train_epochs
77 | global_step = 0
78 | eval_step = 100
79 | best_dev_micro_f1 = 0.0
80 | for epoch in range(self.args.train_epochs):
81 | for train_step, train_data in enumerate(train_loader):
82 | self.model.train()
83 | token_ids = train_data['token_ids'].cuda()
84 | attention_masks = train_data['attention_masks'].cuda()
85 | token_type_ids = train_data['token_type_ids'].cuda()
86 | labels = train_data['labels'].cuda()
87 | train_outputs = self.model(token_ids, attention_masks, token_type_ids)
88 | loss = self.criterion(train_outputs, labels)
89 | self.optimizer.zero_grad()
90 | loss.backward()
91 | self.optimizer.step()
92 | logger.info(
93 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss.item()))
94 | global_step += 1
95 | if global_step % eval_step == 0:
96 | dev_loss, dev_outputs, dev_targets = self.dev()
97 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets)
98 | logger.info(
99 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, accuracy,
100 | micro_f1, macro_f1))
101 | if macro_f1 > best_dev_micro_f1:
102 | logger.info("------------>保存当前最好的模型")
103 | checkpoint = {
104 | 'epoch': epoch,
105 | 'loss': dev_loss,
106 | 'state_dict': self.model.state_dict(),
107 | 'optimizer': self.optimizer.state_dict(),
108 | }
109 | best_dev_micro_f1 = macro_f1
110 | save_path = os.path.join(self.args.output_dir, args.data_name)
111 | if not os.path.exists(save_path):
112 | os.makedirs(save_path)
113 | checkpoint_path = os.path.join(save_path, 'best.pt')
114 | self.save_ckp(checkpoint, checkpoint_path)
115 |
116 | def dev(self):
117 | self.model.eval()
118 | total_loss = 0.0
119 | dev_outputs = []
120 | dev_targets = []
121 | with torch.no_grad():
122 | for dev_step, dev_data in enumerate(self.dev_loader):
123 | token_ids = dev_data['token_ids'].cuda()
124 | attention_masks = dev_data['attention_masks'].cuda()
125 | token_type_ids = dev_data['token_type_ids'].cuda()
126 | labels = dev_data['labels'].cuda()
127 | outputs = self.model(token_ids, attention_masks, token_type_ids)
128 | loss = self.criterion(outputs, labels)
129 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss)
130 | total_loss += loss.item()
131 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
132 | dev_outputs.extend(outputs.tolist())
133 | dev_targets.extend(labels.cpu().detach().numpy().tolist())
134 |
135 | return total_loss, dev_outputs, dev_targets
136 |
137 | def test(self, model, test_loader):
138 | model.eval()
139 | total_loss = 0.0
140 | test_outputs = []
141 | test_targets = []
142 | with torch.no_grad():
143 | for test_step, test_data in enumerate(test_loader):
144 | token_ids = test_data['token_ids'].cuda()
145 | attention_masks = test_data['attention_masks'].cuda()
146 | token_type_ids = test_data['token_type_ids'].cuda()
147 | labels = test_data['labels'].cuda()
148 | outputs = model(token_ids, attention_masks, token_type_ids)
149 |
150 | loss = self.criterion(outputs, labels)
151 | # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss)
152 | total_loss += loss.item()
153 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
154 | test_outputs.extend(outputs.tolist())
155 | test_targets.extend(labels.cpu().detach().numpy().tolist())
156 |
157 | return total_loss, test_outputs, test_targets
158 |
159 | def predict(self, tokenizer, text, id2label, args, model):
160 | model.eval()
161 | with torch.no_grad():
162 | inputs = tokenizer.encode_plus(text=text,
163 | add_special_tokens=True,
164 | max_length=args.max_seq_len,
165 | truncation='longest_first',
166 | padding="max_length",
167 | return_token_type_ids=True,
168 | return_attention_mask=True,
169 | return_tensors='pt')
170 | token_ids = inputs['input_ids'].cuda()
171 | attention_masks = inputs['attention_mask'].cuda()
172 | token_type_ids = inputs['token_type_ids'].cuda()
173 | outputs = model(token_ids, attention_masks, token_type_ids)
174 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist()
175 | if len(outputs) != 0:
176 | outputs = [id2label[i] for i in outputs]
177 | return outputs
178 | else:
179 | return '不好意思,我没有识别出来'
180 |
181 | def get_metrics(self, outputs, targets):
182 | accuracy = accuracy_score(targets, outputs)
183 | micro_f1 = f1_score(targets, outputs, average='micro')
184 | macro_f1 = f1_score(targets, outputs, average='macro')
185 | return accuracy, micro_f1, macro_f1
186 |
187 | def get_classification_report(self, outputs, targets, labels):
188 | report = classification_report(targets, outputs, target_names=labels)
189 | return report
190 |
191 |
192 | datasets = {
193 | "cnews": CNEWSDataset,
194 | "cpws": CPWSDataset
195 | }
196 |
197 | train_files = {
198 | "cnews": "cnews.train.txt",
199 | "cpws": "train_data.txt"
200 | }
201 |
202 | test_files = {
203 | "cnews": "cnews.test.txt",
204 | "cpws": "test_data.txt"
205 | }
206 |
207 |
208 | def main(args, tokenizer):
209 | dataset = datasets.get(args.data_name, None)
210 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None)
211 | if dataset is None:
212 | raise Exception("请输入正确的数据集名称")
213 | label2id = {}
214 | id2label = {}
215 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp:
216 | labels = fp.read().strip().split('\n')
217 | for i, label in enumerate(labels):
218 | label2id[label] = i
219 | id2label[i] = label
220 | print(label2id)
221 |
222 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id)
223 |
224 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
225 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True,
226 | collate_fn=collate.collate_fn, num_workers=4)
227 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file))
228 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False,
229 | collate_fn=collate.collate_fn, num_workers=4)
230 |
231 | model = models.BertForSequenceClassification(args)
232 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
233 |
234 | trainer = Trainer(args, optimizer)
235 |
236 | if args.do_train:
237 | if args.retrain:
238 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
239 | checkpoint = torch.load(checkpoint_path)
240 | model.cuda()
241 | r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
242 | r_model.load_state_dict(checkpoint['state_dict'])
243 | optimizer.load_state_dict(checkpoint['optimizer'])
244 |
245 | epoch = checkpoint['epoch']
246 | loss = checkpoint['loss']
247 | logger.info("加载模型继续训练,epoch:{} loss:{}".format(epoch, loss))
248 | # 训练和验证
249 | trainer.train(r_model, train_loader, dev_loader=test_loader)
250 | else:
251 | model.cuda()
252 | r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
253 | # 训练和验证
254 | trainer.train(r_model, train_loader, dev_loader=test_loader)
255 |
256 | # 测试
257 | if args.do_test:
258 | logger.info('========进行测试========')
259 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
260 | # 多卡预测要先模型并行
261 | model.cuda()
262 | model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
263 | # 再加载模型
264 | model = trainer.load_model(model, checkpoint_path)
265 |
266 | total_loss, test_outputs, test_targets = trainer.test(model, test_loader)
267 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets)
268 | logger.info(
269 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, micro_f1,
270 | macro_f1))
271 | report = trainer.get_classification_report(test_outputs, test_targets, labels)
272 | logger.info(report)
273 |
274 | # 预测
275 | if args.do_predict:
276 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
277 | model = trainer.load_model(model, checkpoint_path)
278 | model.cuda()
279 | model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
280 | line = test_dataset[0]
281 | text = line[0]
282 | print(text)
283 | result = trainer.predict(tokenizer, text, id2label, args, model)
284 | print("预测标签:", result[0])
285 | print("真实标签:", line[1])
286 | print("==========================")
287 |
288 |
289 | if __name__ == '__main__':
290 | # processor = preprocess.Processor()
291 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir)
292 | main(args, tokenizer)
293 |
--------------------------------------------------------------------------------
/main_distributed.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append(r"./")
4 | """
5 | 该文件使用的data_loader.py里面的数据加载方式。
6 | """
7 | # coding=utf-8
8 | import os
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'
11 | import tempfile
12 | import json
13 | import random
14 | from pprint import pprint
15 | import os
16 | import logging
17 | import shutil
18 | from sklearn.metrics import accuracy_score, f1_score, classification_report
19 | import torch
20 | import torch.nn as nn
21 | import numpy as np
22 | import pickle
23 | from torch.utils.data import DataLoader, RandomSampler
24 | from transformers import BertTokenizer
25 | import torch.distributed as dist
26 |
27 | import bert_config
28 | import models
29 | from utils import utils
30 | from data_loader import CNEWSDataset, Collate, CPWSDataset
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | # 单机多卡
36 | # word_size:机器一共有几张卡
37 | # rank:第几块GPU
38 | # local_rank:第几块GPU,和rank相同
39 | # print(torch.cuda.device_count())
40 |
41 | # local_rank = torch.distributed.get_rank()
42 | # args.local_rank = local_rank
43 | # print(args.local_rank)
44 | # dist.init_process_group(backend='gloo',
45 | # init_method=r"file:///D://Code//project//pytorch-distributed training//tmp",
46 | # rank=0,
47 | # world_size=1)
48 | # torch.cuda.set_device(args.local_rank)
49 |
50 |
51 | class Trainer:
52 | def __init__(self, args, optimizer):
53 | self.args = args
54 | self.optimizer = optimizer
55 | self.criterion = nn.CrossEntropyLoss().cuda(self.args.local_rank)
56 |
57 | def load_ckp(self, model, optimizer, checkpoint_path):
58 | checkpoint = torch.load(checkpoint_path)
59 | model.load_state_dict(checkpoint['state_dict'])
60 | optimizer.load_state_dict(checkpoint['optimizer'])
61 | epoch = checkpoint['epoch']
62 | loss = checkpoint['loss']
63 | return model, optimizer, epoch, loss
64 |
65 | def load_model(self, model, checkpoint_path):
66 | checkpoint = torch.load(checkpoint_path)
67 | # new_start_dict = {}
68 | # for k, v in checkpoint['state_dict'].items():
69 | # new_start_dict["module." + k] = v
70 | # model.load_state_dict(new_start_dict)
71 | model.load_state_dict(checkpoint["state_dict"])
72 | return model
73 |
74 | def save_ckp(self, state, checkpoint_path):
75 | torch.save(state, checkpoint_path)
76 |
77 | """
78 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path):
79 | tmp_checkpoint_path = checkpoint_path
80 | torch.save(state, tmp_checkpoint_path)
81 | if is_best:
82 | tmp_best_model_path = best_model_path
83 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path)
84 | """
85 |
86 | def train(self, model, train_loader, train_sampler, dev_loader=None):
87 | self.dev_loader = dev_loader
88 | self.model = model
89 | total_step = len(train_loader) * self.args.train_epochs
90 | global_step = 0
91 | eval_step = 10
92 | best_dev_micro_f1 = 0.0
93 | for epoch in range(self.args.train_epochs):
94 | train_sampler.set_epoch(epoch)
95 | for train_step, train_data in enumerate(train_loader):
96 | self.model.train()
97 | token_ids = train_data['token_ids'].cuda(self.args.local_rank)
98 | attention_masks = train_data['attention_masks'].cuda(self.args.local_rank)
99 | token_type_ids = train_data['token_type_ids'].cuda(self.args.local_rank)
100 | labels = train_data['labels'].cuda(self.args.local_rank)
101 | train_outputs = self.model(token_ids, attention_masks, token_type_ids)
102 |
103 | loss = self.criterion(train_outputs, labels)
104 |
105 | torch.distributed.barrier()
106 |
107 | self.optimizer.zero_grad()
108 | loss.backward()
109 | self.optimizer.step()
110 |
111 | loss = self.loss_reduce(loss)
112 | if args.local_rank == 0:
113 | logger.info(
114 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss))
115 | global_step += 1
116 | if dev_loader is not None and self.args.local_rank == 0:
117 | if global_step % eval_step == 0:
118 | dev_loss, dev_outputs, dev_targets = self.dev()
119 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets)
120 | logger.info(
121 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss,
122 | accuracy,
123 | micro_f1,
124 | macro_f1))
125 | if macro_f1 > best_dev_micro_f1:
126 | logger.info("------------>保存当前最好的模型")
127 | checkpoint = {
128 | 'epoch': epoch,
129 | 'loss': dev_loss,
130 | 'state_dict': self.model.state_dict(),
131 | 'optimizer': self.optimizer.state_dict(),
132 | }
133 | best_dev_micro_f1 = macro_f1
134 | save_path = os.path.join(self.args.output_dir, args.data_name)
135 | if not os.path.exists(save_path):
136 | os.makedirs(save_path)
137 | checkpoint_path = os.path.join(save_path, 'best.pt')
138 | self.save_ckp(checkpoint, checkpoint_path)
139 | if dev_loader is None and self.args.local_rank == 0:
140 | checkpoint = {
141 | 'epoch': epoch,
142 | 'state_dict': self.model.state_dict(),
143 | 'optimizer': self.optimizer.state_dict(),
144 | }
145 | save_path = os.path.join(self.args.output_dir, args.data_name)
146 | if not os.path.exists(save_path):
147 | os.makedirs(save_path)
148 | checkpoint_path = os.path.join(save_path, 'best.pt')
149 | self.save_ckp(checkpoint, checkpoint_path)
150 |
151 | def output_reduce(self, outputs, targets):
152 | output_gather_list = [torch.zeros_like(outputs) for _ in range(self.args.local_world_size)]
153 | # 把每一个GPU的输出聚合起来
154 | dist.all_gather(output_gather_list, outputs)
155 |
156 | outputs = torch.cat(output_gather_list, dim=0)
157 | target_gather_list = [torch.zeros_like(targets) for _ in range(self.args.local_world_size)]
158 | # 把每一个GPU的输出聚合起来
159 | dist.all_gather(target_gather_list, targets)
160 | targets = torch.cat(target_gather_list, dim=0)
161 | return outputs, targets
162 |
163 | def loss_reduce(self, loss):
164 | rt = loss.clone()
165 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
166 | rt /= self.args.local_world_size
167 | return rt
168 |
169 | def dev(self):
170 | self.model.eval()
171 | total_loss = 0.0
172 | dev_outputs = []
173 | dev_targets = []
174 | with torch.no_grad():
175 | for dev_step, dev_data in enumerate(self.dev_loader):
176 | token_ids = dev_data['token_ids'].cuda(self.args.local_rank)
177 | attention_masks = dev_data['attention_masks'].cuda(self.args.local_rank)
178 | token_type_ids = dev_data['token_type_ids'].cuda(self.args.local_rank)
179 | labels = dev_data['labels'].cuda(self.args.local_rank)
180 | outputs = self.model(token_ids, attention_masks, token_type_ids)
181 |
182 | loss = self.criterion(outputs, labels)
183 | torch.distributed.barrier()
184 | # total_loss += loss.item()
185 | loss = self.loss_reduce(loss)
186 | total_loss += loss
187 | outputs, targets = self.output_reduce(outputs, labels)
188 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
189 | dev_outputs.extend(outputs.tolist())
190 | dev_targets.extend(targets.cpu().detach().numpy().tolist())
191 | print(len(dev_outputs), len(dev_targets))
192 | return total_loss, dev_outputs, dev_targets
193 |
194 | def test(self, model, test_loader):
195 | model.eval()
196 | total_loss = 0.0
197 | test_outputs = []
198 | test_targets = []
199 | with torch.no_grad():
200 | for test_step, test_data in enumerate(test_loader):
201 | token_ids = test_data['token_ids'].cuda(self.args.local_rank)
202 | attention_masks = test_data['attention_masks'].cuda(self.args.local_rank)
203 | token_type_ids = test_data['token_type_ids'].cuda(self.args.local_rank)
204 | labels = test_data['labels'].cuda(self.args.local_rank)
205 | outputs = model(token_ids, attention_masks, token_type_ids)
206 |
207 | loss = self.criterion(outputs, labels)
208 | torch.distributed.barrier()
209 | loss = self.loss_reduce(loss)
210 | total_loss += loss
211 | outputs, targets = self.output_reduce(outputs, labels)
212 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
213 | test_outputs.extend(outputs.tolist())
214 | test_targets.extend(targets.cpu().detach().numpy().tolist())
215 | # total_loss += loss.item()
216 | # outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
217 | # test_outputs.extend(outputs.tolist())
218 | # test_targets.extend(labels.cpu().detach().numpy().tolist())
219 |
220 | return total_loss, test_outputs, test_targets
221 |
222 | def predict(self, tokenizer, text, id2label, args, model):
223 | model.eval()
224 | with torch.no_grad():
225 | inputs = tokenizer.encode_plus(text=text,
226 | add_special_tokens=True,
227 | max_length=args.max_seq_len,
228 | truncation='longest_first',
229 | padding="max_length",
230 | return_token_type_ids=True,
231 | return_attention_mask=True,
232 | return_tensors='pt')
233 | token_ids = inputs['input_ids'].cuda(self.args.local_rank)
234 | attention_masks = inputs['attention_mask'].cuda(self.args.local_rank)
235 | token_type_ids = inputs['token_type_ids'].cuda(self.args.local_rank)
236 | outputs = model(token_ids, attention_masks, token_type_ids)
237 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist()
238 | if len(outputs) != 0:
239 | outputs = [id2label[i] for i in outputs]
240 | return outputs
241 | else:
242 | return '不好意思,我没有识别出来'
243 |
244 | def get_metrics(self, outputs, targets):
245 | accuracy = accuracy_score(targets, outputs)
246 | micro_f1 = f1_score(targets, outputs, average='micro')
247 | macro_f1 = f1_score(targets, outputs, average='macro')
248 | return accuracy, micro_f1, macro_f1
249 |
250 | def get_classification_report(self, outputs, targets, labels):
251 | report = classification_report(targets, outputs, target_names=labels)
252 | return report
253 |
254 |
255 | datasets = {
256 | "cnews": CNEWSDataset,
257 | "cpws": CPWSDataset
258 | }
259 |
260 | train_files = {
261 | "cnews": "cnews.train.txt",
262 | "cpws": "train_data.txt"
263 | }
264 |
265 | test_files = {
266 | "cnews": "cnews.test.txt",
267 | "cpws": "test_data.txt"
268 | }
269 |
270 |
271 | def main(args, tokenizer, local_rank, local_world_size):
272 | n = torch.cuda.device_count() // local_world_size
273 | device_ids = list(range(local_rank * n, (local_rank + 1) * n))
274 |
275 | print(
276 | f"[{os.getpid()}] rank = {dist.get_rank()}, "
277 | + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end=''
278 | )
279 | dataset = datasets.get(args.data_name, None)
280 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None)
281 | if dataset is None:
282 | raise Exception("请输入正确的数据集名称")
283 | label2id = {}
284 | id2label = {}
285 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp:
286 | labels = fp.read().strip().split('\n')
287 | for i, label in enumerate(labels):
288 | label2id[label] = i
289 | id2label[i] = label
290 | if args.local_rank == 0:
291 | print(label2id)
292 |
293 | args.train_batch_size = int(args.train_batch_size / torch.cuda.device_count())
294 | args.eval_batch_size = int(args.eval_batch_size / torch.cuda.device_count())
295 |
296 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id)
297 |
298 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
299 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
300 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
301 | collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler)
302 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file))
303 |
304 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
305 |
306 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size,
307 | collate_fn=collate.collate_fn, num_workers=4, sampler=test_sampler)
308 | # test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False,
309 | # collate_fn=collate.collate_fn, num_workers=4)
310 |
311 | model = models.BertForSequenceClassification(args)
312 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
313 | trainer = Trainer(args, optimizer)
314 |
315 | if args.do_train:
316 | if args.retrain:
317 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
318 | checkpoint = torch.load(checkpoint_path)
319 | # trainer.optimizer.load_state_dict(checkpoint['optimizer'])
320 | model.cuda(args.local_rank)
321 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
322 | r_model.load_state_dict(checkpoint['state_dict'])
323 | if args.local_rank == 0:
324 | logger.info("加载模型继续训练")
325 | # 训练和验证
326 | trainer.train(r_model, train_loader, train_sampler, dev_loader=test_loader)
327 | else:
328 |
329 | model.cuda(args.local_rank)
330 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
331 | # 训练和验证
332 | trainer.train(r_model, train_loader, train_sampler, dev_loader=test_loader)
333 |
334 | # 测试
335 | if args.do_test:
336 | if args.local_rank == 0:
337 | logger.info('========进行测试========')
338 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
339 | # 多卡预测要先模型并行
340 | model.cuda(args.local_rank)
341 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
342 | # 再加载模型
343 | model = trainer.load_model(model, checkpoint_path)
344 |
345 | total_loss, test_outputs, test_targets = trainer.test(model, test_loader)
346 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets)
347 | if args.local_rank == 0:
348 | logger.info(
349 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy,
350 | micro_f1,
351 | macro_f1))
352 | report = trainer.get_classification_report(test_outputs, test_targets, labels)
353 | logger.info(report)
354 |
355 | # 预测
356 | if args.do_predict:
357 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
358 | model.cuda(args.local_rank)
359 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
360 | model = trainer.load_model(model, checkpoint_path)
361 | line = test_dataset[0]
362 | text = line[0]
363 | result = trainer.predict(tokenizer, text, id2label, args, model)
364 | if args.local_rank == 0:
365 | print(text)
366 | print("预测标签:", result[0])
367 | print("真实标签:", line[1])
368 | print("==========================")
369 |
370 |
371 | def spmd_main(local_world_size, local_rank, init_method):
372 | env_dict = {
373 | key: os.environ[key]
374 | for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
375 | }
376 | if local_rank == 0:
377 | for k, v in env_dict.items():
378 | print(k, v)
379 | if sys.platform == "win32":
380 | # Distributed package only covers collective communications with Gloo
381 | # backend and FileStore on Windows platform. Set init_method parameter
382 | # in init_process_group to a local file.
383 | if "INIT_METHOD" in os.environ.keys():
384 | print(f"init_method is {os.environ['INIT_METHOD']}")
385 | url_obj = urlparse(os.environ["INIT_METHOD"])
386 | if url_obj.scheme.lower() != "file":
387 | raise ValueError("Windows only supports FileStore")
388 | else:
389 | init_method = os.environ["INIT_METHOD"]
390 | else:
391 | # It is a example application, For convience, we create a file in temp dir.
392 | # current_work_dir = os.getcwd()
393 | # init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
394 | init_method = init_method
395 | print(init_method)
396 | dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]),
397 | world_size=int(env_dict["WORLD_SIZE"]))
398 | else:
399 | print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
400 | dist.init_process_group(backend="nccl")
401 |
402 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir)
403 | main(args, tokenizer, local_rank, local_world_size)
404 |
405 | dist.destroy_process_group()
406 |
407 |
408 | if __name__ == '__main__':
409 | # processor = preprocess.Processor()
410 | args = bert_config.Args().get_parser()
411 | utils.set_seed(args.seed)
412 | utils.set_logger(os.path.join(args.log_dir, 'main.log'))
413 | # The main entry point is called directly without using subprocess
414 | current_work_dir = os.getcwd()
415 | init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
416 | spmd_main(args.local_world_size, args.local_rank, init_method)
417 |
--------------------------------------------------------------------------------
/main_mp_distributed.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append(r"./")
4 | """
5 | 该文件使用的data_loader.py里面的数据加载方式。
6 | """
7 | # coding=utf-8
8 | import os
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,3'
11 | import tempfile
12 | import json
13 | import random
14 | from pprint import pprint
15 | import os
16 | import logging
17 | import shutil
18 | from sklearn.metrics import accuracy_score, f1_score, classification_report
19 | import torch
20 | import torch.nn as nn
21 | import numpy as np
22 | import pickle
23 | from torch.utils.data import DataLoader, RandomSampler
24 | from transformers import BertTokenizer
25 | import torch.distributed as dist
26 | import torch.multiprocessing as mp
27 |
28 | import bert_config
29 | import models
30 | from utils import utils
31 | from data_loader import CNEWSDataset, Collate, CPWSDataset
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | # 单机多卡
37 | # word_size:机器一共有几张卡
38 | # rank:第几块GPU
39 | # local_rank:第几块GPU,和rank相同
40 | # print(torch.cuda.device_count())
41 |
42 | # local_rank = torch.distributed.get_rank()
43 | # args.local_rank = local_rank
44 | # print(args.local_rank)
45 | # dist.init_process_group(backend='gloo',
46 | # init_method=r"file:///D://Code//project//pytorch-distributed training//tmp",
47 | # rank=0,
48 | # world_size=1)
49 | # torch.cuda.set_device(args.local_rank)
50 |
51 |
52 | class Trainer:
53 | def __init__(self, args, optimizer):
54 | self.args = args
55 | self.optimizer = optimizer
56 | self.criterion = nn.CrossEntropyLoss().cuda(self.args.local_rank)
57 |
58 | def load_ckp(self, model, optimizer, checkpoint_path):
59 | checkpoint = torch.load(checkpoint_path)
60 | model.load_state_dict(checkpoint['state_dict'])
61 | optimizer.load_state_dict(checkpoint['optimizer'])
62 | epoch = checkpoint['epoch']
63 | loss = checkpoint['loss']
64 | return model, optimizer, epoch, loss
65 |
66 | def load_model(self, model, checkpoint_path):
67 | checkpoint = torch.load(checkpoint_path)
68 | # new_start_dict = {}
69 | # for k, v in checkpoint['state_dict'].items():
70 | # new_start_dict["module." + k] = v
71 | # model.load_state_dict(new_start_dict)
72 | model.load_state_dict(checkpoint["state_dict"])
73 | return model
74 |
75 | def save_ckp(self, state, checkpoint_path):
76 | torch.save(state, checkpoint_path)
77 |
78 | """
79 | def save_ckp(self, state, is_best, checkpoint_path, best_model_path):
80 | tmp_checkpoint_path = checkpoint_path
81 | torch.save(state, tmp_checkpoint_path)
82 | if is_best:
83 | tmp_best_model_path = best_model_path
84 | shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path)
85 | """
86 |
87 | def train(self, model, train_loader, train_sampler, dev_loader=None):
88 | self.dev_loader = dev_loader
89 | self.model = model
90 | total_step = len(train_loader) * self.args.train_epochs
91 | global_step = 0
92 | eval_step = 10
93 | best_dev_micro_f1 = 0.0
94 | for epoch in range(self.args.train_epochs):
95 | train_sampler.set_epoch(epoch)
96 | for train_step, train_data in enumerate(train_loader):
97 | self.model.train()
98 | token_ids = train_data['token_ids'].cuda(self.args.local_rank)
99 | attention_masks = train_data['attention_masks'].cuda(self.args.local_rank)
100 | token_type_ids = train_data['token_type_ids'].cuda(self.args.local_rank)
101 | labels = train_data['labels'].cuda(self.args.local_rank)
102 | train_outputs = self.model(token_ids, attention_masks, token_type_ids)
103 |
104 | loss = self.criterion(train_outputs, labels)
105 |
106 | torch.distributed.barrier()
107 |
108 | self.optimizer.zero_grad()
109 | loss.backward()
110 | self.optimizer.step()
111 |
112 | loss = self.loss_reduce(loss)
113 | if self.args.local_rank == 0:
114 | print(
115 | "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss))
116 | global_step += 1
117 | if dev_loader is not None and self.args.local_rank == 0:
118 | if global_step % eval_step == 0:
119 | dev_loss, dev_outputs, dev_targets = self.dev()
120 | accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets)
121 | print(
122 | "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss,
123 | accuracy,
124 | micro_f1,
125 | macro_f1))
126 | if macro_f1 > best_dev_micro_f1:
127 | print("------------>保存当前最好的模型")
128 | checkpoint = {
129 | 'epoch': epoch,
130 | 'loss': dev_loss,
131 | 'state_dict': self.model.state_dict(),
132 | 'optimizer': self.optimizer.state_dict(),
133 | }
134 | best_dev_micro_f1 = macro_f1
135 | save_path = os.path.join(self.args.output_dir, self.args.data_name)
136 | if not os.path.exists(save_path):
137 | os.makedirs(save_path)
138 | checkpoint_path = os.path.join(save_path, 'best.pt')
139 | self.save_ckp(checkpoint, checkpoint_path)
140 | if dev_loader is None and self.args.local_rank == 0:
141 | checkpoint = {
142 | 'epoch': epoch,
143 | 'state_dict': self.model.state_dict(),
144 | 'optimizer': self.optimizer.state_dict(),
145 | }
146 | save_path = os.path.join(self.args.output_dir, self.args.data_name)
147 | if not os.path.exists(save_path):
148 | os.makedirs(save_path)
149 | checkpoint_path = os.path.join(save_path, 'best.pt')
150 | self.save_ckp(checkpoint, checkpoint_path)
151 |
152 | def output_reduce(self, outputs, targets):
153 | output_gather_list = [torch.zeros_like(outputs) for _ in range(self.args.local_world_size)]
154 | # 把每一个GPU的输出聚合起来
155 | dist.all_gather(output_gather_list, outputs)
156 |
157 | outputs = torch.cat(output_gather_list, dim=0)
158 | target_gather_list = [torch.zeros_like(targets) for _ in range(self.args.local_world_size)]
159 | # 把每一个GPU的输出聚合起来
160 | dist.all_gather(target_gather_list, targets)
161 | targets = torch.cat(target_gather_list, dim=0)
162 | return outputs, targets
163 |
164 | def loss_reduce(self, loss):
165 | rt = loss.clone()
166 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
167 | rt /= self.args.local_world_size
168 | return rt
169 |
170 | def dev(self):
171 | self.model.eval()
172 | total_loss = 0.0
173 | dev_outputs = []
174 | dev_targets = []
175 | with torch.no_grad():
176 | for dev_step, dev_data in enumerate(self.dev_loader):
177 | token_ids = dev_data['token_ids'].cuda(self.args.local_rank)
178 | attention_masks = dev_data['attention_masks'].cuda(self.args.local_rank)
179 | token_type_ids = dev_data['token_type_ids'].cuda(self.args.local_rank)
180 | labels = dev_data['labels'].cuda(self.args.local_rank)
181 | outputs = self.model(token_ids, attention_masks, token_type_ids)
182 |
183 | loss = self.criterion(outputs, labels)
184 | torch.distributed.barrier()
185 | # total_loss += loss.item()
186 | loss = self.loss_reduce(loss)
187 | total_loss += loss
188 | outputs, targets = self.output_reduce(outputs, labels)
189 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
190 | dev_outputs.extend(outputs.tolist())
191 | dev_targets.extend(targets.cpu().detach().numpy().tolist())
192 | return total_loss, dev_outputs, dev_targets
193 |
194 | def test(self, model, test_loader):
195 | model.eval()
196 | total_loss = 0.0
197 | test_outputs = []
198 | test_targets = []
199 | with torch.no_grad():
200 | for test_step, test_data in enumerate(test_loader):
201 | token_ids = test_data['token_ids'].cuda(self.args.local_rank)
202 | attention_masks = test_data['attention_masks'].cuda(self.args.local_rank)
203 | token_type_ids = test_data['token_type_ids'].cuda(self.args.local_rank)
204 | labels = test_data['labels'].cuda(self.args.local_rank)
205 | outputs = model(token_ids, attention_masks, token_type_ids)
206 |
207 | loss = self.criterion(outputs, labels)
208 | torch.distributed.barrier()
209 | loss = self.loss_reduce(loss)
210 | total_loss += loss
211 | outputs, targets = self.output_reduce(outputs, labels)
212 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
213 | test_outputs.extend(outputs.tolist())
214 | test_targets.extend(targets.cpu().detach().numpy().tolist())
215 | # total_loss += loss.item()
216 | # outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten()
217 | # test_outputs.extend(outputs.tolist())
218 | # test_targets.extend(labels.cpu().detach().numpy().tolist())
219 |
220 | return total_loss, test_outputs, test_targets
221 |
222 | def predict(self, tokenizer, text, id2label, args, model):
223 | model.eval()
224 | with torch.no_grad():
225 | inputs = tokenizer.encode_plus(text=text,
226 | add_special_tokens=True,
227 | max_length=args.max_seq_len,
228 | truncation='longest_first',
229 | padding="max_length",
230 | return_token_type_ids=True,
231 | return_attention_mask=True,
232 | return_tensors='pt')
233 | token_ids = inputs['input_ids'].cuda(self.args.local_rank)
234 | attention_masks = inputs['attention_mask'].cuda(self.args.local_rank)
235 | token_type_ids = inputs['token_type_ids'].cuda(self.args.local_rank)
236 | outputs = model(token_ids, attention_masks, token_type_ids)
237 | outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1).flatten().tolist()
238 | if len(outputs) != 0:
239 | outputs = [id2label[i] for i in outputs]
240 | return outputs
241 | else:
242 | return '不好意思,我没有识别出来'
243 |
244 | def get_metrics(self, outputs, targets):
245 | accuracy = accuracy_score(targets, outputs)
246 | micro_f1 = f1_score(targets, outputs, average='micro')
247 | macro_f1 = f1_score(targets, outputs, average='macro')
248 | return accuracy, micro_f1, macro_f1
249 |
250 | def get_classification_report(self, outputs, targets, labels):
251 | report = classification_report(targets, outputs, target_names=labels)
252 | return report
253 |
254 |
255 | datasets = {
256 | "cnews": CNEWSDataset,
257 | "cpws": CPWSDataset
258 | }
259 |
260 | train_files = {
261 | "cnews": "cnews.train.txt",
262 | "cpws": "train_data.txt"
263 | }
264 |
265 | test_files = {
266 | "cnews": "cnews.test.txt",
267 | "cpws": "test_data.txt"
268 | }
269 |
270 |
271 | def main_worker(local_rank, args):
272 | args.local_rank = local_rank
273 | print(type(args.local_rank))
274 | tokenizer = args.tokenizer
275 | local_world_size = args.local_world_size
276 | # The main entry point is called directly without using subprocess
277 | current_work_dir = os.getcwd()
278 | init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
279 | if sys.platform == "win32":
280 | # Distributed package only covers collective communications with Gloo
281 | # backend and FileStore on Windows platform. Set init_method parameter
282 | # in init_process_group to a local file.
283 | if "INIT_METHOD" in os.environ.keys():
284 | print(f"init_method is {os.environ['INIT_METHOD']}")
285 | url_obj = urlparse(os.environ["INIT_METHOD"])
286 | if url_obj.scheme.lower() != "file":
287 | raise ValueError("Windows only supports FileStore")
288 | else:
289 | init_method = os.environ["INIT_METHOD"]
290 | else:
291 | # It is a example application, For convience, we create a file in temp dir.
292 | # current_work_dir = os.getcwd()
293 | # init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
294 | init_method = init_method
295 | print(init_method)
296 | dist.init_process_group(backend="gloo", init_method=init_method, rank=local_rank,
297 | world_size=args.local_world_size)
298 | dist.barrier()
299 | else:
300 | # print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
301 | dist.init_process_group(backend="nccl")
302 | n = torch.cuda.device_count() // local_world_size
303 | device_ids = list(range(local_rank * n, (local_rank + 1) * n))
304 |
305 | print(
306 | f"[{os.getpid()}] rank = {dist.get_rank()}, "
307 | + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end=''
308 | )
309 | # device_ids = [local_rank]
310 | # print(device_ids)
311 |
312 | dataset = datasets.get(args.data_name, None)
313 | train_file, test_file = train_files.get(args.data_name, None), test_files.get(args.data_name, None)
314 | if dataset is None:
315 | raise Exception("请输入正确的数据集名称")
316 | label2id = {}
317 | id2label = {}
318 | with open('./data/{}/labels.txt'.format(args.data_name), 'r', encoding="utf-8") as fp:
319 | labels = fp.read().strip().split('\n')
320 | for i, label in enumerate(labels):
321 | label2id[label] = i
322 | id2label[i] = label
323 | if args.local_rank == 0:
324 | print(label2id)
325 | if args.local_rank == 0:
326 | print('========加载数据集========')
327 | args.train_batch_size = int(args.train_batch_size / torch.cuda.device_count())
328 | args.eval_batch_size = int(args.eval_batch_size / torch.cuda.device_count())
329 |
330 | collate = Collate(tokenizer=tokenizer, max_len=args.max_seq_len, tag2id=label2id)
331 |
332 | train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
333 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
334 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
335 | collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler)
336 | test_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, test_file))
337 |
338 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
339 |
340 | test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size,
341 | collate_fn=collate.collate_fn, num_workers=4, sampler=test_sampler)
342 | # test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False,
343 | # collate_fn=collate.collate_fn, num_workers=4)
344 | if args.local_rank == 0:
345 | print('========定义模型和优化器========')
346 | model = models.BertForSequenceClassification(args)
347 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
348 | trainer = Trainer(args, optimizer)
349 |
350 | if args.do_train:
351 | if args.retrain:
352 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
353 | checkpoint = torch.load(checkpoint_path)
354 | # trainer.optimizer.load_state_dict(checkpoint['optimizer'])
355 | model.cuda(args.local_rank)
356 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
357 | r_model.load_state_dict(checkpoint['state_dict'])
358 | if args.local_rank == 0:
359 | print("========加载模型继续训练========")
360 | # 训练和验证
361 | trainer.train(r_model, train_loader, train_sampler, dev_loader=None)
362 | else:
363 | if args.local_rank == 0:
364 | print('========进行训练========')
365 | model.cuda(args.local_rank)
366 | r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
367 | # 训练和验证
368 | trainer.train(r_model, train_loader, train_sampler, dev_loader=None)
369 |
370 | # 测试
371 | if args.do_test:
372 | if args.local_rank == 0:
373 | print('========进行测试========')
374 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
375 | # 多卡预测要先模型并行
376 | model.cuda(args.local_rank)
377 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
378 | # 再加载模型
379 | model = trainer.load_model(model, checkpoint_path)
380 |
381 | total_loss, test_outputs, test_targets = trainer.test(model, test_loader)
382 | accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets)
383 | if args.local_rank == 0:
384 | print(
385 | "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy,
386 | micro_f1,
387 | macro_f1))
388 | report = trainer.get_classification_report(test_outputs, test_targets, labels)
389 | print(report)
390 |
391 | # 预测
392 | if args.do_predict:
393 | if args.local_rank == 0:
394 | print('========进行预测========')
395 | checkpoint_path = './checkpoints/{}/best.pt'.format(args.data_name)
396 | model.cuda(args.local_rank)
397 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
398 | model = trainer.load_model(model, checkpoint_path)
399 | line = test_dataset[0]
400 | text = line[0]
401 | result = trainer.predict(tokenizer, text, id2label, args, model)
402 | if args.local_rank == 0:
403 | print(text)
404 | print("预测标签:", result[0])
405 | print("真实标签:", line[1])
406 | print("==========================")
407 |
408 | dist.destroy_process_group()
409 |
410 |
411 | if __name__ == '__main__':
412 | args = bert_config.Args().get_parser()
413 | utils.set_seed(args.seed)
414 | utils.set_logger(os.path.join(args.log_dir, 'main.log'))
415 |
416 | tokenizer = BertTokenizer.from_pretrained(args.bert_dir)
417 | args.tokenizer = tokenizer
418 | args.nprocs = args.local_world_size
419 | mp.spawn(main_worker, nprocs=args.nprocs, args=(args,))
420 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from transformers import BertModel
2 | import torch.nn as nn
3 |
4 |
5 | class BertForSequenceClassification(nn.Module):
6 | def __init__(self, args):
7 | super(BertForSequenceClassification, self).__init__()
8 | self.bert = BertModel.from_pretrained(args.bert_dir)
9 | self.bert_config = self.bert.config
10 | out_dims = self.bert_config.hidden_size
11 | self.dropout = nn.Dropout(0.3)
12 | self.linear = nn.Linear(out_dims, args.num_tags)
13 |
14 | def forward(self, token_ids, attention_masks, token_type_ids):
15 | bert_outputs = self.bert(
16 | input_ids = token_ids,
17 | attention_mask = attention_masks,
18 | token_type_ids = token_type_ids,
19 | )
20 | seq_out = bert_outputs[1]
21 | seq_out = self.dropout(seq_out)
22 | seq_out = self.linear(seq_out)
23 | return seq_out
--------------------------------------------------------------------------------
/nvidia.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | :: 執行的指令 ( 請注意若要使用 pipe 符號必須在之前加上一個 ^ 符號 )
4 | SET ExecuteCommand=nvidia-smi
5 |
6 | :: 單位: 秒
7 | SET ExecutePeriod=1
8 |
9 |
10 | SETLOCAL EnableDelayedExpansion
11 |
12 | :loop
13 |
14 | cls
15 |
16 | echo !date! !time!
17 | echo 每 !ExecutePeriod! 秒執行一次,指令^: !ExecuteCommand!
18 |
19 | echo.
20 |
21 | %ExecuteCommand%
22 |
23 | timeout /t %ExecutePeriod% > nul
24 |
25 | goto loop
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tokenizers==0.10.2
2 | torch==1.6.0
3 | transformers==4.5.1
4 | pytorch-lightning==0.9.0
5 | tensorboard==2.2.0
6 |
--------------------------------------------------------------------------------
/test_ddp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import torch
5 | import torchvision
6 | from torch import distributed as dist
7 | from torchvision.models import resnet18
8 | from torch.utils.data import DataLoader
9 | from torchvision.datasets import MNIST
10 | from torchvision.transforms import ToTensor
11 | from torch.nn.parallel import DistributedDataParallel as DDP
12 | from torch.utils.data.distributed import DistributedSampler
13 | import numpy as np
14 |
15 |
16 | def reduce_loss(tensor, rank, world_size):
17 | with torch.no_grad():
18 | dist.reduce(tensor, dst=0)
19 | if rank == 0:
20 | tensor /= world_size
21 |
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--local_rank', type=int, help="local gpu id")
25 | parser.add_argument('--world_size', type=int, help="total nodes")
26 | args = parser.parse_args()
27 |
28 | # world_size = os.environ["world_size"]
29 |
30 | batch_size = 128
31 | epochs = 5
32 | lr = 0.001
33 | n = torch.cuda.device_count() // args.world_size
34 | device_ids = list(range(args.local_rank * n, (args.local_rank + 1) * n))
35 | print("初始化1")
36 | init_method_path = "D:\\Code\\project\\pytorch-distributed training\\pytorch_bert_chinese_text_classification\\tmp\\"
37 | if os.path.exists(init_method_path):
38 | os.remove(init_method_path)
39 | print("已删除init_method_path")
40 | dist.init_process_group(backend='gloo',
41 | init_method='file:///{}'.format(init_method_path),
42 | rank=int(args.local_rank), world_size=int(args.world_size))
43 | torch.cuda.set_device(args.local_rank)
44 | global_rank = dist.get_rank()
45 | print("初始化2")
46 |
47 | print(
48 | f"[{os.getpid()}] rank = {dist.get_rank()}, "
49 | + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids} \n", end=''
50 | )
51 |
52 | from torchvision.models.resnet import ResNet, BasicBlock
53 |
54 |
55 | class MnistResNet(ResNet):
56 | def __init__(self):
57 | super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
58 | self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
59 |
60 | def forward(self, x):
61 | return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)
62 |
63 |
64 | # net = resnet18()
65 | net = MnistResNet()
66 | net.cuda()
67 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
68 | net = DDP(net, device_ids=[args.local_rank], output_device=args.local_rank)
69 |
70 |
71 | class ToNumpy(object):
72 | def __call__(self, sample):
73 | return np.array(sample)
74 |
75 |
76 | data_root = 'dataset'
77 | trainset = MNIST(root=data_root,
78 | download=True,
79 | train=True,
80 | transform=torchvision.transforms.Compose(
81 | [ToNumpy(), torchvision.transforms.ToTensor()])
82 | )
83 |
84 | valset = MNIST(root=data_root,
85 | download=True,
86 | train=False,
87 | transform=torchvision.transforms.Compose(
88 | [ToNumpy(), torchvision.transforms.ToTensor()])
89 | )
90 |
91 | sampler = DistributedSampler(trainset)
92 | train_loader = DataLoader(trainset,
93 | batch_size=batch_size,
94 | shuffle=False,
95 | pin_memory=True,
96 | sampler=sampler,
97 | )
98 |
99 | val_loader = DataLoader(valset,
100 | batch_size=batch_size,
101 | shuffle=False,
102 | pin_memory=True,
103 | )
104 |
105 | criterion = torch.nn.CrossEntropyLoss()
106 | opt = torch.optim.Adam(net.parameters(), lr=lr)
107 |
108 | net.train()
109 | for e in range(epochs):
110 | # DistributedSampler deterministically shuffle data
111 | # by seting random seed be current number epoch
112 | # so if do not call set_epoch when start of one epoch
113 | # the order of shuffled data will be always same
114 | sampler.set_epoch(e)
115 | for idx, (imgs, labels) in enumerate(train_loader):
116 | imgs = imgs.cuda()
117 | labels = labels.cuda()
118 | output = net(imgs)
119 | loss = criterion(output, labels)
120 | opt.zero_grad()
121 | loss.backward()
122 | opt.step()
123 | reduce_loss(loss, global_rank, args.world_size)
124 | if idx % 10 == 0 and global_rank == 0:
125 | print('Epoch: {} step: {} loss: {}'.format(e, idx, loss.item()))
126 | net.eval()
127 | with torch.no_grad():
128 | cnt = 0
129 | total = len(val_loader.dataset)
130 | for imgs, labels in val_loader:
131 | imgs, labels = imgs.cuda(), labels.cuda()
132 | output = net(imgs)
133 | predict = torch.argmax(output, dim=1)
134 | cnt += (predict == labels).sum().item()
135 |
136 | if global_rank == 0:
137 | print('eval accuracy: {}'.format(cnt / total))
138 |
139 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taishan1994/pytorch_bert_chinese_text_classification/8e3166b98c784f972d17df17ddeb56e3e494184b/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import random
3 | import logging
4 | import time
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def timer(func):
10 | """
11 | 函数计时器
12 | :param func:
13 | :return:
14 | """
15 |
16 | @functools.wraps(func)
17 | def wrapper(*args, **kwargs):
18 | start = time.time()
19 | res = func(*args, **kwargs)
20 | end = time.time()
21 | print("{}共耗时约{:.4f}秒".format(func.__name__, end - start))
22 | return res
23 |
24 | return wrapper
25 |
26 |
27 | def set_seed(seed=123):
28 | """
29 | 设置随机数种子,保证实验可重现
30 | :param seed:
31 | :return:
32 | """
33 | random.seed(seed)
34 | torch.manual_seed(seed)
35 | np.random.seed(seed)
36 | torch.cuda.manual_seed_all(seed)
37 |
38 |
39 | def set_logger(log_path):
40 | """
41 | 配置log
42 | :param log_path:s
43 | :return:
44 | """
45 | logger = logging.getLogger()
46 | logger.setLevel(logging.INFO)
47 |
48 | # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler
49 | if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers):
50 | file_handler = logging.FileHandler(log_path)
51 | formatter = logging.Formatter(
52 | '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s')
53 | file_handler.setFormatter(formatter)
54 | logger.addHandler(file_handler)
55 |
56 | if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers):
57 | stream_handler = logging.StreamHandler()
58 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
59 | logger.addHandler(stream_handler)
60 |
61 |
--------------------------------------------------------------------------------