├── README.md └── bert_of_theseus.py /README.md: -------------------------------------------------------------------------------- 1 | # bert-of-theseus 2 | 3 | Keras version of bert-of-theseus. Also included in [task_iflytek_bert_of_theseus.py](https://github.com/bojone/bert4keras/blob/master/examples/task_iflytek_bert_of_theseus.py). 4 | 5 | ## Requirements 6 | ```bash 7 | pip install bert4keras==0.8.3 8 | ``` 9 | 10 | ## Introduction 11 | - Blog: https://kexue.fm/archives/7575 12 | 13 | ## Contact 14 | - QQ Group: 67729435 15 | - WeChat: spaces_ac_cn 16 | -------------------------------------------------------------------------------- /bert_of_theseus.py: -------------------------------------------------------------------------------- 1 | #! -*- coding:utf-8 -*- 2 | # 文本分类例子下的模型压缩 3 | # 方法为BERT-of-Theseus 4 | # 论文:https://arxiv.org/abs/2002.02925 5 | # 博客:https://kexue.fm/archives/7575 6 | 7 | import json 8 | import numpy as np 9 | from bert4keras.backend import keras, K 10 | from bert4keras.tokenizers import Tokenizer 11 | from bert4keras.models import build_transformer_model 12 | from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr 13 | from bert4keras.snippets import sequence_padding, DataGenerator 14 | from bert4keras.snippets import open 15 | from keras.layers import Input, Lambda, Dense, Layer 16 | from keras.models import Model 17 | 18 | num_classes = 119 19 | maxlen = 128 20 | batch_size = 32 21 | 22 | # BERT base 23 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 24 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 25 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 26 | 27 | 28 | def load_data(filename): 29 | D = [] 30 | with open(filename) as f: 31 | for i, l in enumerate(f): 32 | l = json.loads(l) 33 | text, label = l['sentence'], l['label'] 34 | D.append((text, int(label))) 35 | return D 36 | 37 | 38 | # 加载数据集 39 | train_data = load_data( 40 | '/root/CLUE-master/baselines/CLUEdataset/iflytek/train.json' 41 | ) 42 | valid_data = load_data( 43 | '/root/CLUE-master/baselines/CLUEdataset/iflytek/dev.json' 44 | ) 45 | 46 | # 建立分词器 47 | tokenizer = Tokenizer(dict_path, do_lower_case=True) 48 | 49 | 50 | class data_generator(DataGenerator): 51 | """数据生成器 52 | """ 53 | def __iter__(self, random=False): 54 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 55 | for is_end, (text, label) in self.sample(random): 56 | token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen) 57 | batch_token_ids.append(token_ids) 58 | batch_segment_ids.append(segment_ids) 59 | batch_labels.append([label]) 60 | if len(batch_token_ids) == self.batch_size or is_end: 61 | batch_token_ids = sequence_padding(batch_token_ids) 62 | batch_segment_ids = sequence_padding(batch_segment_ids) 63 | batch_labels = sequence_padding(batch_labels) 64 | yield [batch_token_ids, batch_segment_ids], batch_labels 65 | batch_token_ids, batch_segment_ids, batch_labels = [], [], [] 66 | 67 | 68 | # 转换数据集 69 | train_generator = data_generator(train_data, batch_size) 70 | valid_generator = data_generator(valid_data, batch_size) 71 | 72 | 73 | class BinaryRandomChoice(Layer): 74 | """随机二选一 75 | """ 76 | def __init__(self, **kwargs): 77 | super(BinaryRandomChoice, self).__init__(**kwargs) 78 | self.supports_masking = True 79 | 80 | def compute_mask(self, inputs, mask=None): 81 | if mask is not None: 82 | return mask[1] 83 | 84 | def call(self, inputs): 85 | source, target = inputs 86 | mask = K.random_binomial(shape=[1], p=0.5) 87 | output = mask * source + (1 - mask) * target 88 | return K.in_train_phase(output, target) 89 | 90 | def compute_output_shape(self, input_shape): 91 | return input_shape[1] 92 | 93 | 94 | def bert_of_theseus(predecessor, successor, classfier): 95 | """bert of theseus 96 | """ 97 | inputs = predecessor.inputs 98 | # 固定住已经训练好的层 99 | for layer in predecessor.model.layers: 100 | layer.trainable = False 101 | classfier.trainable = False 102 | # Embedding层替换 103 | predecessor_outputs = predecessor.apply_embeddings(inputs) 104 | successor_outputs = successor.apply_embeddings(inputs) 105 | outputs = BinaryRandomChoice()([predecessor_outputs, successor_outputs]) 106 | # Transformer层替换 107 | layers_per_module = predecessor.num_hidden_layers // successor.num_hidden_layers 108 | for index in range(successor.num_hidden_layers): 109 | predecessor_outputs = outputs 110 | for sub_index in range(layers_per_module): 111 | predecessor_outputs = predecessor.apply_main_layers( 112 | predecessor_outputs, layers_per_module * index + sub_index 113 | ) 114 | successor_outputs = successor.apply_main_layers(outputs, index) 115 | outputs = BinaryRandomChoice()([predecessor_outputs, successor_outputs]) 116 | # 返回模型 117 | outputs = classfier(outputs) 118 | model = Model(inputs, outputs) 119 | return model 120 | 121 | 122 | def evaluate(data, model): 123 | total, right = 0., 0. 124 | for x_true, y_true in data: 125 | y_pred = model.predict(x_true).argmax(axis=1) 126 | y_true = y_true[:, 0] 127 | total += len(y_true) 128 | right += (y_true == y_pred).sum() 129 | return right / total 130 | 131 | 132 | class Evaluator(keras.callbacks.Callback): 133 | def __init__(self, savename): 134 | self.best_val_acc = 0. 135 | self.savename = savename 136 | 137 | def on_epoch_end(self, epoch, logs=None): 138 | val_acc = evaluate(valid_generator, self.model) 139 | if val_acc > self.best_val_acc: 140 | self.best_val_acc = val_acc 141 | self.model.save_weights(self.savename) 142 | print( 143 | u'val_acc: %.5f, best_val_acc: %.5f\n' % 144 | (val_acc, self.best_val_acc) 145 | ) 146 | 147 | 148 | # 加载预训练模型(12层) 149 | predecessor = build_transformer_model( 150 | config_path=config_path, 151 | checkpoint_path=checkpoint_path, 152 | return_keras_model=False, 153 | prefix='Predecessor-' 154 | ) 155 | 156 | # 加载预训练模型(3层) 157 | successor = build_transformer_model( 158 | config_path=config_path, 159 | checkpoint_path=checkpoint_path, 160 | return_keras_model=False, 161 | num_hidden_layers=3, 162 | prefix='Successor-' 163 | ) 164 | 165 | # 判别模型 166 | x_in = Input(shape=K.int_shape(predecessor.output)[1:]) 167 | x = Lambda(lambda x: x[:, 0])(x_in) 168 | x = Dense(units=num_classes, activation='softmax')(x) 169 | classfier = Model(x_in, x) 170 | 171 | predecessor_model = Model(predecessor.inputs, classfier(predecessor.output)) 172 | predecessor_model.compile( 173 | loss='sparse_categorical_crossentropy', 174 | optimizer=Adam(2e-5), # 用足够小的学习率 175 | metrics=['sparse_categorical_accuracy'], 176 | ) 177 | predecessor_model.summary() 178 | 179 | successor_model = Model(successor.inputs, classfier(successor.output)) 180 | successor_model.compile( 181 | loss='sparse_categorical_crossentropy', 182 | optimizer=Adam(2e-5), # 用足够小的学习率 183 | metrics=['sparse_categorical_accuracy'], 184 | ) 185 | successor_model.summary() 186 | 187 | theseus_model = bert_of_theseus(predecessor, successor, classfier) 188 | theseus_model.compile( 189 | loss='sparse_categorical_crossentropy', 190 | optimizer=Adam(2e-5), # 用足够小的学习率 191 | metrics=['sparse_categorical_accuracy'], 192 | ) 193 | theseus_model.summary() 194 | 195 | if __name__ == '__main__': 196 | 197 | # 训练predecessor 198 | predecessor_evaluator = Evaluator('best_predecessor.weights') 199 | predecessor_model.fit_generator( 200 | train_generator.forfit(), 201 | steps_per_epoch=len(train_generator), 202 | epochs=5, 203 | callbacks=[predecessor_evaluator] 204 | ) 205 | 206 | # 训练theseus 207 | theseus_evaluator = Evaluator('best_theseus.weights') 208 | theseus_model.fit_generator( 209 | train_generator.forfit(), 210 | steps_per_epoch=len(train_generator), 211 | epochs=10, 212 | callbacks=[theseus_evaluator] 213 | ) 214 | theseus_model.load_weights('best_theseus.weights') 215 | 216 | # 训练successor 217 | successor_evaluator = Evaluator('best_successor.weights') 218 | successor_model.fit_generator( 219 | train_generator.forfit(), 220 | steps_per_epoch=len(train_generator), 221 | epochs=5, 222 | callbacks=[successor_evaluator] 223 | ) 224 | --------------------------------------------------------------------------------