├── README.md ├── example.py └── convert.py /README.md: -------------------------------------------------------------------------------- 1 | # CDial-GPT-tf 2 | 用bert4keras加载CDial-GPT。原版链接:https://github.com/thu-coai/CDial-GPT 3 | 4 | ## 下载 5 | 6 | **GPT_LCCC-base-tf.zip**,链接: https://pan.baidu.com/s/1-k9_PWg9GuYnywxZabDCyQ 提取码: vs5b 7 | 8 | **GPT_LCCC-large-tf.zip**,链接: https://pan.baidu.com/s/1Akw4NxjPJC-LiXISEfK-nw 提取码: ydzr 9 | 10 | 注:这里的base和large描述的是语料的大小,并非模型的大小,两个模型都是BERT Base级别大小的。 11 | 12 | ## 使用 13 | 14 | 参考[example.py](https://github.com/bojone/CDial-GPT-tf/blob/master/example.py)。关于输入格式,请参考原版项目介绍。 15 | 16 | bert4keras >= 0.9.3 17 | 18 | ## 交流 19 | 20 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn 21 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # bert4keras加载CDial-GPT 3 | # https://github.com/bojone/CDial-GPT-tf 4 | 5 | import numpy as np 6 | from bert4keras.models import build_transformer_model 7 | from bert4keras.tokenizers import Tokenizer 8 | from bert4keras.snippets import AutoRegressiveDecoder 9 | from bert4keras.snippets import uniout 10 | 11 | config_path = '/root/kg/bert/GPT_LCCC-base-tf/gpt_config.json' 12 | checkpoint_path = '/root/kg/bert/GPT_LCCC-base-tf/gpt_model.ckpt' 13 | dict_path = '/root/kg/bert/GPT_LCCC-base-tf/vocab.txt' 14 | 15 | tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器 16 | speakers = [ 17 | tokenizer.token_to_id('[speaker1]'), 18 | tokenizer.token_to_id('[speaker2]') 19 | ] 20 | 21 | model = build_transformer_model( 22 | config_path=config_path, 23 | checkpoint_path=checkpoint_path, 24 | model='gpt' 25 | ) # 建立模型,加载权重 26 | 27 | 28 | class ChatBot(AutoRegressiveDecoder): 29 | """基于随机采样对话机器人 30 | """ 31 | @AutoRegressiveDecoder.wraps(default_rtype='probas') 32 | def predict(self, inputs, output_ids, states): 33 | token_ids, segment_ids = inputs 34 | curr_segment_ids = np.zeros_like(output_ids) + token_ids[0, -1] 35 | token_ids = np.concatenate([token_ids, output_ids], 1) 36 | segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1) 37 | return model.predict([token_ids, segment_ids])[:, -1] 38 | 39 | def response(self, texts, topk=5): 40 | token_ids = [tokenizer._token_start_id, speakers[0]] 41 | segment_ids = [tokenizer._token_start_id, speakers[0]] 42 | for i, text in enumerate(texts): 43 | ids = tokenizer.encode(text)[0][1:-1] + [speakers[(i + 1) % 2]] 44 | token_ids.extend(ids) 45 | segment_ids.extend([speakers[i % 2]] * len(ids)) 46 | segment_ids[-1] = speakers[(i + 1) % 2] 47 | results = self.random_sample([token_ids, segment_ids], 1, topk) 48 | return tokenizer.decode(results[0]) 49 | 50 | 51 | chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32) 52 | print(chatbot.response([u'别爱我没结果', u'你这样会失去我的', u'失去了又能怎样'])) 53 | """ 54 | 回复是随机的,例如:你还有我 | 那就不要爱我 | 你是不是傻 | 等等。 55 | """ 56 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 将CDial-GPT权重转为tf版,方便后面用bert4keras加载 3 | 4 | import numpy as np 5 | import torch 6 | import tensorflow as tf 7 | import keras.backend as K 8 | 9 | in_file = 'GPT_LCCC-base/pytorch_model.bin' 10 | out_file = 'GPT_LCCC-base-tf/gpt_model.ckpt' 11 | num_hidden_layers = 12 12 | 13 | torch_weights = torch.load(in_file, map_location='cpu') 14 | tf_weights = {} 15 | 16 | # CDial-GPT的[CLS]是0、[PAD]是1,不符合一般习惯,所以交换一下 17 | w = torch_weights['transformer.tokens_embed.weight'].numpy() 18 | w = np.concatenate([w[1:2], w[:1], w[2:]], axis=0) 19 | tf_weights['gpt/embeddings/word_embeddings'] = w 20 | 21 | w = torch_weights['transformer.positions_embed.weight'].numpy() 22 | tf_weights['gpt/embeddings/position_embeddings'] = w 23 | 24 | qkv = ['query', 'key', 'value'] 25 | for i in range(num_hidden_layers): 26 | w = torch_weights['transformer.h.%s.attn.c_attn.weight' % i].numpy() 27 | ws = np.split(w, 3, axis=1) 28 | for k, w in zip(qkv, ws): 29 | name = 'gpt/transformer/layer_%s/attention/self/%s/kernel' % (i, k) 30 | tf_weights[name] = w 31 | b = torch_weights['transformer.h.%s.attn.c_attn.bias' % i].numpy() 32 | bs = np.split(b, 3, axis=0) 33 | for k, b in zip(qkv, bs): 34 | name = 'gpt/transformer/layer_%s/attention/self/%s/bias' % (i, k) 35 | tf_weights[name] = b 36 | w = torch_weights['transformer.h.%s.attn.c_proj.weight' % i].numpy() 37 | name = 'gpt/transformer/layer_%s/attention/output/dense/kernel' % i 38 | tf_weights[name] = w 39 | b = torch_weights['transformer.h.%s.attn.c_proj.bias' % i].numpy() 40 | name = 'gpt/transformer/layer_%s/attention/output/dense/bias' % i 41 | tf_weights[name] = b 42 | w = torch_weights['transformer.h.%s.ln_1.weight' % i].numpy() 43 | name = 'gpt/transformer/layer_%s/attention/output/LayerNorm/gamma' % i 44 | tf_weights[name] = w 45 | b = torch_weights['transformer.h.%s.ln_1.bias' % i].numpy() 46 | name = 'gpt/transformer/layer_%s/attention/output/LayerNorm/beta' % i 47 | tf_weights[name] = b 48 | w = torch_weights['transformer.h.%s.mlp.c_fc.weight' % i].numpy() 49 | name = 'gpt/transformer/layer_%s/intermediate/dense/kernel' % i 50 | tf_weights[name] = w 51 | b = torch_weights['transformer.h.%s.mlp.c_fc.bias' % i].numpy() 52 | name = 'gpt/transformer/layer_%s/intermediate/dense/bias' % i 53 | tf_weights[name] = b 54 | w = torch_weights['transformer.h.%s.mlp.c_proj.weight' % i].numpy() 55 | name = 'gpt/transformer/layer_%s/output/dense/kernel' % i 56 | tf_weights[name] = w 57 | b = torch_weights['transformer.h.%s.mlp.c_proj.bias' % i].numpy() 58 | name = 'gpt/transformer/layer_%s/output/dense/bias' % i 59 | tf_weights[name] = b 60 | w = torch_weights['transformer.h.%s.ln_2.weight' % i].numpy() 61 | name = 'gpt/transformer/layer_%s/output/LayerNorm/gamma' % i 62 | tf_weights[name] = w 63 | b = torch_weights['transformer.h.%s.ln_2.bias' % i].numpy() 64 | name = 'gpt/transformer/layer_%s/output/LayerNorm/beta' % i 65 | tf_weights[name] = b 66 | 67 | with tf.Graph().as_default(): 68 | pairs = [] 69 | for name, value in tf_weights.items(): 70 | var = K.variable(tf.zeros(value.shape), name=name) 71 | pairs.append((var, value)) 72 | with tf.Session() as sess: 73 | K.batch_set_value(pairs) 74 | saver = tf.train.Saver() 75 | saver.save(sess, out_file) 76 | --------------------------------------------------------------------------------