├── LICENSE ├── ReadMe.md ├── app.py ├── benchmark.py ├── bert-as-service-master.zip ├── bert ├── __init__.py ├── extract_features.py ├── modeling.py ├── modeling_test.py ├── optimization.py ├── optimization_test.py ├── tokenization.py └── tokenization_test.py ├── bert_serving ├── __init__.py └── client │ └── __init__.py ├── data ├── __init__.py ├── bert_sen.txt ├── bert_sentence.csv ├── bert_word.csv ├── bert_word.txt ├── cnews │ ├── cnews.test.txt │ ├── cnews.train.txt │ ├── cnews.vocab.txt │ └── vectors.txt ├── cnews_loader.py ├── cross_sen.csv ├── new_para.txt ├── predict.txt ├── predict_lstm_atten.txt ├── predict_lstm_early_atten.txt ├── predict_rnn.txt ├── predict_rnn_atten.txt ├── test_x.npy ├── test_xs.npy ├── test_y.npy ├── test_ys.npy ├── train_x.npy ├── train_xs.npy ├── train_y.npy ├── train_ys.npy └── word.txt ├── doc_classfier_bert.py ├── doc_textLoad.py ├── docker ├── Dockerfile └── entrypoint.sh ├── gpu_env.py ├── helper.py ├── helper_text ├── __init__.py ├── cnews_group.py └── copy_data.sh ├── images ├── acc_loss.png ├── acc_loss_rnn.png ├── cnn_architecture.png └── rnn_architecture.png ├── predict.py ├── requirements.client.txt ├── requirements.gpu.txt ├── requirements.txt ├── rnn_model.py ├── run_pre.py ├── run_rnn.py ├── run_rnn_bert.py ├── service ├── __init__.py ├── client.py └── server.py ├── tensorboard └── textlstm │ ├── events.out.tfevents.1543996656.ubuntu │ ├── events.out.tfevents.1544006771.ubuntu │ ├── events.out.tfevents.1544060846.ubuntu │ ├── events.out.tfevents.1544086349.ubuntu │ ├── events.out.tfevents.1544098320.ubuntu │ ├── events.out.tfevents.1544100290.ubuntu │ ├── events.out.tfevents.1544100390.ubuntu │ ├── events.out.tfevents.1544145498.ubuntu │ ├── events.out.tfevents.1544145532.ubuntu │ ├── events.out.tfevents.1544145759.ubuntu │ ├── events.out.tfevents.1544146122.ubuntu │ ├── events.out.tfevents.1544146235.ubuntu │ ├── events.out.tfevents.1544260427.ubuntu │ ├── events.out.tfevents.1544260511.ubuntu │ ├── events.out.tfevents.1544320131.ubuntu │ ├── events.out.tfevents.1544882354.ubuntu │ ├── events.out.tfevents.1544882966.ubuntu │ ├── events.out.tfevents.1544884075.ubuntu │ ├── events.out.tfevents.1544884383.ubuntu │ ├── events.out.tfevents.1544884648.ubuntu │ ├── events.out.tfevents.1544886089.ubuntu │ ├── events.out.tfevents.1544924477.ubuntu │ ├── events.out.tfevents.1544925170.ubuntu │ ├── events.out.tfevents.1544925237.ubuntu │ ├── events.out.tfevents.1544925420.ubuntu │ ├── events.out.tfevents.1544927106.ubuntu │ └── events.out.tfevents.1544929233.ubuntu └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 dzkang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # Text Classification with RNN--2018CCFBDCI汽车用户观点提取 2 | 3 | 汽车用户观点提取,使用bert模型的词向量作为RNN的初始化,其中data的train_x.npy表示的是bert的输入格式 4 | 而原始的数据集是经过word2id以及padding的,y不需要变化,rnn和加bert的rnn都可以用。具体参考text_Loader下的process file函数。 5 | 6 | 7 | 使用循环神经网络进行中文文本分类 8 | 9 | ## 环境 10 | 11 | - Python 2/3 12 | - TensorFlow 1.3以上 13 | - numpy 14 | - scikit-learn 15 | - scipy 16 | 17 | ## 数据集 18 | 19 | 使用汽车用户观点提取的任务进行训练与测试,数据集请自行到2018CCFBCI(https://www.datafountain.cn/competitions/329/details)下载,请遵循数据提供方的开源协议。 20 | 21 | 本次训练使用了其中的10个分类 22 | 23 | ## 预处理 24 | 25 | `data/cnews_loader.py`为数据的预处理文件。 26 | 27 | - `read_file()`: 读取文件数据; 28 | - `build_vocab()`: 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理; 29 | - `read_vocab()`: 读取上一步存储的词汇表,转换为`{词:id}`表示; 30 | - `read_category()`: 将分类目录固定,转换为`{类别: id}`表示; 31 | - `to_words()`: 将一条由id表示的数据重新转换为文字; 32 | - `process_file()`: 将数据集从文字转换为固定长度的id序列表示; 33 | - `batch_iter()`: 为神经网络的训练准备经过shuffle的批次的数据。 34 | 35 | ## RNN循环神经网络 36 | 37 | ### 配置项 38 | 39 | RNN可配置的参数如下所示,在`rnn_model.py`中。 40 | 41 | ```python 42 | class TRNNConfig(object): 43 | """RNN配置参数""" 44 | 45 | # 模型参数 46 | embedding_dim = 64 # 词向量维度 47 | seq_length = 600 # 序列长度 48 | num_classes = 10 # 类别数 49 | vocab_size = 5000 # 词汇表达小 50 | 51 | num_layers= 2 # 隐藏层层数 52 | hidden_dim = 128 # 隐藏层神经元 53 | rnn = 'gru' # lstm 或 gru 54 | 55 | dropout_keep_prob = 0.8 # dropout保留比例 56 | learning_rate = 1e-3 # 学习率 57 | 58 | batch_size = 128 # 每批训练大小 59 | num_epochs = 10 # 总迭代轮次 60 | 61 | print_per_batch = 100 # 每多少轮输出一次结果 62 | save_per_batch = 10 # 每多少轮存入tensorboard 63 | ``` 64 | 65 | ### RNN-bert模型 66 | 67 | 具体参看`run_rnn_bert.py`的实现。 68 | 69 | 关于RNN-bert模型--清华新浪新闻数据集的实现见github(https://github.com/a414351664/Bert-THUCNews )具体如下! 70 | 71 | # Text Classification with RNN 72 | 有bert后缀的都是在原来的基础上,进行改造,得到的效果能到97.5 73 | 无bert后缀的加入了Attention以及多层或者双向模型 74 | 75 | 使用循环神经网络进行中文文本分类 76 | 77 | 78 | ## 环境 79 | 80 | - Python 2/3 81 | - TensorFlow 1.3以上 82 | - numpy 83 | - scikit-learn 84 | - scipy 85 | 86 | ## 数据集 87 | 88 | 使用THUCNews的一个子集进行训练与测试,数据集请自行到[THUCTC:一个高效的中文文本分类工具包](http://thuctc.thunlp.org/)下载,请遵循数据提供方的开源协议。 89 | 90 | 本次训练使用了其中的10个分类,每个分类5000条数据。相关数据下载地址:链接: https://pan.baidu.com/s/1GmBFZfDKsXBMFEYQWFdrfQ 提取码: 9ixg 91 | 92 | 类别如下: 93 | 94 | ``` 95 | 体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐 96 | ``` 97 | 98 | 这个子集可以在此下载:链接: https://pan.baidu.com/s/1hugrfRu 密码: qfud 99 | 100 | 数据集划分如下: 101 | 102 | - 训练集: 5000*10 103 | - 验证集: 500*10 104 | - 测试集: 1000*10 105 | 106 | 从原数据集生成子集的过程请参看`helper`下的两个脚本。其中,`copy_data.sh`用于从每个分类拷贝6500个文件,`cnews_group.py`用于将多个文件整合到一个文件中。执行该文件后,得到三个数据文件: 107 | 108 | - cnews.train.txt: 训练集(50000条) 109 | - cnews.val.txt: 验证集(5000条) 110 | - cnews.test.txt: 测试集(10000条) 111 | 112 | ## 预处理 113 | 114 | `data/cnews_loader.py`为数据的预处理文件。 115 | 116 | - `read_file()`: 读取文件数据; 117 | - `build_vocab()`: 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理; 118 | - `read_vocab()`: 读取上一步存储的词汇表,转换为`{词:id}`表示; 119 | - `read_category()`: 将分类目录固定,转换为`{类别: id}`表示; 120 | - `to_words()`: 将一条由id表示的数据重新转换为文字; 121 | - `process_file()`: 将数据集从文字转换为固定长度的id序列表示; 122 | - `batch_iter()`: 为神经网络的训练准备经过shuffle的批次的数据。 123 | 124 | 经过数据预处理,数据的格式如下: 125 | 126 | | Data | Shape | Data | Shape | 127 | | :---------- | :---------- | :---------- | :---------- | 128 | | x_train | [50000, 600] | y_train | [50000, 10] | 129 | | x_val | [5000, 600] | y_val | [5000, 10] | 130 | | x_test | [10000, 600] | y_test | [10000, 10] | 131 | 132 | ## RNN循环神经网络 133 | 134 | ### 配置项 135 | 136 | RNN可配置的参数如下所示,在`rnn_model.py`中。 137 | 138 | ```python 139 | class TRNNConfig(object): 140 | """RNN配置参数""" 141 | 142 | # 模型参数 143 | embedding_dim = 768 # 词向量维度 144 | seq_length = 512 # 序列长度 145 | num_classes = 10 # 类别数 146 | vocab_size = 5000 # 词汇表达小 147 | 148 | num_layers = 1 # 隐藏层层数 149 | hidden_dim = 512 # 隐藏层神经元 150 | rnn = 'gru' # lstm 或 gru 151 | 152 | attention_dim = 512 153 | l2_reg_lambda = 0.01 154 | 155 | dropout_keep_prob = 0.5 # dropout保留比例 156 | learning_rate = 1e-3 # 学习率 157 | 158 | batch_size = 128 # 每批训练大小 159 | num_epochs = 10 # 总迭代轮次 160 | 161 | print_per_batch = 100 # 每多少轮输出一次结果 162 | save_per_batch = 10 # 每多少轮存入tensorboard 163 | ``` 164 | 165 | ### RNN模型 166 | 167 | 具体参看`rnn_model_bert.py`的实现。 168 | 169 | 大致结构如下: 170 | 171 | ![images/rnn_architecture](images/rnn_architecture.png) 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Han Xiao 4 | 5 | import argparse 6 | import sys 7 | 8 | from bert.extract_features import PoolingStrategy 9 | from service.server import BertServer 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-model_dir', type=str, required=True, 15 | help='directory of a pretrained BERT model') 16 | parser.add_argument('-max_seq_len', type=int, default=128, 17 | help='maximum length of a sequence') 18 | parser.add_argument('-num_worker', type=int, default=1, 19 | help='number of server instances') 20 | parser.add_argument('-max_batch_size', type=int, default=256, 21 | help='maximum number of sequences handled by each worker') 22 | parser.add_argument('-port', '-port_in', '-port_data', type=int, default=5555, 23 | help='server port for receiving data from client') 24 | parser.add_argument('-port_out', '-port_result', type=int, default=5556, 25 | help='server port for outputting result to client') 26 | parser.add_argument('-pooling_layer', type=int, nargs='+', default=[-2], 27 | help='the encoder layer(s) that receives pooling. ' 28 | 'Give a list in order to concatenate several layers into 1.') 29 | parser.add_argument('-pooling_strategy', type=PoolingStrategy.from_string, 30 | default=PoolingStrategy.REDUCE_MEAN, choices=list(PoolingStrategy), 31 | help='the pooling strategy for generating encoding vectors') 32 | parser.add_argument('-gpu_memory_fraction', type=float, default=0.5, 33 | help='determines the fraction of the overall amount of memory ' 34 | 'that each visible GPU should be allocated per worker. ' 35 | 'Should be in range [0.0, 1.0]') 36 | args = parser.parse_args() 37 | param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(vars(args).items())]) 38 | print('usage: %s\n%20s %s\n%s\n%s\n' % (' '.join(sys.argv), 'ARG', 'VALUE', '_' * 50, param_str)) 39 | return args 40 | 41 | 42 | if __name__ == '__main__': 43 | args = get_args() 44 | server = BertServer(args) 45 | server.start() 46 | server.join() 47 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | import sys 4 | import threading 5 | import time 6 | from collections import namedtuple 7 | 8 | from numpy import mean 9 | 10 | from bert.extract_features import PoolingStrategy 11 | from service.client import BertClient 12 | from service.server import BertServer 13 | 14 | PORT = 6666 15 | PORT_OUT = 6667 16 | 17 | 18 | def tprint(msg): 19 | """like print, but won't get newlines confused with multiple threads""" 20 | sys.stdout.write(msg + '\n') 21 | sys.stdout.flush() 22 | 23 | 24 | class BenchmarkClient(threading.Thread): 25 | def __init__(self, args): 26 | super().__init__() 27 | self.batch = [''.join(random.choices(string.ascii_uppercase + string.digits, 28 | k=args.max_seq_len)) for _ in range(args.client_batch_size)] 29 | 30 | self.num_repeat = args.num_repeat 31 | self.avg_time = 0 32 | 33 | def run(self): 34 | time_all = [] 35 | bc = BertClient(port=PORT, port_out=PORT_OUT, show_server_config=False) 36 | for _ in range(self.num_repeat): 37 | start_t = time.perf_counter() 38 | bc.encode(self.batch) 39 | time_all.append(time.perf_counter() - start_t) 40 | print(time_all) 41 | self.avg_time = mean(time_all) 42 | 43 | 44 | if __name__ == '__main__': 45 | common = { 46 | 'model_dir': '/data/cips/data/lab/data/model/chinese_L-12_H-768_A-12', 47 | 'num_worker': 1, 48 | 'num_repeat': 5, 49 | 'port': PORT, 50 | 'port_out': PORT_OUT, 51 | 'max_seq_len': 40, 52 | 'client_batch_size': 2048, 53 | 'max_batch_size': 256, 54 | 'num_client': 1, 55 | 'pooling_strategy': PoolingStrategy.REDUCE_MEAN, 56 | 'pooling_layer': [-2], 57 | 'gpu_memory_fraction': 0.5 58 | } 59 | experiments = { 60 | 'client_batch_size': [1, 4, 8, 16, 64, 256, 512, 1024, 2048, 4096], 61 | 'max_batch_size': [32, 64, 128, 256, 512], 62 | 'max_seq_len': [20, 40, 80, 160, 320], 63 | 'num_client': [2, 4, 8, 16, 32], 64 | 'pooling_layer': [[-j] for j in range(1, 13)] 65 | } 66 | 67 | fp = open('benchmark-%d.result' % common['num_worker'], 'w') 68 | for var_name, var_lst in experiments.items(): 69 | # set common args 70 | args = namedtuple('args_namedtuple', ','.join(common.keys())) 71 | for k, v in common.items(): 72 | setattr(args, k, v) 73 | 74 | avg_speed = [] 75 | for var in var_lst: 76 | # override exp args 77 | setattr(args, var_name, var) 78 | server = BertServer(args) 79 | server.start() 80 | 81 | # sleep until server is ready 82 | time.sleep(15) 83 | all_clients = [BenchmarkClient(args) for _ in range(args.num_client)] 84 | 85 | tprint('num_client: %d' % len(all_clients)) 86 | for bc in all_clients: 87 | bc.start() 88 | 89 | all_thread_speed = [] 90 | for bc in all_clients: 91 | bc.join() 92 | cur_speed = args.client_batch_size / bc.avg_time 93 | all_thread_speed.append(cur_speed) 94 | 95 | max_speed = int(max(all_thread_speed)) 96 | min_speed = int(min(all_thread_speed)) 97 | t_avg_speed = int(mean(all_thread_speed)) 98 | 99 | tprint('%s: %s\t%.3f\t%d/s' % (var_name, var, bc.avg_time, t_avg_speed)) 100 | tprint('max speed: %d\t min speed: %d' % (max_speed, min_speed)) 101 | avg_speed.append(t_avg_speed) 102 | server.close() 103 | 104 | fp.write('#### Speed wrt. `%s`\n\n' % var_name) 105 | fp.write('|`%s`|seqs/s|\n' % var_name) 106 | fp.write('|---|---|\n') 107 | for i, j in zip(var_lst, avg_speed): 108 | fp.write('|%s|%d|\n' % (i, j)) 109 | fp.flush() 110 | fp.close() 111 | -------------------------------------------------------------------------------- /bert-as-service-master.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/bert-as-service-master.zip -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /bert/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import re 16 | from enum import Enum 17 | 18 | import tensorflow as tf 19 | from tensorflow.python.estimator.model_fn import EstimatorSpec 20 | 21 | from bert import tokenization, modeling 22 | 23 | 24 | class PoolingStrategy(Enum): 25 | NONE = 0 26 | REDUCE_MAX = 1 27 | REDUCE_MEAN = 2 28 | REDUCE_MEAN_MAX = 3 29 | FIRST_TOKEN = 4 # corresponds to [CLS] for single sequences 30 | LAST_TOKEN = 5 # corresponds to [SEP] for single sequences 31 | CLS_TOKEN = 4 # corresponds to the first token for single seq. 32 | SEP_TOKEN = 5 # corresponds to the last token for single seq. 33 | 34 | def __str__(self): 35 | return self.name 36 | 37 | @staticmethod 38 | def from_string(s): 39 | try: 40 | return PoolingStrategy[s] 41 | except KeyError: 42 | raise ValueError() 43 | 44 | 45 | class InputExample(object): 46 | 47 | def __init__(self, unique_id, text_a, text_b): 48 | self.unique_id = unique_id 49 | self.text_a = text_a 50 | self.text_b = text_b 51 | 52 | 53 | class InputFeatures(object): 54 | """A single set of features of data.""" 55 | 56 | def __init__(self, input_ids, input_mask, input_type_ids): 57 | # self.unique_id = unique_id 58 | # self.tokens = tokens 59 | self.input_ids = input_ids 60 | self.input_mask = input_mask 61 | self.input_type_ids = input_type_ids 62 | 63 | 64 | def model_fn_builder(bert_config, init_checkpoint, use_one_hot_embeddings=False, 65 | pooling_strategy=PoolingStrategy.REDUCE_MEAN, 66 | pooling_layer=-2): 67 | """Returns `model_fn` closure for TPUEstimator.""" 68 | 69 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 70 | """The `model_fn` for TPUEstimator.""" 71 | 72 | client_id = features["client_id"] 73 | input_ids = features["input_ids"] 74 | input_mask = features["input_mask"] 75 | input_type_ids = features["input_type_ids"] 76 | 77 | model = modeling.BertModel( 78 | config=bert_config, 79 | is_training=False, 80 | input_ids=input_ids, 81 | input_mask=input_mask, 82 | token_type_ids=input_type_ids, 83 | use_one_hot_embeddings=use_one_hot_embeddings) 84 | 85 | if mode != tf.estimator.ModeKeys.PREDICT: 86 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 87 | 88 | tvars = tf.trainable_variables() 89 | (assignment_map, initialized_variable_names 90 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 91 | 92 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 93 | 94 | all_layers = [] 95 | if len(pooling_layer) == 1: 96 | encoder_layer = model.all_encoder_layers[pooling_layer[-1]] 97 | else: 98 | for layer in pooling_layer: 99 | all_layers.append(model.all_encoder_layers[layer]) 100 | encoder_layer = tf.concat(all_layers, -1) 101 | 102 | if pooling_strategy == PoolingStrategy.REDUCE_MEAN: 103 | pooled = tf.reduce_mean(encoder_layer, axis=1) 104 | elif pooling_strategy == PoolingStrategy.REDUCE_MAX: 105 | pooled = tf.reduce_max(encoder_layer, axis=1) 106 | elif pooling_strategy == PoolingStrategy.REDUCE_MEAN_MAX: 107 | pooled = tf.concat([tf.reduce_mean(encoder_layer, axis=1), 108 | tf.reduce_max(encoder_layer, axis=1)], axis=1) 109 | elif pooling_strategy == PoolingStrategy.FIRST_TOKEN or pooling_strategy == PoolingStrategy.CLS_TOKEN: 110 | pooled = tf.squeeze(encoder_layer[:, 0:1, :], axis=1) 111 | elif pooling_strategy == PoolingStrategy.LAST_TOKEN or pooling_strategy == PoolingStrategy.SEP_TOKEN: 112 | seq_len = tf.cast(tf.reduce_sum(input_mask, axis=1), tf.int32) 113 | rng = tf.range(0, tf.shape(seq_len)[0]) 114 | indexes = tf.stack([rng, seq_len - 1], 1) 115 | pooled = tf.gather_nd(encoder_layer, indexes) 116 | elif pooling_strategy == PoolingStrategy.NONE: 117 | pooled = encoder_layer 118 | else: 119 | raise NotImplementedError() 120 | 121 | predictions = { 122 | 'client_id': client_id, 123 | 'encodes': pooled 124 | } 125 | 126 | return EstimatorSpec(mode=mode, predictions=predictions) 127 | 128 | return model_fn 129 | 130 | 131 | def convert_lst_to_features(lst_str, seq_length, tokenizer): 132 | """Loads a data file into a list of `InputBatch`s.""" 133 | 134 | for (ex_index, example) in enumerate(read_examples(lst_str)): 135 | tokens_a = tokenizer.tokenize(example.text_a) 136 | 137 | tokens_b = None 138 | if example.text_b: 139 | tokens_b = tokenizer.tokenize(example.text_b) 140 | 141 | if tokens_b: 142 | # Modifies `tokens_a` and `tokens_b` in place so that the total 143 | # length is less than the specified length. 144 | # Account for [CLS], [SEP], [SEP] with "- 3" 145 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 146 | else: 147 | # Account for [CLS] and [SEP] with "- 2" 148 | if len(tokens_a) > seq_length - 2: 149 | tokens_a = tokens_a[0:(seq_length - 2)] 150 | 151 | # The convention in BERT is: 152 | # (a) For sequence pairs: 153 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 154 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 155 | # (b) For single sequences: 156 | # tokens: [CLS] the dog is hairy . [SEP] 157 | # type_ids: 0 0 0 0 0 0 0 158 | # 159 | # Where "type_ids" are used to indicate whether this is the first 160 | # sequence or the second sequence. The embedding vectors for `type=0` and 161 | # `type=1` were learned during pre-training and are added to the wordpiece 162 | # embedding vector (and position vector). This is not *strictly* necessary 163 | # since the [SEP] token unambiguously separates the sequences, but it makes 164 | # it easier for the model to learn the concept of sequences. 165 | # 166 | # For classification tasks, the first vector (corresponding to [CLS]) is 167 | # used as as the "sentence vector". Note that this only makes sense because 168 | # the entire model is fine-tuned. 169 | tokens = [] 170 | input_type_ids = [] 171 | tokens.append("[CLS]") 172 | input_type_ids.append(0) 173 | for token in tokens_a: 174 | tokens.append(token) 175 | input_type_ids.append(0) 176 | tokens.append("[SEP]") 177 | input_type_ids.append(0) 178 | 179 | if tokens_b: 180 | for token in tokens_b: 181 | tokens.append(token) 182 | input_type_ids.append(1) 183 | tokens.append("[SEP]") 184 | input_type_ids.append(1) 185 | 186 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 187 | 188 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 189 | # tokens are attended to. 190 | input_mask = [1] * len(input_ids) 191 | 192 | # Zero-pad up to the sequence length. 193 | while len(input_ids) < seq_length: 194 | input_ids.append(0) 195 | input_mask.append(0) 196 | input_type_ids.append(0) 197 | 198 | assert len(input_ids) == seq_length 199 | assert len(input_mask) == seq_length 200 | assert len(input_type_ids) == seq_length 201 | 202 | # if ex_index < 5: 203 | # tf.logging.info("*** Example ***") 204 | # tf.logging.info("unique_id: %s" % (example.unique_id)) 205 | # tf.logging.info("tokens: %s" % " ".join( 206 | # [tokenization.printable_text(x) for x in tokens])) 207 | # tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 208 | # tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 209 | # tf.logging.info( 210 | # "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 211 | 212 | yield InputFeatures( 213 | # unique_id=example.unique_id, 214 | # tokens=tokens, 215 | input_ids=input_ids, 216 | input_mask=input_mask, 217 | input_type_ids=input_type_ids) 218 | 219 | 220 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 221 | """Truncates a sequence pair in place to the maximum length.""" 222 | 223 | # This is a simple heuristic which will always truncate the longer sequence 224 | # one token at a time. This makes more sense than truncating an equal percent 225 | # of tokens from each, since if one sequence is very short then each token 226 | # that's truncated likely contains more information than a longer sequence. 227 | while True: 228 | total_length = len(tokens_a) + len(tokens_b) 229 | if total_length <= max_length: 230 | break 231 | if len(tokens_a) > len(tokens_b): 232 | tokens_a.pop() 233 | else: 234 | tokens_b.pop() 235 | 236 | 237 | def read_examples(lst_strs): 238 | """Read a list of `InputExample`s from a list of strings.""" 239 | unique_id = 0 240 | for ss in lst_strs: 241 | line = tokenization.convert_to_unicode(ss) 242 | if not line: 243 | continue 244 | line = line.strip() 245 | text_a = None 246 | text_b = None 247 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 248 | if m is None: 249 | text_a = line 250 | else: 251 | text_a = m.group(1) 252 | text_b = m.group(2) 253 | yield InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b) 254 | unique_id += 1 255 | -------------------------------------------------------------------------------- /bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import six 25 | import tensorflow as tf 26 | 27 | from bert import modeling 28 | 29 | 30 | class BertModelTest(tf.test.TestCase): 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/dilation_rate$", 168 | "^.*/Tensordot/concat$", 169 | "^.*/Tensordot/concat/axis$", 170 | "^testing/.*$", 171 | ] 172 | 173 | ignore_regexes = [re.compile(x) for x in ignore_strings] 174 | 175 | unreachable = self.get_unreachable_ops(graph, outputs) 176 | filtered_unreachable = [] 177 | for x in unreachable: 178 | do_ignore = False 179 | for r in ignore_regexes: 180 | m = r.match(x.name) 181 | if m is not None: 182 | do_ignore = True 183 | if do_ignore: 184 | continue 185 | filtered_unreachable.append(x) 186 | unreachable = filtered_unreachable 187 | 188 | self.assertEqual( 189 | len(unreachable), 0, "The following ops are unreachable: %s" % 190 | (" ".join([x.name for x in unreachable]))) 191 | 192 | @classmethod 193 | def get_unreachable_ops(cls, graph, outputs): 194 | """Finds all of the tensors in graph that are unreachable from outputs.""" 195 | outputs = cls.flatten_recursive(outputs) 196 | output_to_op = collections.defaultdict(list) 197 | op_to_all = collections.defaultdict(list) 198 | assign_out_to_in = collections.defaultdict(list) 199 | 200 | for op in graph.get_operations(): 201 | for x in op.inputs: 202 | op_to_all[op.name].append(x.name) 203 | for y in op.outputs: 204 | output_to_op[y.name].append(op.name) 205 | op_to_all[op.name].append(y.name) 206 | if str(op.type) == "Assign": 207 | for y in op.outputs: 208 | for x in op.inputs: 209 | assign_out_to_in[y.name].append(x.name) 210 | 211 | assign_groups = collections.defaultdict(list) 212 | for out_name in assign_out_to_in.keys(): 213 | name_group = assign_out_to_in[out_name] 214 | for n1 in name_group: 215 | assign_groups[n1].append(out_name) 216 | for n2 in name_group: 217 | if n1 != n2: 218 | assign_groups[n1].append(n2) 219 | 220 | seen_tensors = {} 221 | stack = [x.name for x in outputs] 222 | while stack: 223 | name = stack.pop() 224 | if name in seen_tensors: 225 | continue 226 | seen_tensors[name] = True 227 | 228 | if name in output_to_op: 229 | for op_name in output_to_op[name]: 230 | if op_name in op_to_all: 231 | for input_name in op_to_all[op_name]: 232 | if input_name not in stack: 233 | stack.append(input_name) 234 | 235 | expanded_names = [] 236 | if name in assign_groups: 237 | for assign_name in assign_groups[name]: 238 | expanded_names.append(assign_name) 239 | 240 | for expanded_name in expanded_names: 241 | if expanded_name not in stack: 242 | stack.append(expanded_name) 243 | 244 | unreachable_ops = [] 245 | for op in graph.get_operations(): 246 | is_unreachable = False 247 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 248 | for name in all_names: 249 | if name not in seen_tensors: 250 | is_unreachable = True 251 | if is_unreachable: 252 | unreachable_ops.append(op) 253 | return unreachable_ops 254 | 255 | @classmethod 256 | def flatten_recursive(cls, item): 257 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 258 | output = [] 259 | if isinstance(item, list): 260 | output.extend(item) 261 | elif isinstance(item, tuple): 262 | output.extend(list(item)) 263 | elif isinstance(item, dict): 264 | for (_, v) in six.iteritems(item): 265 | output.append(v) 266 | else: 267 | return [item] 268 | 269 | flat_output = [] 270 | for x in output: 271 | flat_output.extend(cls.flatten_recursive(x)) 272 | return flat_output 273 | 274 | 275 | if __name__ == "__main__": 276 | tf.test.main() 277 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import re 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 25 | """Creates an optimizer training op.""" 26 | global_step = tf.train.get_or_create_global_step() 27 | 28 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 29 | 30 | # Implements linear decay of the learning rate. 31 | learning_rate = tf.train.polynomial_decay( 32 | learning_rate, 33 | global_step, 34 | num_train_steps, 35 | end_learning_rate=0.0, 36 | power=1.0, 37 | cycle=False) 38 | 39 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 40 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 41 | if num_warmup_steps: 42 | global_steps_int = tf.cast(global_step, tf.int32) 43 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 44 | 45 | global_steps_float = tf.cast(global_steps_int, tf.float32) 46 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 47 | 48 | warmup_percent_done = global_steps_float / warmup_steps_float 49 | warmup_learning_rate = init_lr * warmup_percent_done 50 | 51 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 52 | learning_rate = ( 53 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 54 | 55 | # It is recommended that you use this optimizer for fine tuning, since this 56 | # is how the model was trained (note that the Adam m/v variables are NOT 57 | # loaded from init_checkpoint.) 58 | optimizer = AdamWeightDecayOptimizer( 59 | learning_rate=learning_rate, 60 | weight_decay_rate=0.01, 61 | beta_1=0.9, 62 | beta_2=0.999, 63 | epsilon=1e-6, 64 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 65 | 66 | if use_tpu: 67 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 68 | 69 | tvars = tf.trainable_variables() 70 | grads = tf.gradients(loss, tvars) 71 | 72 | # This is how the model was pre-trained. 73 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 74 | 75 | train_op = optimizer.apply_gradients( 76 | zip(grads, tvars), global_step=global_step) 77 | 78 | new_global_step = global_step + 1 79 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 80 | return train_op 81 | 82 | 83 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 84 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 85 | 86 | def __init__(self, 87 | learning_rate, 88 | weight_decay_rate=0.0, 89 | beta_1=0.9, 90 | beta_2=0.999, 91 | epsilon=1e-6, 92 | exclude_from_weight_decay=None, 93 | name="AdamWeightDecayOptimizer"): 94 | """Constructs a AdamWeightDecayOptimizer.""" 95 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 96 | 97 | self.learning_rate = learning_rate 98 | self.weight_decay_rate = weight_decay_rate 99 | self.beta_1 = beta_1 100 | self.beta_2 = beta_2 101 | self.epsilon = epsilon 102 | self.exclude_from_weight_decay = exclude_from_weight_decay 103 | 104 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 105 | """See base class.""" 106 | assignments = [] 107 | for (grad, param) in grads_and_vars: 108 | if grad is None or param is None: 109 | continue 110 | 111 | param_name = self._get_variable_name(param.name) 112 | 113 | m = tf.get_variable( 114 | name=param_name + "/adam_m", 115 | shape=param.shape.as_list(), 116 | dtype=tf.float32, 117 | trainable=False, 118 | initializer=tf.zeros_initializer()) 119 | v = tf.get_variable( 120 | name=param_name + "/adam_v", 121 | shape=param.shape.as_list(), 122 | dtype=tf.float32, 123 | trainable=False, 124 | initializer=tf.zeros_initializer()) 125 | 126 | # Standard Adam update. 127 | next_m = ( 128 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 129 | next_v = ( 130 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 131 | tf.square(grad))) 132 | 133 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 134 | 135 | # Just adding the square of the weights to the loss function is *not* 136 | # the correct way of using L2 regularization/weight decay with Adam, 137 | # since that will interact with the m and v parameters in strange ways. 138 | # 139 | # Instead we want ot decay the weights in a manner that doesn't interact 140 | # with the m/v parameters. This is equivalent to adding the square 141 | # of the weights to the loss with plain (non-momentum) SGD. 142 | if self._do_use_weight_decay(param_name): 143 | update += self.weight_decay_rate * param 144 | 145 | update_with_lr = self.learning_rate * update 146 | 147 | next_param = param - update_with_lr 148 | 149 | assignments.extend( 150 | [param.assign(next_param), 151 | m.assign(next_m), 152 | v.assign(next_v)]) 153 | return tf.group(*assignments, name=name) 154 | 155 | def _do_use_weight_decay(self, param_name): 156 | """Whether to use L2 weight decay for `param_name`.""" 157 | if not self.weight_decay_rate: 158 | return False 159 | if self.exclude_from_weight_decay: 160 | for r in self.exclude_from_weight_decay: 161 | if re.search(r, param_name) is not None: 162 | return False 163 | return True 164 | 165 | def _get_variable_name(self, param_name): 166 | """Get the variable name from the tensor name.""" 167 | m = re.match("^(.*):\\d+$", param_name) 168 | if m is not None: 169 | param_name = m.group(1) 170 | return param_name 171 | -------------------------------------------------------------------------------- /bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from bert import optimization 22 | 23 | 24 | class OptimizationTest(tf.test.TestCase): 25 | 26 | def test_adam(self): 27 | with self.test_session() as sess: 28 | w = tf.get_variable( 29 | "w", 30 | shape=[3], 31 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 32 | x = tf.constant([0.4, 0.2, -0.5]) 33 | loss = tf.reduce_mean(tf.square(x - w)) 34 | tvars = tf.trainable_variables() 35 | grads = tf.gradients(loss, tvars) 36 | global_step = tf.train.get_or_create_global_step() 37 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 38 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 39 | init_op = tf.group(tf.global_variables_initializer(), 40 | tf.local_variables_initializer()) 41 | sess.run(init_op) 42 | for _ in range(100): 43 | sess.run(train_op) 44 | w_np = sess.run(w) 45 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def convert_to_unicode(text): 29 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 30 | if six.PY3: 31 | if isinstance(text, str): 32 | return text 33 | elif isinstance(text, bytes): 34 | return text.decode("utf-8", "ignore") 35 | else: 36 | raise ValueError("Unsupported string type: %s" % (type(text))) 37 | elif six.PY2: 38 | if isinstance(text, str): 39 | return text.decode("utf-8", "ignore") 40 | elif isinstance(text, unicode): 41 | return text 42 | else: 43 | raise ValueError("Unsupported string type: %s" % (type(text))) 44 | else: 45 | raise ValueError("Not running on Python2 or Python 3?") 46 | 47 | 48 | def printable_text(text): 49 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 50 | 51 | # These functions want `str` for both Python2 and Python3, but in one case 52 | # it's a Unicode string and in the other it's a byte string. 53 | if six.PY3: 54 | if isinstance(text, str): 55 | return text 56 | elif isinstance(text, bytes): 57 | return text.decode("utf-8", "ignore") 58 | else: 59 | raise ValueError("Unsupported string type: %s" % (type(text))) 60 | elif six.PY2: 61 | if isinstance(text, str): 62 | return text 63 | elif isinstance(text, unicode): 64 | return text.encode("utf-8") 65 | else: 66 | raise ValueError("Unsupported string type: %s" % (type(text))) 67 | else: 68 | raise ValueError("Not running on Python2 or Python 3?") 69 | 70 | 71 | def load_vocab(vocab_file): 72 | """Loads a vocabulary file into a dictionary.""" 73 | vocab = collections.OrderedDict() 74 | index = 0 75 | with tf.gfile.GFile(vocab_file, "r") as reader: 76 | while True: 77 | token = convert_to_unicode(reader.readline()) 78 | if not token: 79 | break 80 | token = token.strip() 81 | vocab[token] = index 82 | index += 1 83 | return vocab 84 | 85 | 86 | def convert_by_vocab(vocab, items): 87 | """Converts a sequence of [tokens|ids] using the vocab.""" 88 | output = [] 89 | for item in items: 90 | output.append(vocab[item]) 91 | return output 92 | 93 | 94 | def convert_tokens_to_ids(vocab, tokens): 95 | return convert_by_vocab(vocab, tokens) 96 | 97 | 98 | def convert_ids_to_tokens(inv_vocab, ids): 99 | return convert_by_vocab(inv_vocab, ids) 100 | 101 | 102 | def whitespace_tokenize(text): 103 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 104 | text = text.strip() 105 | if not text: 106 | return [] 107 | tokens = text.split() 108 | return tokens 109 | 110 | 111 | class FullTokenizer(object): 112 | """Runs end-to-end tokenziation.""" 113 | 114 | def __init__(self, vocab_file, do_lower_case=True): 115 | self.vocab = load_vocab(vocab_file) 116 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 117 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 118 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 119 | 120 | def tokenize(self, text): 121 | split_tokens = [] 122 | for token in self.basic_tokenizer.tokenize(text): 123 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 124 | split_tokens.append(sub_token) 125 | 126 | return split_tokens 127 | 128 | def convert_tokens_to_ids(self, tokens): 129 | return convert_by_vocab(self.vocab, tokens) 130 | 131 | def convert_ids_to_tokens(self, ids): 132 | return convert_by_vocab(self.inv_vocab, ids) 133 | 134 | 135 | class BasicTokenizer(object): 136 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 137 | 138 | def __init__(self, do_lower_case=True): 139 | """Constructs a BasicTokenizer. 140 | 141 | Args: 142 | do_lower_case: Whether to lower case the input. 143 | """ 144 | self.do_lower_case = do_lower_case 145 | 146 | def tokenize(self, text): 147 | """Tokenizes a piece of text.""" 148 | text = convert_to_unicode(text) 149 | text = self._clean_text(text) 150 | 151 | # This was added on November 1st, 2018 for the multilingual and Chinese 152 | # models. This is also applied to the English models now, but it doesn't 153 | # matter since the English models were not trained on any Chinese data 154 | # and generally don't have any Chinese data in them (there are Chinese 155 | # characters in the vocabulary because Wikipedia does have some Chinese 156 | # words in the English Wikipedia.). 157 | text = self._tokenize_chinese_chars(text) 158 | 159 | orig_tokens = whitespace_tokenize(text) 160 | split_tokens = [] 161 | for token in orig_tokens: 162 | if self.do_lower_case: 163 | token = token.lower() 164 | token = self._run_strip_accents(token) 165 | split_tokens.extend(self._run_split_on_punc(token)) 166 | 167 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 168 | return output_tokens 169 | 170 | def _run_strip_accents(self, text): 171 | """Strips accents from a piece of text.""" 172 | text = unicodedata.normalize("NFD", text) 173 | output = [] 174 | for char in text: 175 | cat = unicodedata.category(char) 176 | if cat == "Mn": 177 | continue 178 | output.append(char) 179 | return "".join(output) 180 | 181 | def _run_split_on_punc(self, text): 182 | """Splits punctuation on a piece of text.""" 183 | chars = list(text) 184 | i = 0 185 | start_new_word = True 186 | output = [] 187 | while i < len(chars): 188 | char = chars[i] 189 | if _is_punctuation(char): 190 | output.append([char]) 191 | start_new_word = True 192 | else: 193 | if start_new_word: 194 | output.append([]) 195 | start_new_word = False 196 | output[-1].append(char) 197 | i += 1 198 | 199 | return ["".join(x) for x in output] 200 | 201 | def _tokenize_chinese_chars(self, text): 202 | """Adds whitespace around any CJK character.""" 203 | output = [] 204 | for char in text: 205 | cp = ord(char) 206 | if self._is_chinese_char(cp): 207 | output.append(" ") 208 | output.append(char) 209 | output.append(" ") 210 | else: 211 | output.append(char) 212 | return "".join(output) 213 | 214 | def _is_chinese_char(self, cp): 215 | """Checks whether CP is the codepoint of a CJK character.""" 216 | # This defines a "chinese character" as anything in the CJK Unicode block: 217 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 218 | # 219 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 220 | # despite its name. The modern Korean Hangul alphabet is a different block, 221 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 222 | # space-separated words, so they are not treated specially and handled 223 | # like the all of the other languages. 224 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 225 | (cp >= 0x3400 and cp <= 0x4DBF) or # 226 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 227 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 228 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 229 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 230 | (cp >= 0xF900 and cp <= 0xFAFF) or # 231 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 232 | return True 233 | 234 | return False 235 | 236 | def _clean_text(self, text): 237 | """Performs invalid character removal and whitespace cleanup on text.""" 238 | output = [] 239 | for char in text: 240 | cp = ord(char) 241 | if cp == 0 or cp == 0xfffd or _is_control(char): 242 | continue 243 | if _is_whitespace(char): 244 | output.append(" ") 245 | else: 246 | output.append(char) 247 | return "".join(output) 248 | 249 | 250 | class WordpieceTokenizer(object): 251 | """Runs WordPiece tokenziation.""" 252 | 253 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 254 | self.vocab = vocab 255 | self.unk_token = unk_token 256 | self.max_input_chars_per_word = max_input_chars_per_word 257 | 258 | def tokenize(self, text): 259 | """Tokenizes a piece of text into its word pieces. 260 | 261 | This uses a greedy longest-match-first algorithm to perform tokenization 262 | using the given vocabulary. 263 | 264 | For example: 265 | input = "unaffable" 266 | output = ["un", "##aff", "##able"] 267 | 268 | Args: 269 | text: A single token or whitespace separated tokens. This should have 270 | already been passed through `BasicTokenizer. 271 | 272 | Returns: 273 | A list of wordpiece tokens. 274 | """ 275 | 276 | text = convert_to_unicode(text) 277 | 278 | output_tokens = [] 279 | for token in whitespace_tokenize(text): 280 | chars = list(token) 281 | if len(chars) > self.max_input_chars_per_word: 282 | output_tokens.append(self.unk_token) 283 | continue 284 | 285 | is_bad = False 286 | start = 0 287 | sub_tokens = [] 288 | while start < len(chars): 289 | end = len(chars) 290 | cur_substr = None 291 | while start < end: 292 | substr = "".join(chars[start:end]) 293 | if start > 0: 294 | substr = "##" + substr 295 | if substr in self.vocab: 296 | cur_substr = substr 297 | break 298 | end -= 1 299 | if cur_substr is None: 300 | is_bad = True 301 | break 302 | sub_tokens.append(cur_substr) 303 | start = end 304 | 305 | if is_bad: 306 | output_tokens.append(self.unk_token) 307 | else: 308 | output_tokens.extend(sub_tokens) 309 | return output_tokens 310 | 311 | 312 | def _is_whitespace(char): 313 | """Checks whether `chars` is a whitespace character.""" 314 | # \t, \n, and \r are technically contorl characters but we treat them 315 | # as whitespace since they are generally considered as such. 316 | if char == " " or char == "\t" or char == "\n" or char == "\r": 317 | return True 318 | cat = unicodedata.category(char) 319 | if cat == "Zs": 320 | return True 321 | return False 322 | 323 | 324 | def _is_control(char): 325 | """Checks whether `chars` is a control character.""" 326 | # These are technically control characters but we count them as whitespace 327 | # characters. 328 | if char == "\t" or char == "\n" or char == "\r": 329 | return False 330 | cat = unicodedata.category(char) 331 | if cat.startswith("C"): 332 | return True 333 | return False 334 | 335 | 336 | def _is_punctuation(char): 337 | """Checks whether `chars` is a punctuation character.""" 338 | cp = ord(char) 339 | # We treat all non-letter/number ASCII as punctuation. 340 | # Characters such as "^", "$", and "`" are not in the Unicode 341 | # Punctuation class but we treat them as punctuation anyways, for 342 | # consistency. 343 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 344 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 345 | return True 346 | cat = unicodedata.category(char) 347 | if cat.startswith("P"): 348 | return True 349 | return False 350 | -------------------------------------------------------------------------------- /bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | 22 | import tensorflow as tf 23 | 24 | from bert import tokenization 25 | 26 | 27 | class TokenizationTest(tf.test.TestCase): 28 | 29 | def test_full_tokenizer(self): 30 | vocab_tokens = [ 31 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 32 | "##ing", "," 33 | ] 34 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | 37 | vocab_file = vocab_writer.name 38 | 39 | tokenizer = tokenization.FullTokenizer(vocab_file) 40 | os.unlink(vocab_file) 41 | 42 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 43 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 44 | 45 | self.assertAllEqual( 46 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 47 | 48 | def test_chinese(self): 49 | tokenizer = tokenization.BasicTokenizer() 50 | 51 | self.assertAllEqual( 52 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 53 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 54 | 55 | def test_basic_tokenizer_lower(self): 56 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 57 | 58 | self.assertAllEqual( 59 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 60 | ["hello", "!", "how", "are", "you", "?"]) 61 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 62 | 63 | def test_basic_tokenizer_no_lower(self): 64 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 65 | 66 | self.assertAllEqual( 67 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 68 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 69 | 70 | def test_wordpiece_tokenizer(self): 71 | vocab_tokens = [ 72 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 73 | "##ing" 74 | ] 75 | 76 | vocab = {} 77 | for (i, token) in enumerate(vocab_tokens): 78 | vocab[token] = i 79 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 80 | 81 | self.assertAllEqual(tokenizer.tokenize(""), []) 82 | 83 | self.assertAllEqual( 84 | tokenizer.tokenize("unwanted running"), 85 | ["un", "##want", "##ed", "runn", "##ing"]) 86 | 87 | self.assertAllEqual( 88 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 89 | 90 | def test_convert_tokens_to_ids(self): 91 | vocab_tokens = [ 92 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 93 | "##ing" 94 | ] 95 | 96 | vocab = {} 97 | for (i, token) in enumerate(vocab_tokens): 98 | vocab[token] = i 99 | 100 | self.assertAllEqual( 101 | tokenization.convert_tokens_to_ids( 102 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 103 | 104 | def test_is_whitespace(self): 105 | self.assertTrue(tokenization._is_whitespace(u" ")) 106 | self.assertTrue(tokenization._is_whitespace(u"\t")) 107 | self.assertTrue(tokenization._is_whitespace(u"\r")) 108 | self.assertTrue(tokenization._is_whitespace(u"\n")) 109 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 110 | 111 | self.assertFalse(tokenization._is_whitespace(u"A")) 112 | self.assertFalse(tokenization._is_whitespace(u"-")) 113 | 114 | def test_is_control(self): 115 | self.assertTrue(tokenization._is_control(u"\u0005")) 116 | 117 | self.assertFalse(tokenization._is_control(u"A")) 118 | self.assertFalse(tokenization._is_control(u" ")) 119 | self.assertFalse(tokenization._is_control(u"\t")) 120 | self.assertFalse(tokenization._is_control(u"\r")) 121 | 122 | def test_is_punctuation(self): 123 | self.assertTrue(tokenization._is_punctuation(u"-")) 124 | self.assertTrue(tokenization._is_punctuation(u"$")) 125 | self.assertTrue(tokenization._is_punctuation(u"`")) 126 | self.assertTrue(tokenization._is_punctuation(u".")) 127 | 128 | self.assertFalse(tokenization._is_punctuation(u"A")) 129 | self.assertFalse(tokenization._is_punctuation(u" ")) 130 | 131 | 132 | if __name__ == "__main__": 133 | tf.test.main() 134 | -------------------------------------------------------------------------------- /bert_serving/__init__.py: -------------------------------------------------------------------------------- 1 | __path__ = __import__('pkgutil').extend_path(__path__, __name__) 2 | -------------------------------------------------------------------------------- /bert_serving/client/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Han Xiao 4 | 5 | import sys 6 | import threading 7 | import time 8 | import uuid 9 | from collections import namedtuple 10 | 11 | import numpy as np 12 | import zmq 13 | from zmq.utils import jsonapi 14 | 15 | __all__ = ['__version__', 'BertClient'] 16 | 17 | # in the future client version must match with server version 18 | __version__ = '1.6.2' 19 | 20 | if sys.version_info >= (3, 0): 21 | _py2 = False 22 | _str = str 23 | _buffer = memoryview 24 | _unicode = lambda x: x 25 | else: 26 | # make it compatible for py2 27 | _py2 = True 28 | _str = basestring 29 | _buffer = buffer 30 | _unicode = lambda x: [BertClient._force_to_unicode(y) for y in x] 31 | 32 | Response = namedtuple('Response', ['id', 'content']) 33 | 34 | 35 | class BertClient: 36 | def __init__(self, ip='localhost', port=5555, port_out=5556, 37 | output_fmt='ndarray', show_server_config=False, 38 | identity=None, check_version=True, check_length=True, 39 | timeout=5000): 40 | """ A client object connected to a BertServer 41 | 42 | Create a BertClient that connects to a BertServer. 43 | Note, server must be ready at the moment you are calling this function. 44 | If you are not sure whether the server is ready, then please set `check_version=False` 45 | 46 | You can also use it as a context manager: 47 | 48 | .. highlight:: python 49 | .. code-block:: python 50 | 51 | with BertClient() as bc: 52 | bc.encode(...) 53 | 54 | # bc is automatically closed out of the context 55 | 56 | :type timeout: int 57 | :type check_version: bool 58 | :type check_length: bool 59 | :type identity: str 60 | :type show_server_config: bool 61 | :type output_fmt: str 62 | :type port_out: int 63 | :type port: int 64 | :type ip: str 65 | :param ip: the ip address of the server 66 | :param port: port for pushing data from client to server, must be consistent with the server side config 67 | :param port_out: port for publishing results from server to client, must be consistent with the server side config 68 | :param output_fmt: the output format of the sentence encodes, either in numpy array or python List[List[float]] (ndarray/list) 69 | :param show_server_config: whether to show server configs when first connected 70 | :param identity: the UUID of this client 71 | :param check_version: check if server has the same version as client, raise AttributeError if not the same 72 | :param check_length: check if server `max_seq_len` is less than the sentence length before sent 73 | :param timeout: set the timeout (milliseconds) for receive operation on the client 74 | """ 75 | 76 | self.context = zmq.Context() 77 | self.sender = self.context.socket(zmq.PUSH) 78 | self.identity = identity or str(uuid.uuid4()).encode('ascii') 79 | self.sender.connect('tcp://%s:%d' % (ip, port)) 80 | 81 | self.receiver = self.context.socket(zmq.SUB) 82 | self.receiver.setsockopt(zmq.SUBSCRIBE, self.identity) 83 | self.receiver.connect('tcp://%s:%d' % (ip, port_out)) 84 | 85 | self.request_id = 0 86 | self.timeout = timeout 87 | self.pending_request = set() 88 | 89 | if output_fmt == 'ndarray': 90 | self.formatter = lambda x: x 91 | elif output_fmt == 'list': 92 | self.formatter = lambda x: x.tolist() 93 | else: 94 | raise AttributeError('"output_fmt" must be "ndarray" or "list"') 95 | 96 | self.output_fmt = output_fmt 97 | self.port = port 98 | self.port_out = port_out 99 | self.ip = ip 100 | self.length_limit = 0 101 | 102 | if check_version or show_server_config or check_length: 103 | s_status = self.server_status 104 | 105 | if check_version and s_status['server_version'] != self.status['client_version']: 106 | raise AttributeError('version mismatch! server version is %s but client version is %s!\n' 107 | 'consider "pip install -U bert-serving-server bert-serving-client"\n' 108 | 'or disable version-check by "BertClient(check_version=False)"' % ( 109 | s_status['server_version'], self.status['client_version'])) 110 | 111 | if show_server_config: 112 | self._print_dict(s_status, 'server config:') 113 | 114 | if check_length: 115 | self.length_limit = int(s_status['max_seq_len']) 116 | 117 | def close(self): 118 | """ 119 | Gently close all connections of the client. If you are using BertClient as context manager, 120 | then this is not necessary. 121 | 122 | """ 123 | self.sender.close() 124 | self.receiver.close() 125 | self.context.term() 126 | 127 | def _send(self, msg, msg_len=0): 128 | self.sender.send_multipart([self.identity, msg, b'%d' % self.request_id, b'%d' % msg_len]) 129 | self.pending_request.add(self.request_id) 130 | self.request_id += 1 131 | 132 | def _recv(self): 133 | response = self.receiver.recv_multipart() 134 | request_id = int(response[-1]) 135 | self.pending_request.remove(request_id) 136 | return Response(request_id, response) 137 | 138 | def _recv_ndarray(self): 139 | request_id, response = self._recv() 140 | arr_info, arr_val = jsonapi.loads(response[1]), response[2] 141 | X = np.frombuffer(_buffer(arr_val), dtype=str(arr_info['dtype'])) 142 | return Response(request_id, self.formatter(X.reshape(arr_info['shape']))) 143 | 144 | @property 145 | def status(self): 146 | """ 147 | Get the status of this BertClient instance 148 | 149 | :rtype: dict[str, str] 150 | :return: a dictionary contains the status of this BertClient instance 151 | 152 | """ 153 | return { 154 | 'identity': self.identity, 155 | 'num_request': self.request_id, 156 | 'num_pending_request': len(self.pending_request), 157 | 'pending_request': self.pending_request, 158 | 'output_fmt': self.output_fmt, 159 | 'port': self.port, 160 | 'port_out': self.port_out, 161 | 'server_ip': self.ip, 162 | 'client_version': __version__, 163 | 'timeout': self.timeout 164 | } 165 | 166 | @property 167 | def server_status(self): 168 | """ 169 | Get the current status of the server connected to this client 170 | 171 | :return: a dictionary contains the current status of the server connected to this client 172 | :rtype: dict[str, str] 173 | 174 | """ 175 | try: 176 | self.receiver.setsockopt(zmq.RCVTIMEO, self.timeout) 177 | self._send(b'SHOW_CONFIG') 178 | return jsonapi.loads(self._recv().content[1]) 179 | except zmq.error.Again as _e: 180 | t_e = TimeoutError( 181 | 'no response from the server (with "timeout"=%d ms), ' 182 | 'is the server on-line? is network broken? are "port" and "port_out" correct?' % self.timeout) 183 | if _py2: 184 | raise t_e 185 | else: 186 | raise t_e from _e 187 | finally: 188 | self.receiver.setsockopt(zmq.RCVTIMEO, -1) 189 | 190 | def encode(self, texts, blocking=True, is_tokenized=False): 191 | """ Encode a list of strings to a list of vectors 192 | 193 | `texts` should be a list of strings, each of which represents a sentence. 194 | If `is_tokenized` is set to True, then `texts` should be list[list[str]], 195 | outer list represents sentence and inner list represent tokens in the sentence. 196 | Note that if `blocking` is set to False, then you need to fetch the result manually afterwards. 197 | 198 | .. highlight:: python 199 | .. code-block:: python 200 | 201 | with BertClient() as bc: 202 | # encode untokenized sentences 203 | bc.encode(['First do it', 204 | 'then do it right', 205 | 'then do it better']) 206 | 207 | # encode tokenized sentences 208 | bc.encode([['First', 'do', 'it'], 209 | ['then', 'do', 'it', 'right'], 210 | ['then', 'do', 'it', 'better']], is_tokenized=True) 211 | 212 | :type is_tokenized: bool 213 | :type blocking: bool 214 | :type texts: list[str] or list[list[str]] 215 | :param is_tokenized: whether the input texts is already tokenized 216 | :param texts: list of sentence to be encoded. Larger list for better efficiency. 217 | :param blocking: wait until the encoded result is returned from the server. If false, will immediately return. 218 | :return: encoded sentence/token-level embeddings, rows correspond to sentences 219 | :rtype: numpy.ndarray or list[list[float]] 220 | 221 | """ 222 | if is_tokenized: 223 | self._check_input_lst_lst_str(texts) 224 | else: 225 | self._check_input_lst_str(texts) 226 | 227 | if self.length_limit and not self._check_length(texts, self.length_limit, is_tokenized): 228 | print('some of your sentences have more tokens than "max_seq_len=%d" set on the server, ' 229 | 'as consequence you may get less-accurate or truncated embeddings.\n' 230 | 'here is what you can do:\n' 231 | '- disable the length-check by create a new "BertClient(check_length=False)" ' 232 | 'when you just want to ignore this warning\n' 233 | '- or, start a new server with a larger "max_seq_len"' % self.length_limit) 234 | 235 | texts = _unicode(texts) 236 | self._send(jsonapi.dumps(texts), len(texts)) 237 | return self._recv_ndarray().content if blocking else None 238 | 239 | def fetch(self, delay=.0): 240 | """ Fetch the encoded vectors from server, use it with `encode(blocking=False)` 241 | 242 | Use it after `encode(texts, blocking=False)`. If there is no pending requests, will return None. 243 | Note that `fetch()` does not preserve the order of the requests! Say you have two non-blocking requests, 244 | R1 and R2, where R1 with 256 samples, R2 with 1 samples. It could be that R2 returns first. 245 | 246 | To fetch all results in the original sending order, please use `fetch_all(sort=True)` 247 | 248 | :type delay: float 249 | :param delay: delay in seconds and then run fetcher 250 | :return: a generator that yields request id and encoded vector in a tuple, where the request id can be used to determine the order 251 | :rtype: Iterator[tuple(int, numpy.ndarray)] 252 | 253 | """ 254 | time.sleep(delay) 255 | while self.pending_request: 256 | yield self._recv_ndarray() 257 | 258 | def fetch_all(self, sort=True, concat=False): 259 | """ Fetch all encoded vectors from server, use it with `encode(blocking=False)` 260 | 261 | Use it `encode(texts, blocking=False)`. If there is no pending requests, it will return None. 262 | 263 | :type sort: bool 264 | :type concat: bool 265 | :param sort: sort results by their request ids. It should be True if you want to preserve the sending order 266 | :param concat: concatenate all results into one ndarray 267 | :return: encoded sentence/token-level embeddings in sending order 268 | :rtype: numpy.ndarray or list[list[float]] 269 | 270 | """ 271 | if self.pending_request: 272 | tmp = list(self.fetch()) 273 | if sort: 274 | tmp = sorted(tmp, key=lambda v: v.id) 275 | tmp = [v.content for v in tmp] 276 | if concat: 277 | if self.output_fmt == 'ndarray': 278 | tmp = np.concatenate(tmp, axis=0) 279 | elif self.output_fmt == 'list': 280 | tmp = [vv for v in tmp for vv in v] 281 | return tmp 282 | 283 | def encode_async(self, batch_generator, max_num_batch=None, delay=0.1, is_tokenized=False): 284 | """ Async encode batches from a generator 285 | 286 | :param is_tokenized: whether batch_generator generates tokenized sentences 287 | :param delay: delay in seconds and then run fetcher 288 | :param batch_generator: a generator that yields list[str] or list[list[str]] (for `is_tokenized=True`) every time 289 | :param max_num_batch: stop after encoding this number of batches 290 | :return: a generator that yields encoded vectors in ndarray, where the request id can be used to determine the order 291 | :rtype: Iterator[tuple(int, numpy.ndarray)] 292 | 293 | """ 294 | 295 | def run(): 296 | cnt = 0 297 | for texts in batch_generator: 298 | self.encode(texts, blocking=False, is_tokenized=is_tokenized) 299 | cnt += 1 300 | if max_num_batch and cnt == max_num_batch: 301 | break 302 | 303 | t = threading.Thread(target=run) 304 | t.start() 305 | return self.fetch(delay) 306 | 307 | @staticmethod 308 | def _check_length(texts, len_limit, tokenized): 309 | if tokenized: 310 | # texts is already tokenized as list of str 311 | return all(len(t) <= len_limit for t in texts) 312 | else: 313 | # do a simple whitespace tokenizer 314 | return all(len(t.split()) <= len_limit for t in texts) 315 | 316 | @staticmethod 317 | def _check_input_lst_str(texts): 318 | if not isinstance(texts, list): 319 | raise TypeError('"%s" must be %s, but received %s' % (texts, type([]), type(texts))) 320 | if not len(texts): 321 | raise ValueError( 322 | '"%s" must be a non-empty list, but received %s with %d elements' % (texts, type(texts), len(texts))) 323 | for idx, s in enumerate(texts): 324 | if not isinstance(s, _str): 325 | raise TypeError('all elements in the list must be %s, but element %d is %s' % (type(''), idx, type(s))) 326 | if not s.strip(): 327 | raise ValueError( 328 | 'all elements in the list must be non-empty string, but element %d is %s' % (idx, repr(s))) 329 | 330 | @staticmethod 331 | def _check_input_lst_lst_str(texts): 332 | if not isinstance(texts, list): 333 | raise TypeError('"texts" must be %s, but received %s' % (type([]), type(texts))) 334 | if not len(texts): 335 | raise ValueError( 336 | '"texts" must be a non-empty list, but received %s with %d elements' % (type(texts), len(texts))) 337 | for s in texts: 338 | BertClient._check_input_lst_str(s) 339 | 340 | @staticmethod 341 | def _force_to_unicode(text): 342 | return text if isinstance(text, unicode) else text.decode('utf-8') 343 | 344 | @staticmethod 345 | def _print_dict(x, title=None): 346 | if title: 347 | print(title) 348 | for k, v in x.items(): 349 | print('%30s\t=\t%-30s' % (k, v)) 350 | 351 | def __enter__(self): 352 | return self 353 | 354 | def __exit__(self, exc_type, exc_val, exc_tb): 355 | self.close() 356 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/__init__.py -------------------------------------------------------------------------------- /data/bert_sen.txt: -------------------------------------------------------------------------------- 1 | subject 2 | 价格 3 | 价格 4 | 价格 5 | 价格 6 | 价格 7 | 价格 8 | 价格 9 | 价格 10 | 价格 11 | 价格 12 | 配置 13 | 配置 14 | 配置 15 | 配置 16 | 配置 17 | 安全性 18 | 油耗 19 | 配置 20 | 油耗 21 | 油耗 22 | 油耗 23 | 油耗 24 | 动力 25 | 油耗 26 | 油耗 27 | 油耗 28 | 舒适性 29 | 动力 30 | 动力 31 | 内饰 32 | 外观 33 | 内饰 34 | 安全性 35 | 内饰 36 | 安全性 37 | 安全性 38 | 安全性 39 | 安全性 40 | 动力 41 | 安全性 42 | 安全性 43 | 操控 44 | 操控 45 | 操控 46 | 动力 47 | 动力 48 | 动力 49 | 动力 50 | 动力 51 | 动力 52 | 动力 53 | 动力 54 | 动力 55 | 动力 56 | 油耗 57 | 动力 58 | 动力 59 | 动力 60 | 动力 61 | 动力 62 | 动力 63 | 动力 64 | 动力 65 | 空间 66 | 舒适性 67 | 外观 68 | 操控 69 | 价格 70 | 外观 71 | 安全性 72 | 油耗 73 | 油耗 74 | 操控 75 | 外观 76 | 动力 77 | 舒适性 78 | 价格 79 | 操控 80 | 外观 81 | 油耗 82 | 动力 83 | 操控 84 | 操控 85 | 油耗 86 | 油耗 87 | 油耗 88 | 安全性 89 | 价格 90 | 动力 91 | 油耗 92 | 价格 93 | 油耗 94 | 动力 95 | 动力 96 | 动力 97 | 安全性 98 | 动力 99 | 配置 100 | 舒适性 101 | 舒适性 102 | 舒适性 103 | 空间 104 | 舒适性 105 | 动力 106 | 价格 107 | 动力 108 | 价格 109 | 安全性 110 | 价格 111 | 外观 112 | 动力 113 | 价格 114 | 动力 115 | 配置 116 | 配置 117 | 油耗 118 | 配置 119 | 价格 120 | 配置 121 | 油耗 122 | 动力 123 | 舒适性 124 | 舒适性 125 | 价格 126 | 内饰 127 | 动力 128 | 舒适性 129 | 油耗 130 | 舒适性 131 | 动力 132 | 外观 133 | 外观 134 | 油耗 135 | 动力 136 | 油耗 137 | 外观 138 | 动力 139 | 价格 140 | 动力 141 | 舒适性 142 | 动力 143 | 动力 144 | 外观 145 | 油耗 146 | 操控 147 | 安全性 148 | 动力 149 | 外观 150 | 空间 151 | 动力 152 | 动力 153 | 动力 154 | 安全性 155 | 价格 156 | 舒适性 157 | 动力 158 | 油耗 159 | 油耗 160 | 价格 161 | 动力 162 | 安全性 163 | 油耗 164 | 动力 165 | 动力 166 | 价格 167 | 价格 168 | 动力 169 | 油耗 170 | 动力 171 | 动力 172 | 舒适性 173 | 动力 174 | 动力 175 | 安全性 176 | 动力 177 | 动力 178 | 操控 179 | 价格 180 | 内饰 181 | 价格 182 | 油耗 183 | 外观 184 | 动力 185 | 内饰 186 | 动力 187 | 操控 188 | 操控 189 | 配置 190 | 操控 191 | 舒适性 192 | 油耗 193 | 舒适性 194 | 油耗 195 | 配置 196 | 空间 197 | 动力 198 | 动力 199 | 价格 200 | 动力 201 | 外观 202 | 动力 203 | 价格 204 | 外观 205 | 舒适性 206 | 动力 207 | 价格 208 | 动力 209 | 动力 210 | 配置 211 | 空间 212 | 油耗 213 | 动力 214 | 油耗 215 | 配置 216 | 内饰 217 | 动力 218 | 油耗 219 | 动力 220 | 外观 221 | 动力 222 | 动力 223 | 价格 224 | 舒适性 225 | 操控 226 | 操控 227 | 油耗 228 | 内饰 229 | 内饰 230 | 动力 231 | 配置 232 | 动力 233 | 价格 234 | 配置 235 | 配置 236 | 外观 237 | 油耗 238 | 配置 239 | 油耗 240 | 操控 241 | 价格 242 | 内饰 243 | 外观 244 | 内饰 245 | 外观 246 | 操控 247 | 动力 248 | 价格 249 | 价格 250 | 价格 251 | 动力 252 | 动力 253 | 价格 254 | 价格 255 | 动力 256 | 动力 257 | 动力 258 | 动力 259 | 动力 260 | 动力 261 | 动力 262 | 动力 263 | 操控 264 | 动力 265 | 动力 266 | 动力 267 | 动力 268 | 动力 269 | 动力 270 | 油耗 271 | 动力 272 | 油耗 273 | 油耗 274 | 油耗 275 | 油耗 276 | 动力 277 | 动力 278 | 动力 279 | 动力 280 | 安全性 281 | 动力 282 | 动力 283 | 动力 284 | 动力 285 | 动力 286 | 安全性 287 | 操控 288 | 舒适性 289 | 动力 290 | 动力 291 | 动力 292 | 动力 293 | 动力 294 | 动力 295 | 动力 296 | 动力 297 | 动力 298 | 安全性 299 | 动力 300 | 动力 301 | 舒适性 302 | 舒适性 303 | 舒适性 304 | 操控 305 | 动力 306 | 外观 307 | 价格 308 | 空间 309 | 内饰 310 | 空间 311 | 外观 312 | 舒适性 313 | 动力 314 | 安全性 315 | 安全性 316 | 内饰 317 | 配置 318 | 动力 319 | 油耗 320 | 舒适性 321 | 动力 322 | 动力 323 | 空间 324 | 动力 325 | 舒适性 326 | 动力 327 | 油耗 328 | 价格 329 | 价格 330 | 动力 331 | 动力 332 | 油耗 333 | 动力 334 | 动力 335 | 配置 336 | 动力 337 | 舒适性 338 | 价格 339 | 价格 340 | 空间 341 | 动力 342 | 油耗 343 | 油耗 344 | 油耗 345 | 操控 346 | 动力 347 | 动力 348 | 动力 349 | 动力 350 | 动力 351 | 价格 352 | 动力 353 | 动力 354 | 动力 355 | 动力 356 | 价格 357 | 舒适性 358 | 动力 359 | 动力 360 | 操控 361 | 油耗 362 | 操控 363 | 价格 364 | 动力 365 | 动力 366 | 外观 367 | 价格 368 | 油耗 369 | 动力 370 | 油耗 371 | 舒适性 372 | 油耗 373 | 外观 374 | 动力 375 | 操控 376 | 动力 377 | 价格 378 | 动力 379 | 价格 380 | 动力 381 | 价格 382 | 动力 383 | 价格 384 | 动力 385 | 外观 386 | 动力 387 | 价格 388 | 油耗 389 | 价格 390 | 动力 391 | 外观 392 | 油耗 393 | 安全性 394 | 动力 395 | 操控 396 | 油耗 397 | 操控 398 | 动力 399 | 动力 400 | 价格 401 | 油耗 402 | 油耗 403 | 油耗 404 | 操控 405 | 油耗 406 | 油耗 407 | 油耗 408 | 价格 409 | 动力 410 | 动力 411 | 操控 412 | 外观 413 | 配置 414 | 动力 415 | 价格 416 | 安全性 417 | 动力 418 | 动力 419 | 油耗 420 | 油耗 421 | 动力 422 | 配置 423 | 操控 424 | 油耗 425 | 动力 426 | 油耗 427 | 价格 428 | 油耗 429 | 动力 430 | 外观 431 | 操控 432 | 价格 433 | 价格 434 | 油耗 435 | 动力 436 | 动力 437 | 操控 438 | 动力 439 | 价格 440 | 价格 441 | 油耗 442 | 油耗 443 | 价格 444 | 动力 445 | 动力 446 | 价格 447 | 外观 448 | 动力 449 | 舒适性 450 | 舒适性 451 | 价格 452 | 价格 453 | 舒适性 454 | 价格 455 | 动力 456 | 安全性 457 | 油耗 458 | 舒适性 459 | 舒适性 460 | 操控 461 | 油耗 462 | 价格 463 | 动力 464 | 安全性 465 | 油耗 466 | 动力 467 | 动力 468 | 价格 469 | 动力 470 | 动力 471 | 操控 472 | 动力 473 | 动力 474 | 内饰 475 | 油耗 476 | 价格 477 | 油耗 478 | 空间 479 | 动力 480 | 油耗 481 | 安全性 482 | 价格 483 | 动力 484 | 动力 485 | 外观 486 | 动力 487 | 外观 488 | 动力 489 | 价格 490 | 外观 491 | 外观 492 | 动力 493 | 油耗 494 | 动力 495 | 动力 496 | 动力 497 | 舒适性 498 | 价格 499 | 价格 500 | 动力 501 | 价格 502 | 外观 503 | 动力 504 | 外观 505 | 内饰 506 | 动力 507 | 动力 508 | 油耗 509 | 空间 510 | 动力 511 | 价格 512 | 动力 513 | 操控 514 | 油耗 515 | 动力 516 | 动力 517 | 油耗 518 | 配置 519 | 价格 520 | 动力 521 | 油耗 522 | 动力 523 | 价格 524 | 价格 525 | 配置 526 | 动力 527 | 舒适性 528 | 动力 529 | 价格 530 | 价格 531 | 价格 532 | 内饰 533 | 外观 534 | 舒适性 535 | 价格 536 | 价格 537 | 动力 538 | 动力 539 | 价格 540 | 动力 541 | 操控 542 | 操控 543 | 动力 544 | 油耗 545 | 价格 546 | 舒适性 547 | 安全性 548 | 价格 549 | 操控 550 | 操控 551 | 内饰 552 | 价格 553 | 价格 554 | 动力 555 | 舒适性 556 | 动力 557 | 价格 558 | 价格 559 | 配置 560 | 舒适性 561 | 油耗 562 | 外观 563 | 油耗 564 | 操控 565 | 油耗 566 | 油耗 567 | 价格 568 | 动力 569 | 配置 570 | 配置 571 | 油耗 572 | 动力 573 | 价格 574 | 动力 575 | 动力 576 | 油耗 577 | 舒适性 578 | 动力 579 | 油耗 580 | 油耗 581 | 动力 582 | 动力 583 | 动力 584 | 价格 585 | 油耗 586 | 价格 587 | 内饰 588 | 安全性 589 | 舒适性 590 | 动力 591 | 动力 592 | 配置 593 | 油耗 594 | 配置 595 | 舒适性 596 | 动力 597 | 动力 598 | 动力 599 | 动力 600 | 价格 601 | 舒适性 602 | 油耗 603 | 价格 604 | 油耗 605 | 价格 606 | 舒适性 607 | 动力 608 | 油耗 609 | 动力 610 | 价格 611 | 油耗 612 | 动力 613 | 外观 614 | 内饰 615 | 安全性 616 | 配置 617 | 舒适性 618 | 动力 619 | 动力 620 | 安全性 621 | 操控 622 | 价格 623 | 外观 624 | 配置 625 | 配置 626 | 空间 627 | 配置 628 | 动力 629 | 空间 630 | 油耗 631 | 价格 632 | 动力 633 | 动力 634 | 油耗 635 | 油耗 636 | 外观 637 | 动力 638 | 操控 639 | 价格 640 | 油耗 641 | 安全性 642 | 动力 643 | 配置 644 | 安全性 645 | 价格 646 | 油耗 647 | 动力 648 | 动力 649 | 动力 650 | 舒适性 651 | 内饰 652 | 动力 653 | 配置 654 | 油耗 655 | 操控 656 | 操控 657 | 动力 658 | 油耗 659 | 内饰 660 | 动力 661 | 动力 662 | 动力 663 | 动力 664 | 动力 665 | 动力 666 | 动力 667 | 油耗 668 | 外观 669 | 安全性 670 | 配置 671 | 动力 672 | 动力 673 | 外观 674 | 动力 675 | 动力 676 | 价格 677 | 动力 678 | 动力 679 | 动力 680 | 动力 681 | 外观 682 | 动力 683 | 动力 684 | 动力 685 | 动力 686 | 动力 687 | 动力 688 | 动力 689 | 动力 690 | 动力 691 | 操控 692 | 操控 693 | 动力 694 | 动力 695 | 外观 696 | 油耗 697 | 动力 698 | 油耗 699 | 油耗 700 | 动力 701 | 操控 702 | 操控 703 | 油耗 704 | 价格 705 | 价格 706 | 动力 707 | 价格 708 | 动力 709 | 动力 710 | 外观 711 | 动力 712 | 价格 713 | 价格 714 | 动力 715 | 操控 716 | 动力 717 | 动力 718 | 油耗 719 | 油耗 720 | 油耗 721 | 动力 722 | 内饰 723 | 油耗 724 | 空间 725 | 动力 726 | 操控 727 | 舒适性 728 | 外观 729 | 价格 730 | 内饰 731 | 动力 732 | 动力 733 | 动力 734 | 动力 735 | 价格 736 | 安全性 737 | 动力 738 | 安全性 739 | 操控 740 | 动力 741 | 舒适性 742 | 舒适性 743 | 空间 744 | 安全性 745 | 操控 746 | 舒适性 747 | 安全性 748 | 舒适性 749 | 动力 750 | 价格 751 | 安全性 752 | 舒适性 753 | 油耗 754 | 油耗 755 | 动力 756 | 动力 757 | 动力 758 | 外观 759 | 内饰 760 | 外观 761 | 配置 762 | 动力 763 | 油耗 764 | 安全性 765 | 操控 766 | 舒适性 767 | 油耗 768 | 配置 769 | 动力 770 | 安全性 771 | 动力 772 | 舒适性 773 | 舒适性 774 | 舒适性 775 | 安全性 776 | 动力 777 | 外观 778 | 动力 779 | 内饰 780 | 空间 781 | 舒适性 782 | 动力 783 | 外观 784 | 价格 785 | 舒适性 786 | 安全性 787 | 价格 788 | 价格 789 | 油耗 790 | 动力 791 | 油耗 792 | 油耗 793 | 油耗 794 | 配置 795 | 外观 796 | 外观 797 | 外观 798 | 价格 799 | 外观 800 | 价格 801 | 价格 802 | 价格 803 | 空间 804 | 舒适性 805 | 油耗 806 | 操控 807 | 动力 808 | 舒适性 809 | 舒适性 810 | 油耗 811 | 操控 812 | 价格 813 | 价格 814 | 油耗 815 | 价格 816 | 油耗 817 | 价格 818 | 动力 819 | 内饰 820 | 操控 821 | 外观 822 | 价格 823 | 动力 824 | 操控 825 | 动力 826 | 操控 827 | 动力 828 | 价格 829 | 动力 830 | 油耗 831 | 油耗 832 | 安全性 833 | 配置 834 | 动力 835 | 内饰 836 | 内饰 837 | 操控 838 | 动力 839 | 配置 840 | 操控 841 | 舒适性 842 | 配置 843 | 舒适性 844 | 配置 845 | 安全性 846 | 动力 847 | 操控 848 | 动力 849 | 动力 850 | 动力 851 | 油耗 852 | 外观 853 | 外观 854 | 配置 855 | 动力 856 | 操控 857 | 动力 858 | 安全性 859 | 操控 860 | 舒适性 861 | 动力 862 | 舒适性 863 | 油耗 864 | 价格 865 | 价格 866 | 操控 867 | 操控 868 | 动力 869 | 动力 870 | 配置 871 | 配置 872 | 动力 873 | 动力 874 | 价格 875 | 价格 876 | 价格 877 | 动力 878 | 油耗 879 | 内饰 880 | 安全性 881 | 价格 882 | 动力 883 | 内饰 884 | 操控 885 | 价格 886 | 安全性 887 | 油耗 888 | 动力 889 | 动力 890 | 油耗 891 | 配置 892 | 配置 893 | 动力 894 | 动力 895 | 价格 896 | 动力 897 | 价格 898 | 动力 899 | 配置 900 | 价格 901 | 动力 902 | 价格 903 | 外观 904 | 动力 905 | 价格 906 | 价格 907 | 动力 908 | 舒适性 909 | 外观 910 | 动力 911 | 油耗 912 | 内饰 913 | 油耗 914 | 外观 915 | 操控 916 | 动力 917 | 动力 918 | 舒适性 919 | 动力 920 | 安全性 921 | 舒适性 922 | 操控 923 | 油耗 924 | 油耗 925 | 外观 926 | 油耗 927 | 内饰 928 | 配置 929 | 空间 930 | 油耗 931 | 动力 932 | 舒适性 933 | 外观 934 | 动力 935 | 动力 936 | 动力 937 | 油耗 938 | 动力 939 | 内饰 940 | 价格 941 | 价格 942 | 动力 943 | 舒适性 944 | 内饰 945 | 空间 946 | 价格 947 | 价格 948 | 操控 949 | 舒适性 950 | 动力 951 | 安全性 952 | 动力 953 | 外观 954 | 外观 955 | 操控 956 | 价格 957 | 操控 958 | 动力 959 | 安全性 960 | 价格 961 | 动力 962 | 价格 963 | 价格 964 | 动力 965 | 价格 966 | 价格 967 | 安全性 968 | 价格 969 | 价格 970 | 外观 971 | 操控 972 | 配置 973 | 操控 974 | 操控 975 | 操控 976 | 操控 977 | 操控 978 | 动力 979 | 油耗 980 | 油耗 981 | 油耗 982 | 油耗 983 | 油耗 984 | 油耗 985 | 油耗 986 | 油耗 987 | 油耗 988 | 操控 989 | 舒适性 990 | 外观 991 | 动力 992 | 动力 993 | 动力 994 | 动力 995 | 动力 996 | 动力 997 | 动力 998 | 外观 999 | 外观 1000 | 空间 1001 | 外观 1002 | 外观 1003 | 外观 1004 | 外观 1005 | 动力 1006 | 价格 1007 | 外观 1008 | 内饰 1009 | 安全性 1010 | 安全性 1011 | 安全性 1012 | 安全性 1013 | 安全性 1014 | 安全性 1015 | 安全性 1016 | 安全性 1017 | 动力 1018 | 安全性 1019 | 动力 1020 | 动力 1021 | 动力 1022 | 价格 1023 | 价格 1024 | 动力 1025 | 价格 1026 | 动力 1027 | 配置 1028 | 动力 1029 | 安全性 1030 | 舒适性 1031 | 动力 1032 | 动力 1033 | 动力 1034 | 动力 1035 | 动力 1036 | 动力 1037 | 动力 1038 | 动力 1039 | 动力 1040 | 动力 1041 | 动力 1042 | 动力 1043 | 安全性 1044 | 空间 1045 | 配置 1046 | 内饰 1047 | 内饰 1048 | 动力 1049 | 价格 1050 | 价格 1051 | 安全性 1052 | 油耗 1053 | 外观 1054 | 油耗 1055 | 操控 1056 | 安全性 1057 | 内饰 1058 | 配置 1059 | 外观 1060 | 安全性 1061 | 油耗 1062 | 舒适性 1063 | 舒适性 1064 | 动力 1065 | 操控 1066 | 价格 1067 | 内饰 1068 | 安全性 1069 | 动力 1070 | 动力 1071 | 空间 1072 | 配置 1073 | 价格 1074 | 价格 1075 | 价格 1076 | 油耗 1077 | 动力 1078 | 油耗 1079 | 动力 1080 | 动力 1081 | 动力 1082 | 安全性 1083 | 外观 1084 | 安全性 1085 | 外观 1086 | 动力 1087 | 配置 1088 | 动力 1089 | 配置 1090 | 配置 1091 | 油耗 1092 | 动力 1093 | 价格 1094 | 油耗 1095 | 动力 1096 | 动力 1097 | 动力 1098 | 外观 1099 | 动力 1100 | 动力 1101 | 动力 1102 | 外观 1103 | 安全性 1104 | 安全性 1105 | 操控 1106 | 操控 1107 | 价格 1108 | 动力 1109 | 舒适性 1110 | 操控 1111 | 动力 1112 | 价格 1113 | 油耗 1114 | 价格 1115 | 内饰 1116 | 舒适性 1117 | 操控 1118 | 动力 1119 | 舒适性 1120 | 油耗 1121 | 油耗 1122 | 舒适性 1123 | 动力 1124 | 油耗 1125 | 价格 1126 | 外观 1127 | 外观 1128 | 动力 1129 | 油耗 1130 | 外观 1131 | 外观 1132 | 操控 1133 | 安全性 1134 | 安全性 1135 | 油耗 1136 | 配置 1137 | 操控 1138 | 油耗 1139 | 动力 1140 | 动力 1141 | 动力 1142 | 油耗 1143 | 动力 1144 | 外观 1145 | 价格 1146 | 安全性 1147 | 价格 1148 | 操控 1149 | 操控 1150 | 配置 1151 | 配置 1152 | 外观 1153 | 动力 1154 | 动力 1155 | 动力 1156 | 操控 1157 | 价格 1158 | 操控 1159 | 安全性 1160 | 操控 1161 | 动力 1162 | 动力 1163 | 价格 1164 | 内饰 1165 | 操控 1166 | 操控 1167 | 油耗 1168 | 动力 1169 | 内饰 1170 | 价格 1171 | 价格 1172 | 动力 1173 | 安全性 1174 | 动力 1175 | 动力 1176 | 安全性 1177 | 空间 1178 | 动力 1179 | 动力 1180 | 动力 1181 | 外观 1182 | 动力 1183 | 安全性 1184 | 动力 1185 | 配置 1186 | 配置 1187 | 操控 1188 | 动力 1189 | 舒适性 1190 | 舒适性 1191 | 外观 1192 | 舒适性 1193 | 外观 1194 | 价格 1195 | 动力 1196 | 安全性 1197 | 动力 1198 | 价格 1199 | 动力 1200 | 内饰 1201 | 外观 1202 | 油耗 1203 | 价格 1204 | 操控 1205 | 动力 1206 | 舒适性 1207 | 油耗 1208 | 外观 1209 | 安全性 1210 | 油耗 1211 | 价格 1212 | 动力 1213 | 动力 1214 | 配置 1215 | 外观 1216 | 油耗 1217 | 动力 1218 | 动力 1219 | 动力 1220 | 价格 1221 | 配置 1222 | 安全性 1223 | 操控 1224 | 动力 1225 | 配置 1226 | 内饰 1227 | 操控 1228 | 油耗 1229 | 油耗 1230 | 操控 1231 | 操控 1232 | 操控 1233 | 内饰 1234 | 动力 1235 | 安全性 1236 | 价格 1237 | 配置 1238 | 操控 1239 | 动力 1240 | 外观 1241 | 操控 1242 | 动力 1243 | 动力 1244 | 配置 1245 | 配置 1246 | 安全性 1247 | 操控 1248 | 操控 1249 | 价格 1250 | 价格 1251 | 价格 1252 | 价格 1253 | 价格 1254 | 操控 1255 | 价格 1256 | 配置 1257 | 配置 1258 | 配置 1259 | 操控 1260 | 操控 1261 | 安全性 1262 | 油耗 1263 | 油耗 1264 | 油耗 1265 | 油耗 1266 | 油耗 1267 | 油耗 1268 | 油耗 1269 | 舒适性 1270 | 内饰 1271 | 内饰 1272 | 内饰 1273 | 内饰 1274 | 操控 1275 | 安全性 1276 | 内饰 1277 | 操控 1278 | 安全性 1279 | 操控 1280 | 安全性 1281 | 动力 1282 | 动力 1283 | 动力 1284 | 动力 1285 | 动力 1286 | 动力 1287 | 动力 1288 | 动力 1289 | 动力 1290 | 空间 1291 | 动力 1292 | 外观 1293 | 安全性 1294 | 舒适性 1295 | 外观 1296 | 油耗 1297 | 动力 1298 | 动力 1299 | 油耗 1300 | 油耗 1301 | 动力 1302 | 油耗 1303 | 油耗 1304 | 价格 1305 | 油耗 1306 | 油耗 1307 | 动力 1308 | 油耗 1309 | 动力 1310 | 动力 1311 | 舒适性 1312 | 安全性 1313 | 价格 1314 | 舒适性 1315 | 配置 1316 | 动力 1317 | 动力 1318 | 价格 1319 | 价格 1320 | 价格 1321 | 外观 1322 | 价格 1323 | 价格 1324 | 动力 1325 | 操控 1326 | 价格 1327 | 动力 1328 | 安全性 1329 | 舒适性 1330 | 动力 1331 | 操控 1332 | 舒适性 1333 | 动力 1334 | 操控 1335 | 动力 1336 | 舒适性 1337 | 动力 1338 | 动力 1339 | 舒适性 1340 | 油耗 1341 | 价格 1342 | 动力 1343 | 价格 1344 | 油耗 1345 | 安全性 1346 | 动力 1347 | 动力 1348 | 舒适性 1349 | 动力 1350 | 操控 1351 | 舒适性 1352 | 动力 1353 | 价格 1354 | 动力 1355 | 动力 1356 | 动力 1357 | 动力 1358 | 油耗 1359 | 动力 1360 | 安全性 1361 | 动力 1362 | 油耗 1363 | 油耗 1364 | 价格 1365 | 价格 1366 | 动力 1367 | 动力 1368 | 动力 1369 | 舒适性 1370 | 动力 1371 | 价格 1372 | 动力 1373 | 操控 1374 | 操控 1375 | 动力 1376 | 配置 1377 | 舒适性 1378 | 安全性 1379 | 舒适性 1380 | 动力 1381 | 安全性 1382 | 动力 1383 | 动力 1384 | 油耗 1385 | 价格 1386 | 配置 1387 | 配置 1388 | 配置 1389 | 动力 1390 | 油耗 1391 | 价格 1392 | 操控 1393 | 配置 1394 | 舒适性 1395 | 价格 1396 | 价格 1397 | 动力 1398 | 价格 1399 | 价格 1400 | 油耗 1401 | 价格 1402 | 外观 1403 | 价格 1404 | 动力 1405 | 空间 1406 | 空间 1407 | 操控 1408 | 配置 1409 | 价格 1410 | 动力 1411 | 动力 1412 | 外观 1413 | 外观 1414 | 操控 1415 | 动力 1416 | 价格 1417 | 价格 1418 | 价格 1419 | 价格 1420 | 价格 1421 | 价格 1422 | 动力 1423 | 价格 1424 | 动力 1425 | 动力 1426 | 动力 1427 | 动力 1428 | 动力 1429 | 外观 1430 | 外观 1431 | 安全性 1432 | 动力 1433 | 动力 1434 | 动力 1435 | 动力 1436 | 动力 1437 | 动力 1438 | 动力 1439 | 动力 1440 | 油耗 1441 | 油耗 1442 | 动力 1443 | 油耗 1444 | 动力 1445 | 动力 1446 | 动力 1447 | 动力 1448 | 外观 1449 | 动力 1450 | 油耗 1451 | 油耗 1452 | 油耗 1453 | 油耗 1454 | 动力 1455 | 动力 1456 | 舒适性 1457 | 操控 1458 | 油耗 1459 | 舒适性 1460 | 操控 1461 | 动力 1462 | 操控 1463 | 动力 1464 | 动力 1465 | 动力 1466 | 安全性 1467 | 动力 1468 | 动力 1469 | 动力 1470 | 动力 1471 | 舒适性 1472 | 舒适性 1473 | 安全性 1474 | 动力 1475 | 油耗 1476 | 舒适性 1477 | 配置 1478 | 舒适性 1479 | 舒适性 1480 | 内饰 1481 | 操控 1482 | 舒适性 1483 | 舒适性 1484 | 配置 1485 | 安全性 1486 | 动力 1487 | 配置 1488 | 价格 1489 | 油耗 1490 | 价格 1491 | 外观 1492 | 动力 1493 | 操控 1494 | 外观 1495 | 动力 1496 | 价格 1497 | 动力 1498 | 外观 1499 | 动力 1500 | 配置 1501 | 动力 1502 | 动力 1503 | 动力 1504 | 油耗 1505 | 动力 1506 | 动力 1507 | 操控 1508 | 舒适性 1509 | 动力 1510 | 动力 1511 | 油耗 1512 | 舒适性 1513 | 价格 1514 | 舒适性 1515 | 动力 1516 | 价格 1517 | 动力 1518 | 动力 1519 | 动力 1520 | 操控 1521 | 价格 1522 | 配置 1523 | 油耗 1524 | 油耗 1525 | 内饰 1526 | 动力 1527 | 动力 1528 | 价格 1529 | 操控 1530 | 油耗 1531 | 动力 1532 | 空间 1533 | 油耗 1534 | 动力 1535 | 舒适性 1536 | 操控 1537 | 动力 1538 | 舒适性 1539 | 动力 1540 | 动力 1541 | 外观 1542 | 舒适性 1543 | 配置 1544 | 动力 1545 | 动力 1546 | 动力 1547 | 配置 1548 | 动力 1549 | 价格 1550 | 动力 1551 | 价格 1552 | 动力 1553 | 动力 1554 | 油耗 1555 | 价格 1556 | 动力 1557 | 油耗 1558 | 价格 1559 | 价格 1560 | 价格 1561 | 价格 1562 | 动力 1563 | 价格 1564 | 价格 1565 | 操控 1566 | 操控 1567 | 油耗 1568 | 动力 1569 | 动力 1570 | 内饰 1571 | 动力 1572 | 价格 1573 | 配置 1574 | 配置 1575 | 价格 1576 | 操控 1577 | 油耗 1578 | 动力 1579 | 动力 1580 | 动力 1581 | 价格 1582 | 油耗 1583 | 动力 1584 | 价格 1585 | 油耗 1586 | 动力 1587 | 油耗 1588 | 动力 1589 | 外观 1590 | 油耗 1591 | 动力 1592 | 配置 1593 | 动力 1594 | 动力 1595 | 价格 1596 | 价格 1597 | 油耗 1598 | 价格 1599 | 空间 1600 | 价格 1601 | 油耗 1602 | 价格 1603 | 外观 1604 | 操控 1605 | 动力 1606 | 动力 1607 | 价格 1608 | 价格 1609 | 内饰 1610 | 舒适性 1611 | 外观 1612 | 舒适性 1613 | 价格 1614 | 动力 1615 | 价格 1616 | 操控 1617 | 外观 1618 | 油耗 1619 | 油耗 1620 | 油耗 1621 | 价格 1622 | 操控 1623 | 动力 1624 | 动力 1625 | 配置 1626 | 动力 1627 | 动力 1628 | 价格 1629 | 价格 1630 | 动力 1631 | 动力 1632 | 外观 1633 | 操控 1634 | 价格 1635 | 内饰 1636 | 动力 1637 | 油耗 1638 | 动力 1639 | 动力 1640 | 动力 1641 | 油耗 1642 | 油耗 1643 | 动力 1644 | 动力 1645 | 动力 1646 | 动力 1647 | 空间 1648 | 油耗 1649 | 舒适性 1650 | 动力 1651 | 动力 1652 | 操控 1653 | 动力 1654 | 价格 1655 | 安全性 1656 | 油耗 1657 | 操控 1658 | 内饰 1659 | 外观 1660 | 动力 1661 | 价格 1662 | 空间 1663 | 价格 1664 | 动力 1665 | 外观 1666 | 操控 1667 | 配置 1668 | 价格 1669 | 操控 1670 | 外观 1671 | 配置 1672 | 配置 1673 | 配置 1674 | 操控 1675 | 价格 1676 | 价格 1677 | 操控 1678 | 动力 1679 | 安全性 1680 | 油耗 1681 | 舒适性 1682 | 动力 1683 | 价格 1684 | 外观 1685 | 操控 1686 | 操控 1687 | 空间 1688 | 价格 1689 | 价格 1690 | 动力 1691 | 动力 1692 | 价格 1693 | 动力 1694 | 动力 1695 | 动力 1696 | 价格 1697 | 配置 1698 | 价格 1699 | 价格 1700 | 价格 1701 | 价格 1702 | 油耗 1703 | 操控 1704 | 价格 1705 | 动力 1706 | 动力 1707 | 舒适性 1708 | 动力 1709 | 动力 1710 | 油耗 1711 | 舒适性 1712 | 价格 1713 | 动力 1714 | 价格 1715 | 价格 1716 | 配置 1717 | 价格 1718 | 油耗 1719 | 配置 1720 | 动力 1721 | 动力 1722 | 配置 1723 | 配置 1724 | 安全性 1725 | 安全性 1726 | 动力 1727 | 配置 1728 | 内饰 1729 | 动力 1730 | 油耗 1731 | 动力 1732 | 安全性 1733 | 动力 1734 | 价格 1735 | 价格 1736 | 安全性 1737 | 动力 1738 | 动力 1739 | 操控 1740 | 舒适性 1741 | 动力 1742 | 外观 1743 | 内饰 1744 | 操控 1745 | 价格 1746 | 动力 1747 | 舒适性 1748 | 操控 1749 | 动力 1750 | 安全性 1751 | 动力 1752 | 舒适性 1753 | 安全性 1754 | 配置 1755 | 油耗 1756 | 动力 1757 | 油耗 1758 | 安全性 1759 | 动力 1760 | 动力 1761 | 动力 1762 | 安全性 1763 | 外观 1764 | 价格 1765 | 外观 1766 | 价格 1767 | 动力 1768 | 安全性 1769 | 外观 1770 | 价格 1771 | 动力 1772 | 动力 1773 | 操控 1774 | 动力 1775 | 动力 1776 | 外观 1777 | 安全性 1778 | 安全性 1779 | 配置 1780 | 油耗 1781 | 外观 1782 | 安全性 1783 | 价格 1784 | 价格 1785 | 舒适性 1786 | 动力 1787 | 动力 1788 | 外观 1789 | 配置 1790 | 动力 1791 | 安全性 1792 | 外观 1793 | 内饰 1794 | 内饰 1795 | 价格 1796 | 操控 1797 | 动力 1798 | 价格 1799 | 舒适性 1800 | 动力 1801 | 外观 1802 | 安全性 1803 | 外观 1804 | 操控 1805 | 空间 1806 | 安全性 1807 | 价格 1808 | 操控 1809 | 安全性 1810 | 油耗 1811 | 价格 1812 | 内饰 1813 | 价格 1814 | 操控 1815 | 动力 1816 | 动力 1817 | 外观 1818 | 安全性 1819 | 价格 1820 | 动力 1821 | 动力 1822 | 操控 1823 | 价格 1824 | 价格 1825 | 价格 1826 | 动力 1827 | 价格 1828 | 价格 1829 | 动力 1830 | 油耗 1831 | 油耗 1832 | 动力 1833 | 价格 1834 | 动力 1835 | 价格 1836 | 价格 1837 | 动力 1838 | 价格 1839 | 价格 1840 | 动力 1841 | 油耗 1842 | 舒适性 1843 | 价格 1844 | 价格 1845 | 油耗 1846 | 动力 1847 | 油耗 1848 | 动力 1849 | 价格 1850 | 价格 1851 | 操控 1852 | 空间 1853 | 空间 1854 | 操控 1855 | 空间 1856 | 空间 1857 | 动力 1858 | 油耗 1859 | 油耗 1860 | 价格 1861 | 动力 1862 | 油耗 1863 | 动力 1864 | 动力 1865 | 内饰 1866 | 动力 1867 | 价格 1868 | 操控 1869 | 价格 1870 | 价格 1871 | 动力 1872 | 价格 1873 | 动力 1874 | 动力 1875 | 操控 1876 | 动力 1877 | 操控 1878 | 内饰 1879 | 动力 1880 | 外观 1881 | 价格 1882 | 动力 1883 | 动力 1884 | 动力 1885 | 油耗 1886 | 外观 1887 | 动力 1888 | 动力 1889 | 价格 1890 | 动力 1891 | 动力 1892 | 价格 1893 | 动力 1894 | 价格 1895 | 外观 1896 | 操控 1897 | 油耗 1898 | 油耗 1899 | 价格 1900 | 动力 1901 | 动力 1902 | 油耗 1903 | 安全性 1904 | 外观 1905 | 价格 1906 | 动力 1907 | 安全性 1908 | 动力 1909 | 舒适性 1910 | 安全性 1911 | 油耗 1912 | 外观 1913 | 油耗 1914 | 外观 1915 | 价格 1916 | 配置 1917 | 舒适性 1918 | 油耗 1919 | 安全性 1920 | 舒适性 1921 | 舒适性 1922 | 动力 1923 | 安全性 1924 | 动力 1925 | 配置 1926 | 动力 1927 | 价格 1928 | 动力 1929 | 安全性 1930 | 安全性 1931 | 价格 1932 | 操控 1933 | 油耗 1934 | 动力 1935 | 动力 1936 | 动力 1937 | 舒适性 1938 | 动力 1939 | 动力 1940 | 舒适性 1941 | 动力 1942 | 动力 1943 | 安全性 1944 | 外观 1945 | 动力 1946 | 价格 1947 | 动力 1948 | 油耗 1949 | 舒适性 1950 | 油耗 1951 | 油耗 1952 | 舒适性 1953 | 操控 1954 | 动力 1955 | 价格 1956 | 舒适性 1957 | 价格 1958 | 舒适性 1959 | 操控 1960 | 配置 1961 | 安全性 1962 | 价格 1963 | 动力 1964 | 舒适性 1965 | 安全性 1966 | 动力 1967 | 外观 1968 | 价格 1969 | 价格 1970 | 动力 1971 | 油耗 1972 | 动力 1973 | 价格 1974 | 油耗 1975 | 安全性 1976 | 舒适性 1977 | 动力 1978 | 动力 1979 | 动力 1980 | 动力 1981 | 价格 1982 | 价格 1983 | 动力 1984 | 外观 1985 | 油耗 1986 | 动力 1987 | 价格 1988 | 价格 1989 | 安全性 1990 | 舒适性 1991 | 动力 1992 | 动力 1993 | 动力 1994 | 油耗 1995 | 操控 1996 | 配置 1997 | 动力 1998 | 操控 1999 | 安全性 2000 | 动力 2001 | 外观 2002 | 舒适性 2003 | 油耗 2004 | 动力 2005 | 操控 2006 | 外观 2007 | 操控 2008 | 价格 2009 | 价格 2010 | 动力 2011 | 价格 2012 | 外观 2013 | 内饰 2014 | 安全性 2015 | 外观 2016 | 外观 2017 | 内饰 2018 | 外观 2019 | 价格 2020 | 外观 2021 | 内饰 2022 | 安全性 2023 | 操控 2024 | 油耗 2025 | 动力 2026 | 安全性 2027 | 动力 2028 | 动力 2029 | 操控 2030 | 价格 2031 | 配置 2032 | 安全性 2033 | 动力 2034 | 舒适性 2035 | 外观 2036 | 空间 2037 | 安全性 2038 | 动力 2039 | 操控 2040 | 外观 2041 | 操控 2042 | 安全性 2043 | 动力 2044 | 动力 2045 | 舒适性 2046 | 动力 2047 | 外观 2048 | 内饰 2049 | 安全性 2050 | 操控 2051 | 动力 2052 | 外观 2053 | 价格 2054 | 配置 2055 | 配置 2056 | 配置 2057 | 配置 2058 | 油耗 2059 | 动力 2060 | 动力 2061 | 油耗 2062 | 操控 2063 | 配置 2064 | 价格 2065 | 动力 2066 | 操控 2067 | 安全性 2068 | 油耗 2069 | 舒适性 2070 | 油耗 2071 | 配置 2072 | 价格 2073 | 舒适性 2074 | 舒适性 2075 | 配置 2076 | 油耗 2077 | 动力 2078 | 动力 2079 | 油耗 2080 | 动力 2081 | 油耗 2082 | 安全性 2083 | 动力 2084 | 动力 2085 | 外观 2086 | 动力 2087 | 动力 2088 | 操控 2089 | 油耗 2090 | 油耗 2091 | 空间 2092 | 动力 2093 | 油耗 2094 | 价格 2095 | 操控 2096 | 动力 2097 | 动力 2098 | 油耗 2099 | 舒适性 2100 | 外观 2101 | 配置 2102 | 动力 2103 | 外观 2104 | 外观 2105 | 舒适性 2106 | 配置 2107 | 动力 2108 | 安全性 2109 | 操控 2110 | 价格 2111 | 外观 2112 | 价格 2113 | 油耗 2114 | 价格 2115 | 操控 2116 | 内饰 2117 | 价格 2118 | 价格 2119 | 价格 2120 | 价格 2121 | 价格 2122 | 价格 2123 | 价格 2124 | 外观 2125 | 操控 2126 | 操控 2127 | 操控 2128 | 油耗 2129 | 油耗 2130 | 油耗 2131 | 内饰 2132 | 空间 2133 | 外观 2134 | 配置 2135 | 动力 2136 | 动力 2137 | 动力 2138 | 动力 2139 | 动力 2140 | 动力 2141 | 动力 2142 | 动力 2143 | 动力 2144 | 动力 2145 | 动力 2146 | 外观 2147 | 空间 2148 | 外观 2149 | 价格 2150 | 外观 2151 | 外观 2152 | 外观 2153 | 安全性 2154 | 安全性 2155 | 安全性 2156 | 安全性 2157 | 安全性 2158 | 外观 2159 | 安全性 2160 | 安全性 2161 | 安全性 2162 | 动力 2163 | 动力 2164 | 动力 2165 | 油耗 2166 | 配置 2167 | 配置 2168 | 配置 2169 | 动力 2170 | 动力 2171 | 动力 2172 | 空间 2173 | 配置 2174 | 油耗 2175 | 外观 2176 | 舒适性 2177 | 舒适性 2178 | 油耗 2179 | 油耗 2180 | 油耗 2181 | 动力 2182 | 油耗 2183 | 内饰 2184 | 安全性 2185 | 安全性 2186 | 动力 2187 | 舒适性 2188 | 动力 2189 | 配置 2190 | 油耗 2191 | 内饰 2192 | 价格 2193 | 配置 2194 | 配置 2195 | 油耗 2196 | 油耗 2197 | 外观 2198 | 价格 2199 | 内饰 2200 | 动力 2201 | 动力 2202 | 外观 2203 | 外观 2204 | 安全性 2205 | 动力 2206 | 价格 2207 | 外观 2208 | 动力 2209 | 动力 2210 | 舒适性 2211 | 动力 2212 | 油耗 2213 | 安全性 2214 | 安全性 2215 | 舒适性 2216 | 油耗 2217 | 动力 2218 | 动力 2219 | 配置 2220 | 操控 2221 | 价格 2222 | 动力 2223 | 价格 2224 | 动力 2225 | 外观 2226 | 配置 2227 | 配置 2228 | 内饰 2229 | 油耗 2230 | 油耗 2231 | 油耗 2232 | 内饰 2233 | 外观 2234 | 操控 2235 | 价格 2236 | 动力 2237 | 配置 2238 | 价格 2239 | 动力 2240 | 操控 2241 | 配置 2242 | 配置 2243 | 油耗 2244 | 操控 2245 | 动力 2246 | 舒适性 2247 | 空间 2248 | 内饰 2249 | 内饰 2250 | 外观 2251 | 价格 2252 | 动力 2253 | 安全性 2254 | 油耗 2255 | 动力 2256 | 安全性 2257 | 动力 2258 | 价格 2259 | 舒适性 2260 | 动力 2261 | 舒适性 2262 | 舒适性 2263 | 操控 2264 | 动力 2265 | 价格 2266 | 外观 2267 | 操控 2268 | 价格 2269 | 舒适性 2270 | 配置 2271 | 舒适性 2272 | 内饰 2273 | 安全性 2274 | 内饰 2275 | 动力 2276 | 配置 2277 | 动力 2278 | 舒适性 2279 | 动力 2280 | 外观 2281 | 配置 2282 | 油耗 2283 | 油耗 2284 | 安全性 2285 | 操控 2286 | 油耗 2287 | 动力 2288 | 安全性 2289 | 动力 2290 | 舒适性 2291 | 价格 2292 | 动力 2293 | 油耗 2294 | 油耗 2295 | 油耗 2296 | 油耗 2297 | 舒适性 2298 | 动力 2299 | 价格 2300 | 配置 2301 | 外观 2302 | 动力 2303 | 价格 2304 | 动力 2305 | 价格 2306 | 动力 2307 | 动力 2308 | 动力 2309 | 操控 2310 | 动力 2311 | 价格 2312 | 动力 2313 | 动力 2314 | 油耗 2315 | 操控 2316 | 安全性 2317 | 舒适性 2318 | 动力 2319 | 配置 2320 | 内饰 2321 | 动力 2322 | 价格 2323 | 操控 2324 | 价格 2325 | 价格 2326 | 空间 2327 | 动力 2328 | 操控 2329 | 操控 2330 | 动力 2331 | 舒适性 2332 | 舒适性 2333 | 动力 2334 | 安全性 2335 | 空间 2336 | 外观 2337 | 动力 2338 | 动力 2339 | 舒适性 2340 | 外观 2341 | 舒适性 2342 | 操控 2343 | 动力 2344 | 内饰 2345 | 配置 2346 | 外观 2347 | 外观 2348 | 操控 2349 | 油耗 2350 | 动力 2351 | 舒适性 2352 | 油耗 2353 | 操控 2354 | 配置 2355 | 油耗 2356 | 价格 2357 | 配置 2358 | 动力 2359 | 油耗 2360 | 安全性 2361 | 动力 2362 | 动力 2363 | 舒适性 2364 | 动力 2365 | 价格 2366 | -------------------------------------------------------------------------------- /data/cnews_loader.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import sys 4 | from collections import Counter 5 | 6 | import numpy as np 7 | import tensorflow.contrib.keras as kr 8 | import tensorflow as tf 9 | 10 | if sys.version_info[0] > 2: 11 | is_py3 = True 12 | else: 13 | # reload(sys) 14 | sys.setdefaultencoding("utf-8") 15 | is_py3 = False 16 | 17 | 18 | def native_word(word, encoding='utf-8'): 19 | """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码""" 20 | if not is_py3: 21 | return word.encode(encoding) 22 | else: 23 | return word 24 | 25 | 26 | def native_content(content): 27 | if not is_py3: 28 | return content.decode('utf-8') 29 | else: 30 | return content 31 | 32 | 33 | def open_file(filename, mode='r'): 34 | """ 35 | 常用文件操作,可在python2和python3间切换. 36 | mode: 'r' or 'w' for read or write 37 | """ 38 | if is_py3: 39 | return open(filename, mode, encoding='utf-8', errors='ignore') 40 | else: 41 | return open(filename, mode) 42 | 43 | 44 | def read_file(filename): 45 | """读取文件数据""" 46 | contents, labels = [], [] 47 | with open_file(filename) as f: 48 | for line in f: 49 | try: 50 | label, content = line.strip().split('\t') 51 | contents.append(content) 52 | if content: 53 | 54 | labels.append(native_content(label)) 55 | except: 56 | pass 57 | return contents, labels 58 | 59 | def read_file_nolabel(filename): 60 | """读取文件数据""" 61 | contents = [] 62 | with open_file(filename) as f: 63 | for line in f: 64 | try: 65 | content = line.strip() 66 | contents.append(content) 67 | # if content: 68 | # contents.append(list(native_content(content))) 69 | # labels.append(native_content(label)) 70 | except: 71 | pass 72 | return contents 73 | 74 | def build_vocab(train_dir, vocab_dir, vocab_size=5000): 75 | """根据训练集构建词汇表,存储, x, y""" 76 | data_train, _ = read_file(train_dir) 77 | 78 | all_data = [] 79 | for content in data_train: 80 | all_data.extend(content) 81 | 82 | # with open('./data/word.txt', 'w') as out: 83 | # for i in range(len(all_data)): 84 | # out.write(all_data[i] + ' ') 85 | counter = Counter(all_data) 86 | count_pairs = counter.most_common(vocab_size - 1) 87 | words, _ = list(zip(*count_pairs)) 88 | # 添加一个 来将所有文本pad为同一长度 89 | words = [''] + list(words) 90 | open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n') 91 | 92 | def load_word2vec_embedding(word_embedding_file, vocab_size, embedding_dim): 93 | ''' 94 | 加载外接的词向量。 95 | :return: 96 | ''' 97 | print ('loading word embedding, it will take few minutes...') 98 | embeddings = np.random.uniform(-1,1,(vocab_size, embedding_dim)) # 4223, 300 99 | # 保证每次随机出来的数一样。 100 | rng = np.random.RandomState(23455) 101 | unknown = np.asarray(rng.normal(size=(embedding_dim))) # 300 102 | # padding = np.asarray(rng.normal(size=(embedding_dim))) 103 | 104 | f = open(word_embedding_file) 105 | for index, line in enumerate(f): 106 | values = line.split() 107 | try: 108 | coefs = np.asarray(values[1:], dtype='float32') # 取向量 109 | except ValueError: 110 | # 如果真的这个词出现在了训练数据里,这么做就会有潜在的bug。那coefs的值就是上一轮的值。 111 | print (values[0], values[1:]) 112 | 113 | embeddings[index] = coefs # 将词和对应的向量存到字典里 114 | f.close() 115 | 116 | 117 | # 顺序不能错,这个和unkown_id和padding id需要一一对应。 118 | # embeddings[-2] = unknown 119 | # embeddings[-1] = unknown 120 | 121 | return tf.get_variable("embeddings", dtype=tf.float32, 122 | shape=[vocab_size, embedding_dim], 123 | initializer=tf.constant_initializer(embeddings), trainable=False) 124 | 125 | def read_vocab(vocab_dir): 126 | """读取词汇表""" 127 | # words = open_file(vocab_dir).read().strip().split('\n') 128 | with open_file(vocab_dir) as fp: 129 | # 如果是py2 则每个值都转化为unicode 130 | words = [native_content(_.strip()) for _ in fp.readlines()] 131 | word_to_id = dict(zip(words, range(len(words)))) 132 | return words, word_to_id 133 | 134 | 135 | def read_category(): 136 | """读取分类目录,固定""" 137 | categories = ['价格', '动力', '油耗', '操控', '配置', '舒适性', '安全性', '内饰', '外观', '空间'] 138 | 139 | categories = [native_content(x) for x in categories] 140 | 141 | cat_to_id = dict(zip(categories, range(len(categories)))) 142 | 143 | return categories, cat_to_id 144 | 145 | 146 | def to_words(content, words): 147 | """将id表示的内容转换为文字""" 148 | return ''.join(words[x] for x in content) 149 | 150 | 151 | # def process_file(filename, word_to_id, cat_to_id, max_length=600): 152 | def process_file(filename, cat_to_id): 153 | """将文件转换为id表示""" 154 | contents, labels = read_file(filename) 155 | 156 | # data_id, label_id = [], [] 157 | label_id = [] 158 | for i in range(len(contents)): 159 | # data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id]) 160 | label_id.append(cat_to_id[labels[i]]) 161 | 162 | # 使用keras提供的pad_sequences来将文本pad为固定长度 163 | # x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length) 164 | y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示 165 | 166 | return contents, y_pad 167 | 168 | def process_file_nolabel(filename, word_to_id, max_length=600): 169 | """将文件转换为id表示""" 170 | contents = read_file_nolabel(filename) 171 | 172 | data_id = [] 173 | for i in range(len(contents)): 174 | data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id]) 175 | # label_id.append(cat_to_id[labels[i]]) 176 | 177 | # 使用keras提供的pad_sequences来将文本pad为固定长度 178 | x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length) 179 | # y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示 180 | 181 | # return x_pad 182 | return contents 183 | 184 | 185 | def batch_iter(x, y, batch_size=64): 186 | """生成批次数据""" 187 | data_len = len(x) 188 | num_batch = int((data_len - 1) / batch_size) + 1 189 | # 区别在于shuffle直接在原来的数组上进行操作,改变原来数组的顺序,无返回值。 190 | # 而permutation不直接在原来的数组上进行操作,而是返回一个新的打乱顺序的数组,并不改变原来的数组。 191 | # indices = np.random.permutation(np.arange(data_len)) 192 | # x_shuffle = x[indices] 193 | # y_shuffle = y[indices] 194 | for i in range(num_batch): 195 | start_id = i * batch_size 196 | end_id = min((i + 1) * batch_size, data_len) 197 | yield x[start_id:end_id], y[start_id:end_id] 198 | 199 | def attention(inputs, attention_size, l2_reg_lambda): 200 | """ 201 | Attention mechanism layer. 202 | :param inputs: outputs of RNN/Bi-RNN layer (not final state) 203 | :param attention_size: linear size of attention weights 204 | :return: outputs of the passed RNN/Bi-RNN reduced with attention vector 205 | """ 206 | # In case of Bi-RNN input we need to concatenate outputs of its forward and backward parts 207 | if isinstance(inputs, tuple): 208 | inputs = tf.concat(2, inputs) 209 | 210 | sequence_length = inputs.get_shape()[1].value # the length of sequences processed in the antecedent RNN layer 211 | hidden_size = inputs.get_shape()[2].value # hidden size of the RNN layer 212 | 213 | # Attention mechanism 214 | W_omega = tf.get_variable("W_omega", initializer=tf.random_normal([hidden_size, attention_size], stddev=0.1)) 215 | b_omega = tf.get_variable("b_omega", initializer=tf.random_normal([attention_size], stddev=0.1)) 216 | u_omega = tf.get_variable("u_omega", initializer=tf.random_normal([attention_size], stddev=0.1)) 217 | 218 | v = tf.tanh(tf.matmul(tf.reshape(inputs, [-1, hidden_size]), W_omega) + tf.reshape(b_omega, [1, -1])) 219 | vu = tf.matmul(v, tf.reshape(u_omega, [-1, 1])) 220 | exps = tf.reshape(tf.exp(vu), [-1, sequence_length]) 221 | alphas = exps / tf.reshape(tf.reduce_sum(exps, 1), [-1, 1]) 222 | 223 | # Output of Bi-RNN is reduced with attention vector 224 | output = tf.reduce_sum(inputs * tf.reshape(alphas, [-1, sequence_length, 1]), 1) 225 | #if l2_reg_lambda > 0: 226 | # l2_loss += tf.nn.l2_loss(W_omega) 227 | # l2_loss += tf.nn.l2_loss(b_omega) 228 | # l2_loss += tf.nn.l2_loss(u_omega) 229 | # tf.add_to_collection('losses', l2_loss) 230 | 231 | return output -------------------------------------------------------------------------------- /data/new_para.txt: -------------------------------------------------------------------------------- 1 | 价格 2 | 价格 3 | 价格 4 | 价格 5 | 价格 6 | 价格 7 | 价格 8 | 价格 9 | 价格 10 | 配置 11 | 配置 12 | 配置 13 | 配置 14 | 配置 15 | 配置 16 | 舒适性 17 | 舒适性 18 | 配置 19 | 油耗 20 | 油耗 21 | 油耗 22 | 油耗 23 | 动力 24 | 油耗 25 | 油耗 26 | 油耗 27 | 舒适性 28 | 动力 29 | 动力 30 | 内饰 31 | 内饰 32 | 内饰 33 | 安全性 34 | 安全性 35 | 安全性 36 | 安全性 37 | 安全性 38 | 安全性 39 | 安全性 40 | 安全性 41 | 安全性 42 | 操控 43 | 操控 44 | 操控 45 | 动力 46 | 动力 47 | 动力 48 | 动力 49 | 动力 50 | 动力 51 | 动力 52 | 动力 53 | 动力 54 | 动力 55 | 动力 56 | 动力 57 | 动力 58 | 动力 59 | 动力 60 | 动力 61 | 动力 62 | 动力 63 | 动力 64 | 空间 65 | 空间 66 | 外观 67 | 安全性 68 | 价格 69 | 舒适性 70 | 舒适性 71 | 油耗 72 | 油耗 73 | 操控 74 | 外观 75 | 空间 76 | 操控 77 | 油耗 78 | 配置 79 | 外观 80 | 油耗 81 | 动力 82 | 操控 83 | 舒适性 84 | 油耗 85 | 油耗 86 | 油耗 87 | 安全性 88 | 价格 89 | 安全性 90 | 油耗 91 | 价格 92 | 油耗 93 | 动力 94 | 动力 95 | 动力 96 | 安全性 97 | 操控 98 | 配置 99 | 舒适性 100 | 舒适性 101 | 操控 102 | 空间 103 | 舒适性 104 | 动力 105 | 价格 106 | 动力 107 | 价格 108 | 安全性 109 | 价格 110 | 外观 111 | 内饰 112 | 价格 113 | 动力 114 | 配置 115 | 配置 116 | 油耗 117 | 配置 118 | 配置 119 | 配置 120 | 油耗 121 | 动力 122 | 舒适性 123 | 舒适性 124 | 价格 125 | 内饰 126 | 价格 127 | 舒适性 128 | 动力 129 | 舒适性 130 | 动力 131 | 外观 132 | 内饰 133 | 油耗 134 | 动力 135 | 油耗 136 | 内饰 137 | 外观 138 | 价格 139 | 内饰 140 | 价格 141 | 动力 142 | 动力 143 | 内饰 144 | 动力 145 | 舒适性 146 | 舒适性 147 | 动力 148 | 外观 149 | 空间 150 | 空间 151 | 操控 152 | 动力 153 | 安全性 154 | 外观 155 | 舒适性 156 | 动力 157 | 油耗 158 | 油耗 159 | 价格 160 | 动力 161 | 舒适性 162 | 油耗 163 | 动力 164 | 动力 165 | 价格 166 | 外观 167 | 动力 168 | 油耗 169 | 内饰 170 | 动力 171 | 舒适性 172 | 动力 173 | 动力 174 | 安全性 175 | 动力 176 | 动力 177 | 内饰 178 | 价格 179 | 外观 180 | 价格 181 | 油耗 182 | 外观 183 | 外观 184 | 内饰 185 | 动力 186 | 操控 187 | 舒适性 188 | 配置 189 | 安全性 190 | 舒适性 191 | 动力 192 | 舒适性 193 | 舒适性 194 | 配置 195 | 空间 196 | 动力 197 | 动力 198 | 价格 199 | 动力 200 | 安全性 201 | 动力 202 | 价格 203 | 空间 204 | 舒适性 205 | 动力 206 | 动力 207 | 动力 208 | 动力 209 | 配置 210 | 空间 211 | 油耗 212 | 空间 213 | 油耗 214 | 配置 215 | 操控 216 | 动力 217 | 动力 218 | 动力 219 | 外观 220 | 动力 221 | 动力 222 | 价格 223 | 舒适性 224 | 空间 225 | 空间 226 | 空间 227 | 内饰 228 | 内饰 229 | 动力 230 | 配置 231 | 动力 232 | 外观 233 | 动力 234 | 配置 235 | 外观 236 | 油耗 237 | 配置 238 | 价格 239 | 操控 240 | 内饰 241 | 内饰 242 | 内饰 243 | 内饰 244 | 内饰 245 | 空间 246 | 操控 247 | 价格 248 | 价格 249 | 价格 250 | 动力 251 | 动力 252 | 价格 253 | 外观 254 | 动力 255 | 动力 256 | 动力 257 | 动力 258 | 动力 259 | 动力 260 | 动力 261 | 动力 262 | 操控 263 | 动力 264 | 动力 265 | 动力 266 | 动力 267 | 动力 268 | 动力 269 | 动力 270 | 动力 271 | 油耗 272 | 油耗 273 | 动力 274 | 动力 275 | 动力 276 | 动力 277 | 动力 278 | 动力 279 | 操控 280 | 动力 281 | 油耗 282 | 动力 283 | 油耗 284 | 油耗 285 | 安全性 286 | 操控 287 | 舒适性 288 | 外观 289 | 动力 290 | 动力 291 | 动力 292 | 动力 293 | 动力 294 | 动力 295 | 动力 296 | 动力 297 | 安全性 298 | 动力 299 | 动力 300 | 舒适性 301 | 舒适性 302 | 舒适性 303 | 安全性 304 | 动力 305 | 空间 306 | 舒适性 307 | 空间 308 | 空间 309 | 内饰 310 | 内饰 311 | 舒适性 312 | 安全性 313 | 舒适性 314 | 安全性 315 | 内饰 316 | 配置 317 | 动力 318 | 油耗 319 | 舒适性 320 | 动力 321 | 动力 322 | 空间 323 | 舒适性 324 | 动力 325 | 动力 326 | 油耗 327 | 价格 328 | 价格 329 | 动力 330 | 动力 331 | 安全性 332 | 动力 333 | 安全性 334 | 配置 335 | 动力 336 | 舒适性 337 | 动力 338 | 空间 339 | 空间 340 | 动力 341 | 油耗 342 | 舒适性 343 | 油耗 344 | 操控 345 | 动力 346 | 动力 347 | 动力 348 | 动力 349 | 动力 350 | 价格 351 | 动力 352 | 动力 353 | 价格 354 | 动力 355 | 价格 356 | 舒适性 357 | 动力 358 | 动力 359 | 操控 360 | 动力 361 | 操控 362 | 价格 363 | 动力 364 | 动力 365 | 操控 366 | 价格 367 | 油耗 368 | 动力 369 | 油耗 370 | 舒适性 371 | 油耗 372 | 操控 373 | 动力 374 | 安全性 375 | 动力 376 | 价格 377 | 舒适性 378 | 价格 379 | 动力 380 | 价格 381 | 动力 382 | 价格 383 | 动力 384 | 内饰 385 | 动力 386 | 价格 387 | 油耗 388 | 价格 389 | 动力 390 | 外观 391 | 油耗 392 | 价格 393 | 动力 394 | 操控 395 | 油耗 396 | 价格 397 | 油耗 398 | 动力 399 | 价格 400 | 油耗 401 | 油耗 402 | 油耗 403 | 内饰 404 | 油耗 405 | 油耗 406 | 油耗 407 | 价格 408 | 动力 409 | 操控 410 | 安全性 411 | 操控 412 | 配置 413 | 动力 414 | 价格 415 | 价格 416 | 动力 417 | 动力 418 | 油耗 419 | 油耗 420 | 安全性 421 | 配置 422 | 安全性 423 | 油耗 424 | 动力 425 | 油耗 426 | 价格 427 | 动力 428 | 动力 429 | 内饰 430 | 安全性 431 | 油耗 432 | 价格 433 | 动力 434 | 动力 435 | 安全性 436 | 操控 437 | 油耗 438 | 价格 439 | 价格 440 | 油耗 441 | 油耗 442 | 外观 443 | 动力 444 | 操控 445 | 价格 446 | 外观 447 | 动力 448 | 舒适性 449 | 舒适性 450 | 价格 451 | 外观 452 | 舒适性 453 | 价格 454 | 动力 455 | 外观 456 | 油耗 457 | 舒适性 458 | 舒适性 459 | 操控 460 | 油耗 461 | 价格 462 | 动力 463 | 安全性 464 | 操控 465 | 动力 466 | 动力 467 | 价格 468 | 舒适性 469 | 动力 470 | 操控 471 | 动力 472 | 动力 473 | 内饰 474 | 油耗 475 | 外观 476 | 油耗 477 | 空间 478 | 动力 479 | 油耗 480 | 舒适性 481 | 价格 482 | 动力 483 | 动力 484 | 外观 485 | 动力 486 | 动力 487 | 动力 488 | 外观 489 | 动力 490 | 价格 491 | 动力 492 | 油耗 493 | 动力 494 | 动力 495 | 舒适性 496 | 外观 497 | 价格 498 | 内饰 499 | 动力 500 | 价格 501 | 外观 502 | 动力 503 | 价格 504 | 内饰 505 | 动力 506 | 动力 507 | 动力 508 | 空间 509 | 动力 510 | 价格 511 | 配置 512 | 舒适性 513 | 油耗 514 | 动力 515 | 动力 516 | 油耗 517 | 配置 518 | 价格 519 | 油耗 520 | 油耗 521 | 动力 522 | 外观 523 | 价格 524 | 配置 525 | 动力 526 | 舒适性 527 | 动力 528 | 价格 529 | 价格 530 | 价格 531 | 舒适性 532 | 价格 533 | 舒适性 534 | 价格 535 | 价格 536 | 动力 537 | 动力 538 | 价格 539 | 动力 540 | 操控 541 | 配置 542 | 动力 543 | 价格 544 | 空间 545 | 舒适性 546 | 舒适性 547 | 内饰 548 | 操控 549 | 内饰 550 | 空间 551 | 价格 552 | 价格 553 | 外观 554 | 舒适性 555 | 油耗 556 | 外观 557 | 价格 558 | 配置 559 | 舒适性 560 | 油耗 561 | 动力 562 | 油耗 563 | 配置 564 | 动力 565 | 油耗 566 | 价格 567 | 动力 568 | 空间 569 | 配置 570 | 油耗 571 | 动力 572 | 价格 573 | 动力 574 | 动力 575 | 动力 576 | 舒适性 577 | 动力 578 | 油耗 579 | 油耗 580 | 动力 581 | 动力 582 | 动力 583 | 价格 584 | 油耗 585 | 安全性 586 | 空间 587 | 安全性 588 | 舒适性 589 | 动力 590 | 内饰 591 | 配置 592 | 油耗 593 | 配置 594 | 舒适性 595 | 动力 596 | 动力 597 | 动力 598 | 动力 599 | 动力 600 | 舒适性 601 | 油耗 602 | 价格 603 | 油耗 604 | 价格 605 | 空间 606 | 动力 607 | 油耗 608 | 动力 609 | 价格 610 | 油耗 611 | 动力 612 | 安全性 613 | 操控 614 | 安全性 615 | 配置 616 | 动力 617 | 动力 618 | 油耗 619 | 舒适性 620 | 操控 621 | 动力 622 | 外观 623 | 动力 624 | 配置 625 | 内饰 626 | 配置 627 | 动力 628 | 空间 629 | 价格 630 | 价格 631 | 动力 632 | 安全性 633 | 油耗 634 | 油耗 635 | 外观 636 | 动力 637 | 操控 638 | 价格 639 | 油耗 640 | 安全性 641 | 动力 642 | 配置 643 | 油耗 644 | 价格 645 | 油耗 646 | 动力 647 | 动力 648 | 动力 649 | 舒适性 650 | 内饰 651 | 动力 652 | 动力 653 | 动力 654 | 操控 655 | 操控 656 | 动力 657 | 油耗 658 | 内饰 659 | 空间 660 | 油耗 661 | 动力 662 | 动力 663 | 动力 664 | 动力 665 | 动力 666 | 油耗 667 | 空间 668 | 安全性 669 | 配置 670 | 动力 671 | 动力 672 | 空间 673 | 动力 674 | 空间 675 | 价格 676 | 动力 677 | 动力 678 | 动力 679 | 操控 680 | 外观 681 | 动力 682 | 动力 683 | 动力 684 | 动力 685 | 动力 686 | 动力 687 | 动力 688 | 动力 689 | 动力 690 | 操控 691 | 操控 692 | 动力 693 | 动力 694 | 外观 695 | 油耗 696 | 动力 697 | 油耗 698 | 油耗 699 | 动力 700 | 操控 701 | 空间 702 | 油耗 703 | 价格 704 | 价格 705 | 价格 706 | 价格 707 | 动力 708 | 动力 709 | 内饰 710 | 动力 711 | 价格 712 | 价格 713 | 动力 714 | 操控 715 | 动力 716 | 动力 717 | 油耗 718 | 油耗 719 | 油耗 720 | 动力 721 | 内饰 722 | 动力 723 | 空间 724 | 操控 725 | 操控 726 | 舒适性 727 | 舒适性 728 | 价格 729 | 内饰 730 | 动力 731 | 动力 732 | 操控 733 | 内饰 734 | 外观 735 | 动力 736 | 动力 737 | 安全性 738 | 安全性 739 | 动力 740 | 舒适性 741 | 舒适性 742 | 空间 743 | 安全性 744 | 价格 745 | 舒适性 746 | 安全性 747 | 舒适性 748 | 动力 749 | 价格 750 | 安全性 751 | 舒适性 752 | 动力 753 | 动力 754 | 动力 755 | 油耗 756 | 动力 757 | 空间 758 | 内饰 759 | 舒适性 760 | 动力 761 | 动力 762 | 油耗 763 | 安全性 764 | 安全性 765 | 舒适性 766 | 动力 767 | 配置 768 | 动力 769 | 安全性 770 | 动力 771 | 舒适性 772 | 舒适性 773 | 舒适性 774 | 安全性 775 | 油耗 776 | 外观 777 | 操控 778 | 内饰 779 | 空间 780 | 舒适性 781 | 动力 782 | 内饰 783 | 价格 784 | 动力 785 | 安全性 786 | 价格 787 | 外观 788 | 油耗 789 | 动力 790 | 油耗 791 | 油耗 792 | 油耗 793 | 配置 794 | 动力 795 | 外观 796 | 外观 797 | 价格 798 | 舒适性 799 | 价格 800 | 价格 801 | 价格 802 | 空间 803 | 舒适性 804 | 动力 805 | 操控 806 | 动力 807 | 舒适性 808 | 舒适性 809 | 油耗 810 | 操控 811 | 价格 812 | 价格 813 | 油耗 814 | 价格 815 | 动力 816 | 价格 817 | 动力 818 | 舒适性 819 | 操控 820 | 外观 821 | 价格 822 | 外观 823 | 操控 824 | 动力 825 | 安全性 826 | 动力 827 | 价格 828 | 动力 829 | 油耗 830 | 油耗 831 | 安全性 832 | 外观 833 | 动力 834 | 空间 835 | 外观 836 | 操控 837 | 动力 838 | 配置 839 | 操控 840 | 舒适性 841 | 配置 842 | 舒适性 843 | 操控 844 | 动力 845 | 动力 846 | 动力 847 | 动力 848 | 动力 849 | 动力 850 | 油耗 851 | 外观 852 | 动力 853 | 配置 854 | 动力 855 | 操控 856 | 动力 857 | 安全性 858 | 操控 859 | 舒适性 860 | 动力 861 | 舒适性 862 | 油耗 863 | 价格 864 | 价格 865 | 操控 866 | 配置 867 | 动力 868 | 动力 869 | 配置 870 | 配置 871 | 动力 872 | 动力 873 | 价格 874 | 价格 875 | 价格 876 | 动力 877 | 油耗 878 | 内饰 879 | 安全性 880 | 价格 881 | 动力 882 | 内饰 883 | 舒适性 884 | 价格 885 | 安全性 886 | 油耗 887 | 动力 888 | 动力 889 | 舒适性 890 | 配置 891 | 配置 892 | 动力 893 | 动力 894 | 价格 895 | 动力 896 | 价格 897 | 动力 898 | 配置 899 | 价格 900 | 动力 901 | 价格 902 | 配置 903 | 动力 904 | 价格 905 | 内饰 906 | 动力 907 | 操控 908 | 配置 909 | 动力 910 | 油耗 911 | 舒适性 912 | 油耗 913 | 外观 914 | 操控 915 | 动力 916 | 动力 917 | 舒适性 918 | 动力 919 | 安全性 920 | 舒适性 921 | 操控 922 | 油耗 923 | 油耗 924 | 外观 925 | 操控 926 | 内饰 927 | 配置 928 | 空间 929 | 油耗 930 | 动力 931 | 舒适性 932 | 配置 933 | 动力 934 | 安全性 935 | 动力 936 | 油耗 937 | 操控 938 | 内饰 939 | 价格 940 | 操控 941 | 动力 942 | 动力 943 | 操控 944 | 空间 945 | 价格 946 | 价格 947 | 操控 948 | 舒适性 949 | 动力 950 | 安全性 951 | 动力 952 | 外观 953 | 动力 954 | 配置 955 | 价格 956 | 操控 957 | 配置 958 | 动力 959 | 价格 960 | 安全性 961 | 价格 962 | 价格 963 | 动力 964 | 价格 965 | 价格 966 | 外观 967 | 价格 968 | 价格 969 | 操控 970 | 安全性 971 | 操控 972 | 操控 973 | 操控 974 | 操控 975 | 操控 976 | 操控 977 | 空间 978 | 油耗 979 | 油耗 980 | 油耗 981 | 油耗 982 | 油耗 983 | 油耗 984 | 油耗 985 | 油耗 986 | 油耗 987 | 操控 988 | 舒适性 989 | 安全性 990 | 动力 991 | 动力 992 | 动力 993 | 动力 994 | 动力 995 | 动力 996 | 动力 997 | 动力 998 | 操控 999 | 空间 1000 | 外观 1001 | 内饰 1002 | 外观 1003 | 外观 1004 | 配置 1005 | 配置 1006 | 外观 1007 | 外观 1008 | 安全性 1009 | 安全性 1010 | 安全性 1011 | 安全性 1012 | 安全性 1013 | 安全性 1014 | 安全性 1015 | 安全性 1016 | 安全性 1017 | 安全性 1018 | 动力 1019 | 动力 1020 | 动力 1021 | 价格 1022 | 价格 1023 | 油耗 1024 | 外观 1025 | 配置 1026 | 配置 1027 | 配置 1028 | 安全性 1029 | 舒适性 1030 | 动力 1031 | 动力 1032 | 动力 1033 | 动力 1034 | 动力 1035 | 动力 1036 | 动力 1037 | 动力 1038 | 动力 1039 | 动力 1040 | 动力 1041 | 动力 1042 | 动力 1043 | 内饰 1044 | 内饰 1045 | 内饰 1046 | 内饰 1047 | 动力 1048 | 价格 1049 | 价格 1050 | 安全性 1051 | 油耗 1052 | 外观 1053 | 空间 1054 | 空间 1055 | 价格 1056 | 操控 1057 | 配置 1058 | 配置 1059 | 安全性 1060 | 油耗 1061 | 空间 1062 | 舒适性 1063 | 动力 1064 | 价格 1065 | 价格 1066 | 内饰 1067 | 内饰 1068 | 动力 1069 | 动力 1070 | 空间 1071 | 配置 1072 | 外观 1073 | 价格 1074 | 价格 1075 | 油耗 1076 | 动力 1077 | 油耗 1078 | 动力 1079 | 动力 1080 | 动力 1081 | 安全性 1082 | 操控 1083 | 安全性 1084 | 内饰 1085 | 动力 1086 | 配置 1087 | 动力 1088 | 动力 1089 | 配置 1090 | 油耗 1091 | 动力 1092 | 配置 1093 | 油耗 1094 | 动力 1095 | 动力 1096 | 动力 1097 | 外观 1098 | 动力 1099 | 动力 1100 | 动力 1101 | 配置 1102 | 安全性 1103 | 安全性 1104 | 操控 1105 | 操控 1106 | 价格 1107 | 动力 1108 | 操控 1109 | 操控 1110 | 动力 1111 | 价格 1112 | 油耗 1113 | 价格 1114 | 内饰 1115 | 舒适性 1116 | 操控 1117 | 动力 1118 | 空间 1119 | 油耗 1120 | 油耗 1121 | 舒适性 1122 | 动力 1123 | 油耗 1124 | 价格 1125 | 内饰 1126 | 外观 1127 | 动力 1128 | 油耗 1129 | 操控 1130 | 动力 1131 | 操控 1132 | 安全性 1133 | 安全性 1134 | 动力 1135 | 配置 1136 | 操控 1137 | 油耗 1138 | 动力 1139 | 动力 1140 | 动力 1141 | 油耗 1142 | 动力 1143 | 外观 1144 | 动力 1145 | 操控 1146 | 价格 1147 | 操控 1148 | 空间 1149 | 配置 1150 | 配置 1151 | 安全性 1152 | 动力 1153 | 动力 1154 | 动力 1155 | 操控 1156 | 价格 1157 | 操控 1158 | 安全性 1159 | 操控 1160 | 动力 1161 | 动力 1162 | 价格 1163 | 配置 1164 | 内饰 1165 | 空间 1166 | 油耗 1167 | 动力 1168 | 内饰 1169 | 内饰 1170 | 价格 1171 | 动力 1172 | 安全性 1173 | 动力 1174 | 动力 1175 | 安全性 1176 | 空间 1177 | 动力 1178 | 动力 1179 | 内饰 1180 | 操控 1181 | 动力 1182 | 价格 1183 | 动力 1184 | 配置 1185 | 配置 1186 | 空间 1187 | 动力 1188 | 舒适性 1189 | 舒适性 1190 | 配置 1191 | 舒适性 1192 | 内饰 1193 | 价格 1194 | 动力 1195 | 安全性 1196 | 动力 1197 | 价格 1198 | 动力 1199 | 内饰 1200 | 内饰 1201 | 油耗 1202 | 价格 1203 | 操控 1204 | 油耗 1205 | 舒适性 1206 | 油耗 1207 | 外观 1208 | 安全性 1209 | 油耗 1210 | 价格 1211 | 动力 1212 | 动力 1213 | 配置 1214 | 外观 1215 | 动力 1216 | 动力 1217 | 动力 1218 | 动力 1219 | 价格 1220 | 配置 1221 | 安全性 1222 | 操控 1223 | 动力 1224 | 配置 1225 | 内饰 1226 | 操控 1227 | 油耗 1228 | 油耗 1229 | 动力 1230 | 安全性 1231 | 舒适性 1232 | 动力 1233 | 动力 1234 | 安全性 1235 | 价格 1236 | 配置 1237 | 操控 1238 | 动力 1239 | 内饰 1240 | 操控 1241 | 动力 1242 | 动力 1243 | 动力 1244 | 配置 1245 | 安全性 1246 | 操控 1247 | 操控 1248 | 价格 1249 | 价格 1250 | 价格 1251 | 价格 1252 | 价格 1253 | 配置 1254 | 价格 1255 | 配置 1256 | 配置 1257 | 配置 1258 | 操控 1259 | 舒适性 1260 | 舒适性 1261 | 油耗 1262 | 油耗 1263 | 油耗 1264 | 油耗 1265 | 油耗 1266 | 油耗 1267 | 油耗 1268 | 舒适性 1269 | 舒适性 1270 | 内饰 1271 | 内饰 1272 | 内饰 1273 | 安全性 1274 | 安全性 1275 | 安全性 1276 | 动力 1277 | 安全性 1278 | 操控 1279 | 操控 1280 | 动力 1281 | 动力 1282 | 动力 1283 | 动力 1284 | 动力 1285 | 动力 1286 | 动力 1287 | 动力 1288 | 动力 1289 | 空间 1290 | 空间 1291 | 安全性 1292 | 动力 1293 | 舒适性 1294 | 外观 1295 | 外观 1296 | 动力 1297 | 动力 1298 | 油耗 1299 | 油耗 1300 | 动力 1301 | 油耗 1302 | 油耗 1303 | 价格 1304 | 油耗 1305 | 油耗 1306 | 动力 1307 | 油耗 1308 | 动力 1309 | 动力 1310 | 舒适性 1311 | 安全性 1312 | 价格 1313 | 舒适性 1314 | 配置 1315 | 动力 1316 | 外观 1317 | 价格 1318 | 价格 1319 | 价格 1320 | 空间 1321 | 价格 1322 | 价格 1323 | 动力 1324 | 配置 1325 | 价格 1326 | 动力 1327 | 安全性 1328 | 舒适性 1329 | 舒适性 1330 | 操控 1331 | 舒适性 1332 | 油耗 1333 | 操控 1334 | 动力 1335 | 操控 1336 | 动力 1337 | 动力 1338 | 舒适性 1339 | 油耗 1340 | 价格 1341 | 动力 1342 | 价格 1343 | 油耗 1344 | 安全性 1345 | 油耗 1346 | 动力 1347 | 操控 1348 | 动力 1349 | 舒适性 1350 | 舒适性 1351 | 动力 1352 | 价格 1353 | 动力 1354 | 舒适性 1355 | 动力 1356 | 动力 1357 | 油耗 1358 | 动力 1359 | 安全性 1360 | 动力 1361 | 油耗 1362 | 舒适性 1363 | 动力 1364 | 价格 1365 | 动力 1366 | 动力 1367 | 动力 1368 | 舒适性 1369 | 动力 1370 | 价格 1371 | 动力 1372 | 舒适性 1373 | 油耗 1374 | 安全性 1375 | 配置 1376 | 油耗 1377 | 安全性 1378 | 舒适性 1379 | 动力 1380 | 安全性 1381 | 动力 1382 | 动力 1383 | 油耗 1384 | 价格 1385 | 配置 1386 | 油耗 1387 | 配置 1388 | 油耗 1389 | 油耗 1390 | 外观 1391 | 操控 1392 | 配置 1393 | 舒适性 1394 | 价格 1395 | 价格 1396 | 动力 1397 | 价格 1398 | 价格 1399 | 油耗 1400 | 价格 1401 | 外观 1402 | 价格 1403 | 动力 1404 | 空间 1405 | 空间 1406 | 配置 1407 | 价格 1408 | 配置 1409 | 动力 1410 | 内饰 1411 | 内饰 1412 | 外观 1413 | 内饰 1414 | 内饰 1415 | 价格 1416 | 价格 1417 | 价格 1418 | 价格 1419 | 价格 1420 | 价格 1421 | 动力 1422 | 动力 1423 | 动力 1424 | 动力 1425 | 动力 1426 | 操控 1427 | 动力 1428 | 外观 1429 | 外观 1430 | 油耗 1431 | 动力 1432 | 动力 1433 | 动力 1434 | 动力 1435 | 油耗 1436 | 动力 1437 | 动力 1438 | 动力 1439 | 动力 1440 | 油耗 1441 | 动力 1442 | 油耗 1443 | 动力 1444 | 动力 1445 | 动力 1446 | 动力 1447 | 动力 1448 | 动力 1449 | 油耗 1450 | 动力 1451 | 油耗 1452 | 动力 1453 | 动力 1454 | 价格 1455 | 操控 1456 | 操控 1457 | 舒适性 1458 | 舒适性 1459 | 舒适性 1460 | 操控 1461 | 操控 1462 | 动力 1463 | 动力 1464 | 动力 1465 | 安全性 1466 | 动力 1467 | 动力 1468 | 动力 1469 | 动力 1470 | 舒适性 1471 | 舒适性 1472 | 安全性 1473 | 动力 1474 | 油耗 1475 | 动力 1476 | 配置 1477 | 舒适性 1478 | 安全性 1479 | 舒适性 1480 | 操控 1481 | 舒适性 1482 | 舒适性 1483 | 配置 1484 | 安全性 1485 | 动力 1486 | 配置 1487 | 价格 1488 | 油耗 1489 | 价格 1490 | 操控 1491 | 操控 1492 | 舒适性 1493 | 内饰 1494 | 动力 1495 | 价格 1496 | 动力 1497 | 价格 1498 | 动力 1499 | 配置 1500 | 动力 1501 | 动力 1502 | 空间 1503 | 油耗 1504 | 动力 1505 | 动力 1506 | 操控 1507 | 舒适性 1508 | 动力 1509 | 动力 1510 | 油耗 1511 | 动力 1512 | 价格 1513 | 舒适性 1514 | 动力 1515 | 价格 1516 | 动力 1517 | 动力 1518 | 动力 1519 | 价格 1520 | 价格 1521 | 舒适性 1522 | 油耗 1523 | 油耗 1524 | 内饰 1525 | 安全性 1526 | 动力 1527 | 价格 1528 | 操控 1529 | 油耗 1530 | 动力 1531 | 空间 1532 | 动力 1533 | 动力 1534 | 配置 1535 | 操控 1536 | 动力 1537 | 价格 1538 | 舒适性 1539 | 动力 1540 | 外观 1541 | 舒适性 1542 | 配置 1543 | 动力 1544 | 动力 1545 | 动力 1546 | 配置 1547 | 动力 1548 | 价格 1549 | 油耗 1550 | 价格 1551 | 动力 1552 | 动力 1553 | 油耗 1554 | 价格 1555 | 动力 1556 | 油耗 1557 | 配置 1558 | 价格 1559 | 价格 1560 | 价格 1561 | 操控 1562 | 价格 1563 | 价格 1564 | 操控 1565 | 操控 1566 | 油耗 1567 | 油耗 1568 | 动力 1569 | 内饰 1570 | 动力 1571 | 价格 1572 | 配置 1573 | 空间 1574 | 价格 1575 | 空间 1576 | 油耗 1577 | 动力 1578 | 动力 1579 | 动力 1580 | 价格 1581 | 油耗 1582 | 安全性 1583 | 价格 1584 | 油耗 1585 | 动力 1586 | 油耗 1587 | 动力 1588 | 外观 1589 | 油耗 1590 | 动力 1591 | 配置 1592 | 动力 1593 | 动力 1594 | 价格 1595 | 价格 1596 | 油耗 1597 | 价格 1598 | 空间 1599 | 价格 1600 | 油耗 1601 | 配置 1602 | 外观 1603 | 舒适性 1604 | 动力 1605 | 动力 1606 | 价格 1607 | 舒适性 1608 | 内饰 1609 | 舒适性 1610 | 内饰 1611 | 舒适性 1612 | 价格 1613 | 动力 1614 | 价格 1615 | 操控 1616 | 舒适性 1617 | 油耗 1618 | 油耗 1619 | 油耗 1620 | 价格 1621 | 操控 1622 | 油耗 1623 | 动力 1624 | 配置 1625 | 舒适性 1626 | 动力 1627 | 价格 1628 | 配置 1629 | 动力 1630 | 动力 1631 | 舒适性 1632 | 操控 1633 | 价格 1634 | 空间 1635 | 油耗 1636 | 油耗 1637 | 动力 1638 | 动力 1639 | 动力 1640 | 油耗 1641 | 油耗 1642 | 动力 1643 | 动力 1644 | 动力 1645 | 动力 1646 | 空间 1647 | 油耗 1648 | 舒适性 1649 | 动力 1650 | 动力 1651 | 操控 1652 | 动力 1653 | 价格 1654 | 安全性 1655 | 油耗 1656 | 舒适性 1657 | 内饰 1658 | 外观 1659 | 动力 1660 | 舒适性 1661 | 空间 1662 | 价格 1663 | 动力 1664 | 外观 1665 | 操控 1666 | 操控 1667 | 舒适性 1668 | 空间 1669 | 舒适性 1670 | 操控 1671 | 配置 1672 | 配置 1673 | 操控 1674 | 价格 1675 | 舒适性 1676 | 操控 1677 | 动力 1678 | 动力 1679 | 油耗 1680 | 动力 1681 | 油耗 1682 | 价格 1683 | 外观 1684 | 操控 1685 | 操控 1686 | 空间 1687 | 价格 1688 | 价格 1689 | 动力 1690 | 动力 1691 | 价格 1692 | 动力 1693 | 动力 1694 | 动力 1695 | 外观 1696 | 配置 1697 | 价格 1698 | 舒适性 1699 | 价格 1700 | 价格 1701 | 油耗 1702 | 操控 1703 | 价格 1704 | 动力 1705 | 舒适性 1706 | 舒适性 1707 | 动力 1708 | 油耗 1709 | 油耗 1710 | 舒适性 1711 | 价格 1712 | 动力 1713 | 价格 1714 | 价格 1715 | 配置 1716 | 价格 1717 | 油耗 1718 | 配置 1719 | 动力 1720 | 油耗 1721 | 配置 1722 | 配置 1723 | 安全性 1724 | 安全性 1725 | 动力 1726 | 配置 1727 | 舒适性 1728 | 动力 1729 | 油耗 1730 | 动力 1731 | 安全性 1732 | 动力 1733 | 价格 1734 | 价格 1735 | 安全性 1736 | 动力 1737 | 动力 1738 | 配置 1739 | 舒适性 1740 | 动力 1741 | 操控 1742 | 内饰 1743 | 操控 1744 | 价格 1745 | 动力 1746 | 舒适性 1747 | 安全性 1748 | 价格 1749 | 安全性 1750 | 动力 1751 | 舒适性 1752 | 安全性 1753 | 配置 1754 | 油耗 1755 | 动力 1756 | 油耗 1757 | 安全性 1758 | 动力 1759 | 动力 1760 | 安全性 1761 | 动力 1762 | 价格 1763 | 价格 1764 | 操控 1765 | 配置 1766 | 动力 1767 | 安全性 1768 | 外观 1769 | 价格 1770 | 动力 1771 | 动力 1772 | 配置 1773 | 油耗 1774 | 动力 1775 | 外观 1776 | 安全性 1777 | 安全性 1778 | 操控 1779 | 油耗 1780 | 外观 1781 | 安全性 1782 | 价格 1783 | 价格 1784 | 动力 1785 | 动力 1786 | 动力 1787 | 空间 1788 | 配置 1789 | 动力 1790 | 安全性 1791 | 配置 1792 | 外观 1793 | 内饰 1794 | 内饰 1795 | 操控 1796 | 动力 1797 | 价格 1798 | 舒适性 1799 | 动力 1800 | 内饰 1801 | 油耗 1802 | 舒适性 1803 | 操控 1804 | 空间 1805 | 安全性 1806 | 价格 1807 | 油耗 1808 | 安全性 1809 | 油耗 1810 | 价格 1811 | 内饰 1812 | 价格 1813 | 操控 1814 | 动力 1815 | 动力 1816 | 内饰 1817 | 内饰 1818 | 价格 1819 | 动力 1820 | 动力 1821 | 操控 1822 | 价格 1823 | 价格 1824 | 价格 1825 | 动力 1826 | 价格 1827 | 价格 1828 | 动力 1829 | 油耗 1830 | 动力 1831 | 动力 1832 | 价格 1833 | 动力 1834 | 价格 1835 | 价格 1836 | 动力 1837 | 外观 1838 | 价格 1839 | 操控 1840 | 油耗 1841 | 空间 1842 | 价格 1843 | 价格 1844 | 油耗 1845 | 动力 1846 | 动力 1847 | 油耗 1848 | 价格 1849 | 价格 1850 | 操控 1851 | 空间 1852 | 空间 1853 | 舒适性 1854 | 空间 1855 | 空间 1856 | 动力 1857 | 动力 1858 | 油耗 1859 | 价格 1860 | 动力 1861 | 油耗 1862 | 动力 1863 | 动力 1864 | 舒适性 1865 | 动力 1866 | 价格 1867 | 操控 1868 | 价格 1869 | 价格 1870 | 动力 1871 | 价格 1872 | 动力 1873 | 动力 1874 | 操控 1875 | 动力 1876 | 操控 1877 | 内饰 1878 | 动力 1879 | 外观 1880 | 价格 1881 | 动力 1882 | 动力 1883 | 动力 1884 | 油耗 1885 | 外观 1886 | 动力 1887 | 动力 1888 | 价格 1889 | 动力 1890 | 动力 1891 | 动力 1892 | 动力 1893 | 空间 1894 | 操控 1895 | 空间 1896 | 油耗 1897 | 油耗 1898 | 价格 1899 | 动力 1900 | 动力 1901 | 油耗 1902 | 安全性 1903 | 空间 1904 | 价格 1905 | 动力 1906 | 安全性 1907 | 动力 1908 | 舒适性 1909 | 安全性 1910 | 油耗 1911 | 外观 1912 | 油耗 1913 | 空间 1914 | 价格 1915 | 配置 1916 | 舒适性 1917 | 油耗 1918 | 动力 1919 | 舒适性 1920 | 舒适性 1921 | 动力 1922 | 安全性 1923 | 油耗 1924 | 操控 1925 | 动力 1926 | 价格 1927 | 动力 1928 | 安全性 1929 | 安全性 1930 | 价格 1931 | 操控 1932 | 油耗 1933 | 动力 1934 | 动力 1935 | 安全性 1936 | 空间 1937 | 动力 1938 | 动力 1939 | 舒适性 1940 | 动力 1941 | 动力 1942 | 安全性 1943 | 外观 1944 | 动力 1945 | 动力 1946 | 动力 1947 | 油耗 1948 | 舒适性 1949 | 油耗 1950 | 油耗 1951 | 价格 1952 | 动力 1953 | 油耗 1954 | 舒适性 1955 | 舒适性 1956 | 动力 1957 | 舒适性 1958 | 操控 1959 | 配置 1960 | 空间 1961 | 价格 1962 | 舒适性 1963 | 舒适性 1964 | 安全性 1965 | 动力 1966 | 内饰 1967 | 价格 1968 | 价格 1969 | 动力 1970 | 油耗 1971 | 动力 1972 | 动力 1973 | 动力 1974 | 安全性 1975 | 舒适性 1976 | 动力 1977 | 动力 1978 | 动力 1979 | 动力 1980 | 价格 1981 | 价格 1982 | 动力 1983 | 外观 1984 | 油耗 1985 | 动力 1986 | 价格 1987 | 价格 1988 | 舒适性 1989 | 舒适性 1990 | 动力 1991 | 动力 1992 | 动力 1993 | 油耗 1994 | 操控 1995 | 价格 1996 | 动力 1997 | 安全性 1998 | 舒适性 1999 | 动力 2000 | 外观 2001 | 舒适性 2002 | 油耗 2003 | 动力 2004 | 操控 2005 | 外观 2006 | 价格 2007 | 价格 2008 | 价格 2009 | 动力 2010 | 舒适性 2011 | 外观 2012 | 内饰 2013 | 安全性 2014 | 价格 2015 | 外观 2016 | 内饰 2017 | 外观 2018 | 价格 2019 | 舒适性 2020 | 动力 2021 | 安全性 2022 | 操控 2023 | 油耗 2024 | 动力 2025 | 安全性 2026 | 动力 2027 | 动力 2028 | 舒适性 2029 | 价格 2030 | 配置 2031 | 安全性 2032 | 动力 2033 | 舒适性 2034 | 操控 2035 | 空间 2036 | 安全性 2037 | 动力 2038 | 配置 2039 | 舒适性 2040 | 操控 2041 | 安全性 2042 | 动力 2043 | 动力 2044 | 舒适性 2045 | 安全性 2046 | 外观 2047 | 舒适性 2048 | 安全性 2049 | 操控 2050 | 动力 2051 | 安全性 2052 | 价格 2053 | 配置 2054 | 配置 2055 | 配置 2056 | 舒适性 2057 | 操控 2058 | 动力 2059 | 动力 2060 | 油耗 2061 | 操控 2062 | 配置 2063 | 价格 2064 | 动力 2065 | 操控 2066 | 安全性 2067 | 油耗 2068 | 舒适性 2069 | 油耗 2070 | 配置 2071 | 价格 2072 | 舒适性 2073 | 舒适性 2074 | 配置 2075 | 油耗 2076 | 安全性 2077 | 安全性 2078 | 动力 2079 | 动力 2080 | 油耗 2081 | 安全性 2082 | 动力 2083 | 舒适性 2084 | 外观 2085 | 动力 2086 | 动力 2087 | 操控 2088 | 油耗 2089 | 油耗 2090 | 空间 2091 | 动力 2092 | 动力 2093 | 价格 2094 | 外观 2095 | 动力 2096 | 动力 2097 | 油耗 2098 | 舒适性 2099 | 外观 2100 | 配置 2101 | 油耗 2102 | 外观 2103 | 外观 2104 | 舒适性 2105 | 配置 2106 | 动力 2107 | 安全性 2108 | 操控 2109 | 价格 2110 | 外观 2111 | 外观 2112 | 油耗 2113 | 价格 2114 | 操控 2115 | 动力 2116 | 价格 2117 | 价格 2118 | 外观 2119 | 外观 2120 | 价格 2121 | 价格 2122 | 价格 2123 | 操控 2124 | 操控 2125 | 操控 2126 | 操控 2127 | 油耗 2128 | 油耗 2129 | 油耗 2130 | 空间 2131 | 空间 2132 | 空间 2133 | 外观 2134 | 动力 2135 | 动力 2136 | 动力 2137 | 动力 2138 | 动力 2139 | 动力 2140 | 动力 2141 | 动力 2142 | 动力 2143 | 动力 2144 | 动力 2145 | 空间 2146 | 空间 2147 | 外观 2148 | 配置 2149 | 内饰 2150 | 外观 2151 | 操控 2152 | 安全性 2153 | 安全性 2154 | 安全性 2155 | 安全性 2156 | 安全性 2157 | 安全性 2158 | 动力 2159 | 安全性 2160 | 安全性 2161 | 油耗 2162 | 动力 2163 | 配置 2164 | 配置 2165 | 配置 2166 | 配置 2167 | 配置 2168 | 动力 2169 | 动力 2170 | 动力 2171 | 配置 2172 | 空间 2173 | 内饰 2174 | 外观 2175 | 舒适性 2176 | 空间 2177 | 油耗 2178 | 油耗 2179 | 油耗 2180 | 动力 2181 | 舒适性 2182 | 内饰 2183 | 安全性 2184 | 安全性 2185 | 动力 2186 | 舒适性 2187 | 操控 2188 | 配置 2189 | 油耗 2190 | 空间 2191 | 配置 2192 | 配置 2193 | 外观 2194 | 油耗 2195 | 油耗 2196 | 外观 2197 | 价格 2198 | 内饰 2199 | 动力 2200 | 动力 2201 | 内饰 2202 | 动力 2203 | 动力 2204 | 动力 2205 | 价格 2206 | 内饰 2207 | 动力 2208 | 动力 2209 | 舒适性 2210 | 动力 2211 | 油耗 2212 | 安全性 2213 | 动力 2214 | 油耗 2215 | 油耗 2216 | 动力 2217 | 操控 2218 | 配置 2219 | 配置 2220 | 价格 2221 | 动力 2222 | 价格 2223 | 配置 2224 | 外观 2225 | 配置 2226 | 配置 2227 | 内饰 2228 | 油耗 2229 | 油耗 2230 | 油耗 2231 | 内饰 2232 | 外观 2233 | 配置 2234 | 价格 2235 | 动力 2236 | 配置 2237 | 价格 2238 | 操控 2239 | 操控 2240 | 配置 2241 | 配置 2242 | 油耗 2243 | 操控 2244 | 动力 2245 | 舒适性 2246 | 空间 2247 | 内饰 2248 | 舒适性 2249 | 操控 2250 | 价格 2251 | 动力 2252 | 安全性 2253 | 油耗 2254 | 动力 2255 | 安全性 2256 | 动力 2257 | 动力 2258 | 舒适性 2259 | 动力 2260 | 动力 2261 | 舒适性 2262 | 操控 2263 | 动力 2264 | 价格 2265 | 内饰 2266 | 安全性 2267 | 价格 2268 | 舒适性 2269 | 安全性 2270 | 空间 2271 | 舒适性 2272 | 安全性 2273 | 内饰 2274 | 动力 2275 | 空间 2276 | 动力 2277 | 舒适性 2278 | 动力 2279 | 外观 2280 | 配置 2281 | 舒适性 2282 | 油耗 2283 | 安全性 2284 | 操控 2285 | 油耗 2286 | 动力 2287 | 安全性 2288 | 动力 2289 | 舒适性 2290 | 价格 2291 | 动力 2292 | 油耗 2293 | 油耗 2294 | 油耗 2295 | 油耗 2296 | 舒适性 2297 | 动力 2298 | 安全性 2299 | 外观 2300 | 外观 2301 | 动力 2302 | 安全性 2303 | 动力 2304 | 价格 2305 | 动力 2306 | 动力 2307 | 动力 2308 | 操控 2309 | 动力 2310 | 价格 2311 | 操控 2312 | 动力 2313 | 油耗 2314 | 空间 2315 | 安全性 2316 | 舒适性 2317 | 动力 2318 | 配置 2319 | 内饰 2320 | 操控 2321 | 价格 2322 | 操控 2323 | 价格 2324 | 价格 2325 | 空间 2326 | 动力 2327 | 内饰 2328 | 操控 2329 | 动力 2330 | 舒适性 2331 | 安全性 2332 | 操控 2333 | 安全性 2334 | 空间 2335 | 外观 2336 | 动力 2337 | 动力 2338 | 舒适性 2339 | 价格 2340 | 舒适性 2341 | 动力 2342 | 动力 2343 | 内饰 2344 | 配置 2345 | 动力 2346 | 外观 2347 | 价格 2348 | 油耗 2349 | 动力 2350 | 舒适性 2351 | 油耗 2352 | 操控 2353 | 配置 2354 | 油耗 2355 | 价格 2356 | 配置 2357 | 动力 2358 | 动力 2359 | 安全性 2360 | 动力 2361 | 操控 2362 | 配置 2363 | 动力 2364 | 安全性 2365 | -------------------------------------------------------------------------------- /data/test_x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/test_x.npy -------------------------------------------------------------------------------- /data/test_xs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/test_xs.npy -------------------------------------------------------------------------------- /data/test_y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/test_y.npy -------------------------------------------------------------------------------- /data/test_ys.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/test_ys.npy -------------------------------------------------------------------------------- /data/train_x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/train_x.npy -------------------------------------------------------------------------------- /data/train_xs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/train_xs.npy -------------------------------------------------------------------------------- /data/train_y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/train_y.npy -------------------------------------------------------------------------------- /data/train_ys.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/data/train_ys.npy -------------------------------------------------------------------------------- /doc_classfier_bert.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import numpy as np 3 | import pandas as pd 4 | import pickle 5 | import time 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 7 | from sklearn.linear_model import LogisticRegression 8 | from doc_textLoad import * 9 | from sklearn.metrics import classification_report 10 | from sklearn.naive_bayes import MultinomialNB 11 | from sklearn import tree 12 | from sklearn import svm 13 | from sklearn.ensemble import RandomForestClassifier 14 | from sklearn.model_selection import cross_val_score 15 | from service.client import BertClient 16 | 17 | classfier = 'LR' 18 | feature = 'Tfidf' 19 | folds = 6 20 | def main(): 21 | # 汽车观点提取,分过词的数据 22 | # df_train = pd.read_csv('./data/ccf_car_train.csv') 23 | # df_test = pd.read_csv('./data/ccf_car_test.csv') 24 | # train = df_train['word_seg'] 25 | # test_x = df_test['content'] 26 | # target = df_train['class'] 27 | # train = np.array(train) 28 | # target = np.array(target) 29 | # list = [] 30 | # for i in target: 31 | # list.append(i) 32 | bc = BertClient() 33 | train = np.load('./data/train_x.npy') 34 | train = list(train) 35 | train = bc.encode(train) 36 | target = np.load('./data/train_y.npy') 37 | test_x = np.load('./data/test_x.npy') 38 | test_y = np.load('./data/test_y.npy') 39 | 40 | test_x = list(test_x) 41 | test_x = bc.encode(test_x) 42 | # 分为训练集和测试的时候加上 43 | # train, test_x, target, test_y = train_test_split(train, target, test_size = 0.15, random_state = 0) 44 | # np.savetxt('./data/train_x', train_X, fmt='%s') 45 | # np.savetxt('./data/train_y', train_y, fmt='%s') 46 | # np.savetxt('./data/test_X', test_X, fmt='%s') 47 | # np.savetxt('./data/test_y', test_y, fmt='%s') 48 | # df_test = pd.read_csv(test_file) 49 | # df_train = df_train.drop(['article'], axis=1) 50 | # df_test = df_test.drop(['article'], axis=1) 51 | # ngram_range:tuple (min_n, max_n) 要提取的不同n-gram的n值范围的下边界和上边界。 将使用n的所有值,使得min_n <= n <= max_n。 52 | # max_features: 构建一个词汇表,该词汇表仅考虑语料库中按术语频率排序的最高max_features 53 | # if feature == 'Count': 54 | # vectoriser = CountVectorizer(ngram_range=(1, 2), min_df = 3) 55 | # elif feature == 'Tfidf': 56 | # vectoriser = TfidfVectorizer(ngram_range=(1, 5), min_df = 3, max_df = 0.7) 57 | # # 构建特征,先训练 58 | # vectoriser.fit(train) 59 | # 训练完进行归一化 总共有315503个词,过滤小于3的,剩下23546, max_df 貌似没用 60 | # (7957, 2082), (1990, 2082) type:crs_matrix 61 | # train_X = vectoriser.transform(train) 62 | # test_X = vectoriser.transform(test_x) 63 | # y_train = df_train['class'] - 1 64 | # train_X = np.array(train_X.data).reshape(train_X.shape) 65 | 66 | # 开始构建分类器 67 | if classfier == 'LR': 68 | ''' 69 | c:正则化系数λ的倒数,float类型,默认为1.0。必须是正浮点型数。像SVM一样,越小的数值表示越强的正则化 70 | ''' 71 | rg = LogisticRegression(C=1) 72 | rg.fit(train, target) 73 | y_pred = rg.predict(test_x) 74 | # elif classfier == 'NB': 75 | # # 使用默认的配置对分类器进行初始化。 76 | # mnb_count = MultinomialNB(alpha=0.2) 77 | # # 使用朴素贝叶斯分类器,对CountVectorizer(不去除停用词)后的训练样本进行参数学习。 78 | # mnb_count.fit(train_X, target) 79 | # y_pred = mnb_count.predict(test_X) 80 | # 81 | # elif classfier =='tree': 82 | # DT = tree.DecisionTreeClassifier() 83 | # DT.fit(train_X, target) 84 | # y_pred = DT.predict(test_X) 85 | ''' 86 | kernel='linear'时,为线性核,C越大分类效果越好,但有可能会过拟合(defaul C=1)。 87 | 88 |    kernel='rbf'时(default),为高斯核,gamma值越小,分类界面越连续;gamma值越大,分类界面越“散”,分类效果越好,但有可能会过拟合。 89 | ''' 90 | # C=0.1 准确率高召回率低 C = 0.8 91 | # elif classfier =='RT': 92 | # sv = RandomForestClassifier(n_estimators=400) 93 | # sv.fit(train_X, target) 94 | # # y_hat = sv.predict(train_X) 95 | # y_pred = sv.predict(test_X) 96 | # scores = cross_val_score(sv, train_X, target, cv=5) 97 | # print(scores) 98 | # 从sklearn.metrics 导入 classification_report。 99 | 100 | # 输出更加详细的其他评价分类性能的指标。 101 | print('classifier is : ' + classfier + '\tFeature is : ' + feature) 102 | print(classification_report(test_y, y_pred)) 103 | # print(classification_report(target, y_hat)) 104 | 105 | # test 106 | # df_test['subject'] = y_pred 107 | # df_result = df_test.loc[:, ['content_id', 'subject']] 108 | # print('last') 109 | # df_result.to_csv('./data/result.csv', index=False) 110 | 111 | # Second apporache 112 | # dataLoad = TextLoader('./data/', batch_size=128) 113 | # dataLoad_test = Loader_test('./data/', batch_size=1) 114 | # # 普通统计CountVectorizer提取特征向量 115 | # vectoriser = CountVectorizer(ngram_range=(1, 2), max_df=.9, min_df=3, max_features=10000) 116 | # rg = LogisticRegression(C=4, dual=True) 117 | # 118 | # # 训练得到文本的特征表示 119 | # filename = './data/vectoriser.sav' 120 | # if not os.path.exists(filename): 121 | # for i in range(dataLoad.num_batches): 122 | # x, y = dataLoad.next_batch() 123 | # vectoriser.fit(x) 124 | # pickle.dump(vectoriser, open(filename, 'wb')) 125 | # else: 126 | # with open(filename, 'rb') as f: 127 | # vectoriser = pickle.load(f) 128 | # 129 | # # # 训练得到逻辑回归模型 130 | # model_name = './data/model.sav' 131 | # if not os.path.exists(model_name): 132 | # for i in range(dataLoad.num_batches): 133 | # dataLoad.pointer = 0 134 | # x, y = dataLoad.next_batch() 135 | # # 训练完进行归一化 136 | # x_train = vectoriser.transform(x) 137 | # y_train = y - 1 138 | # # 开始构建分类器 139 | # rg.fit(x_train, y_train) 140 | # pickle.dump(rg, open(model_name, 'wb')) 141 | # else: 142 | # with open(model_name, 'rb') as f: 143 | # rg = pickle.load(f) 144 | # # test 145 | # x_test = vectoriser.transform(dataLoad_test.df_test['word_seg']) 146 | # y_test = rg.predict(x_test) 147 | # dataLoad_test.df_test['class'] = y_test.to_list() 148 | # dataLoad_test.df_test['class'] += 1 149 | # df_result = dataLoad_test.df_test.loc[:, ['id', 'class']] 150 | # # print('last') 151 | # df_result.to_csv('./data/result.csv', index=False) 152 | # for i in range(dataLoad_test.num_batches): 153 | # x = dataLoad_test.next_batch() 154 | # x_test = vectoriser.transform(x) 155 | # y_test = rg.predict(x_test) 156 | # dataLoad_test.df_test['class'] = y_test.to_list() 157 | # dataLoad_test.df_test['class'] += 1 158 | # df_result = dataLoad_test.df_test.loc[:, ['id', 'class']] 159 | # df_result.to_csv('./data/result.csv', index=False) 160 | 161 | # Third 162 | # start = time.time() 163 | # file_path = './data/test_set.csv' # 要拆分文件的位置 164 | # reader = pd.read_csv(file_path, chunksize=20000) 165 | # count = 0 166 | # for chunk in reader: 167 | # print('save test_set%s.csv' % count) 168 | # chunk.to_csv('test_set' + str(count) + '.csv', index=0) 169 | # use = time.time() - start 170 | # print('{:.0f}m {:.0f}s ...'.format(use // 60, use % 60)) 171 | # count += 1 172 | # 读取大文件,上面使用 173 | # df_train = read_files(6, 'train_set') 174 | # df_test = read_files(6, 'test_set') 175 | # df_train.drop(columns=['article', 'id'], inplace=True) 176 | # df_test.drop(columns=['article'], inplace=True) 177 | # vectorizer = TfidfVectorizer(ngram_range=(1, 2), min_df=3, max_df=0.9) 178 | # vectorizer.fit(df_train['word_seg']) 179 | # x_train = vectorizer.transform(df_train['word_seg']) 180 | # x_test = vectorizer.transform(df_test['word_seg']) 181 | # y_train = df_train['class'] - 1 182 | 183 | if __name__ == '__main__': 184 | main() 185 | -------------------------------------------------------------------------------- /doc_textLoad.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import os 4 | import codecs 5 | import collections 6 | from six.moves import cPickle 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import re 11 | import itertools 12 | ''' 13 | 经常遇到在Python程序运行中得到了一些字符串、列表、字典等数据, 14 | 想要长久的保存下来,方便以后使用,而不是简单的放入内存中关机断电就丢失数据。 15 | ''' 16 | 17 | class TextLoader(): 18 | def __init__(self, data_dir, batch_size): 19 | self.data_dir = data_dir 20 | self.batch_size = batch_size 21 | 22 | train_file = os.path.join(data_dir, "train_set.csv") 23 | # test_file = os.path.join(data_dir, "test_set.csv") 24 | 25 | self.preprocess(train_file) 26 | self.create_batches() 27 | self.reset_batch_pointer() 28 | 29 | def preprocess(self, train_file): 30 | 31 | f = open(train_file) 32 | reader = pd.read_csv(f, sep=',', iterator=True) 33 | loop = True 34 | chunkSize = 10000 35 | chunks = [] 36 | while loop: 37 | try: 38 | chunk = reader.get_chunk(chunkSize) 39 | chunks.append(chunk) 40 | except StopIteration: 41 | loop = False 42 | print("Iteration is stopped.") 43 | df_train = pd.concat(chunks, ignore_index=True) 44 | 45 | 46 | self.train_x = df_train['word_seg'] 47 | self.train_y = df_train['class'] 48 | # self.test = df_test 49 | # 构造语言对,前t-1个词作为输入,t作为label 50 | def create_batches(self): 51 | self.num_batches = int(len(self.train_x) / self.batch_size) 52 | if self.num_batches == 0: 53 | assert False, "Not enough data. Make seq_length and batch_size small." 54 | xdata = np.array(self.train_x[:self.num_batches * self.batch_size]) 55 | ydata = np.array(self.train_y[:self.num_batches * self.batch_size]) 56 | # 直接分成(1464×128) 57 | self.x_batches = np.split(xdata, self.num_batches, 0) 58 | self.y_batches = np.split(ydata, self.num_batches, 0) 59 | 60 | def next_batch(self): 61 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 62 | self.pointer += 1 63 | return x, y 64 | 65 | def reset_batch_pointer(self): 66 | self.pointer = 0 67 | 68 | class Loader_test(): 69 | def __init__(self, data_dir, batch_size): 70 | self.data_dir = data_dir 71 | self.batch_size = batch_size 72 | 73 | # train_file = os.path.join(data_dir, "train_set.csv") 74 | test_file = os.path.join(data_dir, "test_set.csv") 75 | 76 | self.preprocess(test_file) 77 | self.create_batches() 78 | self.reset_batch_pointer() 79 | 80 | def preprocess(self, test_file): 81 | 82 | f = open(test_file) 83 | reader = pd.read_csv(f, sep=',', iterator=True) 84 | loop = True 85 | chunkSize = 10000 86 | chunks = [] 87 | while loop: 88 | try: 89 | chunk = reader.get_chunk(chunkSize) 90 | chunks.append(chunk) 91 | except StopIteration: 92 | loop = False 93 | print("Iteration is stopped.") 94 | self.df_test = pd.concat(chunks, ignore_index=True) 95 | 96 | self.test_x = self.df_test['word_seg'] 97 | # self.train_y = df_train['class'] 98 | # self.test = df_test 99 | # 构造语言对,前t-1个词作为输入,t作为label 100 | def create_batches(self): 101 | self.num_batches = int(len(self.test_x) / self.batch_size) 102 | if self.num_batches == 0: 103 | assert False, "Not enough data. Make seq_length and batch_size small." 104 | xdata = np.array(self.test_x[:self.num_batches * self.batch_size]) 105 | # ydata = np.array(self.train_y[:self.num_batches * self.batch_size]) 106 | # 直接分成(1464×128) 107 | self.x_batches = np.split(xdata, self.num_batches, 0) 108 | # self.y_batches = np.split(ydata, self.num_batches, 0) 109 | 110 | def next_batch(self): 111 | x = self.x_batches[self.pointer] 112 | self.pointer += 1 113 | return x 114 | 115 | def reset_batch_pointer(self): 116 | self.pointer = 0 117 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.12.0-gpu-py3 2 | COPY ./ /app 3 | COPY ./docker/entrypoint.sh /app 4 | WORKDIR /app 5 | RUN pip install -r requirements.gpu.txt 6 | ENTRYPOINT ["/app/entrypoint.sh"] 7 | CMD [] 8 | 9 | -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | num_worker=$1 3 | python app.py -num_worker=${num_worker} -model_dir /model 4 | -------------------------------------------------------------------------------- /gpu_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Han Xiao 4 | 5 | import os 6 | from datetime import datetime 7 | from enum import Enum 8 | 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 10 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 11 | 12 | IGNORE_PATTERNS = ('data', '*.pyc', 'CVS', '.git', 'tmp', '.svn', '__pycache__', '.gitignore', '.*.yaml') 13 | MODEL_ID = datetime.now().strftime("%m%d-%H%M%S") + ( 14 | os.environ['suffix_model_id'] if 'suffix_model_id' in os.environ else '') 15 | APP_NAME = 'bert' 16 | 17 | 18 | class SummaryType(Enum): 19 | SCALAR = 1 20 | HISTOGRAM = 2 21 | SAMPLED = 3 22 | 23 | 24 | class ModeKeys(Enum): 25 | TRAIN = 1 26 | EVAL = 2 27 | INFER = 3 28 | INTERACT = 4 29 | INIT_LAW_EMBED = 5 30 | BOTTLENECK = 6 31 | COMPETITION = 7 32 | ENSEMBLE = 8 33 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def set_logger(context): 5 | logger = logging.getLogger(context) 6 | logger.setLevel(logging.INFO) 7 | formatter = logging.Formatter( 8 | '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt= 9 | '%m-%d %H:%M:%S') 10 | console_handler = logging.StreamHandler() 11 | console_handler.setLevel(logging.INFO) 12 | console_handler.setFormatter(formatter) 13 | logger.handlers = [] 14 | logger.addHandler(console_handler) 15 | return logger 16 | -------------------------------------------------------------------------------- /helper_text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/helper_text/__init__.py -------------------------------------------------------------------------------- /helper_text/cnews_group.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | 将文本整合到 train、test、val 三个文件中 6 | """ 7 | 8 | import os 9 | 10 | def _read_file(filename): 11 | """读取一个文件并转换为一行""" 12 | with open(filename, 'r', encoding='utf-8') as f: 13 | return f.read().replace('\n', '').replace('\t', '').replace('\u3000', '') 14 | 15 | def save_file(dirname): 16 | """ 17 | 将多个文件整合并存到3个文件中 18 | dirname: 原数据目录 19 | 文件内容格式: 类别\t内容 20 | """ 21 | f_train = open('data/cnews/cnews.train.txt', 'w', encoding='utf-8') 22 | f_test = open('data/cnews/cnews.test.txt', 'w', encoding='utf-8') 23 | f_val = open('data/cnews/cnews.val.txt', 'w', encoding='utf-8') 24 | for category in os.listdir(dirname): # 分类目录 25 | cat_dir = os.path.join(dirname, category) 26 | if not os.path.isdir(cat_dir): 27 | continue 28 | files = os.listdir(cat_dir) 29 | count = 0 30 | for cur_file in files: 31 | filename = os.path.join(cat_dir, cur_file) 32 | content = _read_file(filename) 33 | if count < 5000: 34 | f_train.write(category + '\t' + content + '\n') 35 | elif count < 6000: 36 | f_test.write(category + '\t' + content + '\n') 37 | else: 38 | f_val.write(category + '\t' + content + '\n') 39 | count += 1 40 | 41 | print('Finished:', category) 42 | 43 | f_train.close() 44 | f_test.close() 45 | f_val.close() 46 | 47 | 48 | if __name__ == '__main__': 49 | save_file('data/thucnews') 50 | print(len(open('data/cnews/cnews.train.txt', 'r', encoding='utf-8').readlines())) 51 | print(len(open('data/cnews/cnews.test.txt', 'r', encoding='utf-8').readlines())) 52 | print(len(open('data/cnews/cnews.val.txt', 'r', encoding='utf-8').readlines())) 53 | -------------------------------------------------------------------------------- /helper_text/copy_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # copy MAXCOUNT files from each directory 4 | 5 | MAXCOUNT=6500 6 | 7 | for category in $( ls THUCNews); do 8 | echo item: $category 9 | 10 | dir=THUCNews/$category 11 | newdir=data/thucnews/$category 12 | if [ -d $newdir ]; then 13 | rm -rf $newdir 14 | mkdir $newdir 15 | fi 16 | 17 | COUNTER=1 18 | for i in $(ls $dir); do 19 | cp $dir/$i $newdir 20 | if [ $COUNTER -ge $MAXCOUNT ] 21 | then 22 | echo finished 23 | break 24 | fi 25 | let COUNTER=COUNTER+1 26 | done 27 | 28 | done 29 | -------------------------------------------------------------------------------- /images/acc_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/images/acc_loss.png -------------------------------------------------------------------------------- /images/acc_loss_rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/images/acc_loss_rnn.png -------------------------------------------------------------------------------- /images/cnn_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/images/cnn_architecture.png -------------------------------------------------------------------------------- /images/rnn_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/images/rnn_architecture.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | import tensorflow as tf 7 | import tensorflow.contrib.keras as kr 8 | 9 | from cnn_model import TCNNConfig, TextCNN 10 | from data.cnews_loader import read_category, read_vocab 11 | 12 | try: 13 | bool(type(unicode)) 14 | except NameError: 15 | unicode = str 16 | 17 | base_dir = 'data/cnews' 18 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 19 | 20 | save_dir = 'checkpoints/textcnn' 21 | save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径 22 | 23 | 24 | class CnnModel: 25 | def __init__(self): 26 | self.config = TCNNConfig() 27 | self.categories, self.cat_to_id = read_category() 28 | self.words, self.word_to_id = read_vocab(vocab_dir) 29 | self.config.vocab_size = len(self.words) 30 | self.model = TextCNN(self.config) 31 | 32 | self.session = tf.Session() 33 | self.session.run(tf.global_variables_initializer()) 34 | saver = tf.train.Saver() 35 | saver.restore(sess=self.session, save_path=save_path) # 读取保存的模型 36 | 37 | def predict(self, message): 38 | # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行 39 | content = unicode(message) 40 | data = [self.word_to_id[x] for x in content if x in self.word_to_id] 41 | 42 | feed_dict = { 43 | self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length), 44 | self.model.keep_prob: 1.0 45 | } 46 | 47 | y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict) 48 | return self.categories[y_pred_cls[0]] 49 | 50 | 51 | if __name__ == '__main__': 52 | cnn_model = CnnModel() 53 | test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机', 54 | '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00'] 55 | for i in test_demo: 56 | print(cnn_model.predict(i)) 57 | -------------------------------------------------------------------------------- /requirements.client.txt: -------------------------------------------------------------------------------- 1 | # client-side requirements, pretty light-weight right? 2 | numpy 3 | pyzmq >= 17.1.0 # python zmq -------------------------------------------------------------------------------- /requirements.gpu.txt: -------------------------------------------------------------------------------- 1 | # server-side requirements 2 | #tensorflow >= 1.11.0 # CPU Version of TensorFlow. 3 | tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 4 | GPUtil >= 1.3.0 # no need if you dont have GPU 5 | pyzmq >= 17.1.0 # python zmq 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | scipy 3 | numpy -------------------------------------------------------------------------------- /rnn_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | from data.cnews_loader import attention 6 | # from service.client import BertClient 7 | 8 | class TRNNConfig(object): 9 | """RNN配置参数""" 10 | 11 | # 模型参数 12 | embedding_dim = 768 # 词向量维度 13 | # embedding_dim = 200 14 | seq_length = 128 # 序列长度 15 | num_classes = 10 # 类别数 16 | vocab_size = 5000 # 词汇表达小 17 | 18 | num_layers= 1 # 隐藏层层数 19 | hidden_dim = 128 # 隐藏层神经元 20 | rnn = 'gru' # lstm 或 gru 21 | 22 | attention_dim = 50 23 | l2_reg_lambda = 0.01 24 | 25 | dropout_keep_prob = 0.5 # dropout保留比例 26 | learning_rate = 1e-3 # 学习率 27 | 28 | batch_size = 128 # 每批训练大小 29 | num_epochs = 20 # 总迭代轮次 30 | 31 | print_per_batch = 100 # 每多少轮输出一次结果 32 | save_per_batch = 20 # 每多少轮存入tensorboard 33 | 34 | 35 | class TextRNN(object): 36 | """文本分类,RNN模型""" 37 | def __init__(self, config): 38 | self.config = config 39 | 40 | # 三个待输入的数据 41 | self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x') 42 | # 词向量 43 | # self.input_x = tf.placeholder(tf.float32, [None, self.config.seq_length, self.config.embedding_dim], name='input_x') 44 | # 句向量 45 | # self.input_x = tf.placeholder(tf.float32, [None, self.config.embedding_dim], name='input_x') 46 | self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y') 47 | self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 48 | 49 | self.rnn() 50 | 51 | def rnn(self): 52 | """rnn模型""" 53 | 54 | def lstm_cell(): # lstm核 55 | return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True) 56 | 57 | def gru_cell(): # gru核 58 | return tf.contrib.rnn.GRUCell(self.config.hidden_dim) 59 | 60 | def dropout(): # 为每一个rnn核后面加一个dropout层 61 | if (self.config.rnn == 'lstm'): 62 | cell = lstm_cell() 63 | else: 64 | cell = gru_cell() 65 | return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob) 66 | 67 | # 词向量映射 68 | with tf.device('/cpu:0'): 69 | embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim]) 70 | embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 71 | 72 | with tf.name_scope("rnn"): 73 | # 多层rnn网络 74 | cells = [dropout() for _ in range(self.config.num_layers)] 75 | # cell = dropout() 76 | rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) 77 | # (batch_size, num_step, embeddings) (b, 80, 128) 78 | _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32) 79 | 80 | # new 81 | # output = attention(_outputs, self.config.attention_dim, self.config.l2_reg_lambda) 82 | # (b, 128) 83 | last = _outputs[:, -1, :] # 取最后一个时序输出作为结果(b, 128) 84 | 85 | with tf.name_scope("score"): 86 | # bc = BertClient() 87 | # self.input (128) 88 | # output = bc.encode(self.input_x) # (batch_size, 768) 89 | # 全连接层,后面接dropout以及relu激活 90 | fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1') 91 | fc = tf.contrib.layers.dropout(fc, self.keep_prob) 92 | fc = tf.nn.relu(fc) 93 | # fc = tf.layers.dense(fc, 512, name='fc2') 94 | # fc = tf.contrib.layers.dropout(fc, self.keep_prob) 95 | # fc = tf.nn.relu(fc) 96 | # 97 | # fc = tf.layers.dense(fc, 256, name='fc3') 98 | # fc = tf.contrib.layers.dropout(fc, self.keep_prob) 99 | # fc = tf.nn.relu(fc) 100 | # fc = tf.layers.dense(fc, 128, name='fc4') 101 | # fc = tf.contrib.layers.dropout(fc, self.keep_prob) 102 | # fc = tf.nn.relu(fc) 103 | 104 | # 分类器(b, 128)->(b, 10) 105 | self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc5') 106 | self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1) # 预测类别 107 | 108 | with tf.name_scope("optimize"): 109 | # 损失函数,交叉熵 110 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 111 | self.loss = tf.reduce_mean(cross_entropy) 112 | # 优化器 113 | self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss) 114 | 115 | with tf.name_scope("accuracy"): 116 | # 准确率 117 | correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 118 | self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 119 | -------------------------------------------------------------------------------- /run_pre.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import time 8 | from datetime import timedelta 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from sklearn import metrics 13 | 14 | from rnn_model import TRNNConfig, TextRNN 15 | from lstm import * 16 | from data.cnews_loader import * 17 | from service.client import BertClient 18 | 19 | base_dir = 'data/cnews' 20 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 21 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 22 | # val_dir = os.path.join(base_dir, 'cnews.val.txt') 23 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 24 | 25 | save_dir = 'checkpoints/bert_model_20' 26 | save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径 27 | action = 'pre' 28 | 29 | def get_time_dif(start_time): 30 | """获取已使用时间""" 31 | end_time = time.time() 32 | time_dif = end_time - start_time 33 | return timedelta(seconds=int(round(time_dif))) 34 | 35 | 36 | def feed_data(x_batch, y_batch, keep_prob): 37 | feed_dict = { 38 | model.input_x: x_batch, 39 | model.input_y: y_batch, 40 | model.keep_prob: keep_prob 41 | } 42 | return feed_dict 43 | 44 | 45 | def evaluate(sess, x_, y_): 46 | """评估在某一数据上的准确率和损失""" 47 | data_len = len(x_) 48 | batch_eval = batch_iter(x_, y_, 128) 49 | total_loss = 0.0 50 | total_acc = 0.0 51 | for x_batch, y_batch in batch_eval: 52 | batch_len = len(x_batch) 53 | feed_dict = feed_data(x_batch, y_batch, 1.0) 54 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 55 | total_loss += loss * batch_len 56 | total_acc += acc * batch_len 57 | 58 | return total_loss / data_len, total_acc / data_len 59 | 60 | 61 | def train(): 62 | print("Configuring TensorBoard and Saver...") 63 | # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖 64 | tensorboard_dir = 'tensorboard/textlstm' 65 | if not os.path.exists(tensorboard_dir): 66 | os.makedirs(tensorboard_dir) 67 | 68 | tf.summary.scalar("loss", model.loss) 69 | tf.summary.scalar("accuracy", model.acc) 70 | merged_summary = tf.summary.merge_all() 71 | writer = tf.summary.FileWriter(tensorboard_dir) 72 | 73 | # 配置 Saver 74 | saver = tf.train.Saver() 75 | if not os.path.exists(save_dir): 76 | os.makedirs(save_dir) 77 | 78 | print("Loading training and validation data...") 79 | # 载入训练集与验证集 80 | start_time = time.time() 81 | x_train, y_train = process_file(train_dir, cat_to_id) 82 | # x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length) 83 | time_dif = get_time_dif(start_time) 84 | print("Time usage:", time_dif) 85 | 86 | # 创建session 87 | session = tf.Session() 88 | session.run(tf.global_variables_initializer()) 89 | writer.add_graph(session.graph) 90 | 91 | print('Training and evaluating...') 92 | start_time = time.time() 93 | total_batch = 0 # 总批次 94 | best_acc_val = 0.0 # 最佳验证集准确率 95 | last_improved = 0 # 记录上一次提升批次 96 | require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练 97 | 98 | # bc = BertClient() 99 | flag = False 100 | for epoch in range(config.num_epochs): 101 | print('Epoch:', epoch + 1) 102 | batch_train = batch_iter(x_train, y_train, config.batch_size) 103 | for x_batch, y_batch in batch_train: 104 | x_batch = bc.encode(x_batch) 105 | feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob) 106 | 107 | if total_batch % config.save_per_batch == 0: 108 | # 每多少轮次将训练结果写入tensorboard scalar 109 | s = session.run(merged_summary, feed_dict=feed_dict) 110 | writer.add_summary(s, total_batch) 111 | 112 | if total_batch % config.print_per_batch == 0: 113 | # 每多少轮次输出在训练集和验证集上的性能 114 | feed_dict[model.keep_prob] = 1.0 115 | loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict) 116 | # loss_val, acc_val = evaluate(session, x_val, y_val) # todo 117 | 118 | # if acc_val > best_acc_val: 119 | # # 保存最好结果 120 | # best_acc_val = acc_val 121 | # last_improved = total_batch 122 | saver.save(sess=session, save_path=save_path) 123 | improved_str = '*' 124 | # else: 125 | # improved_str = '' 126 | 127 | time_dif = get_time_dif(start_time) 128 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 129 | + 'Time: {3} {4}' 130 | # print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str)) 131 | print(msg.format(total_batch, loss_train, acc_train, time_dif, improved_str)) 132 | 133 | session.run(model.optim, feed_dict=feed_dict) # 运行优化 134 | total_batch += 1 135 | 136 | # if total_batch - last_improved > require_improvement: 137 | # # 验证集正确率长期不提升,提前结束训练 138 | # print("No optimization for a long time, auto-stopping...") 139 | # flag = True 140 | # break # 跳出循环 141 | # if flag: # 同上 142 | # break 143 | 144 | 145 | def test(): 146 | print("Loading test data...") 147 | start_time = time.time() 148 | x_test = process_file_nolabel(test_dir, word_to_id, config.seq_length) 149 | bc = BertClient() 150 | x_test = bc.encode(x_test) 151 | # (test 2364, 80) 152 | session = tf.Session() 153 | session.run(tf.global_variables_initializer()) 154 | saver = tf.train.Saver() 155 | saver.restore(sess=session, save_path=save_path) # 读取保存的模型 156 | 157 | # print('Testing...') 158 | # loss_test, acc_test = evaluate(session, x_test, y_test) 159 | # msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}' 160 | # print(msg.format(loss_test, acc_test)) 161 | 162 | batch_size = 32 163 | data_len = len(x_test) 164 | num_batch = int((data_len - 1) / batch_size) + 1 165 | 166 | # y_test_cls = np.argmax(y_test, 1) 167 | y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果 168 | for i in range(num_batch): # 逐批次处理 169 | start_id = i * batch_size 170 | end_id = min((i + 1) * batch_size, data_len) 171 | feed_dict = { 172 | model.input_x: x_test[start_id:end_id], 173 | model.keep_prob: 1.0 174 | } 175 | y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict) 176 | 177 | # 评估 178 | print("Precision, Recall and F1-Score...") 179 | # print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories)) 180 | 181 | # 混淆矩阵 182 | print("Confusion Matrix...") 183 | # cm = metrics.confusion_matrix(y_test_cls, y_pred_cls) 184 | # print(cm) 185 | 186 | time_dif = get_time_dif(start_time) 187 | print("Time usage:", time_dif) 188 | return y_pred_cls 189 | 190 | 191 | if __name__ == '__main__': 192 | # if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']: 193 | # raise ValueError("""usage: python run_rnn.py [train / test]""") 194 | 195 | print('Configuring RNN model...') 196 | config = TRNNConfig() 197 | if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建 198 | build_vocab(train_dir, vocab_dir, config.vocab_size) 199 | categories, cat_to_id = read_category() 200 | words, word_to_id = read_vocab(vocab_dir) 201 | config.vocab_size = len(words) 202 | # s = ['我们的你的', '你的收到'] 203 | # x_train, y_train = process_file(train_dir, cat_to_id) 204 | # batch_train = batch_iter(x_train, y_train, 128) 205 | # for x_batch, y_batch in batch_train: 206 | # # x_batch = np.array(x_batch) 207 | # print(x_batch) 208 | # print(s) 209 | 210 | 211 | model = TextRNN(config) 212 | 213 | if action == 'train': 214 | train() 215 | else: 216 | y_pred = test() 217 | y_word = [] 218 | for i in range(len(y_pred)): 219 | y_word.append(list(cat_to_id.keys())[list(cat_to_id.values()).index(y_pred[i])]) 220 | np.savetxt('./data/predict_lstm_early_atten.txt', y_word, fmt='%s') 221 | 222 | -------------------------------------------------------------------------------- /run_rnn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import time 8 | from datetime import timedelta 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | # from sklearn import metrics 13 | 14 | from rnn_model import TRNNConfig, TextRNN 15 | from lstm import * 16 | from data.cnews_loader import * 17 | # from service.client import BertClient 18 | 19 | base_dir = 'data/cnews' 20 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 21 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 22 | # val_dir = os.path.join(base_dir, 'cnews.val.txt') 23 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 24 | 25 | save_dir = 'checkpoints/rnn_randomword' 26 | save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径 27 | action = 'train' 28 | 29 | def get_time_dif(start_time): 30 | """获取已使用时间""" 31 | end_time = time.time() 32 | time_dif = end_time - start_time 33 | return timedelta(seconds=int(round(time_dif))) 34 | 35 | 36 | def feed_data(x_batch, y_batch, keep_prob): 37 | feed_dict = { 38 | model.input_x: x_batch, 39 | model.input_y: y_batch, 40 | model.keep_prob: keep_prob 41 | } 42 | return feed_dict 43 | 44 | 45 | def evaluate(sess, x_, y_): 46 | """评估在某一数据上的准确率和损失""" 47 | data_len = len(x_) 48 | batch_eval = batch_iter(x_, y_, 32) 49 | total_loss = 0.0 50 | total_acc = 0.0 51 | for x_batch, y_batch in batch_eval: 52 | batch_len = len(x_batch) 53 | feed_dict = feed_data(x_batch, y_batch, 1.0) 54 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 55 | total_loss += loss * batch_len 56 | total_acc += acc * batch_len 57 | 58 | return total_loss / data_len, total_acc / data_len 59 | 60 | 61 | def train(): 62 | print("Configuring TensorBoard and Saver...") 63 | # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖 64 | tensorboard_dir = 'tensorboard/rnn_randomword' 65 | if not os.path.exists(tensorboard_dir): 66 | os.makedirs(tensorboard_dir) 67 | 68 | tf.summary.scalar("loss", model.loss) 69 | tf.summary.scalar("accuracy", model.acc) 70 | merged_summary = tf.summary.merge_all() 71 | writer = tf.summary.FileWriter(tensorboard_dir) 72 | 73 | # 配置 Saver 74 | saver = tf.train.Saver() 75 | if not os.path.exists(save_dir): 76 | os.makedirs(save_dir) 77 | 78 | print("Loading training and validation data...") 79 | # 载入训练集与验证集 80 | start_time = time.time() 81 | # x_train, y_train = process_file(train_dir, cat_to_id) 82 | # # x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length) 83 | x_train= np.load('./data/train_x.npy') 84 | # x_train = list(x_train) 85 | y_train = np.load('./data/train_y.npy') 86 | x_val = np.load('./data/test_x.npy') 87 | y_val = np.load('./data/test_y.npy') 88 | time_dif = get_time_dif(start_time) 89 | print("Time usage:", time_dif) 90 | 91 | # 创建session 92 | session = tf.Session() 93 | session.run(tf.global_variables_initializer()) 94 | writer.add_graph(session.graph) 95 | 96 | print('Training and evaluating...') 97 | start_time = time.time() 98 | total_batch = 0 # 总批次 99 | best_acc_val = 0.0 # 最佳验证集准确率 100 | last_improved = 0 # 记录上一次提升批次 101 | require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练 102 | 103 | # bc = BertClient() 104 | # x_val = list(x_val) 105 | # x_val = bc.encode(x_val) 106 | flag = False 107 | for epoch in range(config.num_epochs): 108 | print('Epoch:', epoch + 1) 109 | batch_train = batch_iter(x_train, y_train, config.batch_size) 110 | for x_batch, y_batch in batch_train: 111 | # x_batch = bc.encode(x_batch) 112 | feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob) 113 | 114 | # if total_batch % config.save_per_batch == 0: 115 | # # 每多少轮次将训练结果写入tensorboard scalar 116 | # s = session.run(merged_summary, feed_dict=feed_dict) 117 | # writer.add_summary(s, total_batch) 118 | 119 | if total_batch % config.print_per_batch == 0: 120 | # 每多少轮次输出在训练集和验证集上的性能 121 | feed_dict[model.keep_prob] = 1.0 122 | loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict) 123 | loss_val, acc_val = evaluate(session, x_val, y_val) # todo 124 | 125 | if acc_val > best_acc_val: 126 | # 保存最好结果 127 | best_acc_val = acc_val 128 | last_improved = total_batch 129 | saver.save(sess=session, save_path=save_path) 130 | improved_str = '*' 131 | else: 132 | improved_str = '' 133 | 134 | time_dif = get_time_dif(start_time) 135 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 136 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 137 | print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str)) 138 | # print(msg.format(total_batch, loss_train, acc_train, time_dif, improved_str)) 139 | 140 | session.run(model.optim, feed_dict=feed_dict) # 运行优化 141 | total_batch += 1 142 | 143 | if total_batch - last_improved > require_improvement: 144 | # 验证集正确率长期不提升,提前结束训练 145 | print("No optimization for a long time, auto-stopping...") 146 | flag = True 147 | break # 跳出循环 148 | if flag: # 同上 149 | break 150 | 151 | 152 | def test(): 153 | print("Loading test data...") 154 | start_time = time.time() 155 | x_test = process_file_nolabel(test_dir, word_to_id, config.seq_length) 156 | # bc = BertClient() 157 | # x_test = bc.encode(x_test) 158 | # (test 2364, 80) 159 | session = tf.Session() 160 | session.run(tf.global_variables_initializer()) 161 | saver = tf.train.Saver() 162 | saver.restore(sess=session, save_path=save_path) # 读取保存的模型 163 | 164 | # print('Testing...') 165 | # loss_test, acc_test = evaluate(session, x_test, y_test) 166 | # msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}' 167 | # print(msg.format(loss_test, acc_test)) 168 | 169 | batch_size = 32 170 | data_len = len(x_test) 171 | num_batch = int((data_len - 1) / batch_size) + 1 172 | 173 | # y_test_cls = np.argmax(y_test, 1) 174 | y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果 175 | for i in range(num_batch): # 逐批次处理 176 | start_id = i * batch_size 177 | end_id = min((i + 1) * batch_size, data_len) 178 | feed_dict = { 179 | model.input_x: x_test[start_id:end_id], 180 | model.keep_prob: 1.0 181 | } 182 | y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict) 183 | 184 | # 评估 185 | print("Precision, Recall and F1-Score...") 186 | # print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories)) 187 | 188 | # 混淆矩阵 189 | print("Confusion Matrix...") 190 | # cm = metrics.confusion_matrix(y_test_cls, y_pred_cls) 191 | # print(cm) 192 | 193 | time_dif = get_time_dif(start_time) 194 | print("Time usage:", time_dif) 195 | return y_pred_cls 196 | 197 | 198 | if __name__ == '__main__': 199 | # if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']: 200 | # raise ValueError("""usage: python run_rnn.py [train / test]""") 201 | 202 | print('Configuring RNN model...') 203 | config = TRNNConfig() 204 | if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建 205 | build_vocab(train_dir, vocab_dir, config.vocab_size) 206 | categories, cat_to_id = read_category() 207 | words, word_to_id = read_vocab(vocab_dir) 208 | config.vocab_size = len(words) 209 | # s = ['我们的你的', '你的收到'] 210 | # x_train, y_train = process_file(train_dir, cat_to_id) 211 | # batch_train = batch_iter(x_train, y_train, 128) 212 | # for x_batch, y_batch in batch_train: 213 | # # x_batch = np.array(x_batch) 214 | # print(x_batch) 215 | # print(s) 216 | 217 | 218 | model = TextRNN(config) 219 | 220 | if action == 'train': 221 | train() 222 | else: 223 | y_pred = test() 224 | y_word = [] 225 | for i in range(len(y_pred)): 226 | y_word.append(list(cat_to_id.keys())[list(cat_to_id.values()).index(y_pred[i])]) 227 | np.savetxt('./data/new_para.txt', y_word, fmt='%s') 228 | 229 | -------------------------------------------------------------------------------- /run_rnn_bert.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import time 8 | from datetime import timedelta 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from sklearn import metrics 13 | 14 | from rnn_model import TRNNConfig, TextRNN 15 | from lstm import * 16 | from data.cnews_loader import * 17 | from bert_serving.client import BertClient 18 | 19 | base_dir = 'data/cnews' 20 | train_dir = os.path.join(base_dir, 'cnews.train.txt') 21 | test_dir = os.path.join(base_dir, 'cnews.test.txt') 22 | # val_dir = os.path.join(base_dir, 'cnews.val.txt') 23 | vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 24 | 25 | save_dir = 'checkpoints/bert_model_word' 26 | save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径 27 | action = 'train' 28 | 29 | def get_time_dif(start_time): 30 | """获取已使用时间""" 31 | end_time = time.time() 32 | time_dif = end_time - start_time 33 | return timedelta(seconds=int(round(time_dif))) 34 | 35 | 36 | def feed_data(x_batch, y_batch, keep_prob): 37 | feed_dict = { 38 | model.input_x: x_batch, 39 | model.input_y: y_batch, 40 | model.keep_prob: keep_prob 41 | } 42 | return feed_dict 43 | 44 | 45 | def evaluate(sess, x_, y_): 46 | """评估在某一数据上的准确率和损失""" 47 | data_len = len(x_) 48 | batch_eval = batch_iter(x_, y_, 32) 49 | total_loss = 0.0 50 | total_acc = 0.0 51 | for x_batch, y_batch in batch_eval: 52 | batch_len = len(x_batch) 53 | feed_dict = feed_data(x_batch, y_batch, 1.0) 54 | loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 55 | total_loss += loss * batch_len 56 | total_acc += acc * batch_len 57 | 58 | return total_loss / data_len, total_acc / data_len 59 | 60 | 61 | def train(): 62 | print("Configuring TensorBoard and Saver...") 63 | # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖 64 | tensorboard_dir = 'tensorboard/textlstm' 65 | if not os.path.exists(tensorboard_dir): 66 | os.makedirs(tensorboard_dir) 67 | 68 | tf.summary.scalar("loss", model.loss) 69 | tf.summary.scalar("accuracy", model.acc) 70 | merged_summary = tf.summary.merge_all() 71 | writer = tf.summary.FileWriter(tensorboard_dir) 72 | 73 | # 配置 Saver 74 | saver = tf.train.Saver() 75 | if not os.path.exists(save_dir): 76 | os.makedirs(save_dir) 77 | 78 | print("Loading training and validation data...") 79 | # 载入训练集与验证集 80 | start_time = time.time() 81 | # x_train, y_train = process_file(train_dir, cat_to_id) 82 | # # x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length) 83 | x_train= np.load('./data/train_xs.npy') 84 | x_train = list(x_train) 85 | y_train = np.load('./data/train_ys.npy') 86 | x_val = np.load('./data/test_xs.npy') 87 | y_val = np.load('./data/test_ys.npy') 88 | time_dif = get_time_dif(start_time) 89 | print("Time usage:", time_dif) 90 | 91 | # 创建session 92 | session = tf.Session() 93 | session.run(tf.global_variables_initializer()) 94 | writer.add_graph(session.graph) 95 | 96 | print('Training and evaluating...') 97 | start_time = time.time() 98 | total_batch = 0 # 总批次 99 | best_acc_val = 0.0 # 最佳验证集准确率 100 | last_improved = 0 # 记录上一次提升批次 101 | require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练 102 | 103 | bc = BertClient() 104 | x_val = list(x_val) 105 | x_val = bc.encode(x_val) 106 | flag = False 107 | for epoch in range(config.num_epochs): 108 | print('Epoch:', epoch + 1) 109 | batch_train = batch_iter(x_train, y_train, config.batch_size) 110 | for x_batch, y_batch in batch_train: 111 | x_batch = bc.encode(x_batch) 112 | feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob) 113 | session.run(model.optim, feed_dict=feed_dict) # 运行优化 114 | # total_batch += 1 115 | if total_batch % config.save_per_batch == 0: 116 | # 每多少轮次将训练结果写入tensorboard scalar 117 | s = session.run(merged_summary, feed_dict=feed_dict) 118 | writer.add_summary(s, total_batch) 119 | 120 | if total_batch % config.print_per_batch == 0: 121 | # 每多少轮次输出在训练集和验证集上的性能 122 | feed_dict[model.keep_prob] = 1.0 123 | loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict) 124 | loss_val, acc_val = evaluate(session, x_val, y_val) # todo 125 | 126 | if acc_val > best_acc_val: 127 | # 保存最好结果 128 | best_acc_val = acc_val 129 | last_improved = total_batch 130 | saver.save(sess=session, save_path=save_path) 131 | improved_str = '*' 132 | else: 133 | improved_str = '' 134 | 135 | time_dif = get_time_dif(start_time) 136 | msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 137 | + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 138 | print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str)) 139 | # print(msg.format(total_batch, loss_train, acc_train, time_dif, improved_str)) 140 | 141 | session.run(model.optim, feed_dict=feed_dict) # 运行优化 142 | total_batch += 1 143 | 144 | if total_batch - last_improved > require_improvement: 145 | # 验证集正确率长期不提升,提前结束训练 146 | print("No optimization for a long time, auto-stopping...") 147 | flag = True 148 | break # 跳出循环 149 | if flag: # 同上 150 | break 151 | 152 | 153 | def test(): 154 | print("Loading test data...") 155 | start_time = time.time() 156 | x_test = process_file_nolabel(test_dir, word_to_id, config.seq_length) 157 | x_test = list(x_test) 158 | bc = BertClient() 159 | x_test = bc.encode(x_test) 160 | # (test 2364, 80) 161 | session = tf.Session() 162 | session.run(tf.global_variables_initializer()) 163 | saver = tf.train.Saver() 164 | saver.restore(sess=session, save_path=save_path) # 读取保存的模型 165 | 166 | # print('Testing...') 167 | # loss_test, acc_test = evaluate(session, x_test, y_test) 168 | # msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}' 169 | # print(msg.format(loss_test, acc_test)) 170 | 171 | batch_size = 32 172 | data_len = len(x_test) 173 | num_batch = int((data_len - 1) / batch_size) + 1 174 | 175 | # y_test_cls = np.argmax(y_test, 1) 176 | y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果 177 | for i in range(num_batch): # 逐批次处理 178 | start_id = i * batch_size 179 | end_id = min((i + 1) * batch_size, data_len) 180 | feed_dict = { 181 | model.input_x: x_test[start_id:end_id], 182 | model.keep_prob: 1.0 183 | } 184 | y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict) 185 | 186 | # 评估 187 | print("Precision, Recall and F1-Score...") 188 | # print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories)) 189 | 190 | # 混淆矩阵 191 | print("Confusion Matrix...") 192 | # cm = metrics.confusion_matrix(y_test_cls, y_pred_cls) 193 | # print(cm) 194 | 195 | time_dif = get_time_dif(start_time) 196 | print("Time usage:", time_dif) 197 | return y_pred_cls 198 | 199 | 200 | if __name__ == '__main__': 201 | # if len(sys.argv) != 2 or sys.argv[1] not in ['train', 'test']: 202 | # raise ValueError("""usage: python run_rnn.py [train / test]""") 203 | 204 | print('Configuring RNN model...') 205 | config = TRNNConfig() 206 | if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建 207 | build_vocab(train_dir, vocab_dir, config.vocab_size) 208 | categories, cat_to_id = read_category() 209 | words, word_to_id = read_vocab(vocab_dir) 210 | config.vocab_size = len(words) 211 | # s = ['我们的你的', '你的收到'] 212 | # x_train, y_train = process_file(train_dir, cat_to_id) 213 | # batch_train = batch_iter(x_train, y_train, 128) 214 | # for x_batch, y_batch in batch_train: 215 | # # x_batch = np.array(x_batch) 216 | # print(x_batch) 217 | # print(s) 218 | 219 | 220 | model = TextRNN(config) 221 | 222 | if action == 'train': 223 | train() 224 | else: 225 | y_pred = test() 226 | y_word = [] 227 | for i in range(len(y_pred)): 228 | y_word.append(list(cat_to_id.keys())[list(cat_to_id.values()).index(y_pred[i])]) 229 | np.savetxt('./data/bert_word.txt', y_word, fmt='%s') 230 | 231 | -------------------------------------------------------------------------------- /service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/service/__init__.py -------------------------------------------------------------------------------- /service/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Han Xiao 4 | 5 | import sys 6 | import threading 7 | import uuid 8 | 9 | import numpy as np 10 | import zmq 11 | from zmq.utils import jsonapi 12 | 13 | if sys.version_info >= (3, 0): 14 | _str = str 15 | _buffer = memoryview 16 | _unicode = lambda x: x 17 | else: 18 | # make it compatible for py2 19 | _str = basestring 20 | _buffer = buffer 21 | _unicode = lambda x: [BertClient.force_to_unicode(y) for y in x] 22 | 23 | 24 | class BertClient: 25 | def __init__(self, ip='localhost', port=5555, port_out=5556, 26 | output_fmt='ndarray', show_server_config=False, 27 | identity=None): 28 | self.context = zmq.Context() 29 | self.sender = self.context.socket(zmq.PUSH) 30 | self.identity = identity or str(uuid.uuid4()).encode('ascii') 31 | self.sender.connect('tcp://%s:%d' % (ip, port)) 32 | 33 | self.receiver = self.context.socket(zmq.SUB) 34 | self.receiver.setsockopt(zmq.SUBSCRIBE, self.identity) 35 | self.receiver.connect('tcp://%s:%d' % (ip, port_out)) 36 | 37 | if output_fmt == 'ndarray': 38 | self.formatter = lambda x: x 39 | elif output_fmt == 'list': 40 | self.formatter = lambda x: x.tolist() 41 | else: 42 | raise AttributeError('"output_fmt" must be "ndarray" or "list"') 43 | 44 | if show_server_config: 45 | print('server returns the following config:') 46 | for k, v in self.get_server_config().items(): 47 | print('%30s\t=\t%-30s' % (k, v)) 48 | print('you should NOT see this message multiple times! ' 49 | 'if you see it appears repeatedly, ' 50 | 'consider moving "BertClient()" out of the loop.') 51 | 52 | def send(self, msg): 53 | self.sender.send_multipart([self.identity, msg]) 54 | 55 | def recv(self): 56 | return self.receiver.recv_multipart() 57 | 58 | def recv_ndarray(self): 59 | response = self.recv() 60 | arr_info, arr_val = jsonapi.loads(response[1]), response[2] 61 | X = np.frombuffer(_buffer(arr_val), dtype=arr_info['dtype']) 62 | return self.formatter(X.reshape(arr_info['shape'])) 63 | 64 | def get_server_config(self): 65 | self.send(b'SHOW_CONFIG') 66 | response = self.recv() 67 | return jsonapi.loads(response[1]) 68 | 69 | def encode(self, texts, blocking=True): 70 | if self.is_valid_input(texts): 71 | texts = _unicode(texts) 72 | self.send(jsonapi.dumps(texts)) 73 | return self.recv_ndarray() if blocking else None 74 | else: 75 | raise AttributeError('"texts" must be "List[str]" and non-empty!') 76 | 77 | def listen(self, max_num_batch=None): 78 | forever = max_num_batch is None 79 | cnt = 0 80 | while forever or cnt < max_num_batch: 81 | yield self.recv_ndarray() 82 | cnt += 1 83 | 84 | # experimental, use with caution! 85 | def encode_async(self, batch_generator, max_num_batch=None): 86 | def run(): 87 | cnt = 0 88 | for texts in batch_generator: 89 | self.encode(texts, blocking=False) 90 | cnt += 1 91 | if max_num_batch and cnt == max_num_batch: 92 | break 93 | 94 | t = threading.Thread(target=run) 95 | t.start() 96 | return self.listen(max_num_batch) 97 | 98 | @staticmethod 99 | def is_valid_input(texts): 100 | return isinstance(texts, list) and all(isinstance(s, _str) and s.strip() for s in texts) 101 | 102 | @staticmethod 103 | def force_to_unicode(text): 104 | "If text is unicode, it is returned as is. If it's str, convert it to Unicode using UTF-8 encoding" 105 | return text if isinstance(text, unicode) else text.decode('utf-8') 106 | -------------------------------------------------------------------------------- /service/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Han Xiao 4 | import multiprocessing 5 | import os 6 | import sys 7 | import threading 8 | import time 9 | import uuid 10 | from collections import defaultdict 11 | from datetime import datetime 12 | from multiprocessing import Process 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | import zmq 17 | from tensorflow.python.estimator.estimator import Estimator 18 | from tensorflow.python.estimator.run_config import RunConfig 19 | from zmq.utils import jsonapi 20 | 21 | from bert import tokenization, modeling 22 | from bert.extract_features import model_fn_builder, convert_lst_to_features 23 | from helper import set_logger 24 | from service.client import BertClient 25 | 26 | 27 | class ServerCommand: 28 | terminate = b'TERMINATION' 29 | show_config = b'SHOW_CONFIG' 30 | new_job = b'REGISTER' 31 | 32 | 33 | class BertServer(threading.Thread): 34 | def __init__(self, args): 35 | super().__init__() 36 | self.logger = set_logger('VENTILATOR') 37 | 38 | self.model_dir = args.model_dir 39 | self.max_seq_len = args.max_seq_len 40 | self.num_worker = args.num_worker 41 | self.max_batch_size = args.max_batch_size 42 | self.port = args.port 43 | self.args = args 44 | self.args_dict = { 45 | 'model_dir': args.model_dir, 46 | 'max_seq_len': args.max_seq_len, 47 | 'num_worker': args.num_worker, 48 | 'max_batch_size': args.max_batch_size, 49 | 'port': args.port, 50 | 'port_out': args.port_out, 51 | 'pooling_layer': args.pooling_layer, 52 | 'pooling_strategy': args.pooling_strategy.value, 53 | 'tensorflow_version': tf.__version__, 54 | 'python_version': sys.version, 55 | 'server_start_time': str(datetime.now()) 56 | } 57 | self.processes = [] 58 | self.context = zmq.Context() 59 | 60 | # frontend facing client 61 | self.frontend = self.context.socket(zmq.PULL) 62 | self.frontend.bind('tcp://*:%d' % self.port) 63 | 64 | # pair connection between frontend and sink 65 | self.sink = self.context.socket(zmq.PAIR) 66 | self.sink.bind('ipc://*') 67 | self.addr_front2sink = self.sink.getsockopt(zmq.LAST_ENDPOINT).decode('ascii') 68 | 69 | # backend facing workers 70 | self.backend = self.context.socket(zmq.PUSH) 71 | self.backend.bind('ipc://*') 72 | self.addr_backend = self.backend.getsockopt(zmq.LAST_ENDPOINT).decode('ascii') 73 | 74 | # start the sink thread 75 | proc_sink = BertSink(self.args, self.addr_front2sink) 76 | proc_sink.start() 77 | self.processes.append(proc_sink) 78 | self.addr_sink = self.sink.recv().decode('ascii') 79 | self.logger.info('frontend-sink ipc: %s' % self.addr_sink) 80 | 81 | def close(self): 82 | self.logger.info('shutting down...') 83 | for p in self.processes: 84 | p.close() 85 | self.frontend.close() 86 | self.backend.close() 87 | self.sink.close() 88 | self.context.term() 89 | self.logger.info('terminated!') 90 | 91 | def run(self): 92 | available_gpus = range(self.num_worker) 93 | run_on_gpu = True 94 | num_req = 0 95 | try: 96 | import GPUtil 97 | available_gpus = GPUtil.getAvailable(limit=self.num_worker) 98 | if len(available_gpus) < self.num_worker: 99 | self.logger.warn('only %d GPU(s) is available, but ask for %d' % (len(available_gpus), self.num_worker)) 100 | except FileNotFoundError: 101 | self.logger.warn('nvidia-smi is missing, often means no gpu found on this machine. ' 102 | 'will run service on cpu instead') 103 | run_on_gpu = False 104 | 105 | # start the backend processes 106 | for i in available_gpus: 107 | process = BertWorker(i, self.args, self.addr_backend, self.addr_sink) 108 | self.processes.append(process) 109 | process.start() 110 | 111 | try: 112 | while True: 113 | client, msg = self.frontend.recv_multipart() 114 | if msg == ServerCommand.show_config: 115 | self.sink.send_multipart([client, msg, 116 | jsonapi.dumps({**{'client': client.decode('ascii'), 117 | 'num_subprocess': len(self.processes), 118 | 'frontend -> backend': self.addr_backend, 119 | 'backend -> sink': self.addr_sink, 120 | 'frontend <-> sink': self.addr_front2sink, 121 | 'server_current_time': str(datetime.now()), 122 | 'run_on_gpu': run_on_gpu, 123 | 'num_request': num_req}, 124 | **self.args_dict})]) 125 | continue 126 | 127 | num_req += 1 128 | client = client + b'#' + str(uuid.uuid4()).encode('ascii') 129 | seqs = jsonapi.loads(msg) 130 | num_seqs = len(seqs) 131 | # tell sink to collect a new job 132 | self.sink.send_multipart([client, ServerCommand.new_job, b'%d' % num_seqs]) 133 | 134 | if num_seqs > self.max_batch_size: 135 | # divide the large batch into small batches 136 | s_idx = 0 137 | while s_idx < num_seqs: 138 | tmp = seqs[s_idx: (s_idx + self.max_batch_size)] 139 | if tmp: 140 | # get the worker with minimum workload 141 | client_partial_id = client + b'@%d' % s_idx 142 | self.backend.send_multipart([client_partial_id, jsonapi.dumps(tmp)]) 143 | s_idx += len(tmp) 144 | else: 145 | self.backend.send_multipart([client, msg]) 146 | except zmq.error.ContextTerminated: 147 | self.logger.error('context is closed!') 148 | 149 | 150 | class BertSink(Process): 151 | def __init__(self, args, front_sink_addr): 152 | super().__init__() 153 | self.port = args.port_out 154 | self.exit_flag = multiprocessing.Event() 155 | self.logger = set_logger('SINK') 156 | self.front_sink_addr = front_sink_addr 157 | 158 | def close(self): 159 | self.logger.info('shutting down...') 160 | self.exit_flag.set() 161 | self.terminate() 162 | self.join() 163 | self.logger.info('terminated!') 164 | 165 | def run(self): 166 | context = zmq.Context() 167 | # receive from workers 168 | receiver = context.socket(zmq.PULL) 169 | receiver.bind('ipc://*') 170 | 171 | frontend = context.socket(zmq.PAIR) 172 | frontend.connect(self.front_sink_addr) 173 | 174 | # publish to client 175 | sender = context.socket(zmq.PUB) 176 | sender.bind('tcp://*:%d' % self.port) 177 | 178 | pending_checksum = defaultdict(int) 179 | pending_result = defaultdict(list) 180 | job_checksum = {} 181 | 182 | poller = zmq.Poller() 183 | poller.register(frontend, zmq.POLLIN) 184 | poller.register(receiver, zmq.POLLIN) 185 | 186 | # send worker receiver address back to frontend 187 | frontend.send(receiver.getsockopt(zmq.LAST_ENDPOINT)) 188 | 189 | try: 190 | while not self.exit_flag.is_set(): 191 | socks = dict(poller.poll()) 192 | if socks.get(receiver) == zmq.POLLIN: 193 | msg = receiver.recv_multipart() 194 | job_id = msg[0] 195 | # parsing the ndarray 196 | arr_info, arr_val = jsonapi.loads(msg[1]), msg[2] 197 | X = np.frombuffer(memoryview(arr_val), dtype=arr_info['dtype']) 198 | X = X.reshape(arr_info['shape']) 199 | job_info = job_id.split(b'@') 200 | job_id = job_info[0] 201 | partial_id = job_info[1] if len(job_info) == 2 else 0 202 | pending_result[job_id].append((X, partial_id)) 203 | pending_checksum[job_id] += X.shape[0] 204 | self.logger.info('collected job %s (%d/%d)' % (job_id, 205 | pending_checksum[job_id], 206 | job_checksum[job_id])) 207 | 208 | # check if there are finished jobs, send it back to workers 209 | finished = [(k, v) for k, v in pending_result.items() if pending_checksum[k] == job_checksum[k]] 210 | for job_info, tmp in finished: 211 | self.logger.info( 212 | 'job %s %d samples are done! sending back to client' % ( 213 | job_info, job_checksum[job_info])) 214 | # re-sort to the original order 215 | tmp = [x[0] for x in sorted(tmp, key=lambda x: x[1])] 216 | client_addr = job_info.split(b'#')[0] 217 | send_ndarray(sender, client_addr, np.concatenate(tmp, axis=0)) 218 | pending_result.pop(job_info) 219 | pending_checksum.pop(job_info) 220 | job_checksum.pop(job_info) 221 | 222 | if socks.get(frontend) == zmq.POLLIN: 223 | job_info, msg_type, msg_info = frontend.recv_multipart() 224 | if msg_type == ServerCommand.new_job: 225 | job_checksum[job_info] = int(msg_info) 226 | self.logger.info('new job %s size: %d is registered!' % (job_info, int(msg_info))) 227 | elif msg_type == ServerCommand.show_config: 228 | sender.send_multipart([job_info, msg_info]) 229 | except zmq.error.ContextTerminated: 230 | self.logger.error('context is closed!') 231 | 232 | 233 | class BertWorker(Process): 234 | def __init__(self, id, args, worker_address, sink_address): 235 | super().__init__() 236 | self.model_dir = args.model_dir 237 | self.config_fp = os.path.join(self.model_dir, 'bert_config.json') 238 | self.checkpoint_fp = os.path.join(self.model_dir, 'bert_model.ckpt') 239 | self.vocab_fp = os.path.join(args.model_dir, 'vocab.txt') 240 | self.tokenizer = tokenization.FullTokenizer(vocab_file=self.vocab_fp) 241 | self.max_seq_len = args.max_seq_len 242 | self.worker_id = id 243 | self.daemon = True 244 | self.model_fn = model_fn_builder( 245 | bert_config=modeling.BertConfig.from_json_file(self.config_fp), 246 | init_checkpoint=self.checkpoint_fp, 247 | pooling_strategy=args.pooling_strategy, 248 | pooling_layer=args.pooling_layer 249 | ) 250 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.worker_id) 251 | config = tf.ConfigProto() 252 | config.gpu_options.allow_growth = True 253 | config.gpu_options.per_process_gpu_memory_fraction = args.gpu_memory_fraction 254 | self.estimator = Estimator(self.model_fn, config=RunConfig(session_config=config)) 255 | self.exit_flag = multiprocessing.Event() 256 | self.logger = set_logger('WORKER-%d' % self.worker_id) 257 | self.worker_address = worker_address 258 | self.sink_address = sink_address 259 | 260 | def close(self): 261 | self.logger.info('shutting down...') 262 | self.exit_flag.set() 263 | self.terminate() 264 | self.join() 265 | self.logger.info('terminated!') 266 | 267 | def run(self): 268 | context = zmq.Context() 269 | receiver = context.socket(zmq.PULL) 270 | receiver.connect(self.worker_address) 271 | 272 | sink = context.socket(zmq.PUSH) 273 | sink.connect(self.sink_address) 274 | 275 | input_fn = self.input_fn_builder(receiver) 276 | 277 | self.logger.info('ready and listening') 278 | start_t = time.perf_counter() 279 | for r in self.estimator.predict(input_fn, yield_single_examples=False): 280 | # logger.info('new result!') 281 | send_ndarray(sink, r['client_id'], r['encodes']) 282 | time_used = time.perf_counter() - start_t 283 | start_t = time.perf_counter() 284 | self.logger.info('job %s\tsamples: %4d\tdone: %.2fs' % 285 | (r['client_id'], r['encodes'].shape[0], time_used)) 286 | 287 | receiver.close() 288 | sink.close() 289 | context.term() 290 | self.logger.info('terminated!') 291 | 292 | def input_fn_builder(self, worker): 293 | def gen(): 294 | while not self.exit_flag.is_set(): 295 | client_id, msg = worker.recv_multipart() 296 | msg = jsonapi.loads(msg) 297 | self.logger.info('new job %s, size: %d' % (client_id, len(msg))) 298 | if BertClient.is_valid_input(msg): 299 | tmp_f = list(convert_lst_to_features(msg, self.max_seq_len, self.tokenizer)) 300 | yield { 301 | 'client_id': client_id, 302 | 'input_ids': [f.input_ids for f in tmp_f], 303 | 'input_mask': [f.input_mask for f in tmp_f], 304 | 'input_type_ids': [f.input_type_ids for f in tmp_f] 305 | } 306 | else: 307 | self.logger.error('unsupported type of job %s! sending back None' % client_id) 308 | 309 | def input_fn(): 310 | return (tf.data.Dataset.from_generator( 311 | gen, 312 | output_types={'input_ids': tf.int32, 313 | 'input_mask': tf.int32, 314 | 'input_type_ids': tf.int32, 315 | 'client_id': tf.string}, 316 | output_shapes={ 317 | 'client_id': (), 318 | 'input_ids': (None, self.max_seq_len), 319 | 'input_mask': (None, self.max_seq_len), 320 | 'input_type_ids': (None, self.max_seq_len)})) 321 | 322 | return input_fn 323 | 324 | 325 | def send_ndarray(src, dest, X, flags=0, copy=True, track=False): 326 | """send a numpy array with metadata""" 327 | md = dict(dtype=str(X.dtype), shape=X.shape) 328 | return src.send_multipart([dest, jsonapi.dumps(md), X], flags, copy=copy, track=track) 329 | -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1543996656.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1543996656.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544006771.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544006771.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544060846.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544060846.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544086349.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544086349.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544098320.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544098320.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544100290.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544100290.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544100390.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544100390.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544145498.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544145498.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544145532.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544145532.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544145759.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544145759.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544146122.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544146122.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544146235.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544146235.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544260427.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544260427.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544260511.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544260511.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544320131.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544320131.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544882354.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544882354.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544882966.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544882966.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544884075.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544884075.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544884383.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544884383.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544884648.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544884648.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544886089.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544886089.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544924477.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544924477.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544925170.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544925170.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544925237.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544925237.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544925420.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544925420.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544927106.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544927106.ubuntu -------------------------------------------------------------------------------- /tensorboard/textlstm/events.out.tfevents.1544929233.ubuntu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengwei-iie/Bert-TextClassification/d9f79ccab909a5f48082ed7b866293a9caaf2abb/tensorboard/textlstm/events.out.tfevents.1544929233.ubuntu -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | from service.client import BertClient 3 | 4 | bc = BertClient() 5 | s = ['你好。', '嗨!', '今日我本来打算去公园。'] 6 | a = bc.encode(s) 7 | # a = bc.encode(['Today is good.']) 8 | print(a.shape) 9 | # print(s.type) 10 | print(a[0][1].shape) 11 | print(a[0][1]) 12 | print(a[0][4]) 13 | print(a[1][3]) 14 | print(a[0][5]) 15 | --------------------------------------------------------------------------------