├── .gitignore ├── README.md ├── albert.py ├── albert_base_zh_additional_36k_steps ├── albert_config_base.json └── vocab.txt ├── albert_tiny ├── albert_config_tiny.json └── vocab.txt ├── albert_xlarge_zh_183k ├── albert_config_xlarge.json └── vocab.txt ├── data └── sougou_mini │ ├── test.csv │ └── train.csv ├── model_evaluate.py ├── model_predict.py ├── model_train.py ├── requirements.txt └── 模型参数对比.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | label.json 3 | __pycache__ 4 | albert_tiny/albert_model* 5 | albert_base_zh_additional_36k_steps/albert_model* 6 | albert_base_zh_additional_36k_steps/checkpoint 7 | albert_xlarge_zh_183k/albert_model* 8 | albert_xlarge_zh_183k/checkpoint 9 | albert_test.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 本项目采用Keras和ALBERT实现文本多分类任务,其中对ALBERT进行微调。 2 | 3 | ### 维护者 4 | 5 | - jclian91 6 | 7 | ### 数据集 8 | 9 | sougou小分类数据集,共有5个类别,分别为体育、健康、军事、教育、汽车。 10 | 11 | 划分为训练集和测试集,其中训练集每个分类800条样本,测试集每个分类100条样本。 12 | 13 | ### 代码结构 14 | 15 | ``` 16 | . 17 | ├── albert_tiny(albert tiny预训练模型) 18 | │ ├── albert_config_tiny.json 19 | │ ├── albert_model.ckpt.data-00000-of-00001 20 | │ ├── albert_model.ckpt.index 21 | │ ├── albert_model.ckpt.meta 22 | │ ├── checkpoint 23 | │ └── vocab.txt 24 | ├── albert_base_zh_additional_36k_steps(albert base预训练模型) 25 | │ ├── albert_config_base.json 26 | │ ├── albert_model.ckpt.data-00000-of-00001 27 | │ ├── albert_model.ckpt.index 28 | │ ├── albert_model.ckpt.meta 29 | │ ├── checkpoint 30 | │ └── vocab.txt 31 | ├── albert_xlarge_zh_183k(albert large预训练模型) 32 | │ ├── albert_config_xlarge.json 33 | │ ├── albert_model.ckpt.data-00000-of-00001 34 | │ ├── albert_model.ckpt.index 35 | │ ├── albert_model.ckpt.meta 36 | │ ├── checkpoint 37 | │ └── vocab.txt 38 | ├── albert.py(albert模型构建脚本,来自开源项目) 39 | ├── albert_test.py(albert模型导入测试脚本) 40 | ├── data(数据集) 41 | │ └── sougou_mini 42 | │ ├── test.csv 43 | │ └── train.csv 44 | ├── label.json(标签词典) 45 | ├── model_evaluate.py(模型评估脚本) 46 | ├── model_predict.py(模型预测脚本) 47 | ├── model_train.py(模型训练脚本) 48 | ├── README.md 49 | └── requirements.txt(第三方模块) 50 | ``` 51 | 52 | ## 模型效果 53 | 54 | - sougou数据集, albert-tiny 55 | 56 | 模型参数: batch_size = 8, maxlen = 300, epoch=3 57 | 58 | 评估结果: 59 | 60 | ``` 61 | precision recall f1-score support 62 | 63 | 体育 0.9700 0.9798 0.9749 99 64 | 健康 0.9278 0.9091 0.9184 99 65 | 军事 0.9899 0.9899 0.9899 99 66 | 教育 0.8585 0.9192 0.8878 99 67 | 汽车 1.0000 0.9394 0.9688 99 68 | 69 | accuracy 0.9475 495 70 | macro avg 0.9492 0.9475 0.9479 495 71 | weighted avg 0.9492 0.9475 0.9479 495 72 | ``` 73 | 74 | - sougou数据集, albert_base_zh_additional_36k_steps 75 | 76 | 模型参数: batch_size = 8, maxlen = 300, epoch=3 77 | 78 | 评估结果: 79 | 80 | ``` 81 | precision recall f1-score support 82 | 83 | 体育 0.9802 1.0000 0.9900 99 84 | 健康 0.9684 0.9293 0.9485 99 85 | 军事 1.0000 0.9899 0.9949 99 86 | 教育 0.8739 0.9798 0.9238 99 87 | 汽车 1.0000 0.9091 0.9524 99 88 | 89 | accuracy 0.9616 495 90 | macro avg 0.9645 0.9616 0.9619 495 91 | weighted avg 0.9645 0.9616 0.9619 495 92 | ``` 93 | 94 | - sougou数据集, albert_xlarge_zh_183k 95 | 96 | 模型参数: batch_size = 2, maxlen = 300, epoch=3 97 | 98 | 评估结果: 99 | 100 | ``` 101 | precision recall f1-score support 102 | 103 | 体育 0.9898 0.9798 0.9848 99 104 | 健康 0.9412 0.9697 0.9552 99 105 | 军事 0.9706 1.0000 0.9851 99 106 | 教育 0.9300 0.9394 0.9347 99 107 | 汽车 0.9892 0.9293 0.9583 99 108 | 109 | accuracy 0.9636 495 110 | macro avg 0.9642 0.9636 0.9636 495 111 | weighted avg 0.9642 0.9636 0.9636 495 112 | ``` 113 | 114 | ### 项目启动 115 | 116 | 1. 将ALBERT中文预训练模型放在对应的文件夹下 117 | 2. 所需Python第三方模块参考requirements.txt文档 118 | 3. 自己需要分类的数据按照data/sougou_mini的格式准备好 119 | 4. 调整模型参数,运行model_train.py进行模型训练 120 | 5. 运行model_evaluate.py进行模型评估 121 | 6. 运行model_predict.py对新文本进行评估 -------------------------------------------------------------------------------- /albert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from keras_bert.backend import keras 9 | from keras_transformer import gelu 10 | from keras_bert import get_custom_objects as get_bert_custom_objects 11 | from keras_bert.layers import Masked, Extract 12 | from keras_pos_embd import PositionEmbedding 13 | from keras_layer_normalization import LayerNormalization 14 | from keras_multi_head import MultiHeadAttention 15 | from keras_position_wise_feed_forward import FeedForward 16 | from keras_adaptive_softmax import AdaptiveEmbedding, AdaptiveSoftmax 17 | 18 | 19 | __all__ = [ 20 | 'get_custom_objects', 'build_albert', 21 | 'load_brightmart_albert_zh_checkpoint', 22 | ] 23 | 24 | 25 | def get_custom_objects(): 26 | custom_objects = get_bert_custom_objects() 27 | custom_objects['AdaptiveEmbedding'] = AdaptiveEmbedding 28 | custom_objects['AdaptiveSoftmax'] = AdaptiveSoftmax 29 | return custom_objects 30 | 31 | 32 | def build_albert(token_num, 33 | pos_num=512, 34 | seq_len=512, 35 | embed_dim=128, 36 | hidden_dim=768, 37 | transformer_num=12, 38 | head_num=12, 39 | feed_forward_dim=3072, 40 | dropout_rate=0.1, 41 | attention_activation=None, 42 | feed_forward_activation='gelu', 43 | training=True, 44 | trainable=None, 45 | output_layers=None): 46 | """Get ALBERT model. 47 | See: https://arxiv.org/pdf/1909.11942.pdf 48 | :param token_num: Number of tokens. 49 | :param pos_num: Maximum position. 50 | :param seq_len: Maximum length of the input sequence or None. 51 | :param embed_dim: Dimensions of embeddings. 52 | :param hidden_dim: Dimensions of hidden layers. 53 | :param transformer_num: Number of transformers. 54 | :param head_num: Number of heads in multi-head attention 55 | in each transformer. 56 | :param feed_forward_dim: Dimension of the feed forward layer 57 | in each transformer. 58 | :param dropout_rate: Dropout rate. 59 | :param attention_activation: Activation for attention layers. 60 | :param feed_forward_activation: Activation for feed-forward layers. 61 | :param training: A built model with MLM and NSP outputs will be returned 62 | if it is `True`, otherwise the input layers and the last 63 | feature extraction layer will be returned. 64 | :param trainable: Whether the model is trainable. 65 | :param output_layers: A list of indices of output layers. 66 | """ 67 | if attention_activation == 'gelu': 68 | attention_activation = gelu 69 | if feed_forward_activation == 'gelu': 70 | feed_forward_activation = gelu 71 | if trainable is None: 72 | trainable = training 73 | 74 | def _trainable(_layer): 75 | if isinstance(trainable, (list, tuple, set)): 76 | for prefix in trainable: 77 | if _layer.name.startswith(prefix): 78 | return True 79 | return False 80 | return trainable 81 | 82 | # Build inputs 83 | input_token = keras.layers.Input(shape=(seq_len,), name='Input-Token') 84 | input_segment = keras.layers.Input(shape=(seq_len,), name='Input-Segment') 85 | inputs = [input_token, input_segment] 86 | 87 | # Build embeddings 88 | embed_token, embed_weights, embed_projection = AdaptiveEmbedding( 89 | input_dim=token_num, 90 | output_dim=hidden_dim, 91 | embed_dim=embed_dim, 92 | mask_zero=True, 93 | trainable=trainable, 94 | return_embeddings=True, 95 | return_projections=True, 96 | name='Embed-Token', 97 | )(input_token) 98 | embed_segment = keras.layers.Embedding( 99 | input_dim=2, 100 | output_dim=hidden_dim, 101 | trainable=trainable, 102 | name='Embed-Segment', 103 | )(input_segment) 104 | embed_layer = keras.layers.Add(name='Embed-Token-Segment')( 105 | [embed_token, embed_segment]) 106 | embed_layer = PositionEmbedding( 107 | input_dim=pos_num, 108 | output_dim=hidden_dim, 109 | mode=PositionEmbedding.MODE_ADD, 110 | trainable=trainable, 111 | name='Embedding-Position', 112 | )(embed_layer) 113 | 114 | if dropout_rate > 0.0: 115 | dropout_layer = keras.layers.Dropout( 116 | rate=dropout_rate, 117 | name='Embedding-Dropout', 118 | )(embed_layer) 119 | else: 120 | dropout_layer = embed_layer 121 | embed_layer = LayerNormalization( 122 | trainable=trainable, 123 | name='Embedding-Norm', 124 | )(dropout_layer) 125 | 126 | # Build shared transformer 127 | attention_layer = MultiHeadAttention( 128 | head_num=head_num, 129 | activation=attention_activation, 130 | name='Attention', 131 | ) 132 | attention_normal = LayerNormalization(name='Attention-Normal') 133 | feed_forward_layer = FeedForward( 134 | units=feed_forward_dim, 135 | activation=feed_forward_activation, 136 | name='Feed-Forward' 137 | ) 138 | feed_forward_normal = LayerNormalization(name='Feed-Forward-Normal') 139 | 140 | transformed = embed_layer 141 | transformed_layers = [] 142 | for i in range(transformer_num): 143 | attention_input = transformed 144 | transformed = attention_layer(transformed) 145 | if dropout_rate > 0.0: 146 | transformed = keras.layers.Dropout( 147 | rate=dropout_rate, 148 | name='Attention-Dropout-{}'.format(i + 1), 149 | )(transformed) 150 | transformed = keras.layers.Add( 151 | name='Attention-Add-{}'.format(i + 1), 152 | )([attention_input, transformed]) 153 | transformed = attention_normal(transformed) 154 | 155 | feed_forward_input = transformed 156 | transformed = feed_forward_layer(transformed) 157 | if dropout_rate > 0.0: 158 | transformed = keras.layers.Dropout( 159 | rate=dropout_rate, 160 | name='Feed-Forward-Dropout-{}'.format(i + 1), 161 | )(transformed) 162 | transformed = keras.layers.Add( 163 | name='Feed-Forward-Add-{}'.format(i + 1), 164 | )([feed_forward_input, transformed]) 165 | transformed = feed_forward_normal(transformed) 166 | transformed_layers.append(transformed) 167 | 168 | if training: 169 | # Build tasks 170 | mlm_dense_layer = keras.layers.Dense( 171 | units=hidden_dim, 172 | activation=feed_forward_activation, 173 | name='MLM-Dense', 174 | )(transformed) 175 | mlm_norm_layer = LayerNormalization(name='MLM-Norm')(mlm_dense_layer) 176 | mlm_pred_layer = AdaptiveSoftmax( 177 | input_dim=hidden_dim, 178 | output_dim=token_num, 179 | embed_dim=embed_dim, 180 | bind_embeddings=True, 181 | bind_projections=True, 182 | name='MLM-Sim', 183 | )([mlm_norm_layer, embed_weights, embed_projection]) 184 | masked_layer = Masked(name='MLM')([mlm_pred_layer, inputs[-1]]) 185 | extract_layer = Extract(index=0, name='Extract')(transformed) 186 | nsp_dense_layer = keras.layers.Dense( 187 | units=hidden_dim, 188 | activation='tanh', 189 | name='SOP-Dense', 190 | )(extract_layer) 191 | nsp_pred_layer = keras.layers.Dense( 192 | units=2, 193 | activation='softmax', 194 | name='SOP', 195 | )(nsp_dense_layer) 196 | model = keras.models.Model( 197 | inputs=inputs, 198 | outputs=[masked_layer, nsp_pred_layer]) 199 | for layer in model.layers: 200 | layer.trainable = _trainable(layer) 201 | return model 202 | if output_layers is not None: 203 | if isinstance(output_layers, list): 204 | output_layers = [ 205 | transformed_layers[index] for index in output_layers] 206 | output = keras.layers.Concatenate( 207 | name='Output', 208 | )(output_layers) 209 | else: 210 | output = transformed_layers[output_layers] 211 | model = keras.models.Model(inputs=inputs, outputs=output) 212 | return model 213 | model = keras.models.Model(inputs=inputs, outputs=transformed) 214 | for layer in model.layers: 215 | layer.trainable = _trainable(layer) 216 | return inputs, transformed 217 | 218 | 219 | def load_brightmart_albert_zh_checkpoint(checkpoint_path, **kwargs): 220 | """Load checkpoint from https://github.com/brightmart/albert_zh 221 | :param checkpoint_path: path to checkpoint folder. 222 | :param kwargs: arguments for albert model. 223 | :return: 224 | """ 225 | config = {} 226 | for file_name in os.listdir(checkpoint_path): 227 | if file_name.startswith('albert_config'): 228 | with open(os.path.join(checkpoint_path, file_name)) as reader: 229 | config = json.load(reader) 230 | break 231 | 232 | def _set_if_not_existed(key, value): 233 | if key not in kwargs: 234 | kwargs[key] = value 235 | 236 | _set_if_not_existed('training', True) 237 | training = kwargs['training'] 238 | _set_if_not_existed('token_num', config['vocab_size']) 239 | _set_if_not_existed('pos_num', config['max_position_embeddings']) 240 | _set_if_not_existed('seq_len', config['max_position_embeddings']) 241 | _set_if_not_existed('embed_dim', config['embedding_size']) 242 | _set_if_not_existed('hidden_dim', config['hidden_size']) 243 | _set_if_not_existed('transformer_num', config['num_hidden_layers']) 244 | _set_if_not_existed('head_num', config['num_attention_heads']) 245 | _set_if_not_existed('feed_forward_dim', config['intermediate_size']) 246 | _set_if_not_existed('dropout_rate', config['hidden_dropout_prob']) 247 | _set_if_not_existed('feed_forward_activation', config['hidden_act']) 248 | 249 | model = build_albert(**kwargs) 250 | if not training: 251 | inputs, outputs = model 252 | model = keras.models.Model(inputs, outputs) 253 | 254 | def _checkpoint_loader(checkpoint_file): 255 | def _loader(name): 256 | return tf.train.load_variable(checkpoint_file, name) 257 | return _loader 258 | 259 | loader = _checkpoint_loader( 260 | os.path.join(checkpoint_path, 'albert_model.ckpt')) 261 | 262 | model.get_layer(name='Embed-Token').set_weights([ 263 | loader('bert/embeddings/word_embeddings'), 264 | loader('bert/embeddings/word_embeddings_2'), 265 | ]) 266 | model.get_layer(name='Embed-Segment').set_weights([ 267 | loader('bert/embeddings/token_type_embeddings'), 268 | ]) 269 | model.get_layer(name='Embedding-Position').set_weights([ 270 | loader('bert/embeddings/position_embeddings'), 271 | ]) 272 | model.get_layer(name='Embedding-Norm').set_weights([ 273 | loader('bert/embeddings/LayerNorm/gamma'), 274 | loader('bert/embeddings/LayerNorm/beta'), 275 | ]) 276 | 277 | model.get_layer(name='Attention').set_weights([ 278 | loader('bert/encoder/layer_shared/attention/self/query/kernel'), 279 | loader('bert/encoder/layer_shared/attention/self/query/bias'), 280 | loader('bert/encoder/layer_shared/attention/self/key/kernel'), 281 | loader('bert/encoder/layer_shared/attention/self/key/bias'), 282 | loader('bert/encoder/layer_shared/attention/self/value/kernel'), 283 | loader('bert/encoder/layer_shared/attention/self/value/bias'), 284 | loader('bert/encoder/layer_shared/attention/output/dense/kernel'), 285 | loader('bert/encoder/layer_shared/attention/output/dense/bias'), 286 | ]) 287 | model.get_layer(name='Attention-Normal').set_weights([ 288 | loader('bert/encoder/layer_shared/attention/output/LayerNorm/gamma'), 289 | loader('bert/encoder/layer_shared/attention/output/LayerNorm/beta'), 290 | ]) 291 | model.get_layer(name='Feed-Forward').set_weights([ 292 | loader('bert/encoder/layer_shared/intermediate/dense/kernel'), 293 | loader('bert/encoder/layer_shared/intermediate/dense/bias'), 294 | loader('bert/encoder/layer_shared/output/dense/kernel'), 295 | loader('bert/encoder/layer_shared/output/dense/bias'), 296 | ]) 297 | model.get_layer(name='Feed-Forward-Normal').set_weights([ 298 | loader('bert/encoder/layer_shared/output/LayerNorm/gamma'), 299 | loader('bert/encoder/layer_shared/output/LayerNorm/beta'), 300 | ]) 301 | 302 | if training: 303 | model.get_layer(name='MLM-Dense').set_weights([ 304 | loader('cls/predictions/transform/dense/kernel'), 305 | loader('cls/predictions/transform/dense/bias'), 306 | ]) 307 | model.get_layer(name='MLM-Norm').set_weights([ 308 | loader('cls/predictions/transform/LayerNorm/gamma'), 309 | loader('cls/predictions/transform/LayerNorm/beta'), 310 | ]) 311 | model.get_layer(name='MLM-Sim').set_weights([ 312 | loader('cls/predictions/output_bias'), 313 | ]) 314 | 315 | model.get_layer(name='SOP-Dense').set_weights([ 316 | loader('bert/pooler/dense/kernel'), 317 | loader('bert/pooler/dense/bias'), 318 | ]) 319 | model.get_layer(name='SOP').set_weights([ 320 | np.transpose(loader('cls/seq_relationship/output_weights')), 321 | loader('cls/seq_relationship/output_bias'), 322 | ]) 323 | 324 | return model 325 | -------------------------------------------------------------------------------- /albert_base_zh_additional_36k_steps/albert_config_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 768, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 3072 , 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 12, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_tiny/albert_config_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 312, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 1248 , 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 4, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /albert_xlarge_zh_183k/albert_config_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 2048, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 8192, 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 32, 12 | "num_hidden_layers": 24, 13 | 14 | "pooler_fc_size": 1024, 15 | "pooler_num_attention_heads": 64, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128, 21 | "ln_type":"postln" 22 | 23 | } 24 | -------------------------------------------------------------------------------- /model_evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/12/23 15:28 3 | # @Author : Jclian91 4 | # @File : model_evaluate.py 5 | # @Place : Yangpu, Shanghai 6 | # 模型评估脚本 7 | import json 8 | import numpy as np 9 | import pandas as pd 10 | from keras.models import load_model 11 | # from keras_bert import get_custom_objects 12 | from albert import get_custom_objects 13 | from sklearn.metrics import classification_report 14 | 15 | from model_train import token_dict, OurTokenizer 16 | 17 | maxlen = 300 18 | 19 | # 加载训练好的模型 20 | model = load_model("albert_large_cls_sougou.h5", custom_objects=get_custom_objects()) 21 | tokenizer = OurTokenizer(token_dict) 22 | with open("label.json", "r", encoding="utf-8") as f: 23 | label_dict = json.loads(f.read()) 24 | 25 | 26 | # 对单句话进行预测 27 | def predict_single_text(text): 28 | # 利用BERT进行tokenize 29 | text = text[:maxlen] 30 | x1, x2 = tokenizer.encode(first=text) 31 | X1 = x1 + [0] * (maxlen - len(x1)) if len(x1) < maxlen else x1 32 | X2 = x2 + [0] * (maxlen - len(x2)) if len(x2) < maxlen else x2 33 | 34 | # 模型预测并输出预测结果 35 | predicted = model.predict([[X1], [X2]]) 36 | y = np.argmax(predicted[0]) 37 | return label_dict[str(y)] 38 | 39 | 40 | # 模型评估 41 | def evaluate(): 42 | test_df = pd.read_csv("data/sougou_mini/test.csv").fillna(value="") 43 | true_y_list, pred_y_list = [], [] 44 | for i in range(test_df.shape[0]): 45 | print("predict %d samples" % (i+1)) 46 | true_y, content = test_df.iloc[i, :] 47 | pred_y = predict_single_text(content) 48 | true_y_list.append(true_y) 49 | pred_y_list.append(pred_y) 50 | 51 | return classification_report(true_y_list, pred_y_list, digits=4) 52 | 53 | 54 | output_data = evaluate() 55 | print("model evaluate result:\n") 56 | print(output_data) -------------------------------------------------------------------------------- /model_predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/12/23 15:28 3 | # @Author : Jclian91 4 | # @File : model_predict.py 5 | # @Place : Yangpu, Shanghai 6 | # 模型预测脚本 7 | 8 | import time 9 | import json 10 | import numpy as np 11 | 12 | from model_train import token_dict, OurTokenizer 13 | from keras.models import load_model 14 | # from keras_bert import get_custom_objects 15 | from albert import get_custom_objects 16 | 17 | maxlen = 300 18 | 19 | # 加载训练好的模型 20 | model = load_model("albert_cls_sougou.h5", custom_objects=get_custom_objects()) 21 | tokenizer = OurTokenizer(token_dict) 22 | with open("label.json", "r", encoding="utf-8") as f: 23 | label_dict = json.loads(f.read()) 24 | 25 | s_time = time.time() 26 | # 预测示例语句 27 | text = "说到硬派越野SUV,你会想起哪些车型?是被称为“霸道”的丰田 普拉多 (配置 | 询价) ,还是被叫做“山猫”的帕杰罗,亦或者是“渣男专车”奔驰大G、" \ 28 | "“沙漠王子”途乐。总之,随着世界各国越来越重视对环境的保护,那些大排量的越野SUV在不久的将来也会渐渐消失在我们的视线之中,所以与其错过," \ 29 | "不如趁着还年轻,在有生之年里赶紧去入手一台能让你心仪的硬派越野SUV。而今天我想要来跟你们聊的,正是全球公认的十大硬派越野SUV," \ 30 | "越野迷们看完之后也不妨思考一下,到底哪款才是你的菜,下面话不多说,赶紧开始吧。" 31 | 32 | 33 | # 利用BERT进行tokenize 34 | text = text[:maxlen] 35 | x1, x2 = tokenizer.encode(first=text) 36 | 37 | X1 = x1 + [0] * (maxlen-len(x1)) if len(x1) < maxlen else x1 38 | X2 = x2 + [0] * (maxlen-len(x2)) if len(x2) < maxlen else x2 39 | 40 | # 模型预测并输出预测结果 41 | predicted = model.predict([[X1], [X2]]) 42 | y = np.argmax(predicted[0]) 43 | 44 | 45 | print("原文: %s" % text) 46 | print("预测标签: %s" % label_dict[str(y)]) 47 | e_time = time.time() 48 | print("cost time:", e_time-s_time) -------------------------------------------------------------------------------- /model_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/12/23 14:19 3 | # @Author : Jclian91 4 | # @File : model_train.py 5 | # @Place : Yangpu, Shanghai 6 | import codecs 7 | import json 8 | import os 9 | # os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 10 | import numpy as np 11 | import pandas as pd 12 | from keras.layers import * 13 | from keras.models import Model 14 | from keras.optimizers import Adam 15 | from keras_bert import Tokenizer 16 | from albert import load_brightmart_albert_zh_checkpoint, build_albert 17 | 18 | 19 | # 建议长度 <= 510 20 | maxlen = 300 21 | BATCH_SIZE = 8 22 | dict_path = './albert_xlarge_zh_183k/vocab.txt' 23 | 24 | 25 | token_dict = {} 26 | with codecs.open(dict_path, 'r', 'utf-8') as reader: 27 | for line in reader: 28 | token = line.strip() 29 | token_dict[token] = len(token_dict) 30 | 31 | 32 | class OurTokenizer(Tokenizer): 33 | def _tokenize(self, text): 34 | R = [] 35 | for c in text: 36 | if c in self._token_dict: 37 | R.append(c) 38 | else: 39 | R.append('[UNK]') # 剩余的字符是[UNK] 40 | return R 41 | 42 | 43 | tokenizer = OurTokenizer(token_dict) 44 | 45 | 46 | def seq_padding(X, padding=0): 47 | L = [len(x) for x in X] 48 | ML = max(L) 49 | return np.array([ 50 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X 51 | ]) 52 | 53 | 54 | class DataGenerator: 55 | 56 | def __init__(self, data, batch_size=BATCH_SIZE): 57 | self.data = data 58 | self.batch_size = batch_size 59 | self.steps = len(self.data) // self.batch_size 60 | if len(self.data) % self.batch_size != 0: 61 | self.steps += 1 62 | 63 | def __len__(self): 64 | return self.steps 65 | 66 | def __iter__(self): 67 | while True: 68 | idxs = list(range(len(self.data))) 69 | np.random.shuffle(idxs) 70 | X1, X2, Y = [], [], [] 71 | for i in idxs: 72 | d = self.data[i] 73 | text = d[0][:maxlen] 74 | x1, x2 = tokenizer.encode(first=text) 75 | y = d[1] 76 | X1.append(x1) 77 | X2.append(x2) 78 | Y.append(y) 79 | if len(X1) == self.batch_size or i == idxs[-1]: 80 | X1 = seq_padding(X1) 81 | X2 = seq_padding(X2) 82 | Y = seq_padding(Y) 83 | yield [X1, X2], Y 84 | [X1, X2, Y] = [], [], [] 85 | 86 | 87 | # 构建模型 88 | def create_cls_model(num_labels): 89 | albert_model = load_brightmart_albert_zh_checkpoint('albert_xlarge_zh_183k', training=False) 90 | 91 | for layer in albert_model.layers: 92 | layer.trainable = True 93 | 94 | x1_in = Input(shape=(None,)) 95 | x2_in = Input(shape=(None,)) 96 | 97 | x = albert_model([x1_in, x2_in]) 98 | x = Lambda(lambda x: x[:, 0])(x) # 取出[CLS]对应的向量用来做分类 99 | p = Dense(num_labels, activation='softmax')(x) # 多分类 100 | 101 | model = Model([x1_in, x2_in], p) 102 | model.compile( 103 | loss='categorical_crossentropy', 104 | optimizer=Adam(1e-5), # 用足够小的学习率 105 | metrics=['accuracy'] 106 | ) 107 | model.summary() 108 | 109 | return model 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | # 数据处理, 读取训练集和测试集 115 | print("begin data processing...") 116 | train_df = pd.read_csv("data/sougou_mini/train.csv").fillna(value="") 117 | test_df = pd.read_csv("data/sougou_mini/test.csv").fillna(value="") 118 | 119 | labels = train_df["label"].unique() 120 | with open("label.json", "w", encoding="utf-8") as f: 121 | f.write(json.dumps(dict(zip(range(len(labels)), labels)), ensure_ascii=False, indent=2)) 122 | 123 | train_data = [] 124 | test_data = [] 125 | for i in range(train_df.shape[0]): 126 | label, content = train_df.iloc[i, :] 127 | label_id = [0] * len(labels) 128 | for j, _ in enumerate(labels): 129 | if _ == label: 130 | label_id[j] = 1 131 | train_data.append((content, label_id)) 132 | 133 | for i in range(test_df.shape[0]): 134 | label, content = test_df.iloc[i, :] 135 | label_id = [0] * len(labels) 136 | for j, _ in enumerate(labels): 137 | if _ == label: 138 | label_id[j] = 1 139 | test_data.append((content, label_id)) 140 | 141 | print("finish data processing!") 142 | 143 | # 模型训练 144 | model = create_cls_model(len(labels)) 145 | train_D = DataGenerator(train_data) 146 | test_D = DataGenerator(test_data) 147 | 148 | print("begin model training...") 149 | model.fit_generator( 150 | train_D.__iter__(), 151 | steps_per_epoch=len(train_D), 152 | epochs=3, 153 | validation_data=test_D.__iter__(), 154 | validation_steps=len(test_D) 155 | ) 156 | 157 | print("finish model training!") 158 | 159 | # 模型保存 160 | model.save('albert_large_cls_sougou.h5') 161 | print("Model saved!") 162 | 163 | result = model.evaluate_generator(test_D.__iter__(), steps=len(test_D)) 164 | print("模型评估结果:", result) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras_adaptive_softmax==0.6.0 2 | numpy==1.16.4 3 | keras_multi_head==0.27.0 4 | keras_pos_embd==0.11.0 5 | keras_layer_normalization==0.14.0 6 | tensorflow==1.14.0 7 | keras_transformer==0.38.0 8 | keras_bert==0.83.0 9 | keras_position_wise_feed_forward==0.6.0 10 | pandas==0.23.4 11 | Keras==2.2.4 12 | scikit_learn==0.24.0 13 | -------------------------------------------------------------------------------- /模型参数对比.md: -------------------------------------------------------------------------------- 1 | sougou-mini数据集,参数对比 2 | 3 | - albert_tiny 4 | 5 | Total params: 4,079,061 6 | Trainable params: 4,079,061 7 | 8 | - albert_base_zh_additional_36k_steps 9 | 10 | Total params: 10,290,693 11 | Trainable params: 10,290,693 12 | 13 | - albert_xlarge_zh_183k 14 | 15 | Total params: 54,391,813 16 | Trainable params: 54,391,813 17 | 18 | - chinese_L-12_H-768_A-12 19 | 20 | Total params: 101,680,901 21 | Trainable params: 101,680,901 --------------------------------------------------------------------------------