├── README.md └── convert.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_bert_to_tf 2 | pytorch版bert权重转tf 3 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # pytorch版bert权重转tf 3 | # 参考:https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_pytorch_checkpoint_to_original_tf.py 4 | 5 | import torch 6 | import tensorflow as tf 7 | 8 | in_file = '/root/kg/bert/thu/ms_bert/ms/pytorch_model.bin' 9 | out_file = '/root/kg/bert/thu/ms_bert/bert_model.ckpt' 10 | 11 | torch_weights = torch.load(in_file, map_location='cpu') 12 | tensors_to_transpose = ( 13 | "dense.weight", "attention.self.query", "attention.self.key", 14 | "attention.self.value" 15 | ) 16 | 17 | var_map = ( 18 | ('layer.', 'layer_'), 19 | ('word_embeddings.weight', 'word_embeddings'), 20 | ('position_embeddings.weight', 'position_embeddings'), 21 | ('token_type_embeddings.weight', 'token_type_embeddings'), 22 | ('.', '/'), 23 | ('LayerNorm/weight', 'LayerNorm/gamma'), 24 | ('LayerNorm/bias', 'LayerNorm/beta'), 25 | ('weight', 'kernel'), 26 | ('cls/predictions/bias', 'cls/predictions/output_bias'), 27 | ('cls/seq_relationship/kernel', 'cls/seq_relationship/output_weights'), 28 | ('cls/seq_relationship/bias', 'cls/seq_relationship/output_bias'), 29 | ) 30 | 31 | 32 | def to_tf_var_name(name): 33 | for patt, repl in iter(var_map): 34 | name = name.replace(patt, repl) 35 | return name 36 | 37 | 38 | with tf.Graph().as_default(): 39 | for var_name in torch_weights: 40 | tf_name = to_tf_var_name(var_name) 41 | print(tf_name) 42 | torch_tensor = torch_weights[var_name].numpy() 43 | if any([x in var_name for x in tensors_to_transpose]): 44 | torch_tensor = torch_tensor.T 45 | tf_var = tf.Variable(torch_tensor, name=tf_name) 46 | with tf.Session() as sess: 47 | sess.run(tf.global_variables_initializer()) 48 | saver = tf.train.Saver() 49 | saver.save(sess, out_file, write_meta_graph=False) 50 | --------------------------------------------------------------------------------