├── ner ├── .gitignore ├── README.md ├── extract_txt.py ├── test_crf.py └── train_crf_loss.py ├── en2zh ├── .gitignore ├── README.md ├── extract_tmx.py ├── test.py └── train.py ├── chatbot ├── .gitignore ├── params.json ├── README.md ├── test.py ├── test_anti.py ├── test_compare.py ├── extract_conv.py ├── train.py └── train_anti.py ├── .gitignore ├── chatbot_cut ├── .gitignore ├── README.md ├── read_vector.py ├── test.py ├── test_anti.py ├── test_compare.py ├── extract_conv.py ├── train.py └── train_anti.py ├── xpy.sh ├── fake_data.py ├── threadedgenerator.py ├── test_atten.py ├── word_sequence.py ├── test_crf.py ├── README.md ├── test.py ├── data_utils.py ├── pylintrc ├── rnn_crf.py └── sequence_to_sequence.py /ner/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.txt 3 | *.pkl 4 | s2ss*.ckpt.* 5 | checkpoint 6 | -------------------------------------------------------------------------------- /en2zh/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.tmx 3 | *.gz 4 | *.pkl 5 | *.ckpt* 6 | checkpoint 7 | -------------------------------------------------------------------------------- /chatbot/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.zip 3 | *.pkl 4 | *.conv 5 | *.ckpt* 6 | checkpoint 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.py[cod] 3 | .ipynb_checkpoints 4 | __pycache__ 5 | bak 6 | *.pkl 7 | -------------------------------------------------------------------------------- /chatbot_cut/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.zip 3 | *.pkl 4 | *.conv 5 | *.ckpt* 6 | checkpoint 7 | *.vec 8 | -------------------------------------------------------------------------------- /xpy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #pylint **/*.py 4 | find . -iname "*.py" -not -path "./front/*" | xargs pylint 5 | -------------------------------------------------------------------------------- /chatbot/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "bidirectional": true, 3 | "use_residual": false, 4 | "use_dropout": false, 5 | "time_major": false, 6 | "cell_type": "lstm", 7 | "depth": 2, 8 | "attention_type": "Bahdanau", 9 | "hidden_units": 128, 10 | "optimizer": "adam", 11 | "learning_rate": 0.001, 12 | "embedding_size": 300 13 | } 14 | -------------------------------------------------------------------------------- /ner/README.md: -------------------------------------------------------------------------------- 1 | 2 | # NER 测试 3 | 4 | 中文 命名实体识别(Named Entity Recognizer) 测试 5 | 6 | ## 1、下载数据 7 | 8 | 下载页面 9 | 10 | https://github.com/lancopku/Chinese-Literature-NER-RE-Dataset 11 | 12 | 下载链接: 13 | 14 | wget "https://raw.githubusercontent.com/lancopku/Chinese-Literature-NER-RE-Dataset/master/ner/train.txt" 15 | 16 | wget "https://raw.githubusercontent.com/lancopku/Chinese-Literature-NER-RE-Dataset/master/ner/test.txt" 17 | 18 | wget "https://raw.githubusercontent.com/lancopku/Chinese-Literature-NER-RE-Dataset/master/ner/validation.txt" 19 | 20 | ## 2、预处理数据 21 | 22 | 运行 `extract_txt.py` 得到 `ner.pkl` 23 | 24 | ## 3、训练数据 25 | 26 | ### crf 模型 27 | 28 | 运行 `train_crf_loss.py` 训练(默认到`./s2ss_ner_crf.ckpt`) 29 | 30 | ## 4、测试数据(测试翻译) 31 | 32 | #### seq2seq 模型 33 | 34 | 运行 `test.py` 查看测试结果 35 | 36 | ### 或者 crf 模型 37 | 38 | 运行 `test_crf.py` 查看测试结果 39 | -------------------------------------------------------------------------------- /chatbot/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Chatbot测试 3 | 4 | ## 1、下载数据 5 | 6 | Subtitle data from [here](https://github.com/fateleak/dgk_lost_conv) 7 | 8 | ``` 9 | wget https://lvzhe.oss-cn-beijing.aliyuncs.com/dgk_shooter_min.conv.zip 10 | ``` 11 | 12 | 输出:`dgk_shooter_min.conv.zip` 13 | 14 | ## 2、解压缩 15 | 16 | ``` 17 | unzip dgk_shooter_min.conv.zip 18 | ``` 19 | 20 | 输出:`dgk_shooter_min.conv` 21 | 22 | ## 3、预处理数据 23 | 24 | ``` 25 | python3 extract_conv.py 26 | ``` 27 | 28 | 输出:`chatbot.pkl` 29 | 30 | ## 4、训练数据 31 | 32 | 运行 `python3 train.py` 训练(默认到`./s2ss_chatbot.ckpt`) 33 | 34 | 或者! 35 | 36 | 运行 `python3 train_anti.py` 训练抗语言模型(默认到`./s2ss_chatbot_anti.ckpt`) 37 | 38 | ## 5、测试数据(测试对话) 39 | 40 | 运行 `python3 test.py` 查看测试结果,需要提前训练普通模型 41 | 42 | 或者! 43 | 44 | 运行 `python3 test_anti.py` 查看抗语言模型的测试结果,需要提前训练抗语言模型 45 | 46 | 或者! 47 | 48 | 运行 `python3 test_compare.py` 查看普通模型和抗语言模型的对比测试结果, 49 | 需要提前训练两个模型 50 | -------------------------------------------------------------------------------- /chatbot_cut/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Chatbot测试 3 | 4 | ## 1、下载数据 5 | 6 | Subtitle data from [here](https://github.com/fateleak/dgk_lost_conv) 7 | 8 | ``` 9 | wget https://github.com/fateleak/dgk_lost_conv/raw/master/dgk_shooter_min.conv.zip 10 | ``` 11 | 12 | 输出:`dgk_shooter_min.conv.zip` 13 | 14 | ## 2、解压缩 15 | 16 | ``` 17 | unzip dgk_shooter_min.conv.zip 18 | ``` 19 | 20 | 输出:`dgk_shooter_min.conv` 21 | 22 | 23 | ## 3、下载训练好的fasttext的embedding 24 | 25 | https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md 26 | 27 | 注意是文本格式的 28 | 29 | ``` 30 | wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.zh.vec 31 | ``` 32 | 33 | 得到 `wiki.zh.vec` 文件 34 | 35 | ## 4、改变embedding格式 36 | 37 | 运行 38 | 39 | ``` 40 | python3 read_vector.py 41 | ``` 42 | 43 | 得到 `word_vec.pkl`文件在目录下 44 | 45 | ## 5、预处理数据(前面两步embedding部分必须执行完) 46 | 47 | ``` 48 | python3 extract_conv.py 49 | ``` 50 | 51 | 输出:`chatbot.pkl` 52 | 53 | ## 6、训练数据 54 | 55 | 运行 `python3 train.py` 训练(默认到`./s2ss_chatbot.ckpt`) 56 | 57 | 或者! 58 | 59 | 运行 `python3 train_anti.py` 训练抗语言模型(默认到`./s2ss_chatbot_anti.ckpt`) 60 | 61 | ## 7、测试数据(测试对话) 62 | 63 | 运行 `python3 test.py` 查看测试结果,需要提前训练普通模型 64 | 65 | 或者! 66 | 67 | 运行 `python3 test_anti.py` 查看抗语言模型的测试结果,需要提前训练抗语言模型 68 | 69 | 或者! 70 | 71 | 运行 `python3 test_compare.py` 查看普通模型和抗语言模型的对比测试结果, 72 | 需要提前训练两个模型 73 | -------------------------------------------------------------------------------- /chatbot_cut/read_vector.py: -------------------------------------------------------------------------------- 1 | """ 2 | 读取一个文本格式的,保存预训练好的embedding的文件 3 | 4 | wiki.zh.vec 5 | 6 | 它的第一行会被忽略 7 | 第二行开始,每行是 词 + 空格 + 词向量维度0 + 空格 + 词向量维度1 + ... 8 | 9 | 参考fasttext的文本格式 10 | 11 | https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md 12 | """ 13 | 14 | import pickle 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | 19 | def read_vector(path='wiki.zh.vec', output_path='word_vec.pkl'): 20 | """ 21 | 读取文本文件 path 中的数据,并且生成一个 dict 写入到 output_path 22 | 23 | 格式: 24 | word_vec = { 25 | 'word_1': np.array(vec_of_word_1), 26 | 'word_2': np.array(vec_of_word_2), 27 | ... 28 | } 29 | """ 30 | fp = open(path, 'r') 31 | word_vec = {} 32 | first_skip = False 33 | dim = None 34 | for line in tqdm(fp): 35 | if not first_skip: 36 | first_skip = True 37 | else: 38 | line = line.strip() 39 | line = line.split(' ') 40 | if len(line) >= 2: 41 | word = line[0] 42 | vec_text = line[1:] 43 | vec = np.array([float(v) for v in vec_text]) 44 | word_vec[word] = vec 45 | if dim is None: 46 | dim = vec.shape 47 | 48 | # PAD_TAG = '' 49 | # UNK_TAG = '' 50 | # START_TAG = '' 51 | # END_TAG = '' 52 | 53 | np.random.seed(0) 54 | word_vec[''] = np.random.random(size=(300,)) - 0.5 55 | word_vec[''] = np.random.random(size=(300,)) - 0.5 56 | word_vec[''] = np.random.random(size=(300,)) - 0.5 57 | 58 | pickle.dump(word_vec, open(output_path, 'wb')) 59 | 60 | if __name__ == '__main__': 61 | read_vector() 62 | -------------------------------------------------------------------------------- /en2zh/README.md: -------------------------------------------------------------------------------- 1 | 2 | # 英汉翻译测试 3 | 4 | ## 1、下载数据 5 | 6 | 下载页面 7 | 8 | Enlish and Chinese parallel corpus from [here](http://opus.nlpl.eu/OpenSubtitles2018.php) 9 | 10 | 下载链接: 11 | 12 | ``` 13 | wget -O "en-zh_cn.tmx.gz" "http://opus.nlpl.eu/download.php?f=OpenSubtitles2018/en-zh_cn.tmx.gz" 14 | ``` 15 | 16 | 输出:`en-zh_cn.tmx.gz` 17 | 18 | ## 2、解压数据 19 | 20 | 这个数据是`英文-中文`的平行语聊 21 | 22 | 解压缩: 23 | 24 | ``` 25 | gunzip -k en-zh_cn.tmx.gz 26 | ``` 27 | 28 | 下载并解压数据,然后重命名为 `en-zh_zh.tmx` (如果有有必要) 29 | 30 | 这应该是一个xml格式(在`linux`下可以用`head`命令查看下是否正确) 31 | 32 | ## 3、预处理数据 33 | 34 | 运行 `python3 extract_tmx.py` 得到 `en-zh_cn.pkl` 35 | 36 | ## 4、训练数据 37 | 38 | 运行 `python3 train.py` 训练(默认到`./s2ss_en2zh.ckpt`) 39 | 40 | ## 5、测试数据(测试翻译) 41 | 42 | 运行 `python3 test.py` 查看测试结果 43 | 44 | ## 测试结果样例 45 | 46 | ***我不保证能重复实现能得到一模一样的结果*** 47 | 48 | ``` 49 | Input English Sentence:go to hell 50 | [[30475 71929 33464]] [3] 51 | [[41337 48900 41337 44789 3]] 52 | ['go', 'to', 'hell'] 53 | ['去', '地狱', '去', '吧', ''] 54 | Input English Sentence:nothing, but the best for you 55 | [[50448 467 13008 71007 10118 27982 79204]] [7] 56 | [[ 25904 132783 90185 4 28145 81577 80498 28798 3]] 57 | ['nothing', ',', 'but', 'the', 'best', 'for', 'you'] 58 | ['什么', '都', '没有', ' ', '但', '最好', '是', '你', ''] 59 | Input English Sentence:i'm a bad boy 60 | [[35437 268 4018 8498 11775]] [5] 61 | [[ 69313 80498 21899 49069 100342 3 -1]] 62 | ['i', "'m", 'a', 'bad', 'boy'] 63 | ['我', '是', '个', '坏', '男孩', '', ''] 64 | Input English Sentence:i'm really a bad boy 65 | [[35437 268 58417 4018 8498 11775]] [6] 66 | [[ 69313 103249 80498 17043 49069 100342 3 3 3 3 67 | 3 3]] 68 | ['i', "'m", 'really', 'a', 'bad', 'boy'] 69 | ['我', '真的', '是', '一个', '坏', '男孩', '', '', '', '', '', ''] 70 | ``` 71 | -------------------------------------------------------------------------------- /ner/extract_txt.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | 把tmx(xml)的数据解开,分词,然后保存到data.pkl 4 | """ 5 | 6 | # import re 7 | import sys 8 | import pickle 9 | 10 | sys.path.append('..') 11 | 12 | def read_txt(path): 13 | """读取一个txt文件的NER标注数据""" 14 | x_data, y_data = [], [] 15 | x, y = [], [] 16 | for line in open(path, 'r'): 17 | line = line.strip() 18 | line = line.split(' ') 19 | if len(line) == 2: 20 | x.append(line[0]) 21 | y.append(line[1]) 22 | else: 23 | if x and y: 24 | x_data.append(x) 25 | y_data.append(y) 26 | x, y = [], [] 27 | return x_data, y_data 28 | 29 | def main(limit=100): 30 | """执行程序 31 | Args: 32 | limit: 只输出句子长度小于limit的句子 33 | """ 34 | from word_sequence import WordSequence 35 | 36 | x_data, y_data = [], [] 37 | 38 | x, y = read_txt('train.txt') 39 | x_data += x 40 | y_data += y 41 | 42 | x, y = read_txt('validation.txt') 43 | x_data += x 44 | y_data += y 45 | 46 | x, y = read_txt('test.txt') 47 | x_data += x 48 | y_data += y 49 | 50 | print(len(x_data)) 51 | 52 | print(x_data[:10]) 53 | print(y_data[:10]) 54 | 55 | print('tokenize') 56 | 57 | data = list(zip(x_data, y_data)) 58 | data = [(x, y) for x, y in data if len(x) < limit and len(y) < limit] 59 | x_data, y_data = zip(*data) 60 | 61 | print(x_data[:10]) 62 | print(y_data[:10]) 63 | 64 | print(len(x_data), len(y_data)) 65 | 66 | print('fit word_sequence') 67 | 68 | ws_input = WordSequence() 69 | ws_target = WordSequence() 70 | ws_input.fit(x_data, min_count=1) 71 | ws_target.fit(y_data, min_count=1) 72 | 73 | print('dump') 74 | 75 | pickle.dump( 76 | (x_data, y_data, ws_input, ws_target), 77 | open('ner.pkl', 'wb') 78 | ) 79 | 80 | print('done') 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /fake_data.py: -------------------------------------------------------------------------------- 1 | 2 | """生成一些虚假数据 3 | """ 4 | 5 | import random 6 | import numpy as np 7 | from word_sequence import WordSequence 8 | 9 | def generate(max_len=10, size=1000, same_len=False, seed=0): 10 | """生成虚假数据 11 | """ 12 | 13 | dictionary = { 14 | 'a': '1', 15 | 'b': '2', 16 | 'c': '3', 17 | 'd': '4', 18 | 'aa': '1', 19 | 'bb': '2', 20 | 'cc': '3', 21 | 'dd': '4', 22 | 'aaa': '1', 23 | } 24 | 25 | if seed is not None: 26 | random.seed(seed) 27 | 28 | input_list = sorted(list(dictionary.keys())) 29 | 30 | x_data = [] 31 | y_data = [] 32 | 33 | for _ in range(size): 34 | a_len = int(random.random() * max_len) + 1 35 | x = [] 36 | y = [] 37 | for _ in range(a_len): 38 | word = input_list[int(random.random() * len(input_list))] 39 | x.append(word) 40 | y.append(dictionary[word]) 41 | if not same_len: 42 | if y[-1] == '2': 43 | y.append('2') 44 | elif y[-1] == '3': 45 | y.append('3') 46 | y.append('4') 47 | x_data.append(x) 48 | y_data.append(y) 49 | 50 | ws_input = WordSequence() 51 | ws_input.fit(x_data)# + y_data) 52 | # ws_target = ws_input 53 | ws_target = WordSequence() 54 | ws_target.fit(y_data) 55 | return x_data, y_data, ws_input, ws_target 56 | 57 | def test(): 58 | """测试自身""" 59 | x_data, y_data, ws_input, ws_target = generate() 60 | print(len(x_data)) 61 | assert len(x_data) == 1000 62 | print(len(y_data)) 63 | assert len(y_data) == 1000 64 | print(np.max([len(x) for x in x_data])) 65 | assert np.max([len(x) for x in x_data]) == 10 66 | print(len(ws_input)) 67 | assert len(ws_input) == 14 68 | print(len(ws_target)) 69 | assert len(ws_target) == 9 70 | print('done') 71 | 72 | if __name__ == '__main__': 73 | test() 74 | -------------------------------------------------------------------------------- /chatbot/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | # import jieba 12 | # from nltk.tokenize import word_tokenize 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(params): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | 24 | x_data, _ = pickle.load(open('chatbot.pkl', 'rb')) 25 | ws = pickle.load(open('ws.pkl', 'rb')) 26 | 27 | for x in x_data[:5]: 28 | print(' '.join(x)) 29 | 30 | config = tf.ConfigProto( 31 | device_count={'CPU': 1, 'GPU': 0}, 32 | allow_soft_placement=True, 33 | log_device_placement=False 34 | ) 35 | 36 | # save_path = '/tmp/s2ss_chatbot.ckpt' 37 | save_path = './s2ss_chatbot.ckpt' 38 | 39 | # 测试部分 40 | tf.reset_default_graph() 41 | model_pred = SequenceToSequence( 42 | input_vocab_size=len(ws), 43 | target_vocab_size=len(ws), 44 | batch_size=1, 45 | mode='decode', 46 | beam_width=0, 47 | **params 48 | ) 49 | init = tf.global_variables_initializer() 50 | 51 | with tf.Session(config=config) as sess: 52 | sess.run(init) 53 | model_pred.load(sess, save_path) 54 | 55 | while True: 56 | user_text = input('Input Chat Sentence:') 57 | if user_text in ('exit', 'quit'): 58 | exit(0) 59 | x_test = [list(user_text.lower())] 60 | # x_test = [word_tokenize(user_text)] 61 | bar = batch_flow([x_test], ws, 1) 62 | x, xl = next(bar) 63 | x = np.flip(x, axis=1) 64 | # x = np.array([ 65 | # list(reversed(xx)) 66 | # for xx in x 67 | # ]) 68 | print(x, xl) 69 | pred = model_pred.predict( 70 | sess, 71 | np.array(x), 72 | np.array(xl) 73 | ) 74 | print(pred) 75 | # prob = np.exp(prob.transpose()) 76 | print(ws.inverse_transform(x[0])) 77 | # print(ws.inverse_transform(pred[0])) 78 | # print(pred.shape, prob.shape) 79 | for p in pred: 80 | ans = ws.inverse_transform(p) 81 | print(ans) 82 | 83 | 84 | def main(): 85 | """入口程序""" 86 | import json 87 | test(json.load(open('params.json'))) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /chatbot/test_anti.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | # import jieba 12 | # from nltk.tokenize import word_tokenize 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(params): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | 24 | x_data, _ = pickle.load(open('chatbot.pkl', 'rb')) 25 | ws = pickle.load(open('ws.pkl', 'rb')) 26 | 27 | for x in x_data[:5]: 28 | print(' '.join(x)) 29 | 30 | config = tf.ConfigProto( 31 | device_count={'CPU': 1, 'GPU': 0}, 32 | allow_soft_placement=True, 33 | log_device_placement=False 34 | ) 35 | 36 | # save_path = '/tmp/s2ss_chatbot.ckpt' 37 | save_path = './s2ss_chatbot_anti.ckpt' 38 | 39 | # 测试部分 40 | tf.reset_default_graph() 41 | model_pred = SequenceToSequence( 42 | input_vocab_size=len(ws), 43 | target_vocab_size=len(ws), 44 | batch_size=1, 45 | mode='decode', 46 | beam_width=0, 47 | **params 48 | ) 49 | init = tf.global_variables_initializer() 50 | 51 | with tf.Session(config=config) as sess: 52 | sess.run(init) 53 | model_pred.load(sess, save_path) 54 | 55 | while True: 56 | user_text = input('Input Chat Sentence:') 57 | if user_text in ('exit', 'quit'): 58 | exit(0) 59 | x_test = [list(user_text.lower())] 60 | # x_test = [word_tokenize(user_text)] 61 | bar = batch_flow([x_test], ws, 1) 62 | x, xl = next(bar) 63 | x = np.flip(x, axis=1) 64 | # x = np.array([ 65 | # list(reversed(xx)) 66 | # for xx in x 67 | # ]) 68 | print(x, xl) 69 | pred = model_pred.predict( 70 | sess, 71 | np.array(x), 72 | np.array(xl) 73 | ) 74 | print(pred) 75 | # prob = np.exp(prob.transpose()) 76 | print(ws.inverse_transform(x[0])) 77 | # print(ws.inverse_transform(pred[0])) 78 | # print(pred.shape, prob.shape) 79 | for p in pred: 80 | ans = ws.inverse_transform(p) 81 | print(ans) 82 | 83 | 84 | def main(): 85 | """入口程序""" 86 | import json 87 | test(json.load(open('params.json'))) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /en2zh/extract_tmx.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | 把tmx(xml)的数据解开,分词,然后保存到data.pkl 4 | """ 5 | 6 | # import re 7 | import sys 8 | import pickle 9 | import xml.etree.ElementTree as ET 10 | from tqdm import tqdm 11 | import nltk 12 | import jieba 13 | 14 | sys.path.append('..') 15 | 16 | def main(limit=15): 17 | """执行程序 18 | Args: 19 | limit: 只输出句子长度小于limit的句子 20 | """ 21 | from word_sequence import WordSequence 22 | 23 | x_data, y_data = [], [] 24 | tree = ET.parse('en-zh_cn.tmx') 25 | root = tree.getroot() 26 | body = root.find('body') 27 | for tu in tqdm(body.findall('tu')): 28 | en = '' 29 | zh = '' 30 | for tuv in tu.findall('tuv'): 31 | if list(tuv.attrib.values())[0] == 'en': 32 | en += tuv.find('seg').text 33 | elif list(tuv.attrib.values())[0] == 'zh_cn': 34 | zh += tuv.find('seg').text 35 | 36 | if en and zh: 37 | x_data.append(en) 38 | y_data.append(zh) 39 | 40 | print(len(x_data)) 41 | 42 | print(x_data[:10]) 43 | print(y_data[:10]) 44 | 45 | print('tokenize') 46 | 47 | def en_tokenize(text): 48 | # text = re.sub('[\((][^\))]+[\))]', '', text) 49 | return nltk.word_tokenize(text.lower()) 50 | 51 | x_data = [ 52 | en_tokenize(x) 53 | for x in tqdm(x_data) 54 | ] 55 | 56 | def zh_tokenize(text): 57 | # text = text.replace(',', ',') 58 | # text = text.replace('。', '.') 59 | # text = text.replace('?', '?') 60 | # text = text.replace('!', '!') 61 | # text = text.replace(':', ':') 62 | # text = re.sub(r'[^\u4e00-\u9fff,\.\?\!…《》]:', '', text) 63 | # text = text.strip() 64 | text = jieba.lcut(text.lower()) 65 | return text 66 | 67 | y_data = [ 68 | zh_tokenize(y) 69 | for y in tqdm(y_data) 70 | ] 71 | 72 | data = list(zip(x_data, y_data)) 73 | data = [(x, y) for x, y in data if len(x) < limit and len(y) < limit] 74 | x_data, y_data = zip(*data) 75 | 76 | print(x_data[:10]) 77 | print(y_data[:10]) 78 | 79 | print(len(x_data), len(y_data)) 80 | 81 | print('fit word_sequence') 82 | 83 | ws_input = WordSequence() 84 | ws_target = WordSequence() 85 | ws_input.fit(x_data) 86 | ws_target.fit(y_data) 87 | 88 | print('dump') 89 | 90 | pickle.dump( 91 | (x_data, y_data, ws_input, ws_target), 92 | open('en-zh_cn.pkl', 'wb') 93 | ) 94 | 95 | print('done') 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /ner/test_crf.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对RNNCRF模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | sys.path.append('..') 13 | 14 | 15 | def test(bidirectional, cell_type, depth, 16 | use_residual, use_dropout, time_major, 17 | hidden_units, output_project_active): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from rnn_crf import RNNCRF 21 | from data_utils import batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | 24 | x_data, _, ws_input, ws_target = pickle.load(open('ner.pkl', 'rb')) 25 | 26 | for x in x_data[:5]: 27 | print(' '.join(x)) 28 | 29 | config = tf.ConfigProto( 30 | device_count={'CPU': 1, 'GPU': 0}, 31 | allow_soft_placement=True, 32 | log_device_placement=False 33 | ) 34 | 35 | save_path = './s2ss_crf.ckpt' 36 | 37 | # 测试部分 38 | tf.reset_default_graph() 39 | model_pred = RNNCRF( 40 | input_vocab_size=len(ws_input), 41 | target_vocab_size=len(ws_target), 42 | max_decode_step=100, 43 | batch_size=1, 44 | mode='decode', 45 | bidirectional=bidirectional, 46 | cell_type=cell_type, 47 | depth=depth, 48 | use_residual=use_residual, 49 | use_dropout=use_dropout, 50 | parallel_iterations=1, 51 | time_major=time_major, 52 | hidden_units=hidden_units, 53 | output_project_active=output_project_active 54 | ) 55 | init = tf.global_variables_initializer() 56 | 57 | with tf.Session(config=config) as sess: 58 | sess.run(init) 59 | model_pred.load(sess, save_path) 60 | 61 | while True: 62 | user_text = input('Input Sentence:') 63 | if user_text in ('exit', 'quit'): 64 | exit(0) 65 | x_test = [list(user_text.lower())] 66 | bar = batch_flow([x_test, x_test], [ws_input, ws_target], 1) 67 | x, xl, _, _ = next(bar) 68 | # x = np.array([ 69 | # list(reversed(xx)) 70 | # for xx in x 71 | # ]) 72 | print(x, xl) 73 | pred = model_pred.predict( 74 | sess, 75 | np.array(x), 76 | np.array(xl) 77 | ) 78 | print(pred) 79 | print(ws_input.inverse_transform(x[0])) 80 | print(ws_target.inverse_transform(pred[0])) 81 | 82 | 83 | def main(): 84 | """入口程序,开始测试不同参数组合""" 85 | random.seed(0) 86 | np.random.seed(0) 87 | tf.set_random_seed(0) 88 | test(True, 'lstm', 1, False, True, False, 64, 'tanh') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /en2zh/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import nltk 12 | 13 | sys.path.append('..') 14 | 15 | 16 | def test(bidirectional, cell_type, depth, 17 | attention_type, use_residual, use_dropout, time_major, hidden_units): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | 24 | x_data, _, ws_input, ws_target = pickle.load(open('en-zh_cn.pkl', 'rb')) 25 | 26 | for x in x_data[:5]: 27 | print(' '.join(x)) 28 | 29 | config = tf.ConfigProto( 30 | device_count={'CPU': 1, 'GPU': 0}, 31 | allow_soft_placement=True, 32 | log_device_placement=False 33 | ) 34 | 35 | save_path = './s2ss_en2zh.ckpt' 36 | 37 | # 测试部分 38 | tf.reset_default_graph() 39 | model_pred = SequenceToSequence( 40 | input_vocab_size=len(ws_input), 41 | target_vocab_size=len(ws_target), 42 | batch_size=1, 43 | mode='decode', 44 | beam_width=12, 45 | bidirectional=bidirectional, 46 | cell_type=cell_type, 47 | depth=depth, 48 | attention_type=attention_type, 49 | use_residual=use_residual, 50 | use_dropout=use_dropout, 51 | parallel_iterations=1, 52 | time_major=time_major, 53 | hidden_units=hidden_units # for test 54 | ) 55 | init = tf.global_variables_initializer() 56 | 57 | with tf.Session(config=config) as sess: 58 | sess.run(init) 59 | model_pred.load(sess, save_path) 60 | 61 | while True: 62 | user_text = input('Input English Sentence:') 63 | if user_text in ('exit', 'quit'): 64 | exit(0) 65 | x_test = [nltk.word_tokenize(user_text.lower())] 66 | bar = batch_flow([x_test], [ws_input], 1) 67 | x, xl = next(bar) 68 | # x = np.array([ 69 | # list(reversed(xx)) 70 | # for xx in x 71 | # ]) 72 | print(x, xl) 73 | pred = model_pred.predict( 74 | sess, 75 | np.array(x), 76 | np.array(xl) 77 | ) 78 | print(pred) 79 | print(ws_input.inverse_transform(x[0])) 80 | print(ws_target.inverse_transform(pred[0])) 81 | 82 | 83 | def main(): 84 | """入口程序,开始测试不同参数组合""" 85 | random.seed(0) 86 | np.random.seed(0) 87 | tf.set_random_seed(0) 88 | test(True, 'lstm', 3, 'Bahdanau', True, True, True, 64) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /chatbot/test_compare.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | sys.path.append('..') 13 | 14 | 15 | def test(params): 16 | """测试不同参数在生成的假数据上的运行结果""" 17 | 18 | from sequence_to_sequence import SequenceToSequence 19 | from data_utils import batch_flow 20 | from word_sequence import WordSequence # pylint: disable=unused-variable 21 | 22 | ws = pickle.load(open('ws.pkl', 'rb')) 23 | 24 | # for x in x_data[:5]: 25 | # print(' '.join(x)) 26 | 27 | config = tf.ConfigProto( 28 | device_count={'CPU': 1, 'GPU': 0}, 29 | allow_soft_placement=True, 30 | log_device_placement=False 31 | ) 32 | 33 | # save_path = '/tmp/s2ss_chatbot.ckpt' 34 | save_path = './s2ss_chatbot.ckpt' 35 | save_path_rl = './s2ss_chatbot_anti.ckpt' 36 | 37 | graph = tf.Graph() 38 | graph_rl = tf.Graph() 39 | 40 | with graph_rl.as_default(): 41 | model_rl = SequenceToSequence( 42 | input_vocab_size=len(ws), 43 | target_vocab_size=len(ws), 44 | batch_size=1, 45 | mode='decode', 46 | beam_width=12, 47 | **params 48 | ) 49 | init = tf.global_variables_initializer() 50 | sess_rl = tf.Session(config=config) 51 | sess_rl.run(init) 52 | model_rl.load(sess_rl, save_path_rl) 53 | 54 | # 测试部分 55 | with graph.as_default(): 56 | model_pred = SequenceToSequence( 57 | input_vocab_size=len(ws), 58 | target_vocab_size=len(ws), 59 | batch_size=1, 60 | mode='decode', 61 | beam_width=12, 62 | **params 63 | ) 64 | init = tf.global_variables_initializer() 65 | sess = tf.Session(config=config) 66 | sess.run(init) 67 | model_pred.load(sess, save_path) 68 | 69 | while True: 70 | user_text = input('Input Chat Sentence:') 71 | if user_text in ('exit', 'quit'): 72 | exit(0) 73 | x_test = list(user_text.lower()) 74 | x_test = [x_test] 75 | bar = batch_flow([x_test], [ws], 1) 76 | x, xl = next(bar) 77 | x = np.flip(x, axis=1) 78 | print(x, xl) 79 | pred = model_pred.predict( 80 | sess, 81 | np.array(x), 82 | np.array(xl) 83 | ) 84 | pred_rl = model_rl.predict( 85 | sess_rl, 86 | np.array(x), 87 | np.array(xl) 88 | ) 89 | print(ws.inverse_transform(x[0])) 90 | print('no:', ws.inverse_transform(pred[0])) 91 | print('rl:', ws.inverse_transform(pred_rl[0])) 92 | p = [] 93 | for pp in ws.inverse_transform(pred_rl[0]): 94 | if pp == WordSequence.END_TAG: 95 | break 96 | if pp == WordSequence.PAD_TAG: 97 | break 98 | p.append(pp) 99 | 100 | 101 | def main(): 102 | """入口程序""" 103 | import json 104 | test(json.load(open('params.json'))) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /threadedgenerator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://gist.github.com/everilae/9697228 3 | I added __next__ 4 | QHduan 5 | """ 6 | 7 | # A simple generator wrapper, not sure if it's good for anything at all. 8 | # With basic python threading 9 | from threading import Thread 10 | from queue import Queue 11 | 12 | # ... or use multiprocessing versions 13 | # WARNING: use sentinel based on value, not identity 14 | # from multiprocessing import Process, Queue as MpQueue 15 | 16 | 17 | class ThreadedGenerator(object): 18 | """ 19 | Generator that runs on a separate thread, returning values to calling 20 | thread. Care must be taken that the iterator does not mutate any shared 21 | variables referenced in the calling thread. 22 | """ 23 | 24 | def __init__(self, iterator, 25 | sentinel=object(), 26 | queue_maxsize=0, 27 | daemon=False): 28 | self._iterator = iterator 29 | self._sentinel = sentinel 30 | self._queue = Queue(maxsize=queue_maxsize) 31 | self._thread = Thread( 32 | name=repr(iterator), 33 | target=self._run 34 | ) 35 | self._thread.daemon = daemon 36 | self._started = False 37 | 38 | def __repr__(self): 39 | return 'ThreadedGenerator({!r})'.format(self._iterator) 40 | 41 | def _run(self): 42 | try: 43 | for value in self._iterator: 44 | if not self._started: 45 | return 46 | self._queue.put(value) 47 | finally: 48 | self._queue.put(self._sentinel) 49 | 50 | def close(self): 51 | self._started = False 52 | try: 53 | while True: 54 | self._queue.get(timeout=0) 55 | except KeyboardInterrupt as e: 56 | raise e 57 | except: # pylint: disable=bare-except 58 | pass 59 | # self._thread.join() 60 | 61 | def __iter__(self): 62 | self._started = True 63 | self._thread.start() 64 | for value in iter(self._queue.get, self._sentinel): 65 | yield value 66 | self._thread.join() 67 | self._started = False 68 | 69 | def __next__(self): 70 | if not self._started: 71 | self._started = True 72 | self._thread.start() 73 | value = self._queue.get(timeout=30) 74 | if value == self._sentinel: 75 | raise StopIteration() 76 | return value 77 | 78 | 79 | def test(): 80 | """测试""" 81 | 82 | def gene(): 83 | i = 0 84 | while True: 85 | yield i 86 | i += 1 87 | t = gene() 88 | tt = ThreadedGenerator(t) 89 | for _ in range(10): 90 | print(next(tt)) 91 | tt.close() 92 | # for i in range(10): 93 | # print(next(tt)) 94 | 95 | # for t in ThreadedGenerator(range(10)): 96 | # print(t) 97 | # print('-' * 10) 98 | # 99 | # t = ThreadedGenerator(range(10)) 100 | # # def gene(): 101 | # # for t in range(10): 102 | # # yield t 103 | # # t = gene() 104 | # for _ in range(10): 105 | # print(next(t)) 106 | # print('-' * 10) 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | test() 112 | -------------------------------------------------------------------------------- /chatbot_cut/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import jieba 12 | # from nltk.tokenize import word_tokenize 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(bidirectional, cell_type, depth, 18 | attention_type, use_residual, use_dropout, time_major, hidden_units): 19 | """测试不同参数在生成的假数据上的运行结果""" 20 | 21 | from sequence_to_sequence import SequenceToSequence 22 | from data_utils import batch_flow 23 | from word_sequence import WordSequence # pylint: disable=unused-variable 24 | 25 | x_data, _, ws = pickle.load(open('chatbot.pkl', 'rb')) 26 | 27 | for x in x_data[:5]: 28 | print(' '.join(x)) 29 | 30 | config = tf.ConfigProto( 31 | device_count={'CPU': 1, 'GPU': 0}, 32 | allow_soft_placement=True, 33 | log_device_placement=False 34 | ) 35 | 36 | # save_path = '/tmp/s2ss_chatbot.ckpt' 37 | save_path = './s2ss_chatbot.ckpt' 38 | 39 | # 测试部分 40 | tf.reset_default_graph() 41 | model_pred = SequenceToSequence( 42 | input_vocab_size=len(ws), 43 | target_vocab_size=len(ws), 44 | batch_size=1, 45 | mode='decode', 46 | beam_width=0, 47 | bidirectional=bidirectional, 48 | cell_type=cell_type, 49 | depth=depth, 50 | attention_type=attention_type, 51 | use_residual=use_residual, 52 | use_dropout=use_dropout, 53 | parallel_iterations=1, 54 | time_major=time_major, 55 | hidden_units=hidden_units, 56 | share_embedding=True, 57 | pretrained_embedding=True 58 | ) 59 | init = tf.global_variables_initializer() 60 | 61 | with tf.Session(config=config) as sess: 62 | sess.run(init) 63 | model_pred.load(sess, save_path) 64 | 65 | while True: 66 | user_text = input('Input Chat Sentence:') 67 | if user_text in ('exit', 'quit'): 68 | exit(0) 69 | x_test = [jieba.lcut(user_text.lower())] 70 | # x_test = [word_tokenize(user_text)] 71 | bar = batch_flow([x_test], ws, 1) 72 | x, xl = next(bar) 73 | x = np.flip(x, axis=1) 74 | # x = np.array([ 75 | # list(reversed(xx)) 76 | # for xx in x 77 | # ]) 78 | pred = model_pred.predict( 79 | sess, 80 | np.array(x), 81 | np.array(xl) 82 | ) 83 | print(pred) 84 | # prob = np.exp(prob.transpose()) 85 | print(ws.inverse_transform(x[0])) 86 | # print(ws.inverse_transform(pred[0])) 87 | # print(pred.shape, prob.shape) 88 | for p in pred: 89 | ans = ws.inverse_transform(p) 90 | print(ans) 91 | 92 | 93 | def main(): 94 | """入口程序,开始测试不同参数组合""" 95 | random.seed(0) 96 | np.random.seed(0) 97 | tf.set_random_seed(0) 98 | test( 99 | bidirectional=True, 100 | cell_type='lstm', 101 | depth=2, 102 | attention_type='Bahdanau', 103 | use_residual=False, 104 | use_dropout=False, 105 | time_major=False, 106 | hidden_units=512 107 | ) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /chatbot_cut/test_anti.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import jieba 12 | # from nltk.tokenize import word_tokenize 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(bidirectional, cell_type, depth, 18 | attention_type, use_residual, use_dropout, time_major, hidden_units): 19 | """测试不同参数在生成的假数据上的运行结果""" 20 | 21 | from sequence_to_sequence import SequenceToSequence 22 | from data_utils import batch_flow 23 | from word_sequence import WordSequence # pylint: disable=unused-variable 24 | 25 | x_data, _, ws = pickle.load(open('chatbot.pkl', 'rb')) 26 | 27 | for x in x_data[:5]: 28 | print(' '.join(x)) 29 | 30 | config = tf.ConfigProto( 31 | device_count={'CPU': 1, 'GPU': 0}, 32 | allow_soft_placement=True, 33 | log_device_placement=False 34 | ) 35 | 36 | # save_path = '/tmp/s2ss_chatbot.ckpt' 37 | save_path = './s2ss_chatbot_anti.ckpt' 38 | 39 | # 测试部分 40 | tf.reset_default_graph() 41 | model_pred = SequenceToSequence( 42 | input_vocab_size=len(ws), 43 | target_vocab_size=len(ws), 44 | batch_size=1, 45 | mode='decode', 46 | beam_width=0, 47 | bidirectional=bidirectional, 48 | cell_type=cell_type, 49 | depth=depth, 50 | attention_type=attention_type, 51 | use_residual=use_residual, 52 | use_dropout=use_dropout, 53 | parallel_iterations=1, 54 | time_major=time_major, 55 | hidden_units=hidden_units, 56 | share_embedding=True, 57 | pretrained_embedding=True 58 | ) 59 | init = tf.global_variables_initializer() 60 | 61 | with tf.Session(config=config) as sess: 62 | sess.run(init) 63 | model_pred.load(sess, save_path) 64 | 65 | while True: 66 | user_text = input('Input Chat Sentence:') 67 | if user_text in ('exit', 'quit'): 68 | exit(0) 69 | x_test = [jieba.lcut(user_text.lower())] 70 | # x_test = [word_tokenize(user_text)] 71 | bar = batch_flow([x_test], ws, 1) 72 | x, xl = next(bar) 73 | x = np.flip(x, axis=1) 74 | # x = np.array([ 75 | # list(reversed(xx)) 76 | # for xx in x 77 | # ]) 78 | print(x, xl) 79 | pred = model_pred.predict( 80 | sess, 81 | np.array(x), 82 | np.array(xl) 83 | ) 84 | print(pred) 85 | # prob = np.exp(prob.transpose()) 86 | print(ws.inverse_transform(x[0])) 87 | # print(ws.inverse_transform(pred[0])) 88 | # print(pred.shape, prob.shape) 89 | for p in pred: 90 | ans = ws.inverse_transform(p) 91 | print(ans) 92 | 93 | 94 | def main(): 95 | """入口程序,开始测试不同参数组合""" 96 | random.seed(0) 97 | np.random.seed(0) 98 | tf.set_random_seed(0) 99 | test( 100 | bidirectional=True, 101 | cell_type='lstm', 102 | depth=2, 103 | attention_type='Bahdanau', 104 | use_residual=False, 105 | use_dropout=False, 106 | time_major=False, 107 | hidden_units=512 108 | ) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /chatbot/extract_conv.py: -------------------------------------------------------------------------------- 1 | """把 dgk_shooter_min.conv 文件格式转换为可训练格式 2 | """ 3 | 4 | import re 5 | import sys 6 | import pickle 7 | from tqdm import tqdm 8 | 9 | sys.path.append('..') 10 | 11 | 12 | def make_split(line): 13 | """构造合并两个句子之间的符号 14 | """ 15 | if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)): 16 | return [] 17 | return [','] 18 | 19 | 20 | def good_line(line): 21 | if len(re.findall(r'[a-zA-Z0-9]', ''.join(line))) > 2: 22 | return False 23 | return True 24 | 25 | 26 | def regular(sen): 27 | sen = re.sub(r'\.{3,100}', '…', sen) 28 | sen = re.sub(r'…{2,100}', '…', sen) 29 | sen = re.sub(r'[,]{1,100}', ',', sen) 30 | sen = re.sub(r'[\.]{1,100}', '。', sen) 31 | sen = re.sub(r'[\?]{1,100}', '?', sen) 32 | sen = re.sub(r'[!]{1,100}', '!', sen) 33 | return sen 34 | 35 | 36 | def main(limit=20, x_limit=3, y_limit=6): 37 | """执行程序 38 | Args: 39 | limit: 只输出句子长度小于limit的句子 40 | """ 41 | from word_sequence import WordSequence 42 | 43 | print('extract lines') 44 | fp = open('dgk_shooter_min.conv', 'r', errors='ignore') 45 | last_line = None 46 | groups = [] 47 | group = [] 48 | for line in tqdm(fp): 49 | if line.startswith('M '): 50 | line = line.replace('\n', '') 51 | if '/' in line: 52 | line = line[2:].split('/') 53 | else: 54 | line = list(line[2:]) 55 | line = line[:-1] 56 | group.append(list(regular(''.join(line)))) 57 | else: # if line.startswith('E'): 58 | last_line = None 59 | if group: 60 | groups.append(group) 61 | group = [] 62 | if group: 63 | groups.append(group) 64 | group = [] 65 | print('extract groups') 66 | x_data = [] 67 | y_data = [] 68 | for group in tqdm(groups): 69 | for i, line in enumerate(group): 70 | last_line = None 71 | if i > 0: 72 | last_line = group[i - 1] 73 | if not good_line(last_line): 74 | last_line = None 75 | next_line = None 76 | if i < len(group) - 1: 77 | next_line = group[i + 1] 78 | if not good_line(next_line): 79 | next_line = None 80 | next_next_line = None 81 | if i < len(group) - 2: 82 | next_next_line = group[i + 2] 83 | if not good_line(next_next_line): 84 | next_next_line = None 85 | 86 | if next_line: 87 | x_data.append(line) 88 | y_data.append(next_line) 89 | if last_line and next_line: 90 | x_data.append(last_line + make_split(last_line) + line) 91 | y_data.append(next_line) 92 | if next_line and next_next_line: 93 | x_data.append(line) 94 | y_data.append(next_line + make_split(next_line) \ 95 | + next_next_line) 96 | 97 | print(len(x_data), len(y_data)) 98 | for ask, answer in zip(x_data[:20], y_data[:20]): 99 | print(''.join(ask)) 100 | print(''.join(answer)) 101 | print('-' * 20) 102 | 103 | data = list(zip(x_data, y_data)) 104 | data = [ 105 | (x, y) 106 | for x, y in data 107 | if len(x) < limit \ 108 | and len(y) < limit \ 109 | and len(y) >= y_limit \ 110 | and len(x) >= x_limit 111 | ] 112 | x_data, y_data = zip(*data) 113 | 114 | print('fit word_sequence') 115 | 116 | ws_input = WordSequence() 117 | ws_input.fit(x_data + y_data) 118 | 119 | print('dump') 120 | 121 | pickle.dump( 122 | (x_data, y_data), 123 | open('chatbot.pkl', 'wb') 124 | ) 125 | pickle.dump(ws_input, open('ws.pkl', 'wb')) 126 | 127 | print('done') 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /chatbot_cut/test_compare.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import jieba 12 | 13 | sys.path.append('..') 14 | 15 | 16 | def test(bidirectional, cell_type, depth, 17 | attention_type, use_residual, use_dropout, time_major, hidden_units): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | 24 | _, _, ws = pickle.load(open('chatbot.pkl', 'rb')) 25 | 26 | # for x in x_data[:5]: 27 | # print(' '.join(x)) 28 | 29 | config = tf.ConfigProto( 30 | device_count={'CPU': 1, 'GPU': 0}, 31 | allow_soft_placement=True, 32 | log_device_placement=False 33 | ) 34 | 35 | # save_path = '/tmp/s2ss_chatbot.ckpt' 36 | save_path = './s2ss_chatbot.ckpt' 37 | save_path_rl = './s2ss_chatbot_anti.ckpt' 38 | 39 | graph = tf.Graph() 40 | graph_rl = tf.Graph() 41 | 42 | with graph_rl.as_default(): 43 | model_rl = SequenceToSequence( 44 | input_vocab_size=len(ws), 45 | target_vocab_size=len(ws), 46 | batch_size=1, 47 | mode='decode', 48 | beam_width=12, 49 | bidirectional=bidirectional, 50 | cell_type=cell_type, 51 | depth=depth, 52 | attention_type=attention_type, 53 | use_residual=use_residual, 54 | use_dropout=use_dropout, 55 | parallel_iterations=1, 56 | time_major=time_major, 57 | hidden_units=hidden_units, 58 | share_embedding=True, 59 | pretrained_embedding=True 60 | ) 61 | init = tf.global_variables_initializer() 62 | sess_rl = tf.Session(config=config) 63 | sess_rl.run(init) 64 | model_rl.load(sess_rl, save_path_rl) 65 | 66 | # 测试部分 67 | with graph.as_default(): 68 | model_pred = SequenceToSequence( 69 | input_vocab_size=len(ws), 70 | target_vocab_size=len(ws), 71 | batch_size=1, 72 | mode='decode', 73 | beam_width=12, 74 | bidirectional=bidirectional, 75 | cell_type=cell_type, 76 | depth=depth, 77 | attention_type=attention_type, 78 | use_residual=use_residual, 79 | use_dropout=use_dropout, 80 | parallel_iterations=1, 81 | time_major=time_major, 82 | hidden_units=hidden_units, 83 | share_embedding=True, 84 | pretrained_embedding=True 85 | ) 86 | init = tf.global_variables_initializer() 87 | sess = tf.Session(config=config) 88 | sess.run(init) 89 | model_pred.load(sess, save_path) 90 | 91 | while True: 92 | user_text = input('Input Chat Sentence:') 93 | if user_text in ('exit', 'quit'): 94 | exit(0) 95 | x_test = [jieba.lcut(user_text.lower())] 96 | bar = batch_flow([x_test], [ws], 1) 97 | x, xl = next(bar) 98 | x = np.flip(x, axis=1) 99 | print(x, xl) 100 | pred = model_pred.predict( 101 | sess, 102 | np.array(x), 103 | np.array(xl) 104 | ) 105 | pred_rl = model_rl.predict( 106 | sess_rl, 107 | np.array(x), 108 | np.array(xl) 109 | ) 110 | print(ws.inverse_transform(x[0])) 111 | print('no:', ws.inverse_transform(pred[0])) 112 | print('rl:', ws.inverse_transform(pred_rl[0])) 113 | p = [] 114 | for pp in ws.inverse_transform(pred_rl[0]): 115 | if pp == WordSequence.END_TAG: 116 | break 117 | if pp == WordSequence.PAD_TAG: 118 | break 119 | p.append(pp) 120 | 121 | 122 | def main(): 123 | """入口程序,开始测试不同参数组合""" 124 | random.seed(0) 125 | np.random.seed(0) 126 | tf.set_random_seed(0) 127 | test( 128 | bidirectional=True, 129 | cell_type='lstm', 130 | depth=2, 131 | attention_type='Bahdanau', 132 | use_residual=False, 133 | use_dropout=False, 134 | time_major=False, 135 | hidden_units=512 136 | ) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /chatbot_cut/extract_conv.py: -------------------------------------------------------------------------------- 1 | """把 dgk_shooter_min.conv 文件格式转换为可训练格式 2 | """ 3 | 4 | import re 5 | import sys 6 | import pickle 7 | import jieba 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | sys.path.append('..') 12 | 13 | 14 | def make_split(line): 15 | """构造合并两个句子之间的符号 16 | """ 17 | if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)): 18 | return [] 19 | return [','] 20 | 21 | 22 | def good_line(line): 23 | """判断一个句子是否好""" 24 | if len(re.findall(r'[a-zA-Z0-9]', ''.join(line))) > 2: 25 | return False 26 | return True 27 | 28 | 29 | def regular(sen): 30 | """整理句子""" 31 | sen = re.sub(r'\.{3,100}', '…', sen) 32 | sen = re.sub(r'…{2,100}', '…', sen) 33 | sen = re.sub(r'[,]{1,100}', ',', sen) 34 | sen = re.sub(r'[\.]{1,100}', '。', sen) 35 | sen = re.sub(r'[\?]{1,100}', '?', sen) 36 | sen = re.sub(r'[!]{1,100}', '!', sen) 37 | return sen 38 | 39 | 40 | def main(limit=20, x_limit=3, y_limit=6): 41 | """执行程序 42 | Args: 43 | limit: 只输出句子长度小于limit的句子 44 | """ 45 | from word_sequence import WordSequence 46 | 47 | print('load pretrained vec') 48 | word_vec = pickle.load(open('word_vec.pkl', 'rb')) 49 | 50 | print('extract lines') 51 | fp = open('dgk_shooter_min.conv', 'r', errors='ignore') 52 | last_line = None 53 | groups = [] 54 | group = [] 55 | for line in tqdm(fp): 56 | if line.startswith('M '): 57 | line = line.replace('\n', '') 58 | if '/' in line: 59 | line = line[2:].split('/') 60 | else: 61 | line = list(line[2:]) 62 | line = line[:-1] 63 | group.append(jieba.lcut(regular(''.join(line)))) 64 | else: # if line.startswith('E'): 65 | last_line = None 66 | if group: 67 | groups.append(group) 68 | group = [] 69 | if group: 70 | groups.append(group) 71 | group = [] 72 | print('extract groups') 73 | x_data = [] 74 | y_data = [] 75 | for group in tqdm(groups): 76 | for i, line in enumerate(group): 77 | last_line = None 78 | if i > 0: 79 | last_line = group[i - 1] 80 | if not good_line(last_line): 81 | last_line = None 82 | next_line = None 83 | if i < len(group) - 1: 84 | next_line = group[i + 1] 85 | if not good_line(next_line): 86 | next_line = None 87 | next_next_line = None 88 | if i < len(group) - 2: 89 | next_next_line = group[i + 2] 90 | if not good_line(next_next_line): 91 | next_next_line = None 92 | 93 | if next_line: 94 | x_data.append(line) 95 | y_data.append(next_line) 96 | # if last_line and next_line: 97 | # x_data.append(last_line + make_split(last_line) + line) 98 | # y_data.append(next_line) 99 | # if next_line and next_next_line: 100 | # x_data.append(line) 101 | # y_data.append(next_line + make_split(next_line) \ 102 | # + next_next_line) 103 | 104 | print(len(x_data), len(y_data)) 105 | for ask, answer in zip(x_data[:20], y_data[:20]): 106 | print(''.join(ask)) 107 | print(''.join(answer)) 108 | print('-' * 20) 109 | 110 | data = list(zip(x_data, y_data)) 111 | data = [ 112 | (x, y) 113 | for x, y in data 114 | if len(x) < limit \ 115 | and len(y) < limit \ 116 | and len(y) >= y_limit \ 117 | and len(x) >= x_limit 118 | ] 119 | x_data, y_data = zip(*data) 120 | 121 | print('refine train data') 122 | 123 | train_data = x_data + y_data 124 | 125 | # good_train_data = [] 126 | # for line in tqdm(train_data): 127 | # good_train_data.append([ 128 | # x for x in line 129 | # if x in word_vec 130 | # ]) 131 | # train_data = good_train_data 132 | 133 | print('fit word_sequence') 134 | 135 | ws_input = WordSequence() 136 | 137 | ws_input.fit(train_data, max_features=100000) 138 | 139 | print('dump word_sequence') 140 | 141 | pickle.dump( 142 | (x_data, y_data, ws_input), 143 | open('chatbot.pkl', 'wb') 144 | ) 145 | 146 | print('make embedding vecs') 147 | 148 | emb = np.zeros((len(ws_input), len(word_vec['']))) 149 | 150 | np.random.seed(1) 151 | for word, ind in ws_input.dict.items(): 152 | if word in word_vec: 153 | emb[ind] = word_vec[word] 154 | else: 155 | emb[ind] = np.random.random(size=(300,)) - 0.5 156 | 157 | print('dump emb') 158 | 159 | pickle.dump( 160 | emb, 161 | open('emb.pkl', 'wb') 162 | ) 163 | 164 | print('done') 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /test_atten.py: -------------------------------------------------------------------------------- 1 | """ 2 | 测试并展示 attention 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import matplotlib.cm as cm 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from sequence_to_sequence import SequenceToSequence 11 | from data_utils import batch_flow 12 | 13 | def test(bidirectional, cell_type, depth, attention_type): 14 | """测试并展示attention图 15 | """ 16 | 17 | from tqdm import tqdm 18 | from fake_data import generate 19 | 20 | # 获取一些假数据 21 | x_data, y_data, ws_input, ws_target = generate(size=10000) 22 | 23 | # 训练部分 24 | 25 | split = int(len(x_data) * 0.9) 26 | x_train, x_test, y_train, y_test = ( 27 | x_data[:split], x_data[split:], y_data[:split], y_data[split:]) 28 | n_epoch = 2 29 | batch_size = 32 30 | steps = int(len(x_train) / batch_size) + 1 31 | 32 | config = tf.ConfigProto( 33 | device_count={'CPU': 1, 'GPU': 0}, 34 | allow_soft_placement=True, 35 | log_device_placement=False 36 | ) 37 | 38 | save_path = '/tmp/s2ss_atten.ckpt' 39 | 40 | with tf.Graph().as_default(): 41 | 42 | model = SequenceToSequence( 43 | input_vocab_size=len(ws_input), 44 | target_vocab_size=len(ws_target), 45 | batch_size=batch_size, 46 | learning_rate=0.001, 47 | bidirectional=bidirectional, 48 | cell_type=cell_type, 49 | depth=depth, 50 | attention_type=attention_type, 51 | parallel_iterations=1 52 | ) 53 | init = tf.global_variables_initializer() 54 | 55 | with tf.Session(config=config) as sess: 56 | sess.run(init) 57 | for epoch in range(1, n_epoch + 1): 58 | costs = [] 59 | flow = batch_flow( 60 | [x_train, y_train], [ws_input, ws_target], batch_size 61 | ) 62 | bar = tqdm(range(steps), 63 | desc='epoch {}, loss=0.000000'.format(epoch)) 64 | for _ in bar: 65 | x, xl, y, yl = next(flow) 66 | cost = model.train(sess, x, xl, y, yl) 67 | costs.append(cost) 68 | bar.set_description('epoch {} loss={:.6f}'.format( 69 | epoch, 70 | np.mean(costs) 71 | )) 72 | 73 | model.save(sess, save_path) 74 | 75 | # attention 展示 不能用 beam search 的 76 | # 所以这里只是用 greedy 77 | 78 | with tf.Graph().as_default(): 79 | model_pred = SequenceToSequence( 80 | input_vocab_size=len(ws_input), 81 | target_vocab_size=len(ws_target), 82 | batch_size=1, 83 | mode='decode', 84 | beam_width=0, 85 | bidirectional=bidirectional, 86 | cell_type=cell_type, 87 | depth=depth, 88 | attention_type=attention_type, 89 | parallel_iterations=1 90 | ) 91 | init = tf.global_variables_initializer() 92 | 93 | with tf.Session(config=config) as sess: 94 | sess.run(init) 95 | model_pred.load(sess, save_path) 96 | 97 | pbar = batch_flow([x_test, y_test], [ws_input, ws_target], 1) 98 | t = 0 99 | for x, xl, y, yl in pbar: 100 | pred, atten = model_pred.predict( 101 | sess, 102 | np.array(x), 103 | np.array(xl), 104 | attention=True 105 | ) 106 | ox = ws_input.inverse_transform(x[0]) 107 | oy = ws_target.inverse_transform(y[0]) 108 | op = ws_target.inverse_transform(pred[0]) 109 | print(ox) 110 | print(oy) 111 | print(op) 112 | 113 | fig, ax = plt.subplots() 114 | cax = ax.matshow(atten.reshape( 115 | [atten.shape[0], atten.shape[2]] 116 | ), cmap=cm.coolwarm) 117 | ax.set_xticks(np.arange(len(ox))) 118 | ax.set_yticks(np.arange(len(op))) 119 | ax.set_xticklabels(ox) 120 | ax.set_yticklabels(op) 121 | fig.colorbar(cax) 122 | plt.show() 123 | 124 | print('-' * 30) 125 | 126 | 127 | t += 1 128 | if t >= 10: 129 | break 130 | 131 | 132 | if __name__ == '__main__': 133 | 134 | # for bidirectional in (True, False): 135 | # for cell_type in ('gru', 'lstm'): 136 | # for depth in (1, 2, 3): 137 | # for attention_type in ('Luong', 'Bahdanau'): 138 | # print( 139 | # 'bidirectional, cell_type, depth, attention_type', 140 | # bidirectional, cell_type, depth, attention_type 141 | # ) 142 | test(bidirectional=True, cell_type='lstm', 143 | depth=2, attention_type='Bahdanau') 144 | -------------------------------------------------------------------------------- /chatbot/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | # from sklearn.utils import shuffle 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(params): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow_bucket as batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | from threadedgenerator import ThreadedGenerator 24 | 25 | x_data, y_data = pickle.load(open('chatbot.pkl', 'rb')) 26 | ws = pickle.load(open('ws.pkl', 'rb')) 27 | 28 | # 训练部分 29 | n_epoch = 2 30 | batch_size = 128 31 | # x_data, y_data = shuffle(x_data, y_data, random_state=0) 32 | # x_data = x_data[:10000] 33 | # y_data = y_data[:10000] 34 | steps = int(len(x_data) / batch_size) + 1 35 | 36 | config = tf.ConfigProto( 37 | # device_count={'CPU': 1, 'GPU': 0}, 38 | allow_soft_placement=True, 39 | log_device_placement=False 40 | ) 41 | 42 | save_path = './s2ss_chatbot.ckpt' 43 | 44 | tf.reset_default_graph() 45 | with tf.Graph().as_default(): 46 | random.seed(0) 47 | np.random.seed(0) 48 | tf.set_random_seed(0) 49 | 50 | with tf.Session(config=config) as sess: 51 | 52 | model = SequenceToSequence( 53 | input_vocab_size=len(ws), 54 | target_vocab_size=len(ws), 55 | batch_size=batch_size, 56 | **params 57 | ) 58 | init = tf.global_variables_initializer() 59 | sess.run(init) 60 | 61 | # print(sess.run(model.input_layer.kernel)) 62 | # exit(1) 63 | 64 | flow = ThreadedGenerator( 65 | batch_flow([x_data, y_data], ws, batch_size, 66 | add_end=[False, True]), 67 | queue_maxsize=30) 68 | 69 | for epoch in range(1, n_epoch + 1): 70 | costs = [] 71 | bar = tqdm(range(steps), total=steps, 72 | desc='epoch {}, loss=0.000000'.format(epoch)) 73 | for _ in bar: 74 | x, xl, y, yl = next(flow) 75 | x = np.flip(x, axis=1) 76 | # print(x, y) 77 | # print(xl, yl) 78 | # exit(1) 79 | cost, lr = model.train(sess, x, xl, y, yl, return_lr=True) 80 | costs.append(cost) 81 | bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format( 82 | epoch, 83 | np.mean(costs), 84 | lr 85 | )) 86 | 87 | model.save(sess, save_path) 88 | 89 | flow.close() 90 | 91 | # 测试部分 92 | tf.reset_default_graph() 93 | model_pred = SequenceToSequence( 94 | input_vocab_size=len(ws), 95 | target_vocab_size=len(ws), 96 | batch_size=1, 97 | mode='decode', 98 | beam_width=12, 99 | parallel_iterations=1, 100 | **params 101 | ) 102 | init = tf.global_variables_initializer() 103 | 104 | with tf.Session(config=config) as sess: 105 | sess.run(init) 106 | model_pred.load(sess, save_path) 107 | 108 | bar = batch_flow([x_data, y_data], ws, 1, add_end=False) 109 | t = 0 110 | for x, xl, y, yl in bar: 111 | x = np.flip(x, axis=1) 112 | pred = model_pred.predict( 113 | sess, 114 | np.array(x), 115 | np.array(xl) 116 | ) 117 | print(ws.inverse_transform(x[0])) 118 | print(ws.inverse_transform(y[0])) 119 | print(ws.inverse_transform(pred[0])) 120 | t += 1 121 | if t >= 3: 122 | break 123 | 124 | tf.reset_default_graph() 125 | model_pred = SequenceToSequence( 126 | input_vocab_size=len(ws), 127 | target_vocab_size=len(ws), 128 | batch_size=1, 129 | mode='decode', 130 | beam_width=1, 131 | parallel_iterations=1, 132 | **params 133 | ) 134 | init = tf.global_variables_initializer() 135 | 136 | with tf.Session(config=config) as sess: 137 | sess.run(init) 138 | model_pred.load(sess, save_path) 139 | 140 | bar = batch_flow([x_data, y_data], ws, 1, add_end=False) 141 | t = 0 142 | for x, xl, y, yl in bar: 143 | pred = model_pred.predict( 144 | sess, 145 | np.array(x), 146 | np.array(xl) 147 | ) 148 | print(ws.inverse_transform(x[0])) 149 | print(ws.inverse_transform(y[0])) 150 | print(ws.inverse_transform(pred[0])) 151 | t += 1 152 | if t >= 3: 153 | break 154 | 155 | 156 | def main(): 157 | """入口程序""" 158 | import json 159 | test(json.load(open('params.json'))) 160 | 161 | 162 | if __name__ == '__main__': 163 | main() 164 | -------------------------------------------------------------------------------- /word_sequence.py: -------------------------------------------------------------------------------- 1 | """ 2 | WordSequence类 3 | 4 | 维护一个字典,把一个list(或者字符串)编码化,或者反向恢复 5 | 6 | """ 7 | 8 | 9 | import numpy as np 10 | 11 | 12 | class WordSequence(object): 13 | """一个可以把句子编码化(index)的类 14 | """ 15 | 16 | PAD_TAG = '' 17 | UNK_TAG = '' 18 | START_TAG = '' 19 | END_TAG = '' 20 | PAD = 0 21 | UNK = 1 22 | START = 2 23 | END = 3 24 | 25 | 26 | def __init__(self): 27 | """初始化基本的dict 28 | """ 29 | self.dict = { 30 | WordSequence.PAD_TAG: WordSequence.PAD, 31 | WordSequence.UNK_TAG: WordSequence.UNK, 32 | WordSequence.START_TAG: WordSequence.START, 33 | WordSequence.END_TAG: WordSequence.END, 34 | } 35 | self.fited = False 36 | 37 | 38 | def to_index(self, word): 39 | """把一个单字转换为index 40 | """ 41 | assert self.fited, 'WordSequence 尚未 fit' 42 | if word in self.dict: 43 | return self.dict[word] 44 | return WordSequence.UNK 45 | 46 | 47 | def to_word(self, index): 48 | """把一个index转换为单字 49 | """ 50 | assert self.fited, 'WordSequence 尚未 fit' 51 | for k, v in self.dict.items(): 52 | if v == index: 53 | return k 54 | return WordSequence.UNK_TAG 55 | 56 | 57 | def size(self): 58 | """返回字典大小 59 | """ 60 | assert self.fited, 'WordSequence 尚未 fit' 61 | return len(self.dict) + 1 62 | 63 | def __len__(self): 64 | """返回字典大小 65 | """ 66 | return self.size() 67 | 68 | 69 | def fit(self, sentences, min_count=5, max_count=None, max_features=None): 70 | """训练 WordSequence 71 | Args: 72 | min_count 最小出现次数 73 | max_count 最大出现次数 74 | max_features 最大特征数 75 | 76 | ws = WordSequence() 77 | ws.fit([['hello', 'world']]) 78 | """ 79 | assert not self.fited, 'WordSequence 只能 fit 一次' 80 | 81 | count = {} 82 | for sentence in sentences: 83 | arr = list(sentence) 84 | for a in arr: 85 | if a not in count: 86 | count[a] = 0 87 | count[a] += 1 88 | 89 | if min_count is not None: 90 | count = {k: v for k, v in count.items() if v >= min_count} 91 | 92 | if max_count is not None: 93 | count = {k: v for k, v in count.items() if v <= max_count} 94 | 95 | self.dict = { 96 | WordSequence.PAD_TAG: WordSequence.PAD, 97 | WordSequence.UNK_TAG: WordSequence.UNK, 98 | WordSequence.START_TAG: WordSequence.START, 99 | WordSequence.END_TAG: WordSequence.END, 100 | } 101 | 102 | if isinstance(max_features, int): 103 | count = sorted(list(count.items()), key=lambda x: x[1]) 104 | if max_features is not None and len(count) > max_features: 105 | count = count[-int(max_features):] 106 | for w, _ in count: 107 | self.dict[w] = len(self.dict) 108 | else: 109 | for w in sorted(count.keys()): 110 | self.dict[w] = len(self.dict) 111 | 112 | self.fited = True 113 | 114 | 115 | def transform(self, 116 | sentence, max_len=None): 117 | """把句子转换为向量 118 | 例如输入 ['a', 'b', 'c'] 119 | 输出 [1, 2, 3] 这个数字是字典里的编号,顺序没有意义 120 | """ 121 | assert self.fited, 'WordSequence 尚未 fit' 122 | 123 | # if max_len is not None: 124 | # r = [self.PAD] * max_len 125 | # else: 126 | # r = [self.PAD] * len(sentence) 127 | 128 | if max_len is not None: 129 | r = [self.PAD] * max_len 130 | else: 131 | r = [self.PAD] * len(sentence) 132 | 133 | for index, a in enumerate(sentence): 134 | if max_len is not None and index >= len(r): 135 | break 136 | r[index] = self.to_index(a) 137 | 138 | return np.array(r) 139 | 140 | 141 | def inverse_transform(self, indices, 142 | ignore_pad=False, ignore_unk=False, 143 | ignore_start=False, ignore_end=False): 144 | """把向量转换为句子,和上面的相反 145 | """ 146 | ret = [] 147 | for i in indices: 148 | word = self.to_word(i) 149 | if word == WordSequence.PAD_TAG and ignore_pad: 150 | continue 151 | if word == WordSequence.UNK_TAG and ignore_unk: 152 | continue 153 | if word == WordSequence.START_TAG and ignore_start: 154 | continue 155 | if word == WordSequence.END_TAG and ignore_end: 156 | continue 157 | ret.append(word) 158 | 159 | return ret 160 | 161 | 162 | def test(): 163 | """测试 164 | """ 165 | ws = WordSequence() 166 | ws.fit([ 167 | ['第', '一', '句', '话'], 168 | ['第', '二', '句', '话'] 169 | ]) 170 | 171 | indice = ws.transform(['第', '三']) 172 | print(indice) 173 | 174 | back = ws.inverse_transform(indice) 175 | print(back) 176 | 177 | if __name__ == '__main__': 178 | test() 179 | -------------------------------------------------------------------------------- /test_crf.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import random 6 | import itertools 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | 13 | from rnn_crf import RNNCRF 14 | from data_utils import batch_flow 15 | from fake_data import generate 16 | 17 | 18 | def test(bidirectional, cell_type, depth, 19 | use_residual, use_dropout, output_project_active, crf_loss): 20 | """测试不同参数在生成的假数据上的运行结果""" 21 | 22 | # 获取一些假数据 23 | x_data, y_data, ws_input, ws_target = generate(size=10000, same_len=True) 24 | 25 | # 训练部分 26 | 27 | split = int(len(x_data) * 0.8) 28 | x_train, x_test, y_train, y_test = ( 29 | x_data[:split], x_data[split:], y_data[:split], y_data[split:]) 30 | n_epoch = 1 31 | batch_size = 32 32 | steps = int(len(x_train) / batch_size) + 1 33 | 34 | config = tf.ConfigProto( 35 | device_count={'CPU': 1, 'GPU': 0}, 36 | allow_soft_placement=True, 37 | log_device_placement=False 38 | ) 39 | 40 | save_path = '/tmp/s2ss_crf.ckpt' 41 | 42 | tf.reset_default_graph() 43 | with tf.Graph().as_default(): 44 | random.seed(0) 45 | np.random.seed(0) 46 | tf.set_random_seed(0) 47 | 48 | with tf.Session(config=config) as sess: 49 | 50 | model = RNNCRF( 51 | input_vocab_size=len(ws_input), 52 | target_vocab_size=len(ws_target), 53 | max_decode_step=100, 54 | batch_size=batch_size, 55 | learning_rate=0.001, 56 | bidirectional=bidirectional, 57 | cell_type=cell_type, 58 | depth=depth, 59 | use_residual=use_residual, 60 | use_dropout=use_dropout, 61 | output_project_active=output_project_active, 62 | hidden_units=64, 63 | embedding_size=64, 64 | parallel_iterations=1, 65 | crf_loss=crf_loss 66 | ) 67 | init = tf.global_variables_initializer() 68 | sess.run(init) 69 | 70 | # print(sess.run(model.input_layer.kernel)) 71 | # exit(1) 72 | 73 | for epoch in range(1, n_epoch + 1): 74 | costs = [] 75 | flow = batch_flow( 76 | [x_train, y_train], [ws_input, ws_target], batch_size 77 | ) 78 | bar = tqdm(range(steps), 79 | desc='epoch {}, loss=0.000000'.format(epoch)) 80 | for _ in bar: 81 | x, xl, y, yl = next(flow) 82 | cost = model.train(sess, x, xl, y, yl) 83 | costs.append(cost) 84 | bar.set_description('epoch {} loss={:.6f}'.format( 85 | epoch, 86 | np.mean(costs) 87 | )) 88 | 89 | model.save(sess, save_path) 90 | 91 | # 测试部分 92 | tf.reset_default_graph() 93 | model_pred = RNNCRF( 94 | input_vocab_size=len(ws_input), 95 | target_vocab_size=len(ws_target), 96 | max_decode_step=100, 97 | batch_size=batch_size, 98 | mode='decode', 99 | bidirectional=bidirectional, 100 | cell_type=cell_type, 101 | depth=depth, 102 | use_residual=use_residual, 103 | use_dropout=use_dropout, 104 | output_project_active=output_project_active, 105 | hidden_units=64, 106 | embedding_size=64, 107 | parallel_iterations=1, 108 | crf_loss=crf_loss 109 | ) 110 | init = tf.global_variables_initializer() 111 | 112 | with tf.Session(config=config) as sess: 113 | sess.run(init) 114 | model_pred.load(sess, save_path) 115 | 116 | flow = batch_flow([x_test, y_test], [ws_input, ws_target], batch_size) 117 | pbar = tqdm(range(100)) 118 | acc = [] 119 | for i in pbar: 120 | x, xl, y, yl = next(flow) 121 | pred = model_pred.predict( 122 | sess, 123 | np.array(x), 124 | np.array(xl) 125 | ) 126 | 127 | for j in range(batch_size): 128 | right = np.sum(y[j][:yl[j]] == pred[j][:yl[j]]) 129 | acc.append(right / yl[j]) 130 | 131 | if i < 3: 132 | print(ws_input.inverse_transform(x[0])) 133 | print(ws_target.inverse_transform(y[0])) 134 | print(ws_target.inverse_transform(pred[0])) 135 | else: 136 | pbar.set_description('acc: {}'.format(np.mean(acc))) 137 | 138 | 139 | def main(): 140 | """入口程序,开始测试不同参数组合""" 141 | random.seed(0) 142 | np.random.seed(0) 143 | tf.set_random_seed(0) 144 | 145 | params = OrderedDict(( 146 | ('bidirectional', (True, False)), 147 | ('cell_type', ('gru', 'lstm')), 148 | ('depth', (1, 2, 3)), 149 | ('use_residual', (True, False)), 150 | ('use_dropout', (True, False)), 151 | ('output_project_active', (None, 'tanh', 'sigmoid', 'linear')), 152 | ('crf_loss', (False, True)) 153 | )) 154 | 155 | loop = itertools.product(*params.values()) 156 | 157 | for param_value in loop: 158 | param = OrderedDict(zip(params.keys(), param_value)) 159 | print('=' * 30) 160 | for key, value in param.items(): 161 | print(key, ':', value) 162 | print('-' * 30) 163 | test(**param) 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /chatbot/train_anti.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | # from sklearn.utils import shuffle 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(params): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow_bucket as batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | from threadedgenerator import ThreadedGenerator 24 | 25 | x_data, y_data = pickle.load(open('chatbot.pkl', 'rb')) 26 | ws = pickle.load(open('ws.pkl', 'rb')) 27 | 28 | # 训练部分 29 | n_epoch = 2 30 | batch_size = 128 31 | # x_data, y_data = shuffle(x_data, y_data, random_state=0) 32 | # x_data = x_data[:100000] 33 | # y_data = y_data[:100000] 34 | steps = int(len(x_data) / batch_size) + 1 35 | 36 | config = tf.ConfigProto( 37 | # device_count={'CPU': 1, 'GPU': 0}, 38 | allow_soft_placement=True, 39 | log_device_placement=False 40 | ) 41 | 42 | save_path = './s2ss_chatbot_anti.ckpt' 43 | 44 | tf.reset_default_graph() 45 | with tf.Graph().as_default(): 46 | random.seed(0) 47 | np.random.seed(0) 48 | tf.set_random_seed(0) 49 | 50 | with tf.Session(config=config) as sess: 51 | 52 | model = SequenceToSequence( 53 | input_vocab_size=len(ws), 54 | target_vocab_size=len(ws), 55 | batch_size=batch_size, 56 | **params 57 | ) 58 | init = tf.global_variables_initializer() 59 | sess.run(init) 60 | 61 | # print(sess.run(model.input_layer.kernel)) 62 | # exit(1) 63 | 64 | flow = ThreadedGenerator( 65 | batch_flow([x_data, y_data], ws, batch_size, 66 | add_end=[False, True]), 67 | queue_maxsize=30) 68 | 69 | dummy_encoder_inputs = np.array([ 70 | np.array([WordSequence.PAD]) for _ in range(batch_size)]) 71 | dummy_encoder_inputs_lengths = np.array([1] * batch_size) 72 | 73 | for epoch in range(1, n_epoch + 1): 74 | costs = [] 75 | bar = tqdm(range(steps), total=steps, 76 | desc='epoch {}, loss=0.000000'.format(epoch)) 77 | for _ in bar: 78 | x, xl, y, yl = next(flow) 79 | x = np.flip(x, axis=1) 80 | 81 | add_loss = model.train(sess, 82 | dummy_encoder_inputs, 83 | dummy_encoder_inputs_lengths, 84 | y, yl, loss_only=True) 85 | 86 | add_loss *= -0.5 87 | # print(x, y) 88 | cost, lr = model.train(sess, x, xl, y, yl, 89 | return_lr=True, add_loss=add_loss) 90 | costs.append(cost) 91 | bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format( 92 | epoch, 93 | np.mean(costs), 94 | lr 95 | )) 96 | 97 | model.save(sess, save_path) 98 | 99 | flow.close() 100 | 101 | # 测试部分 102 | tf.reset_default_graph() 103 | model_pred = SequenceToSequence( 104 | input_vocab_size=len(ws), 105 | target_vocab_size=len(ws), 106 | batch_size=1, 107 | mode='decode', 108 | beam_width=12, 109 | **params 110 | ) 111 | init = tf.global_variables_initializer() 112 | 113 | with tf.Session(config=config) as sess: 114 | sess.run(init) 115 | model_pred.load(sess, save_path) 116 | 117 | bar = batch_flow([x_data, y_data], ws, 1, add_end=False) 118 | t = 0 119 | for x, xl, y, yl in bar: 120 | x = np.flip(x, axis=1) 121 | pred = model_pred.predict( 122 | sess, 123 | np.array(x), 124 | np.array(xl) 125 | ) 126 | print(ws.inverse_transform(x[0])) 127 | print(ws.inverse_transform(y[0])) 128 | print(ws.inverse_transform(pred[0])) 129 | t += 1 130 | if t >= 3: 131 | break 132 | 133 | tf.reset_default_graph() 134 | model_pred = SequenceToSequence( 135 | input_vocab_size=len(ws), 136 | target_vocab_size=len(ws), 137 | batch_size=1, 138 | mode='decode', 139 | beam_width=1, 140 | **params 141 | ) 142 | init = tf.global_variables_initializer() 143 | 144 | with tf.Session(config=config) as sess: 145 | sess.run(init) 146 | model_pred.load(sess, save_path) 147 | 148 | bar = batch_flow([x_data, y_data], ws, 1, add_end=False) 149 | t = 0 150 | for x, xl, y, yl in bar: 151 | pred = model_pred.predict( 152 | sess, 153 | np.array(x), 154 | np.array(xl) 155 | ) 156 | print(ws.inverse_transform(x[0])) 157 | print(ws.inverse_transform(y[0])) 158 | print(ws.inverse_transform(pred[0])) 159 | t += 1 160 | if t >= 3: 161 | break 162 | 163 | 164 | def main(): 165 | """入口程序""" 166 | import json 167 | test(json.load(open('params.json'))) 168 | 169 | 170 | if __name__ == '__main__': 171 | main() 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | *A lot of Chinese in codes and docs* 3 | 4 | # Just another seq2seq repo 5 | 6 | - [x] 主要是从个人角度梳理了一下seq2seq的代码 7 | - [x] 加入了可选基本的CRF支持,loss和infer(还不确定对 8 | - [x] 加入了一些中文注释 9 | - [x] 相对于其他一些repo,bug可能会少一些 10 | - 有些repo的实现在不同参数下会有问题:例如有些支持gru不支持lstm,有些不支持bidirectional,有些选择depth > 1的时候会有各种bug之类的,这些问题我都尽量修正了,虽然不保证实现肯定是对的 11 | - [x] 后续我可能会添加一些中文的例子,例如对联、古诗、闲聊、NER 12 | - [x] 根据本repo,我会整理一份seq2seq中间的各种trick和实现细节的坑 13 | - [参考这里](https://github.com/qhduan/ConversationalRobotDesign/tree/master/%E8%81%8A%E5%A4%A9%E6%9C%BA%E5%99%A8%E4%BA%BA%EF%BC%9A%E7%A5%9E%E7%BB%8F%E5%AF%B9%E8%AF%9D%E6%A8%A1%E5%9E%8B%E7%9A%84%E5%AE%9E%E7%8E%B0%E4%B8%8E%E6%8A%80%E5%B7%A7) 14 | - [x] pretrained embedding support 15 | - 参考[chatbot_cut](chatbot_cut/) 16 | - [ ] 后续这个repo会作为一个基础完成一个dialogue system(的一部分,例如NLU) 17 | - seq2seq模型至少可以作为通用NER实现 18 | - 截止2018年初,最好的NER应该还是bi-LSTM + CRF,也有不加CRF效果好的 19 | 20 | 21 | [作者的一系列小文章,欢迎吐槽](https://github.com/qhduan/ConversationalRobotDesign) 22 | 23 | # Update Log 24 | 25 | 2018-03-10 26 | 27 | 我把一些代码内的trick设置的更接近NMT了。 28 | 29 | 尝试训练更好的chatbot模型(嘬死)。 30 | 31 | 添加了一个支持加载训练好的embedding的模型,参考[chatbot_cut/](chatbot_cut/), 32 | 这个例子是“词级的”, 33 | 分词用的jieba, 34 | 默认的预训练模型是fasttext, 35 | 详情点击看文档、代码。 36 | 37 | 2018-03-06 38 | 39 | 增加了chatbot中anti-lm的训练方法样例,在`chatbot/train_anti.py`中。 40 | 这个模式参考了[Li et al., 2015](https://arxiv.org/pdf/1510.03055v3.pdf)和代码 41 | [Marsan-Ma/tf_chatbot_seq2seq_antilm](https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm)。 42 | 43 | 加入anti-lm来看,diversity是有提高,不过整体来看,并不是说就很好好。 44 | 但是明显降低了机器回答“我不知道”和“我不知道你在说什么”这样的语言概率。 45 | 46 | 虽然我在不同的地方在还尝试实现了下面这两个(其实都是一个人写的啦) 47 | [Li et al., 2016](https://arxiv.org/abs/1606.01541) 48 | [Li et al., 2017](https://arxiv.org/abs/1701.06547) 49 | 不过基本上不太成功的感觉,虽然我也没做太严格的做法。 50 | 51 | # Known issues 52 | 53 | Example里的例子和整个项目,虽然未经验证,但是在内存较小的电脑上(<8GB),可能会有问题。 54 | 这涉及到数据处理、数据输入、模型参数等部分,所以严格来说并不算BUG。 55 | 56 | chatbot模型根本没有一个什么是`“好”`的评价标准, 57 | 也根本没有`“好”`的数据。 58 | 所以不要对结果有过度期待,仅供娱乐。 59 | 如果你问我仅供娱乐还写它干嘛? 60 | 本repo只是为了实现各种seq2seq技术, 61 | 也有有用的翻译和NER啊, 62 | 当然很多部分都是学习与研究性质的,工业化需要很多改进。 63 | chatbot部分虽然我花了不少时间, 64 | 但是那个还只是娱乐而已, 65 | 实际应用起来,对话质量、系统成本可能很高。 66 | 我能保证的只是,这个模型基本上没原则性问题而已, 67 | 至少给一个参考,看看我写的垃圾代码和别人写的代码的区别,是吧。 68 | 69 | 当然也不是说就是不能用,例如你能自己搞一些质量很高的数据啦。 70 | 比如说[这位仁兄的repo](https://github.com/bshao001/ChatLearner) 71 | 他就自己弄了一份质量很高的数据, 72 | 搭配一些合理的扩展, 73 | 例如给数据添加功能性词汇 `_func_get_current_time` 之类感觉的东西, 74 | 就能让chatbot实现一些简单功能。 75 | 76 | 简单的说就是把训练数据设置为, 77 | 上一句是`现 在 几 点`, 78 | 下一句是`现 在 时 间 _func_get_current_time`, 79 | 这样在输出部分如果解析到`_func_get_current_time`这个词 80 | 就自动替换为时间的话, 81 | 就可以得到类似“报时”的功能了。 82 | (技术没有好坏,应用在哪最重要!~~这句话是不是很装逼) 83 | 84 | # Platform 85 | 86 | 作者在一台64GB内存 + GTX1070 6GB + Ubuntu 16.04电脑上运行。 87 | 88 | 内存肯定不需要这么大,不过显存如果在2GB,如果要在GPU上运行模型,可能需要调节batch_size等模型参数。 89 | 90 | # Example 91 | 92 | Example里面用到的数据,都是比较小且粗糙的。 93 | 作者只基本验证了可行性,所以也不可能实用了,例如英汉翻译就别期待准确率很高了, 94 | 大概意思到了就代表模型的一定有效性了。 95 | 96 | [英汉句子翻译实例](/en2zh/) 97 | 98 | #### 测试结果样例 99 | 100 | ***我不保证能重复实现能得到一模一样的结果*** 101 | 102 | ``` 103 | Input English Sentence:go to hell 104 | [[30475 71929 33464]] [3] 105 | [[41337 48900 41337 44789 3]] 106 | ['go', 'to', 'hell'] 107 | ['去', '地狱', '去', '吧', ''] 108 | Input English Sentence:nothing, but the best for you 109 | [[50448 467 13008 71007 10118 27982 79204]] [7] 110 | [[ 25904 132783 90185 4 28145 81577 80498 28798 3]] 111 | ['nothing', ',', 'but', 'the', 'best', 'for', 'you'] 112 | ['什么', '都', '没有', ' ', '但', '最好', '是', '你', ''] 113 | Input English Sentence:i'm a bad boy 114 | [[35437 268 4018 8498 11775]] [5] 115 | [[ 69313 80498 21899 49069 100342 3 -1]] 116 | ['i', "'m", 'a', 'bad', 'boy'] 117 | ['我', '是', '个', '坏', '男孩', '', ''] 118 | Input English Sentence:i'm really a bad boy 119 | [[35437 268 58417 4018 8498 11775]] [6] 120 | [[ 69313 103249 80498 17043 49069 100342 3 3 3 3 121 | 3 3]] 122 | ['i', "'m", 'really', 'a', 'bad', 'boy'] 123 | ['我', '真的', '是', '一个', '坏', '男孩', '', '', '', '', '', ''] 124 | ``` 125 | 126 | [NER实例](/ner/) 127 | 128 | [Chatbot实例](/chatbot/) 129 | 130 | 131 | `test_atten.py` 脚本,测试并展示 attention 的热力图 132 | 133 | 134 | # TensorFlow alert 135 | 136 | Test in 137 | 138 | ```python 139 | import tensorflow as tf 140 | tf.__version__ >= '1.4.0' and tf.__version__ <= '1.5.0' 141 | ``` 142 | 143 | TensorFlow的API总是变,不能保证后续的更新兼容 144 | 145 | 本repo本质是一个学习性质的repo,作者只是希望尽量保持代码的整齐、理解、可读,并不对不同平台(尤其windows)的兼容,或者后续更新做保证,对不起 146 | 147 | # Related work 148 | 149 | As mention in the head of `sequence_to_sequence.py`, 150 | At beginning, the code is heavily borrow from [here](https://github.com/JayParks/tf-seq2seq/blob/master/seq2seq_model.py) 151 | 152 | I have modified a lot of code, some ***Chinese comments*** in the code. 153 | And fix many bugs, restructure many things, add more features. 154 | 155 | Code was borrow heavily from: 156 | 157 | https://github.com/JayParks/tf-seq2seq/blob/master/seq2seq_model.py 158 | 159 | Another wonderful example is: 160 | 161 | https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm 162 | 163 | Official sequence2sequence tutorial 164 | 165 | https://www.tensorflow.org/tutorials/seq2seq 166 | 167 | Official sequence2sequence project: 168 | 169 | https://github.com/tensorflow/nmt 170 | 171 | Another official sequence2sequence model: 172 | 173 | https://github.com/tensorflow/tensor2tensor 174 | 175 | Another seq2seq repo: 176 | 177 | https://github.com/ematvey/tensorflow-seq2seq-tutorials 178 | 179 | A very nice chatbot example: 180 | 181 | https://github.com/bshao001/ChatLearner 182 | 183 | 184 | # pylint 185 | 186 | pylintrc from [here](https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc) 187 | 188 | changed indent from 2 to 4 189 | 190 | PS. 谷歌的lint一般建议indent是2,相反百度的lint很多建议indent是4, 191 | 个人怀疑这里面有“中文”的问题,也许是因为从小习惯作文空两格?(就是四个英文空格了) 192 | 193 | 我个人是习惯4个的 194 | -------------------------------------------------------------------------------- /en2zh/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | 13 | sys.path.append('..') 14 | 15 | 16 | def test(bidirectional, cell_type, depth, 17 | attention_type, use_residual, use_dropout, time_major, hidden_units): 18 | """测试不同参数在生成的假数据上的运行结果""" 19 | 20 | from sequence_to_sequence import SequenceToSequence 21 | from data_utils import batch_flow 22 | from word_sequence import WordSequence # pylint: disable=unused-variable 23 | 24 | x_data, y_data, ws_input, ws_target = pickle.load( 25 | open('en-zh_cn.pkl', 'rb')) 26 | 27 | # 获取一些假数据 28 | # x_data, y_data, ws_input, ws_target = generate(size=10000) 29 | 30 | # 训练部分 31 | split = int(len(x_data) * 0.8) 32 | x_train, x_test, y_train, y_test = ( 33 | x_data[:split], x_data[split:], y_data[:split], y_data[split:]) 34 | n_epoch = 2 35 | batch_size = 256 36 | steps = int(len(x_train) / batch_size) + 1 37 | 38 | config = tf.ConfigProto( 39 | # device_count={'CPU': 1, 'GPU': 0}, 40 | allow_soft_placement=True, 41 | log_device_placement=False 42 | ) 43 | 44 | save_path = './s2ss_en2zh.ckpt' 45 | 46 | tf.reset_default_graph() 47 | with tf.Graph().as_default(): 48 | random.seed(0) 49 | np.random.seed(0) 50 | tf.set_random_seed(0) 51 | 52 | with tf.Session(config=config) as sess: 53 | 54 | model = SequenceToSequence( 55 | input_vocab_size=len(ws_input), 56 | target_vocab_size=len(ws_target), 57 | batch_size=batch_size, 58 | learning_rate=0.001, 59 | bidirectional=bidirectional, 60 | cell_type=cell_type, 61 | depth=depth, 62 | attention_type=attention_type, 63 | use_residual=use_residual, 64 | use_dropout=use_dropout, 65 | parallel_iterations=64, 66 | hidden_units=hidden_units, 67 | optimizer='adam', 68 | time_major=time_major 69 | ) 70 | init = tf.global_variables_initializer() 71 | sess.run(init) 72 | 73 | # print(sess.run(model.input_layer.kernel)) 74 | # exit(1) 75 | 76 | flow = batch_flow( 77 | [x_train, y_train], [ws_input, ws_target], batch_size 78 | ) 79 | 80 | for epoch in range(1, n_epoch + 1): 81 | costs = [] 82 | bar = tqdm(range(steps), total=steps, 83 | desc='epoch {}, loss=0.000000'.format(epoch)) 84 | for _ in bar: 85 | x, xl, y, yl = next(flow) 86 | cost = model.train(sess, x, xl, y, yl) 87 | costs.append(cost) 88 | bar.set_description('epoch {} loss={:.6f}'.format( 89 | epoch, 90 | np.mean(costs) 91 | )) 92 | 93 | model.save(sess, save_path) 94 | 95 | # 测试部分 96 | tf.reset_default_graph() 97 | model_pred = SequenceToSequence( 98 | input_vocab_size=len(ws_input), 99 | target_vocab_size=len(ws_target), 100 | batch_size=1, 101 | mode='decode', 102 | beam_width=12, 103 | bidirectional=bidirectional, 104 | cell_type=cell_type, 105 | depth=depth, 106 | attention_type=attention_type, 107 | use_residual=use_residual, 108 | use_dropout=use_dropout, 109 | hidden_units=hidden_units, 110 | time_major=time_major, 111 | parallel_iterations=1 # for test 112 | ) 113 | init = tf.global_variables_initializer() 114 | 115 | with tf.Session(config=config) as sess: 116 | sess.run(init) 117 | model_pred.load(sess, save_path) 118 | 119 | bar = batch_flow([x_test, y_test], [ws_input, ws_target], 1) 120 | t = 0 121 | for x, xl, y, yl in bar: 122 | pred = model_pred.predict( 123 | sess, 124 | np.array(x), 125 | np.array(xl) 126 | ) 127 | print(ws_input.inverse_transform(x[0])) 128 | print(ws_target.inverse_transform(y[0])) 129 | print(ws_target.inverse_transform(pred[0])) 130 | t += 1 131 | if t >= 3: 132 | break 133 | 134 | tf.reset_default_graph() 135 | model_pred = SequenceToSequence( 136 | input_vocab_size=len(ws_input), 137 | target_vocab_size=len(ws_target), 138 | batch_size=1, 139 | mode='decode', 140 | beam_width=1, 141 | bidirectional=bidirectional, 142 | cell_type=cell_type, 143 | depth=depth, 144 | attention_type=attention_type, 145 | use_residual=use_residual, 146 | use_dropout=use_dropout, 147 | hidden_units=hidden_units, 148 | time_major=time_major, 149 | parallel_iterations=1 # for test 150 | ) 151 | init = tf.global_variables_initializer() 152 | 153 | with tf.Session(config=config) as sess: 154 | sess.run(init) 155 | model_pred.load(sess, save_path) 156 | 157 | bar = batch_flow([x_test, y_test], [ws_input, ws_target], 1) 158 | t = 0 159 | for x, xl, y, yl in bar: 160 | pred = model_pred.predict( 161 | sess, 162 | np.array(x), 163 | np.array(xl) 164 | ) 165 | print(ws_input.inverse_transform(x[0])) 166 | print(ws_target.inverse_transform(y[0])) 167 | print(ws_target.inverse_transform(pred[0])) 168 | t += 1 169 | if t >= 3: 170 | break 171 | 172 | 173 | def main(): 174 | """入口程序,开始测试不同参数组合""" 175 | random.seed(0) 176 | np.random.seed(0) 177 | tf.set_random_seed(0) 178 | test(True, 'lstm', 3, 'Bahdanau', True, True, True, 64) 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /ner/train_crf_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对RNNCRF模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | 13 | sys.path.append('..') 14 | 15 | 16 | def test(bidirectional, cell_type, depth, 17 | use_residual, use_dropout, time_major, hidden_units, 18 | output_project_active, crf_loss=True, save_path='./s2ss_crf.ckpt'): 19 | """测试不同参数在生成的假数据上的运行结果""" 20 | 21 | from rnn_crf import RNNCRF 22 | from data_utils import batch_flow 23 | from word_sequence import WordSequence # pylint: disable=unused-variable 24 | 25 | x_data, y_data, ws_input, ws_target = pickle.load( 26 | open('ner.pkl', 'rb')) 27 | 28 | # 训练部分 29 | split = int(len(x_data) * 0.8) 30 | x_train, x_test, y_train, y_test = ( 31 | x_data[:split], x_data[split:], y_data[:split], y_data[split:]) 32 | n_epoch = 10 33 | batch_size = 128 34 | steps = int(len(x_train) / batch_size) + 1 35 | 36 | config = tf.ConfigProto( 37 | # device_count={'CPU': 1, 'GPU': 0}, 38 | allow_soft_placement=True, 39 | log_device_placement=False 40 | ) 41 | 42 | tf.reset_default_graph() 43 | with tf.Graph().as_default(): 44 | random.seed(0) 45 | np.random.seed(0) 46 | tf.set_random_seed(0) 47 | 48 | with tf.Session(config=config) as sess: 49 | 50 | model = RNNCRF( 51 | input_vocab_size=len(ws_input), 52 | target_vocab_size=len(ws_target), 53 | max_decode_step=100, 54 | batch_size=batch_size, 55 | learning_rate=0.001, 56 | bidirectional=bidirectional, 57 | cell_type=cell_type, 58 | depth=depth, 59 | use_residual=use_residual, 60 | use_dropout=use_dropout, 61 | parallel_iterations=64, 62 | hidden_units=hidden_units, 63 | optimizer='adam', 64 | time_major=time_major, 65 | output_project_active=output_project_active, 66 | crf_loss=crf_loss 67 | ) 68 | init = tf.global_variables_initializer() 69 | sess.run(init) 70 | 71 | # print(sess.run(model.input_layer.kernel)) 72 | # exit(1) 73 | 74 | flow = batch_flow( 75 | [x_train, y_train], [ws_input, ws_target], batch_size 76 | ) 77 | 78 | for epoch in range(1, n_epoch + 1): 79 | costs = [] 80 | bar = tqdm(range(steps), total=steps, 81 | desc='epoch {}, loss=0.000000'.format(epoch)) 82 | for _ in bar: 83 | x, xl, y, yl = next(flow) 84 | cost = model.train(sess, x, xl, y, yl) 85 | costs.append(cost) 86 | bar.set_description('epoch {} loss={:.6f}'.format( 87 | epoch, 88 | np.mean(costs) 89 | )) 90 | 91 | model.save(sess, save_path) 92 | 93 | # 测试部分 94 | 95 | tf.reset_default_graph() 96 | model_pred = RNNCRF( 97 | input_vocab_size=len(ws_input), 98 | target_vocab_size=len(ws_target), 99 | max_decode_step=100, 100 | batch_size=batch_size, 101 | mode='decode', 102 | bidirectional=bidirectional, 103 | cell_type=cell_type, 104 | depth=depth, 105 | use_residual=use_residual, 106 | use_dropout=use_dropout, 107 | hidden_units=hidden_units, 108 | time_major=time_major, 109 | parallel_iterations=1, 110 | output_project_active=output_project_active, 111 | crf_loss=crf_loss 112 | ) 113 | init = tf.global_variables_initializer() 114 | 115 | with tf.Session(config=config) as sess: 116 | sess.run(init) 117 | model_pred.load(sess, save_path) 118 | 119 | pbar = tqdm(range(100)) 120 | flow = batch_flow([x_test, y_test], [ws_input, ws_target], batch_size) 121 | acc = [] 122 | prec = [] 123 | rec = [] 124 | for i in pbar: 125 | x, xl, y, yl = next(flow) 126 | pred = model_pred.predict( 127 | sess, 128 | np.array(x), 129 | np.array(xl) 130 | ) 131 | 132 | for j in range(batch_size): 133 | 134 | right = np.asarray(ws_target.inverse_transform(y[j])) 135 | predict = ws_target.inverse_transform(pred[j]) 136 | if len(predict) < len(right): 137 | predict += ['O'] * (len(right) - len(predict)) 138 | predict = np.asarray(predict) 139 | 140 | pp = predict[:yl[j]] 141 | rr = right[:yl[j]] 142 | if len(rr) > 0: 143 | acc.append(np.sum(pp == rr) / len(rr)) 144 | 145 | pp = predict[:yl[j]] 146 | rr = right[:yl[j]] 147 | pp = pp[rr != 'O'] 148 | rr = rr[rr != 'O'] 149 | if len(rr) > 0: 150 | rec.append(np.sum(pp == rr) / len(rr)) 151 | 152 | pp = predict[:yl[j]] 153 | rr = right[:yl[j]] 154 | rr = rr[pp != 'O'] 155 | pp = pp[pp != 'O'] 156 | if len(rr) > 0: 157 | prec.append(np.sum(pp == rr) / len(rr)) 158 | 159 | if i < 3: 160 | # print(ws_input.inverse_transform(x[0])) 161 | # print(ws_target.inverse_transform(y[0])) 162 | # print(ws_target.inverse_transform(pred[0])) 163 | pass 164 | else: 165 | pbar.set_description( 166 | 'acc: {:.4f} prec: {:.4f} rec: {:.4f}'.format( 167 | np.mean(acc), np.mean(prec), np.mean(rec))) 168 | 169 | 170 | def main(): 171 | """入口程序,开始测试不同参数组合""" 172 | random.seed(0) 173 | np.random.seed(0) 174 | tf.set_random_seed(0) 175 | test(True, 'lstm', 1, False, True, False, 64, 'tanh') 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /chatbot_cut/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | # from sklearn.utils import shuffle 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(bidirectional, cell_type, depth, 18 | attention_type, use_residual, use_dropout, time_major, hidden_units): 19 | """测试不同参数在生成的假数据上的运行结果""" 20 | 21 | from sequence_to_sequence import SequenceToSequence 22 | from data_utils import batch_flow_bucket as batch_flow 23 | from word_sequence import WordSequence # pylint: disable=unused-variable 24 | from threadedgenerator import ThreadedGenerator 25 | 26 | emb = pickle.load(open('emb.pkl', 'rb')) 27 | 28 | x_data, y_data, ws = pickle.load( 29 | open('chatbot.pkl', 'rb')) 30 | 31 | # 训练部分 32 | n_epoch = 5 33 | batch_size = 128 34 | # x_data, y_data = shuffle(x_data, y_data, random_state=0) 35 | # x_data = x_data[:10000] 36 | # y_data = y_data[:10000] 37 | steps = int(len(x_data) / batch_size) + 1 38 | 39 | config = tf.ConfigProto( 40 | # device_count={'CPU': 1, 'GPU': 0}, 41 | allow_soft_placement=True, 42 | log_device_placement=False 43 | ) 44 | 45 | save_path = './s2ss_chatbot.ckpt' 46 | 47 | tf.reset_default_graph() 48 | with tf.Graph().as_default(): 49 | random.seed(0) 50 | np.random.seed(0) 51 | tf.set_random_seed(0) 52 | 53 | with tf.Session(config=config) as sess: 54 | 55 | model = SequenceToSequence( 56 | input_vocab_size=len(ws), 57 | target_vocab_size=len(ws), 58 | batch_size=batch_size, 59 | bidirectional=bidirectional, 60 | cell_type=cell_type, 61 | depth=depth, 62 | attention_type=attention_type, 63 | use_residual=use_residual, 64 | use_dropout=use_dropout, 65 | hidden_units=hidden_units, 66 | time_major=time_major, 67 | learning_rate=0.001, 68 | optimizer='adam', 69 | share_embedding=True, 70 | dropout=0.2, 71 | pretrained_embedding=True 72 | ) 73 | init = tf.global_variables_initializer() 74 | sess.run(init) 75 | 76 | # 加载训练好的embedding 77 | model.feed_embedding(sess, encoder=emb) 78 | 79 | # print(sess.run(model.input_layer.kernel)) 80 | # exit(1) 81 | 82 | flow = ThreadedGenerator( 83 | batch_flow([x_data, y_data], ws, batch_size), 84 | queue_maxsize=30) 85 | 86 | for epoch in range(1, n_epoch + 1): 87 | costs = [] 88 | bar = tqdm(range(steps), total=steps, 89 | desc='epoch {}, loss=0.000000'.format(epoch)) 90 | for _ in bar: 91 | x, xl, y, yl = next(flow) 92 | x = np.flip(x, axis=1) 93 | # print(x, y) 94 | cost, lr = model.train(sess, x, xl, y, yl, return_lr=True) 95 | costs.append(cost) 96 | bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format( 97 | epoch, 98 | np.mean(costs), 99 | lr 100 | )) 101 | 102 | model.save(sess, save_path) 103 | 104 | flow.close() 105 | 106 | # 测试部分 107 | tf.reset_default_graph() 108 | model_pred = SequenceToSequence( 109 | input_vocab_size=len(ws), 110 | target_vocab_size=len(ws), 111 | batch_size=1, 112 | mode='decode', 113 | beam_width=12, 114 | bidirectional=bidirectional, 115 | cell_type=cell_type, 116 | depth=depth, 117 | attention_type=attention_type, 118 | use_residual=use_residual, 119 | use_dropout=use_dropout, 120 | hidden_units=hidden_units, 121 | time_major=time_major, 122 | parallel_iterations=1, 123 | learning_rate=0.001, 124 | optimizer='adam', 125 | share_embedding=True, 126 | pretrained_embedding=True 127 | ) 128 | init = tf.global_variables_initializer() 129 | 130 | with tf.Session(config=config) as sess: 131 | sess.run(init) 132 | model_pred.load(sess, save_path) 133 | 134 | bar = batch_flow([x_data, y_data], ws, 1) 135 | t = 0 136 | for x, xl, y, yl in bar: 137 | x = np.flip(x, axis=1) 138 | pred = model_pred.predict( 139 | sess, 140 | np.array(x), 141 | np.array(xl) 142 | ) 143 | print(ws.inverse_transform(x[0])) 144 | print(ws.inverse_transform(y[0])) 145 | print(ws.inverse_transform(pred[0])) 146 | t += 1 147 | if t >= 3: 148 | break 149 | 150 | tf.reset_default_graph() 151 | model_pred = SequenceToSequence( 152 | input_vocab_size=len(ws), 153 | target_vocab_size=len(ws), 154 | batch_size=1, 155 | mode='decode', 156 | beam_width=1, 157 | bidirectional=bidirectional, 158 | cell_type=cell_type, 159 | depth=depth, 160 | attention_type=attention_type, 161 | use_residual=use_residual, 162 | use_dropout=use_dropout, 163 | hidden_units=hidden_units, 164 | time_major=time_major, 165 | parallel_iterations=1, 166 | learning_rate=0.001, 167 | optimizer='adam', 168 | share_embedding=True, 169 | pretrained_embedding=True 170 | ) 171 | init = tf.global_variables_initializer() 172 | 173 | with tf.Session(config=config) as sess: 174 | sess.run(init) 175 | model_pred.load(sess, save_path) 176 | 177 | bar = batch_flow([x_data, y_data], ws, 1) 178 | t = 0 179 | for x, xl, y, yl in bar: 180 | pred = model_pred.predict( 181 | sess, 182 | np.array(x), 183 | np.array(xl) 184 | ) 185 | print(ws.inverse_transform(x[0])) 186 | print(ws.inverse_transform(y[0])) 187 | print(ws.inverse_transform(pred[0])) 188 | t += 1 189 | if t >= 3: 190 | break 191 | 192 | 193 | def main(): 194 | """入口程序,开始测试不同参数组合""" 195 | random.seed(0) 196 | np.random.seed(0) 197 | tf.set_random_seed(0) 198 | test( 199 | bidirectional=True, 200 | cell_type='lstm', 201 | depth=2, 202 | attention_type='Bahdanau', 203 | use_residual=False, 204 | use_dropout=False, 205 | time_major=False, 206 | hidden_units=512 207 | ) 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import random 6 | import itertools 7 | from collections import OrderedDict 8 | 9 | import pandas as pd 10 | import numpy as np 11 | import tensorflow as tf 12 | from tqdm import tqdm 13 | 14 | from sequence_to_sequence import SequenceToSequence 15 | from data_utils import batch_flow 16 | from fake_data import generate 17 | 18 | 19 | def test(bidirectional, cell_type, depth, 20 | attention_type, use_residual, use_dropout, time_major): 21 | """测试不同参数在生成的假数据上的运行结果""" 22 | 23 | # 获取一些假数据 24 | x_data, y_data, ws_input, ws_target = generate(size=1000) 25 | 26 | # 训练部分 27 | 28 | split = int(len(x_data) * 0.8) 29 | x_train, x_test, y_train, y_test = ( 30 | x_data[:split], x_data[split:], y_data[:split], y_data[split:]) 31 | n_epoch = 1 32 | batch_size = 8 33 | steps = int(len(x_train) / batch_size) + 1 34 | 35 | config = tf.ConfigProto( 36 | device_count={'CPU': 1, 'GPU': 0}, 37 | allow_soft_placement=True, 38 | log_device_placement=False 39 | ) 40 | 41 | save_path = '/tmp/s2ss.ckpt' 42 | 43 | tf.reset_default_graph() 44 | with tf.Graph().as_default(): 45 | random.seed(0) 46 | np.random.seed(0) 47 | tf.set_random_seed(0) 48 | 49 | with tf.Session(config=config) as sess: 50 | 51 | model = SequenceToSequence( 52 | input_vocab_size=len(ws_input), 53 | target_vocab_size=len(ws_target), 54 | batch_size=batch_size, 55 | learning_rate=0.001, 56 | bidirectional=bidirectional, 57 | cell_type=cell_type, 58 | depth=depth, 59 | attention_type=attention_type, 60 | use_residual=use_residual, 61 | use_dropout=use_dropout, 62 | time_major=time_major, 63 | hidden_units=64, 64 | embedding_size=64, 65 | parallel_iterations=1 # for test 66 | ) 67 | init = tf.global_variables_initializer() 68 | sess.run(init) 69 | 70 | # print(sess.run(model.input_layer.kernel)) 71 | # exit(1) 72 | 73 | for epoch in range(1, n_epoch + 1): 74 | costs = [] 75 | flow = batch_flow( 76 | [x_train, y_train], [ws_input, ws_target], batch_size 77 | ) 78 | bar = tqdm(range(steps), 79 | desc='epoch {}, loss=0.000000'.format(epoch)) 80 | for _ in bar: 81 | x, xl, y, yl = next(flow) 82 | cost = model.train(sess, x, xl, y, yl) 83 | costs.append(cost) 84 | bar.set_description('epoch {} loss={:.6f}'.format( 85 | epoch, 86 | np.mean(costs) 87 | )) 88 | 89 | model.save(sess, save_path) 90 | 91 | # 测试部分 92 | tf.reset_default_graph() 93 | model_pred = SequenceToSequence( 94 | input_vocab_size=len(ws_input), 95 | target_vocab_size=len(ws_target), 96 | batch_size=1, 97 | mode='decode', 98 | beam_width=5, 99 | bidirectional=bidirectional, 100 | cell_type=cell_type, 101 | depth=depth, 102 | attention_type=attention_type, 103 | use_residual=use_residual, 104 | use_dropout=use_dropout, 105 | time_major=time_major, 106 | hidden_units=64, 107 | embedding_size=64, 108 | parallel_iterations=1 # for test 109 | ) 110 | init = tf.global_variables_initializer() 111 | 112 | with tf.Session(config=config) as sess: 113 | sess.run(init) 114 | model_pred.load(sess, save_path) 115 | 116 | bar = batch_flow([x_test, y_test], [ws_input, ws_target], 1) 117 | t = 0 118 | for x, xl, y, yl in bar: 119 | pred = model_pred.predict( 120 | sess, 121 | np.array(x), 122 | np.array(xl) 123 | ) 124 | print(ws_input.inverse_transform(x[0])) 125 | print(ws_target.inverse_transform(y[0])) 126 | print(ws_target.inverse_transform(pred[0])) 127 | t += 1 128 | if t >= 3: 129 | break 130 | 131 | tf.reset_default_graph() 132 | model_pred = SequenceToSequence( 133 | input_vocab_size=len(ws_input), 134 | target_vocab_size=len(ws_target), 135 | batch_size=1, 136 | mode='decode', 137 | beam_width=0, 138 | bidirectional=bidirectional, 139 | cell_type=cell_type, 140 | depth=depth, 141 | attention_type=attention_type, 142 | use_residual=use_residual, 143 | use_dropout=use_dropout, 144 | time_major=time_major, 145 | hidden_units=64, 146 | embedding_size=64, 147 | parallel_iterations=1 # for test 148 | ) 149 | init = tf.global_variables_initializer() 150 | 151 | with tf.Session(config=config) as sess: 152 | sess.run(init) 153 | model_pred.load(sess, save_path) 154 | 155 | bar = batch_flow([x_test, y_test], [ws_input, ws_target], 1) 156 | t = 0 157 | for x, xl, y, yl in bar: 158 | pred = model_pred.predict( 159 | sess, 160 | np.array(x), 161 | np.array(xl) 162 | ) 163 | print(ws_input.inverse_transform(x[0])) 164 | print(ws_target.inverse_transform(y[0])) 165 | print(ws_target.inverse_transform(pred[0])) 166 | t += 1 167 | if t >= 3: 168 | break 169 | 170 | # return last train loss 171 | return np.mean(costs) 172 | 173 | 174 | def main(): 175 | """入口程序,开始测试不同参数组合""" 176 | random.seed(0) 177 | np.random.seed(0) 178 | tf.set_random_seed(0) 179 | 180 | params = OrderedDict(( 181 | ('cell_type', ('gru', 'lstm')), 182 | ('attention_type', ('Luong', 'Bahdanau')), 183 | ('use_dropout', (True, False)), 184 | ('time_major', (True, False)), 185 | ('depth', (1, 2, 3)), 186 | ('bidirectional', (True, False)), 187 | ('use_residual', (True, False)), 188 | )) 189 | 190 | loop = itertools.product(*params.values()) 191 | 192 | rows = [] 193 | for param_value in loop: 194 | param = OrderedDict(zip(params.keys(), param_value)) 195 | print('=' * 30) 196 | row = [] 197 | for key, value in param.items(): 198 | print(key, ':', value) 199 | row.append(value) 200 | print('-' * 30) 201 | cost = test(**param) 202 | row += [cost] 203 | rows.append(row) 204 | 205 | columns = list(params.keys()) + ['loss'] 206 | dframe = pd.DataFrame(rows, columns=columns) 207 | dframe.to_excel('/tmp/s2ss_test.xlsx') 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /chatbot_cut/train_anti.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对SequenceToSequence模型进行基本的参数组合测试 3 | """ 4 | 5 | import sys 6 | import random 7 | import pickle 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tqdm import tqdm 12 | # from sklearn.utils import shuffle 13 | 14 | sys.path.append('..') 15 | 16 | 17 | def test(bidirectional, cell_type, depth, 18 | attention_type, use_residual, use_dropout, time_major, hidden_units): 19 | """测试不同参数在生成的假数据上的运行结果""" 20 | 21 | from sequence_to_sequence import SequenceToSequence 22 | from data_utils import batch_flow_bucket as batch_flow 23 | from word_sequence import WordSequence # pylint: disable=unused-variable 24 | from threadedgenerator import ThreadedGenerator 25 | 26 | emb = pickle.load(open('emb.pkl', 'rb')) 27 | 28 | x_data, y_data, ws = pickle.load( 29 | open('chatbot.pkl', 'rb')) 30 | 31 | # 训练部分 32 | n_epoch = 5 33 | batch_size = 128 34 | # x_data, y_data = shuffle(x_data, y_data, random_state=0) 35 | # x_data = x_data[:100000] 36 | # y_data = y_data[:100000] 37 | steps = int(len(x_data) / batch_size) + 1 38 | 39 | config = tf.ConfigProto( 40 | # device_count={'CPU': 1, 'GPU': 0}, 41 | allow_soft_placement=True, 42 | log_device_placement=False 43 | ) 44 | 45 | save_path = './s2ss_chatbot_anti.ckpt' 46 | 47 | tf.reset_default_graph() 48 | with tf.Graph().as_default(): 49 | random.seed(0) 50 | np.random.seed(0) 51 | tf.set_random_seed(0) 52 | 53 | with tf.Session(config=config) as sess: 54 | 55 | model = SequenceToSequence( 56 | input_vocab_size=len(ws), 57 | target_vocab_size=len(ws), 58 | batch_size=batch_size, 59 | bidirectional=bidirectional, 60 | cell_type=cell_type, 61 | depth=depth, 62 | attention_type=attention_type, 63 | use_residual=use_residual, 64 | use_dropout=use_dropout, 65 | hidden_units=hidden_units, 66 | time_major=time_major, 67 | learning_rate=0.001, 68 | optimizer='adam', 69 | share_embedding=True, 70 | dropout=0.2, 71 | pretrained_embedding=True 72 | ) 73 | init = tf.global_variables_initializer() 74 | sess.run(init) 75 | 76 | # 加载训练好的embedding 77 | model.feed_embedding(sess, encoder=emb) 78 | 79 | # print(sess.run(model.input_layer.kernel)) 80 | # exit(1) 81 | 82 | flow = ThreadedGenerator( 83 | batch_flow([x_data, y_data], ws, batch_size), 84 | queue_maxsize=30) 85 | 86 | dummy_encoder_inputs = np.array([ 87 | np.array([WordSequence.PAD]) for _ in range(batch_size)]) 88 | dummy_encoder_inputs_lengths = np.array([1] * batch_size) 89 | 90 | for epoch in range(1, n_epoch + 1): 91 | costs = [] 92 | bar = tqdm(range(steps), total=steps, 93 | desc='epoch {}, loss=0.000000'.format(epoch)) 94 | for _ in bar: 95 | x, xl, y, yl = next(flow) 96 | x = np.flip(x, axis=1) 97 | 98 | add_loss = model.train(sess, 99 | dummy_encoder_inputs, 100 | dummy_encoder_inputs_lengths, 101 | y, yl, loss_only=True) 102 | 103 | add_loss *= -0.5 104 | # print(x, y) 105 | cost, lr = model.train(sess, x, xl, y, yl, 106 | return_lr=True, add_loss=add_loss) 107 | costs.append(cost) 108 | bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format( 109 | epoch, 110 | np.mean(costs), 111 | lr 112 | )) 113 | 114 | model.save(sess, save_path) 115 | 116 | flow.close() 117 | 118 | # 测试部分 119 | tf.reset_default_graph() 120 | model_pred = SequenceToSequence( 121 | input_vocab_size=len(ws), 122 | target_vocab_size=len(ws), 123 | batch_size=1, 124 | mode='decode', 125 | beam_width=12, 126 | bidirectional=bidirectional, 127 | cell_type=cell_type, 128 | depth=depth, 129 | attention_type=attention_type, 130 | use_residual=use_residual, 131 | use_dropout=use_dropout, 132 | hidden_units=hidden_units, 133 | time_major=time_major, 134 | parallel_iterations=1, 135 | learning_rate=0.001, 136 | optimizer='adam', 137 | share_embedding=True, 138 | pretrained_embedding=True 139 | ) 140 | init = tf.global_variables_initializer() 141 | 142 | with tf.Session(config=config) as sess: 143 | sess.run(init) 144 | model_pred.load(sess, save_path) 145 | 146 | bar = batch_flow([x_data, y_data], ws, 1) 147 | t = 0 148 | for x, xl, y, yl in bar: 149 | x = np.flip(x, axis=1) 150 | pred = model_pred.predict( 151 | sess, 152 | np.array(x), 153 | np.array(xl) 154 | ) 155 | print(ws.inverse_transform(x[0])) 156 | print(ws.inverse_transform(y[0])) 157 | print(ws.inverse_transform(pred[0])) 158 | t += 1 159 | if t >= 3: 160 | break 161 | 162 | tf.reset_default_graph() 163 | model_pred = SequenceToSequence( 164 | input_vocab_size=len(ws), 165 | target_vocab_size=len(ws), 166 | batch_size=1, 167 | mode='decode', 168 | beam_width=1, 169 | bidirectional=bidirectional, 170 | cell_type=cell_type, 171 | depth=depth, 172 | attention_type=attention_type, 173 | use_residual=use_residual, 174 | use_dropout=use_dropout, 175 | hidden_units=hidden_units, 176 | time_major=time_major, 177 | parallel_iterations=1, 178 | learning_rate=0.001, 179 | optimizer='adam', 180 | share_embedding=True, 181 | pretrained_embedding=True 182 | ) 183 | init = tf.global_variables_initializer() 184 | 185 | with tf.Session(config=config) as sess: 186 | sess.run(init) 187 | model_pred.load(sess, save_path) 188 | 189 | bar = batch_flow([x_data, y_data], ws, 1) 190 | t = 0 191 | for x, xl, y, yl in bar: 192 | pred = model_pred.predict( 193 | sess, 194 | np.array(x), 195 | np.array(xl) 196 | ) 197 | print(ws.inverse_transform(x[0])) 198 | print(ws.inverse_transform(y[0])) 199 | print(ws.inverse_transform(pred[0])) 200 | t += 1 201 | if t >= 3: 202 | break 203 | 204 | 205 | def main(): 206 | """入口程序,开始测试不同参数组合""" 207 | random.seed(0) 208 | np.random.seed(0) 209 | tf.set_random_seed(0) 210 | test( 211 | bidirectional=True, 212 | cell_type='lstm', 213 | depth=2, 214 | attention_type='Bahdanau', 215 | use_residual=False, 216 | use_dropout=False, 217 | time_major=False, 218 | hidden_units=512 219 | ) 220 | 221 | 222 | if __name__ == '__main__': 223 | main() 224 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 一些数据操作所需的模块 3 | """ 4 | 5 | import random 6 | import numpy as np 7 | from tensorflow.python.client import device_lib 8 | from word_sequence import WordSequence 9 | 10 | VOCAB_SIZE_THRESHOLD_CPU = 50000 11 | 12 | 13 | def _get_available_gpus(): 14 | """获取当前可用GPU数量""" 15 | local_device_protos = device_lib.list_local_devices() 16 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 17 | 18 | 19 | def _get_embed_device(vocab_size): 20 | """Decide on which device to place an embed matrix given its vocab size. 21 | 根据输入输出的字典大小,选择在CPU还是GPU上初始化embedding向量 22 | """ 23 | gpus = _get_available_gpus() 24 | if not gpus or vocab_size > VOCAB_SIZE_THRESHOLD_CPU: 25 | return "/cpu:0" 26 | return "/gpu:0" 27 | 28 | 29 | def transform_sentence(sentence, ws, max_len=None, add_end=False): 30 | """转换一个单独句子 31 | Args: 32 | sentence: 一句话,例如一个数组['你', '好', '吗'] 33 | ws: 一个WordSequence对象,转换器 34 | max_len: 35 | 进行padding的长度,也就是如果sentence长度小于max_len 36 | 则padding到max_len这么长 37 | Ret: 38 | encoded: 39 | 一个经过ws转换的数组,例如[4, 5, 6, 3] 40 | encoded_len: 上面的长度 41 | """ 42 | encoded = ws.transform( 43 | sentence, 44 | max_len=max_len if max_len is not None else len(sentence)) 45 | encoded_len = len(sentence) + (1 if add_end else 0) # add end 46 | if encoded_len > len(encoded): 47 | encoded_len = len(encoded) 48 | return encoded, encoded_len 49 | 50 | 51 | def batch_flow(data, ws, batch_size, raw=False, add_end=True): 52 | """从数据中随机 batch_size 个的数据,然后 yield 出去 53 | Args: 54 | data: 55 | 是一个数组,必须包含一个护着更多个同等的数据队列数组 56 | ws: 57 | 可以是一个WordSequence对象,也可以是多个组成的数组 58 | 如果是多个,那么数组数量应该与data的数据数量保持一致,即len(data) == len(ws) 59 | batch_size: 60 | 批量的大小 61 | raw: 62 | 是否返回原始对象,如果为True,假设结果ret,那么len(ret) == len(data) * 3 63 | 如果为False,那么len(ret) == len(data) * 2 64 | 65 | 例如需要输入问题与答案的队列,问题队列Q = (q_1, q_2, q_3 ... q_n) 66 | 答案队列A = (a_1, a_2, a_3 ... a_n),有len(Q) == len(A) 67 | ws是一个Q与A共用的WordSequence对象, 68 | 那么可以有: batch_flow([Q, A], ws, batch_size=32) 69 | 这样会返回一个generator,每次next(generator)会返回一个包含4个对象的数组,分别代表: 70 | next(generator) == q_i_encoded, q_i_len, a_i_encoded, a_i_len 71 | 如果设置raw = True,则: 72 | next(generator) == q_i_encoded, q_i_len, q_i, a_i_encoded, a_i_len, a_i 73 | 74 | 其中 q_i_encoded 相当于 ws.transform(q_i) 75 | 76 | 不过经过了batch修正,把一个batch中每个结果的长度,padding到了数组内最大的句子长度 77 | """ 78 | 79 | all_data = list(zip(*data)) 80 | 81 | if isinstance(ws, (list, tuple)): 82 | assert len(ws) == len(data), \ 83 | 'len(ws) must equal to len(data) if ws is list or tuple' 84 | 85 | if isinstance(add_end, bool): 86 | add_end = [add_end] * len(data) 87 | else: 88 | assert(isinstance(add_end, (list, tuple))), \ 89 | 'add_end 不是 boolean,就应该是一个list(tuple) of boolean' 90 | assert len(add_end) == len(data), \ 91 | '如果 add_end 是list(tuple),那么 add_end 的长度应该和输入数据长度一致' 92 | 93 | mul = 2 94 | if raw: 95 | mul = 3 96 | 97 | while True: 98 | data_batch = random.sample(all_data, batch_size) 99 | batches = [[] for i in range(len(data) * mul)] 100 | 101 | max_lens = [] 102 | for j in range(len(data)): 103 | max_len = max([ 104 | len(x[j]) if hasattr(x[j], '__len__') else 0 105 | for x in data_batch 106 | ]) + (1 if add_end[j] else 0) 107 | max_lens.append(max_len) 108 | 109 | for d in data_batch: 110 | for j in range(len(data)): 111 | if isinstance(ws, (list, tuple)): 112 | w = ws[j] 113 | else: 114 | w = ws 115 | 116 | # 添加结尾 117 | line = d[j] 118 | if add_end[j] and isinstance(line, (tuple, list)): 119 | line = list(line) + [WordSequence.END_TAG] 120 | 121 | if w is not None: 122 | x, xl = transform_sentence(line, w, max_lens[j], add_end[j]) 123 | batches[j * mul].append(x) 124 | batches[j * mul + 1].append(xl) 125 | else: 126 | batches[j * mul].append(line) 127 | batches[j * mul + 1].append(line) 128 | if raw: 129 | batches[j * mul + 2].append(line) 130 | batches = [np.asarray(x) for x in batches] 131 | 132 | yield batches 133 | 134 | 135 | 136 | def batch_flow_bucket(data, ws, batch_size, raw=False, 137 | add_end=True, 138 | n_buckets=5, bucket_ind=1, 139 | debug=False): 140 | """batch_flow的bucket版本 141 | 多了两重要参数,一个是n_buckets,一个是bucket_ind 142 | n_buckets是分成几个buckets,理论上n_buckets == 1时就相当于没有进行buckets操作 143 | bucket_ind是指定哪一维度的输入数据作为bucket的依据 144 | """ 145 | 146 | all_data = list(zip(*data)) 147 | lengths = sorted(list(set([len(x[bucket_ind]) for x in all_data]))) 148 | if n_buckets > len(lengths): 149 | n_buckets = len(lengths) 150 | 151 | splits = np.array(lengths)[ 152 | (np.linspace(0, 1, 5, endpoint=False) * len(lengths)).astype(int) 153 | ].tolist() 154 | splits += [np.inf] 155 | 156 | if debug: 157 | print(splits) 158 | 159 | ind_data = {} 160 | for x in all_data: 161 | l = len(x[bucket_ind]) 162 | for ind, s in enumerate(splits[:-1]): 163 | if l >= s and l <= splits[ind + 1]: 164 | if ind not in ind_data: 165 | ind_data[ind] = [] 166 | ind_data[ind].append(x) 167 | break 168 | 169 | 170 | inds = sorted(list(ind_data.keys())) 171 | ind_p = [len(ind_data[x]) / len(all_data) for x in inds] 172 | if debug: 173 | print(np.sum(ind_p), ind_p) 174 | 175 | if isinstance(ws, (list, tuple)): 176 | assert len(ws) == len(data), \ 177 | 'len(ws) must equal to len(data) if ws is list or tuple' 178 | 179 | 180 | 181 | if isinstance(add_end, bool): 182 | add_end = [add_end] * len(data) 183 | else: 184 | assert(isinstance(add_end, (list, tuple))), \ 185 | 'add_end 不是 boolean,就应该是一个list(tuple) of boolean' 186 | assert len(add_end) == len(data), \ 187 | '如果 add_end 是list(tuple),那么 add_end 的长度应该和输入数据长度一致' 188 | 189 | mul = 2 190 | if raw: 191 | mul = 3 192 | 193 | while True: 194 | choice_ind = np.random.choice(inds, p=ind_p) 195 | if debug: 196 | print('choice_ind', choice_ind) 197 | data_batch = random.sample(ind_data[choice_ind], batch_size) 198 | batches = [[] for i in range(len(data) * mul)] 199 | 200 | max_lens = [] 201 | for j in range(len(data)): 202 | max_len = max([ 203 | len(x[j]) if hasattr(x[j], '__len__') else 0 204 | for x in data_batch 205 | ]) + (1 if add_end[j] else 0) 206 | max_lens.append(max_len) 207 | 208 | for d in data_batch: 209 | for j in range(len(data)): 210 | if isinstance(ws, (list, tuple)): 211 | w = ws[j] 212 | else: 213 | w = ws 214 | 215 | # 添加结尾 216 | line = d[j] 217 | if add_end[j] and isinstance(line, (tuple, list)): 218 | line = list(line) + [WordSequence.END_TAG] 219 | 220 | if w is not None: 221 | x, xl = transform_sentence(line, w, max_lens[j], add_end[j]) 222 | batches[j * mul].append(x) 223 | batches[j * mul + 1].append(xl) 224 | else: 225 | batches[j * mul].append(line) 226 | batches[j * mul + 1].append(line) 227 | if raw: 228 | batches[j * mul + 2].append(line) 229 | batches = [np.asarray(x) for x in batches] 230 | 231 | yield batches 232 | 233 | 234 | 235 | def test_batch_flow(): 236 | """test batch_flow function""" 237 | from fake_data import generate 238 | x_data, y_data, ws_input, ws_target = generate(size=10000) 239 | flow = batch_flow([x_data, y_data], [ws_input, ws_target], 4) 240 | x, xl, y, yl = next(flow) 241 | print(x.shape, y.shape, xl.shape, yl.shape) 242 | 243 | 244 | def test_batch_flow_bucket(): 245 | """test batch_flow function""" 246 | from fake_data import generate 247 | x_data, y_data, ws_input, ws_target = generate(size=10000) 248 | flow = batch_flow_bucket( 249 | [x_data, y_data], [ws_input, ws_target], 4, 250 | debug=True) 251 | for _ in range(10): 252 | x, xl, y, yl = next(flow) 253 | print(x.shape, y.shape, xl.shape, yl.shape) 254 | 255 | 256 | if __name__ == '__main__': 257 | test_batch_flow_bucket() 258 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Profiled execution. 11 | profile=no 12 | 13 | # Add files or directories to the blacklist. They should be base names, not 14 | # paths. 15 | ignore=CVS 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=yes 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | 25 | [MESSAGES CONTROL] 26 | 27 | # Enable the message, report, category or checker with the given id(s). You can 28 | # either give multiple identifier separated by comma (,) or put this option 29 | # multiple time. See also the "--disable" option for examples. 30 | enable=indexing-exception,old-raise-syntax 31 | 32 | # Disable the message, report, category or checker with the given id(s). You 33 | # can either give multiple identifiers separated by comma (,) or put this 34 | # option multiple times (only on the command line, not in the configuration 35 | # file where it should appear only once).You can also use "--disable=all" to 36 | # disable everything first and then reenable specific checks. For example, if 37 | # you want to run only the similarities checker, you can use "--disable=all 38 | # --enable=similarities". If you want to run only the classes checker, but have 39 | # no Warning level messages displayed, use"--disable=all --enable=classes 40 | # --disable=W" 41 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager 42 | 43 | 44 | # Set the cache size for astng objects. 45 | cache-size=500 46 | 47 | 48 | [REPORTS] 49 | 50 | # Set the output format. Available formats are text, parseable, colorized, msvs 51 | # (visual studio) and html. You can also give a reporter class, eg 52 | # mypackage.mymodule.MyReporterClass. 53 | output-format=text 54 | 55 | # Put messages in a separate file for each module / package specified on the 56 | # command line instead of printing them on stdout. Reports (if any) will be 57 | # written in a file name "pylint_global.[txt|html]". 58 | files-output=no 59 | 60 | # Tells whether to display a full report or only the messages 61 | reports=no 62 | 63 | # Python expression which should return a note less than 10 (10 is the highest 64 | # note). You have access to the variables errors warning, statement which 65 | # respectively contain the number of errors / warnings messages and the total 66 | # number of statements analyzed. This is used by the global evaluation report 67 | # (RP0004). 68 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 69 | 70 | # Add a comment according to your evaluation note. This is used by the global 71 | # evaluation report (RP0004). 72 | comment=no 73 | 74 | # Template used to display messages. This is a python new-style format string 75 | # used to format the message information. See doc for all details 76 | #msg-template= 77 | 78 | 79 | [TYPECHECK] 80 | 81 | # Tells whether missing members accessed in mixin class should be ignored. A 82 | # mixin class is detected if its name ends with "mixin" (case insensitive). 83 | ignore-mixin-members=yes 84 | 85 | # List of classes names for which member attributes should not be checked 86 | # (useful for classes with attributes dynamically set). 87 | ignored-classes=SQLObject 88 | 89 | # When zope mode is activated, add a predefined set of Zope acquired attributes 90 | # to generated-members. 91 | zope=no 92 | 93 | # List of members which are set dynamically and missed by pylint inference 94 | # system, and so shouldn't trigger E0201 when accessed. Python regular 95 | # expressions are accepted. 96 | generated-members=REQUEST,acl_users,aq_parent 97 | 98 | # List of decorators that create context managers from functions, such as 99 | # contextlib.contextmanager. 100 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 101 | 102 | 103 | [VARIABLES] 104 | 105 | # Tells whether we should check for unused import in __init__ files. 106 | init-import=no 107 | 108 | # A regular expression matching the beginning of the name of dummy variables 109 | # (i.e. not used). 110 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 111 | 112 | # List of additional names supposed to be defined in builtins. Remember that 113 | # you should avoid to define new builtins when possible. 114 | additional-builtins= 115 | 116 | 117 | [BASIC] 118 | 119 | # Required attributes for module, separated by a comma 120 | required-attributes= 121 | 122 | # List of builtins function names that should not be used, separated by a comma 123 | bad-functions=apply,input,reduce 124 | 125 | 126 | # Disable the report(s) with the given id(s). 127 | # All non-Google reports are disabled by default. 128 | disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 129 | 130 | # Regular expression which should only match correct module names 131 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 132 | 133 | # Regular expression which should only match correct module level names 134 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 135 | 136 | # Regular expression which should only match correct class names 137 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 138 | 139 | # Regular expression which should only match correct function names 140 | function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 141 | 142 | # Regular expression which should only match correct method names 143 | method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 144 | 145 | # Regular expression which should only match correct instance attribute names 146 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 147 | 148 | # Regular expression which should only match correct argument names 149 | argument-rgx=^[a-z][a-z0-9_]*$ 150 | 151 | # Regular expression which should only match correct variable names 152 | variable-rgx=^[a-z][a-z0-9_]*$ 153 | 154 | # Regular expression which should only match correct attribute names in class 155 | # bodies 156 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 157 | 158 | # Regular expression which should only match correct list comprehension / 159 | # generator expression variable names 160 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 161 | 162 | # Good variable names which should always be accepted, separated by a comma 163 | good-names=main,_ 164 | 165 | # Bad variable names which should always be refused, separated by a comma 166 | bad-names= 167 | 168 | # Regular expression which should only match function or class names that do 169 | # not require a docstring. 170 | no-docstring-rgx=(__.*__|main) 171 | 172 | # Minimum line length for functions/classes that require docstrings, shorter 173 | # ones are exempt. 174 | docstring-min-length=10 175 | 176 | 177 | [FORMAT] 178 | 179 | # Maximum number of characters on a single line. 180 | max-line-length=80 181 | 182 | # Regexp for a line that is allowed to be longer than the limit. 183 | ignore-long-lines=(?x) 184 | (^\s*(import|from)\s 185 | |\$Id:\s\/\/depot\/.+#\d+\s\$ 186 | |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') 187 | |^\s*\#\ LINT\.ThenChange 188 | |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ 189 | |pylint 190 | |""" 191 | |\# 192 | |lambda 193 | |(https?|ftp):) 194 | 195 | # Allow the body of an if to be on the same line as the test if there is no 196 | # else. 197 | single-line-if-stmt=y 198 | 199 | # List of optional constructs for which whitespace checking is disabled 200 | no-space-check= 201 | 202 | # Maximum number of lines in a module 203 | max-module-lines=99999 204 | 205 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 206 | # tab). 207 | indent-string=' ' 208 | 209 | 210 | [SIMILARITIES] 211 | 212 | # Minimum lines number of a similarity. 213 | min-similarity-lines=4 214 | 215 | # Ignore comments when computing similarities. 216 | ignore-comments=yes 217 | 218 | # Ignore docstrings when computing similarities. 219 | ignore-docstrings=yes 220 | 221 | # Ignore imports when computing similarities. 222 | ignore-imports=no 223 | 224 | 225 | [MISCELLANEOUS] 226 | 227 | # List of note tags to take in consideration, separated by a comma. 228 | notes= 229 | 230 | 231 | [IMPORTS] 232 | 233 | # Deprecated modules which should not be used, separated by a comma 234 | deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets 235 | 236 | # Create a graph of every (i.e. internal and external) dependencies in the 237 | # given file (report RP0402 must not be disabled) 238 | import-graph= 239 | 240 | # Create a graph of external dependencies in the given file (report RP0402 must 241 | # not be disabled) 242 | ext-import-graph= 243 | 244 | # Create a graph of internal dependencies in the given file (report RP0402 must 245 | # not be disabled) 246 | int-import-graph= 247 | 248 | 249 | [CLASSES] 250 | 251 | # List of interface methods to ignore, separated by a comma. This is used for 252 | # instance to not check methods defines in Zope's Interface base class. 253 | ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by 254 | 255 | # List of method names used to declare (i.e. assign) instance attributes. 256 | defining-attr-methods=__init__,__new__,setUp 257 | 258 | # List of valid names for the first argument in a class method. 259 | valid-classmethod-first-arg=cls,class_ 260 | 261 | # List of valid names for the first argument in a metaclass class method. 262 | valid-metaclass-classmethod-first-arg=mcs 263 | 264 | 265 | [DESIGN] 266 | 267 | # Maximum number of arguments for function / method 268 | max-args=5 269 | 270 | # Argument names that match this expression will be ignored. Default to name 271 | # with leading underscore 272 | ignored-argument-names=_.* 273 | 274 | # Maximum number of locals for function / method body 275 | max-locals=15 276 | 277 | # Maximum number of return / yield for function / method body 278 | max-returns=6 279 | 280 | # Maximum number of branch for function / method body 281 | max-branches=12 282 | 283 | # Maximum number of statements in function / method body 284 | max-statements=50 285 | 286 | # Maximum number of parents for a class (see R0901). 287 | max-parents=7 288 | 289 | # Maximum number of attributes for a class (see R0902). 290 | max-attributes=7 291 | 292 | # Minimum number of public methods for a class (see R0903). 293 | min-public-methods=2 294 | 295 | # Maximum number of public methods for a class (see R0904). 296 | max-public-methods=20 297 | 298 | 299 | [EXCEPTIONS] 300 | 301 | # Exceptions that will emit a warning when being caught. Defaults to 302 | # "Exception" 303 | overgeneral-exceptions=Exception,StandardError,BaseException 304 | 305 | 306 | [AST] 307 | 308 | # Maximum line length for lambdas 309 | short-func-length=1 310 | 311 | # List of module members that should be marked as deprecated. 312 | # All of the string functions are listed in 4.1.4 Deprecated string functions 313 | # in the Python 2.4 docs. 314 | deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc 315 | 316 | 317 | [DOCSTRING] 318 | 319 | # List of exceptions that do not need to be mentioned in the Raises section of 320 | # a docstring. 321 | ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError 322 | 323 | 324 | 325 | [TOKENS] 326 | 327 | # Number of spaces of indent required when the last token on the preceding line 328 | # is an open (, [, or {. 329 | indent-after-paren=4 330 | 331 | 332 | [GOOGLE LINES] 333 | 334 | # Regexp for a proper copyright notice. 335 | copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. 336 | -------------------------------------------------------------------------------- /rnn_crf.py: -------------------------------------------------------------------------------- 1 | """ 2 | QHDuan 3 | 2018-02-05 4 | 5 | RNN-CRF Model 6 | 7 | crf: 8 | https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/crf 9 | """ 10 | 11 | 12 | import math 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | # from tensorflow import layers 17 | from tensorflow.contrib.rnn import LSTMCell 18 | from tensorflow.contrib.rnn import GRUCell 19 | from tensorflow.contrib.rnn import MultiRNNCell 20 | from tensorflow.contrib.rnn import DropoutWrapper 21 | from tensorflow.contrib.rnn import ResidualWrapper 22 | from tensorflow.contrib.rnn import LSTMStateTuple 23 | 24 | from word_sequence import WordSequence 25 | from data_utils import _get_embed_device 26 | 27 | 28 | class RNNCRF(object): 29 | """SequenceToSequence Model 30 | 31 | 基本流程 32 | __init__ 基本参数保存,验证参数合法性 33 | build_model 开始构建整个模型 34 | init_placeholders 初始化一些tensorflow的变量占位符 35 | build_encoder 初始化编码器 36 | build_single_cell 37 | build_encoder_cell 38 | build_decoder_crf 初始化解码器 39 | init_optimizer 如果是在训练模式则初始化优化器 40 | train 训练一个batch的数据 41 | predict 预测一个batch的数据 42 | """ 43 | 44 | def __init__(self, 45 | input_vocab_size, 46 | target_vocab_size, 47 | max_decode_step, 48 | batch_size=32, 49 | embedding_size=128, 50 | mode='train', 51 | hidden_units=256, 52 | depth=1, 53 | cell_type='lstm', 54 | dropout=0.2, 55 | use_dropout=False, 56 | use_residual=False, 57 | optimizer='adam', 58 | learning_rate=0.001, 59 | min_learning_rate=1e-6, 60 | decay_steps=500000, 61 | max_gradient_norm=5.0, 62 | bidirectional=False, 63 | output_project_active=None, 64 | time_major=False, 65 | seed=0, 66 | parallel_iterations=32, 67 | crf_loss=True): 68 | """保存参数变量,开始构建整个模型 69 | Args: 70 | input_vocab_size: 输入词表大小 71 | target_vocab_size: 输出词表大小 72 | max_decode_step: 73 | 最大的解码长度,可以是很大的整数,默认是None 74 | None的情况下默认是encoder输入最大长度的 4 倍 75 | batch_size: 数据batch的大小 76 | embedding_size, 输入词表与输出词表embedding的维度 77 | mode: 取值为 train 或者 decode,训练模式或者预测模式 78 | hidden_units: 79 | RNN模型的中间层大小,encoder和decoder层相同 80 | 如果encoder层是bidirectional的话,decoder层是双倍大小 81 | depth: encoder和decoder的rnn层数 82 | cell_type: rnn神经元类型,lstm 或者 gru 83 | dropout: dropout比例,取值 [0, 1) 84 | use_dropout: 是否使用dropout 85 | use_residual:# 是否使用residual 86 | optimizer: 优化方法, adam, adadelta, sgd, rmsprop, momentum 87 | learning_rate: 学习率 88 | max_gradient_norm: 梯度正则剪裁的系数 89 | bidirectional: encoder 是否为双向 90 | output_project_active: 91 | 是否在crf之前使用一个投影层,并指定一个激活函数 92 | None, 'tanh', 'sigmoid', 'linear' 93 | time_major: 94 | 是否在“计算过程”中使用时间为主的批量数据 95 | 注意,改变这个参数并不要求改变输入数据的格式 96 | 输入数据的格式为 [batch_size, time_step] 是一个二维矩阵 97 | time_step是句子长度 98 | 经过 embedding 之后,数据会变为 99 | [batch_size, time_step, embedding_size] 100 | 这是一个三维矩阵(或者三维张量Tensor) 101 | 这样的数据格式是 time_major=False 的 102 | 如果设置 time_major=True 的话,在部分计算的时候,会把矩阵转置为 103 | [time_step, batch_size, embedding_size] 104 | 也就是 time_step 是第一维,所以叫 time_major 105 | TensorFlow官方文档认为time_major=True会比较快 106 | seed: 一些层间操作的随机数 seed 设置 107 | parallel_iterations: 108 | dynamic_rnn 和 dynamic_decode 的并行数量 109 | 如果要取得可重复结果,在有dropout的情况下,应该设置为 1,否则结果会不确定 110 | """ 111 | 112 | self.input_vocab_size = input_vocab_size 113 | self.target_vocab_size = target_vocab_size 114 | self.max_decode_step = max_decode_step 115 | self.batch_size = batch_size 116 | self.embedding_size = embedding_size 117 | self.hidden_units = hidden_units 118 | self.depth = depth 119 | self.cell_type = cell_type 120 | self.use_dropout = use_dropout 121 | self.use_residual = use_residual 122 | self.mode = mode 123 | self.optimizer = optimizer 124 | self.learning_rate = learning_rate 125 | self.min_learning_rate = min_learning_rate 126 | self.decay_steps = decay_steps 127 | self.max_gradient_norm = max_gradient_norm 128 | self.keep_prob = 1.0 - dropout 129 | self.bidirectional = bidirectional 130 | self.output_project_active = output_project_active 131 | self.seed = seed 132 | self.parallel_iterations = parallel_iterations 133 | self.time_major = time_major 134 | self.crf_loss = crf_loss 135 | 136 | assert output_project_active in (None, 'tanh', 'sigmoid', 'linear'), \ 137 | 'output_project_active 必须是 None, "tanh", "sigmoid", "linear"之一' 138 | 139 | assert mode in ('train', 'decode'), \ 140 | 'mode 必须是 "train" 或 "decode" 而不是 "{}"'.format(mode) 141 | 142 | assert dropout >= 0.0 and dropout < 1.0, '0 <= dropout < 1' 143 | 144 | self.keep_prob_placeholder = tf.placeholder( 145 | tf.float32, 146 | shape=[], 147 | name='keep_prob' 148 | ) 149 | 150 | self.global_step = tf.Variable( 151 | 0, trainable=False, name='global_step' 152 | ) 153 | self.global_epoch_step = tf.Variable( 154 | 0, trainable=False, name='global_epoch_step' 155 | ) 156 | self.global_epoch_step_op = tf.assign( 157 | self.global_epoch_step, 158 | self.global_epoch_step + 1 159 | ) 160 | 161 | assert self.optimizer.lower() in \ 162 | ('adadelta', 'adam', 'rmsprop', 'momentum', 'sgd'), \ 163 | 'optimizer 必须是下列之一: adadelta, adam, rmsprop, momentum, sgd' 164 | 165 | self.build_model() 166 | 167 | 168 | def build_model(self): 169 | """构建整个模型 170 | 分别构建 171 | 编码器(encoder) 172 | 解码器(decoder) 173 | 优化器(只在训练时构建,optimizer) 174 | """ 175 | self.init_placeholders() 176 | self.build_encoder() 177 | self.build_decoder_crf() 178 | 179 | if self.mode == 'train': 180 | self.init_optimizer() 181 | 182 | self.saver = tf.train.Saver() 183 | 184 | 185 | def init_placeholders(self): 186 | """初始化训练、预测所需的变量 187 | """ 188 | 189 | # 编码器输入,shape=(batch_size, time_step) 190 | # 有 batch_size 句话,每句话是最大长度为 time_step 的 index 表示 191 | self.encoder_inputs = tf.placeholder( 192 | dtype=tf.int32, 193 | shape=(self.batch_size, None), 194 | name='encoder_inputs' 195 | ) 196 | 197 | # crf 是固定长度的 198 | self.encoder_inputs_length = tf.fill( 199 | dims=[self.batch_size], 200 | value=self.max_decode_step, 201 | name='encoder_inputs_length' 202 | ) 203 | 204 | # 训练模式 205 | 206 | # 解码器输入,shape=(batch_size, time_step) 207 | self.decoder_inputs = tf.placeholder( 208 | dtype=tf.int32, 209 | shape=(self.batch_size, None), 210 | name='decoder_inputs' 211 | ) 212 | 213 | # 解码器长度输入,shape=(batch_size,) 214 | self.decoder_inputs_length = tf.placeholder( 215 | dtype=tf.int32, 216 | shape=(self.batch_size,), 217 | name='decoder_inputs_length' 218 | ) 219 | 220 | self.decoder_start_token = tf.ones( 221 | shape=(self.batch_size, 1), 222 | dtype=tf.int32 223 | ) * WordSequence.START 224 | 225 | self.decoder_end_token = tf.ones( 226 | shape=(self.batch_size, 1), 227 | dtype=tf.int32 228 | ) * WordSequence.END 229 | 230 | # 实际训练的解码器输入,实际上是 start_token + decoder_inputs 231 | self.decoder_inputs_train = tf.concat([ 232 | self.decoder_start_token, 233 | self.decoder_inputs 234 | ], axis=1) 235 | 236 | # 这个变量用来计算一个mask,用来对loss函数的反向传播进行修正 237 | # 这里需要 + 1,因为会自动给训练结果增加 end_token 238 | self.decoder_inputs_length_train = self.decoder_inputs_length + 1 239 | 240 | # 实际训练的解码器目标,实际上是 decoder_inputs + end_token 241 | self.decoder_targets_train = tf.concat([ 242 | self.decoder_inputs, 243 | self.decoder_end_token 244 | ], axis=1) 245 | 246 | 247 | def build_single_cell(self, n_hidden, use_residual): 248 | """构建一个单独的rnn cell 249 | Args: 250 | n_hidden: 隐藏层神经元数量 251 | use_residual: 是否使用residual wrapper 252 | """ 253 | cell_type = LSTMCell 254 | if self.cell_type.lower() == 'gru': 255 | cell_type = GRUCell 256 | cell = cell_type(n_hidden) 257 | 258 | if self.use_dropout: 259 | cell = DropoutWrapper( 260 | cell, 261 | dtype=tf.float32, 262 | output_keep_prob=self.keep_prob_placeholder, 263 | seed=self.seed 264 | ) 265 | if use_residual: 266 | cell = ResidualWrapper(cell) 267 | 268 | return cell 269 | 270 | def build_encoder_cell(self): 271 | """构建一个单独的编码器cell 272 | """ 273 | return MultiRNNCell([ 274 | self.build_single_cell( 275 | self.hidden_units, 276 | use_residual=self.use_residual 277 | ) 278 | for _ in range(self.depth) 279 | ]) 280 | 281 | 282 | def build_encoder(self): 283 | """构建编码器 284 | """ 285 | # print("构建编码器") 286 | with tf.variable_scope('encoder'): 287 | # 构建 encoder_cell 288 | self.encoder_cell = self.build_encoder_cell() 289 | 290 | # Initialize encoder_embeddings to have variance=1. 291 | sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. 292 | initializer = tf.random_uniform_initializer( 293 | -sqrt3, sqrt3, dtype=tf.float32 294 | ) 295 | 296 | # 编码器的embedding 297 | with tf.device(_get_embed_device(self.input_vocab_size)): 298 | self.encoder_embeddings = tf.get_variable( 299 | name='embedding', 300 | shape=(self.input_vocab_size, self.embedding_size), 301 | initializer=initializer, 302 | dtype=tf.float32 303 | ) 304 | 305 | # embedded之后的输入 shape = (batch_size, time_step, embedding_size) 306 | self.encoder_inputs_embedded = tf.nn.embedding_lookup( 307 | params=self.encoder_embeddings, 308 | ids=self.encoder_inputs 309 | ) 310 | 311 | # Encode input sequences into context vectors: 312 | # encoder_outputs: [batch_size, max_time_step, cell_output_size] 313 | # encoder_state: [batch_size, cell_output_size] 314 | 315 | inputs = self.encoder_inputs_embedded 316 | if self.time_major: 317 | inputs = tf.transpose(inputs, (1, 0, 2)) 318 | 319 | if not self.bidirectional: 320 | ( 321 | self.encoder_outputs, 322 | self.encoder_last_state 323 | ) = tf.nn.dynamic_rnn( 324 | cell=self.encoder_cell, 325 | inputs=inputs, 326 | sequence_length=self.encoder_inputs_length, 327 | dtype=tf.float32, 328 | time_major=self.time_major, 329 | parallel_iterations=self.parallel_iterations, 330 | swap_memory=True 331 | ) 332 | else: 333 | self.encoder_cell_bw = self.build_encoder_cell() 334 | ( 335 | (encoder_fw_outputs, encoder_bw_outputs), 336 | (encoder_fw_state, encoder_bw_state) 337 | ) = tf.nn.bidirectional_dynamic_rnn( 338 | cell_fw=self.encoder_cell, 339 | cell_bw=self.encoder_cell_bw, 340 | inputs=inputs, 341 | sequence_length=self.encoder_inputs_length, 342 | dtype=tf.float32, 343 | time_major=self.time_major, 344 | parallel_iterations=self.parallel_iterations, 345 | swap_memory=True 346 | ) 347 | 348 | self.encoder_outputs = tf.concat( 349 | (encoder_fw_outputs, encoder_bw_outputs), 2) 350 | 351 | # 在 bidirectional 的情况下合并 state 352 | # QHD 353 | # borrow from 354 | # https://github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/model_new.py 355 | # 对上面链接中的代码有修改,因为原代码没有考虑多层cell的情况(MultiRNNCell) 356 | if isinstance(encoder_fw_state[0], LSTMStateTuple): 357 | # LSTM 的 cell 358 | self.encoder_last_state = tuple([ 359 | LSTMStateTuple( 360 | c=tf.concat(( 361 | encoder_fw_state[i].c, 362 | encoder_bw_state[i].c 363 | ), 1), 364 | h=tf.concat(( 365 | encoder_fw_state[i].h, 366 | encoder_bw_state[i].h 367 | ), 1) 368 | ) 369 | for i in range(len(encoder_fw_state)) 370 | ]) 371 | elif isinstance(encoder_fw_state[0], tf.Tensor): 372 | # GRU 的中间状态只有一个,所以类型是 tf.Tensor 373 | # 分别合并(concat)就可以了 374 | self.encoder_last_state = tuple([ 375 | tf.concat( 376 | (encoder_fw_state[i], encoder_bw_state[i]), 377 | 1, name='bidirectional_concat_{}'.format(i) 378 | ) 379 | for i in range(len(encoder_fw_state)) 380 | ]) 381 | 382 | 383 | def build_decoder_crf(self): 384 | """构建crf解码器 385 | """ 386 | 387 | with tf.variable_scope('decoder_crf'): 388 | encoder_outputs = self.encoder_outputs 389 | 390 | hidden_units = self.hidden_units 391 | if self.bidirectional: 392 | hidden_units *= 2 393 | 394 | encoder_outputs = tf.concat(encoder_outputs, 395 | axis=2) 396 | encoder_outputs = tf.reshape(encoder_outputs, 397 | [-1, hidden_units], name='crf_output') 398 | self.encoder_outputs = encoder_outputs 399 | 400 | # Initialize decoder embeddings to have variance=1. 401 | sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. 402 | initializer = tf.random_uniform_initializer( 403 | -sqrt3, sqrt3, dtype=tf.float32 404 | ) 405 | 406 | if self.output_project_active is not None: 407 | proj_w = tf.get_variable('proj_w', 408 | [hidden_units, hidden_units], 409 | initializer=initializer) 410 | proj_b = tf.get_variable('proj_b', [hidden_units], 411 | initializer=tf.zeros_initializer()) 412 | encoder_outputs = tf.nn.xw_plus_b( 413 | encoder_outputs, proj_w, proj_b, name='proj_output') 414 | 415 | if self.output_project_active == 'tanh': 416 | encoder_outputs = tf.tanh(encoder_outputs) 417 | elif self.output_project_active == 'sigmoid': 418 | encoder_outputs = tf.sigmoid(encoder_outputs) 419 | 420 | # 把encoder的结果进行一次线性变换所需要的变量 421 | crf_w = tf.get_variable('crf_w', 422 | [hidden_units, self.target_vocab_size], 423 | initializer=initializer) 424 | crf_b = tf.get_variable('crf_b', [self.target_vocab_size], 425 | initializer=tf.zeros_initializer()) 426 | 427 | outputs = tf.nn.xw_plus_b(encoder_outputs, 428 | crf_w, crf_b, name='crf_output') 429 | 430 | # crf 计算必须固定住max_decode_step 431 | self.logits = tf.reshape( 432 | outputs, 433 | shape=[self.batch_size, 434 | self.max_decode_step, 435 | self.target_vocab_size]) 436 | 437 | if self.crf_loss: 438 | ( 439 | log_likelihood, 440 | self.transition_params 441 | ) = tf.contrib.crf.crf_log_likelihood( 442 | self.logits, 443 | self.decoder_inputs, 444 | self.decoder_inputs_length) 445 | 446 | ( 447 | self.viterbi_sequence, 448 | self.viterbi_score 449 | ) = tf.contrib.crf.crf_decode( 450 | self.logits, 451 | self.transition_params, 452 | self.encoder_inputs_length) 453 | 454 | self.loss = tf.reduce_mean(-log_likelihood) 455 | else: 456 | self.outputs = tf.argmax(self.logits, 2) 457 | 458 | self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 459 | labels=self.decoder_inputs, logits=self.logits) 460 | 461 | masks = tf.sequence_mask( 462 | lengths=self.decoder_inputs_length, 463 | maxlen=self.max_decode_step, 464 | dtype=tf.float32, name='masks' 465 | ) 466 | 467 | self.loss *= masks 468 | self.loss = tf.reduce_sum(self.loss) 469 | 470 | 471 | def save(self, sess, save_path='model.ckpt'): 472 | """保存模型""" 473 | self.saver.save(sess, save_path=save_path) 474 | 475 | 476 | def load(self, sess, save_path='model.ckpt'): 477 | """读取模型""" 478 | print('try load model from', save_path) 479 | self.saver.restore(sess, save_path) 480 | 481 | 482 | def check_feeds(self, encoder_inputs, encoder_inputs_length, 483 | decoder_inputs, decoder_inputs_length, decode): 484 | """检查输入变量,并返回input_feed 485 | 486 | 我们首先会把数据编码,例如把“你好吗”,编码为[0, 1, 2] 487 | 多个句子组成一个batch,共同训练,例如一个batch_size=2,那么训练矩阵就可能是 488 | encoder_inputs = [ 489 | [0, 1, 2, 3], 490 | [4, 5, 6, 7] 491 | ] 492 | 它所代表的可能是:[['我', '是', '帅', '哥'], ['你', '好', '啊', '']] 493 | 注意第一句的真实长度是 4,第二句只有 3(最后的是一个填充数据) 494 | 495 | 那么: 496 | encoder_inputs_length = [4, 3] 497 | 来代表输入整个batch的真实长度 498 | 注意,为了符合算法要求,每个batch的句子必须是长度降序的,也就是说你输入一个 499 | encoder_inputs_length = [1, 10] 这样是错误的,必须在输入前排序到 500 | encoder_inputs_length = [10, 1] 这样才行 501 | 502 | decoder_inputs 和 decoder_inputs_length 所代表的含义差不多 503 | 504 | Args: 505 | encoder_inputs: 506 | 一个整形二维矩阵 [batch_size, max_source_time_steps] 507 | encoder_inputs_length: 508 | 一个整形向量 [batch_size] 509 | 每个维度是encoder句子的真实长度 510 | decoder_inputs: 511 | 一个整形矩阵 [batch_size, max_target_time_steps] 512 | decoder_inputs_length: 513 | 一个整形向量 [batch_size] 514 | 每个维度是decoder句子的真实长度 515 | decode: 用来指示正在训练模式(decode=False)还是预测模式(decode=True) 516 | Returns: 517 | tensorflow所操作需要的input_feed,包括 518 | encoder_inputs, encoder_inputs_length, 519 | decoder_inputs, decoder_inputs_length 520 | """ 521 | 522 | input_batch_size = encoder_inputs.shape[0] 523 | if input_batch_size != encoder_inputs_length.shape[0]: 524 | raise ValueError( 525 | "Encoder inputs and their lengths must be equal in their " 526 | "batch_size, %d != %d" % ( 527 | input_batch_size, encoder_inputs_length.shape[0])) 528 | 529 | if not decode: 530 | target_batch_size = decoder_inputs.shape[0] 531 | if target_batch_size != input_batch_size: 532 | raise ValueError( 533 | "Encoder inputs and Decoder inputs must be equal in their " 534 | "batch_size, %d != %d" % ( 535 | input_batch_size, target_batch_size)) 536 | if target_batch_size != decoder_inputs_length.shape[0]: 537 | raise ValueError( 538 | "Decoder targets and their lengths must be equal in their " 539 | "batch_size, %d != %d" % ( 540 | target_batch_size, decoder_inputs_length.shape[0])) 541 | 542 | input_feed = {} 543 | 544 | input_feed[self.encoder_inputs.name] = encoder_inputs 545 | input_feed[self.encoder_inputs_length.name] = encoder_inputs_length 546 | 547 | if not decode: 548 | input_feed[self.decoder_inputs.name] = decoder_inputs 549 | input_feed[self.decoder_inputs_length.name] = decoder_inputs_length 550 | 551 | return input_feed 552 | 553 | 554 | def init_optimizer(self): 555 | """初始化优化器 556 | 支持的方法有 sgd, adadelta, adam, rmsprop, momentum 557 | """ 558 | 559 | # 学习率下降算法 560 | learning_rate = tf.train.polynomial_decay( 561 | self.learning_rate, 562 | self.global_step, 563 | self.decay_steps, 564 | self.min_learning_rate, 565 | power=0.5 566 | ) 567 | self.current_learning_rate = learning_rate 568 | 569 | # 设置优化器,合法的优化器如下 570 | # 'adadelta', 'adam', 'rmsprop', 'momentum', 'sgd' 571 | trainable_params = tf.trainable_variables() 572 | if self.optimizer.lower() == 'adadelta': 573 | self.opt = tf.train.AdadeltaOptimizer( 574 | learning_rate=learning_rate) 575 | elif self.optimizer.lower() == 'adam': 576 | self.opt = tf.train.AdamOptimizer( 577 | learning_rate=learning_rate) 578 | elif self.optimizer.lower() == 'rmsprop': 579 | self.opt = tf.train.RMSPropOptimizer( 580 | learning_rate=learning_rate) 581 | elif self.optimizer.lower() == 'momentum': 582 | self.opt = tf.train.MomentumOptimizer( 583 | learning_rate=learning_rate, momentum=0.9) 584 | elif self.optimizer.lower() == 'sgd': 585 | self.opt = tf.train.GradientDescentOptimizer( 586 | learning_rate=learning_rate) 587 | 588 | # Compute gradients of loss w.r.t. all trainable variables 589 | gradients = tf.gradients(self.loss, trainable_params) 590 | # Clip gradients by a given maximum_gradient_norm 591 | clip_gradients, _ = tf.clip_by_global_norm( 592 | gradients, self.max_gradient_norm) 593 | # Update the model 594 | self.updates = self.opt.apply_gradients( 595 | zip(clip_gradients, trainable_params), 596 | global_step=self.global_step) 597 | 598 | 599 | def train(self, sess, encoder_inputs, encoder_inputs_length, 600 | decoder_inputs, decoder_inputs_length): 601 | """训练模型""" 602 | 603 | # 如果是crf模式,自动 padding 到 self.max_decode_step 604 | # self.max_decode_step 相当于 max_time_step 605 | encoder_inputs_crf = [] 606 | for item in encoder_inputs: 607 | encoder_inputs_crf.append(list(item) + \ 608 | [WordSequence.PAD] * (self.max_decode_step - len(item))) 609 | encoder_inputs = np.array(encoder_inputs_crf) 610 | 611 | decoder_inputs_crf = [] 612 | for item in decoder_inputs: 613 | decoder_inputs_crf.append(list(item) + \ 614 | [WordSequence.PAD] * (self.max_decode_step - len(item))) 615 | decoder_inputs = np.array(decoder_inputs_crf) 616 | 617 | # 输入 618 | input_feed = self.check_feeds( 619 | encoder_inputs, encoder_inputs_length, 620 | decoder_inputs, decoder_inputs_length, 621 | False 622 | ) 623 | 624 | # 设置 dropout 625 | input_feed[self.keep_prob_placeholder.name] = self.keep_prob 626 | 627 | # 输出 628 | output_feed = [self.updates, self.loss] 629 | _, cost = sess.run(output_feed, input_feed) 630 | 631 | return cost 632 | 633 | 634 | def predict(self, sess, 635 | encoder_inputs, encoder_inputs_length): 636 | """预测输出""" 637 | 638 | # 输入 639 | # 如果是 crf 模式,就把输入补全到最大长度 self.max_decode_step 640 | # 相当于 max_time_step 641 | encoder_inputs_crf = [] 642 | for item in encoder_inputs: 643 | encoder_inputs_crf.append(list(item) + \ 644 | [WordSequence.PAD] * (self.max_decode_step - len(item))) 645 | encoder_inputs = np.array(encoder_inputs_crf) 646 | 647 | input_feed = self.check_feeds( 648 | encoder_inputs, encoder_inputs_length, 649 | np.zeros(encoder_inputs.shape), 650 | np.zeros(encoder_inputs_length.shape), 651 | True 652 | ) 653 | 654 | input_feed[self.keep_prob_placeholder.name] = 1.0 655 | 656 | # crf mode 657 | if self.crf_loss: 658 | pred = sess.run(self.viterbi_sequence, input_feed) 659 | preds = [] 660 | for i in range(pred.shape[0]): 661 | item = pred[i][:encoder_inputs_length[i]] 662 | preds.append(item) 663 | 664 | return np.array(preds) 665 | # else: 666 | pred = sess.run(self.outputs, input_feed) 667 | preds = [] 668 | for i in range(pred.shape[0]): 669 | item = pred[i][:encoder_inputs_length[i]] 670 | preds.append(item) 671 | 672 | return np.array(preds) 673 | -------------------------------------------------------------------------------- /sequence_to_sequence.py: -------------------------------------------------------------------------------- 1 | """ 2 | QHDuan 3 | 2018-02-05 4 | 5 | sequence to sequence Model 6 | 7 | 官方文档的意思好像是time_major=True的情况下会快一点 8 | https://www.tensorflow.org/tutorials/seq2seq 9 | 不过现在代码都在time_major=False上 10 | 11 | test on tensorflow == 1.4.1 12 | seq2seq: 13 | https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/seq2seq 14 | 15 | Code was borrow heavily from: 16 | https://github.com/JayParks/tf-seq2seq/blob/master/seq2seq_model.py 17 | Another wonderful example is: 18 | https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm 19 | Official sequence2sequence project: 20 | https://github.com/tensorflow/nmt 21 | Another official sequence2sequence model: 22 | https://github.com/tensorflow/tensor2tensor 23 | """ 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorflow import layers 28 | # from tensorflow.python.util import nest 29 | from tensorflow.python.ops import array_ops 30 | from tensorflow.contrib import seq2seq 31 | from tensorflow.contrib.seq2seq import BahdanauAttention 32 | from tensorflow.contrib.seq2seq import LuongAttention 33 | from tensorflow.contrib.seq2seq import AttentionWrapper 34 | from tensorflow.contrib.seq2seq import BeamSearchDecoder 35 | from tensorflow.contrib.rnn import LSTMCell 36 | from tensorflow.contrib.rnn import GRUCell 37 | from tensorflow.contrib.rnn import MultiRNNCell 38 | from tensorflow.contrib.rnn import DropoutWrapper 39 | from tensorflow.contrib.rnn import ResidualWrapper 40 | # from tensorflow.contrib.rnn import LSTMStateTuple 41 | 42 | from word_sequence import WordSequence 43 | from data_utils import _get_embed_device 44 | 45 | 46 | class SequenceToSequence(object): 47 | """SequenceToSequence Model 48 | 49 | 基本流程 50 | __init__ 基本参数保存,验证参数合法性 51 | build_model 开始构建整个模型 52 | init_placeholders 初始化一些tensorflow的变量占位符 53 | build_encoder 初始化编码器 54 | build_single_cell 55 | build_encoder_cell 56 | build_decoder 初始化解码器 57 | build_single_cell 58 | build_decoder_cell 59 | init_optimizer 如果是在训练模式则初始化优化器 60 | train 训练一个batch的数据 61 | predict 预测一个batch的数据 62 | """ 63 | 64 | def __init__(self, 65 | input_vocab_size, 66 | target_vocab_size, 67 | batch_size=32, 68 | embedding_size=300, 69 | mode='train', 70 | hidden_units=256, 71 | depth=1, 72 | beam_width=0, 73 | cell_type='lstm', 74 | dropout=0.2, 75 | use_dropout=False, 76 | use_residual=False, 77 | optimizer='adam', 78 | learning_rate=1e-3, 79 | min_learning_rate=1e-6, 80 | decay_steps=500000, 81 | max_gradient_norm=5.0, 82 | max_decode_step=None, 83 | attention_type='Bahdanau', 84 | bidirectional=False, 85 | time_major=False, 86 | seed=0, 87 | parallel_iterations=None, 88 | share_embedding=False, 89 | pretrained_embedding=False): 90 | """保存参数变量,开始构建整个模型 91 | Args: 92 | input_vocab_size: 输入词表大小 93 | target_vocab_size: 输出词表大小 94 | batch_size: 数据batch的大小 95 | embedding_size, 输入词表与输出词表embedding的维度 96 | mode: 取值为 train 或者 decode,训练模式或者预测模式 97 | hidden_units: 98 | RNN模型的中间层大小,encoder和decoder层相同 99 | 如果encoder层是bidirectional的话,decoder层是双倍大小 100 | depth: encoder和decoder的rnn层数 101 | beam_width: 102 | beam_width是beamsearch的超参,用于解码 103 | 如果大于0则使用beamsearch,小于等于0则不使用 104 | cell_type: rnn神经元类型,lstm 或者 gru 105 | dropout: dropout比例,取值 [0, 1) 106 | use_dropout: 是否使用dropout 107 | use_residual:# 是否使用residual 108 | optimizer: 优化方法, adam, adadelta, sgd, rmsprop, momentum 109 | learning_rate: 学习率 110 | max_gradient_norm: 梯度正则剪裁的系数 111 | max_decode_step: 112 | 最大的解码长度,可以是很大的整数,默认是None 113 | None的情况下默认是encoder输入最大长度的 4 倍 114 | attention_type: 'Bahdanau' or 'Luong' 不同的 attention 类型 115 | bidirectional: encoder 是否为双向 116 | time_major: 117 | 是否在“计算过程”中使用时间为主的批量数据 118 | 注意,改变这个参数并不要求改变输入数据的格式 119 | 输入数据的格式为 [batch_size, time_step] 是一个二维矩阵 120 | time_step是句子长度 121 | 经过 embedding 之后,数据会变为 122 | [batch_size, time_step, embedding_size] 123 | 这是一个三维矩阵(或者三维张量Tensor) 124 | 这样的数据格式是 time_major=False 的 125 | 如果设置 time_major=True 的话,在部分计算的时候,会把矩阵转置为 126 | [time_step, batch_size, embedding_size] 127 | 也就是 time_step 是第一维,所以叫 time_major 128 | TensorFlow官方文档认为time_major=True会比较快 129 | seed: 一些层间操作的随机数 seed 设置 130 | parallel_iterations: 131 | dynamic_rnn 和 dynamic_decode 的并行数量 132 | 如果要取得可重复结果,在有dropout的情况下,应该设置为 133 | share_embedding: 134 | 如果为True,那么encoder和decoder就会公用一个embedding 135 | """ 136 | 137 | self.input_vocab_size = input_vocab_size 138 | self.target_vocab_size = target_vocab_size 139 | self.batch_size = batch_size 140 | self.embedding_size = embedding_size 141 | self.hidden_units = hidden_units 142 | self.depth = depth 143 | self.cell_type = cell_type.lower() 144 | self.use_dropout = use_dropout 145 | self.use_residual = use_residual 146 | self.attention_type = attention_type 147 | self.mode = mode 148 | self.optimizer = optimizer 149 | self.learning_rate = learning_rate 150 | self.min_learning_rate = min_learning_rate 151 | self.decay_steps = decay_steps 152 | self.max_gradient_norm = max_gradient_norm 153 | self.keep_prob = 1.0 - dropout 154 | self.bidirectional = bidirectional 155 | self.seed = seed 156 | self.pretrained_embedding = pretrained_embedding 157 | if isinstance(parallel_iterations, int): 158 | self.parallel_iterations = parallel_iterations 159 | else: # if parallel_iterations is None: 160 | self.parallel_iterations = batch_size 161 | self.time_major = time_major 162 | self.share_embedding = share_embedding 163 | # Initialize encoder_embeddings to have variance=1. 164 | # sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. 165 | # self.initializer = tf.random_uniform_initializer( 166 | # -sqrt3, sqrt3, dtype=tf.float32 167 | # ) 168 | self.initializer = tf.random_uniform_initializer( 169 | -0.05, 0.05, dtype=tf.float32 170 | ) 171 | # self.initializer = None 172 | 173 | assert self.cell_type in ('gru', 'lstm'), \ 174 | 'cell_type 应该是 GRU 或者 LSTM' 175 | 176 | if share_embedding: 177 | assert input_vocab_size == target_vocab_size, \ 178 | '如果打开 share_embedding,两个vocab_size必须一样' 179 | 180 | assert mode in ('train', 'decode'), \ 181 | 'mode 必须是 "train" 或 "decode" 而不是 "{}"'.format(mode) 182 | 183 | assert dropout >= 0.0 and dropout < 1.0, '0 <= dropout < 1' 184 | 185 | assert attention_type.lower() in ('bahdanau', 'luong'), \ 186 | '''attention_type 必须是 "bahdanau" 或 "luong" 而不是 "{}" 187 | '''.format(attention_type) 188 | 189 | assert beam_width < target_vocab_size, \ 190 | 'beam_width {} 应该小于 target vocab size {}'.format( 191 | beam_width, target_vocab_size 192 | ) 193 | 194 | self.keep_prob_placeholder = tf.placeholder( 195 | tf.float32, 196 | shape=[], 197 | name='keep_prob' 198 | ) 199 | 200 | self.global_step = tf.Variable( 201 | 0, trainable=False, name='global_step' 202 | ) 203 | 204 | self.use_beamsearch_decode = False 205 | self.beam_width = beam_width 206 | self.use_beamsearch_decode = True if self.beam_width > 0 else False 207 | self.max_decode_step = max_decode_step 208 | 209 | assert self.optimizer.lower() in \ 210 | ('adadelta', 'adam', 'rmsprop', 'momentum', 'sgd'), \ 211 | 'optimizer 必须是下列之一: adadelta, adam, rmsprop, momentum, sgd' 212 | 213 | self.build_model() 214 | 215 | 216 | def build_model(self): 217 | """构建整个模型 218 | 分别构建 219 | 编码器(encoder) 220 | 解码器(decoder) 221 | 优化器(只在训练时构建,optimizer) 222 | """ 223 | self.init_placeholders() 224 | encoder_outputs, encoder_state = self.build_encoder() 225 | self.build_decoder(encoder_outputs, encoder_state) 226 | 227 | if self.mode == 'train': 228 | self.init_optimizer() 229 | 230 | self.saver = tf.train.Saver() 231 | 232 | 233 | def init_placeholders(self): 234 | """初始化训练、预测所需的变量 235 | """ 236 | 237 | self.add_loss = tf.placeholder( 238 | dtype=tf.float32, 239 | name='add_loss' 240 | ) 241 | 242 | # 编码器输入,shape=(batch_size, time_step) 243 | # 有 batch_size 句话,每句话是最大长度为 time_step 的 index 表示 244 | self.encoder_inputs = tf.placeholder( 245 | dtype=tf.int32, 246 | shape=(self.batch_size, None), 247 | name='encoder_inputs' 248 | ) 249 | 250 | # 编码器长度输入,shape=(batch_size, 1) 251 | # 指的是 batch_size 句话每句话的长度 252 | self.encoder_inputs_length = tf.placeholder( 253 | dtype=tf.int32, 254 | shape=(self.batch_size,), 255 | name='encoder_inputs_length' 256 | ) 257 | 258 | if self.mode == 'train': 259 | # 训练模式 260 | 261 | # 解码器输入,shape=(batch_size, time_step) 262 | # 注意,会默认里面已经在每句结尾包含 263 | self.decoder_inputs = tf.placeholder( 264 | dtype=tf.int32, 265 | shape=(self.batch_size, None), 266 | name='decoder_inputs' 267 | ) 268 | 269 | # 解码器输入的reward,用于强化学习训练,shape=(batch_size, time_step) 270 | self.rewards = tf.placeholder( 271 | dtype=tf.float32, 272 | shape=(self.batch_size, 1), 273 | name='rewards' 274 | ) 275 | 276 | # 解码器长度输入,shape=(batch_size,) 277 | self.decoder_inputs_length = tf.placeholder( 278 | dtype=tf.int32, 279 | shape=(self.batch_size,), 280 | name='decoder_inputs_length' 281 | ) 282 | 283 | self.decoder_start_token = tf.ones( 284 | shape=(self.batch_size, 1), 285 | dtype=tf.int32 286 | ) * WordSequence.START 287 | 288 | # 实际训练的解码器输入,实际上是 start_token + decoder_inputs 289 | self.decoder_inputs_train = tf.concat([ 290 | self.decoder_start_token, 291 | self.decoder_inputs 292 | ], axis=1) 293 | 294 | 295 | def build_single_cell(self, n_hidden, use_residual): 296 | """构建一个单独的rnn cell 297 | Args: 298 | n_hidden: 隐藏层神经元数量 299 | use_residual: 是否使用residual wrapper 300 | """ 301 | 302 | if self.cell_type == 'gru': 303 | cell_type = GRUCell 304 | else: 305 | cell_type = LSTMCell 306 | 307 | cell = cell_type(n_hidden) 308 | 309 | if self.use_dropout: 310 | cell = DropoutWrapper( 311 | cell, 312 | dtype=tf.float32, 313 | output_keep_prob=self.keep_prob_placeholder, 314 | seed=self.seed 315 | ) 316 | 317 | if use_residual: 318 | cell = ResidualWrapper(cell) 319 | 320 | return cell 321 | 322 | def build_encoder_cell(self): 323 | """构建一个单独的编码器cell 324 | """ 325 | return MultiRNNCell([ 326 | self.build_single_cell( 327 | self.hidden_units, 328 | use_residual=self.use_residual 329 | ) 330 | for _ in range(self.depth) 331 | ]) 332 | 333 | 334 | def feed_embedding(self, sess, encoder=None, decoder=None): 335 | """加载预训练好的embedding 336 | """ 337 | assert self.pretrained_embedding, \ 338 | '必须开启pretrained_embedding才能使用feed_embedding' 339 | assert encoder is not None or decoder is not None, \ 340 | 'encoder 和 decoder 至少得输入一个吧大佬!' 341 | 342 | if encoder is not None: 343 | sess.run(self.encoder_embeddings_init, 344 | {self.encoder_embeddings_placeholder: encoder}) 345 | 346 | if decoder is not None: 347 | sess.run(self.decoder_embeddings_init, 348 | {self.decoder_embeddings_placeholder: decoder}) 349 | 350 | 351 | def build_encoder(self): 352 | """构建编码器 353 | """ 354 | # print("构建编码器") 355 | with tf.variable_scope('encoder'): 356 | # 构建 encoder_cell 357 | encoder_cell = self.build_encoder_cell() 358 | 359 | # 编码器的embedding 360 | with tf.device(_get_embed_device(self.input_vocab_size)): 361 | 362 | # 加载训练好的embedding 363 | if self.pretrained_embedding: 364 | 365 | self.encoder_embeddings = tf.Variable( 366 | tf.constant( 367 | 0.0, 368 | shape=(self.input_vocab_size, self.embedding_size) 369 | ), 370 | trainable=True, 371 | name='embeddings' 372 | ) 373 | self.encoder_embeddings_placeholder = tf.placeholder( 374 | tf.float32, 375 | (self.input_vocab_size, self.embedding_size) 376 | ) 377 | self.encoder_embeddings_init = \ 378 | self.encoder_embeddings.assign( 379 | self.encoder_embeddings_placeholder) 380 | 381 | else: 382 | self.encoder_embeddings = tf.get_variable( 383 | name='embedding', 384 | shape=(self.input_vocab_size, self.embedding_size), 385 | initializer=self.initializer, 386 | dtype=tf.float32 387 | ) 388 | 389 | # embedded之后的输入 shape = (batch_size, time_step, embedding_size) 390 | self.encoder_inputs_embedded = tf.nn.embedding_lookup( 391 | params=self.encoder_embeddings, 392 | ids=self.encoder_inputs 393 | ) 394 | 395 | if self.use_residual: 396 | self.encoder_inputs_embedded = \ 397 | layers.dense(self.encoder_inputs_embedded, 398 | self.hidden_units, 399 | use_bias=False, 400 | name='encoder_residual_projection') 401 | 402 | # Encode input sequences into context vectors: 403 | # encoder_outputs: [batch_size, max_time_step, cell_output_size] 404 | # encoder_state: [batch_size, cell_output_size] 405 | 406 | inputs = self.encoder_inputs_embedded 407 | if self.time_major: 408 | inputs = tf.transpose(inputs, (1, 0, 2)) 409 | 410 | if not self.bidirectional: 411 | # 单向 RNN 412 | ( 413 | encoder_outputs, 414 | encoder_state 415 | ) = tf.nn.dynamic_rnn( 416 | cell=encoder_cell, 417 | inputs=inputs, 418 | sequence_length=self.encoder_inputs_length, 419 | dtype=tf.float32, 420 | time_major=self.time_major, 421 | parallel_iterations=self.parallel_iterations, 422 | swap_memory=True 423 | ) 424 | else: 425 | # 双向 RNN 比较麻烦 426 | encoder_cell_bw = self.build_encoder_cell() 427 | ( 428 | (encoder_fw_outputs, encoder_bw_outputs), 429 | (encoder_fw_state, encoder_bw_state) 430 | ) = tf.nn.bidirectional_dynamic_rnn( 431 | cell_fw=encoder_cell, 432 | cell_bw=encoder_cell_bw, 433 | inputs=inputs, 434 | sequence_length=self.encoder_inputs_length, 435 | dtype=tf.float32, 436 | time_major=self.time_major, 437 | parallel_iterations=self.parallel_iterations, 438 | swap_memory=True 439 | ) 440 | 441 | # 首先合并两个方向 RNN 的输出 442 | encoder_outputs = tf.concat( 443 | (encoder_fw_outputs, encoder_bw_outputs), 2) 444 | 445 | encoder_state = [] 446 | for i in range(self.depth): 447 | encoder_state.append(encoder_fw_state[i]) 448 | encoder_state.append(encoder_bw_state[i]) 449 | encoder_state = tuple(encoder_state) 450 | 451 | return encoder_outputs, encoder_state 452 | 453 | 454 | def build_decoder_cell(self, encoder_outputs, encoder_state): 455 | """构建解码器cell""" 456 | 457 | encoder_inputs_length = self.encoder_inputs_length 458 | batch_size = self.batch_size 459 | 460 | if self.bidirectional: 461 | encoder_state = encoder_state[-self.depth:] 462 | 463 | if self.time_major: 464 | encoder_outputs = tf.transpose(encoder_outputs, (1, 0, 2)) 465 | 466 | # 使用 BeamSearchDecoder 的时候,必须根据 beam_width 来成倍的扩大一些变量 467 | # encoder_outputs, encoder_state, encoder_inputs_length 468 | # needs to be tiled so that: 469 | # [batch_size, .., ..] -> [batch_size x beam_width, .., ..] 470 | if self.use_beamsearch_decode: 471 | encoder_outputs = seq2seq.tile_batch( 472 | encoder_outputs, multiplier=self.beam_width) 473 | encoder_state = seq2seq.tile_batch( 474 | encoder_state, multiplier=self.beam_width) 475 | encoder_inputs_length = seq2seq.tile_batch( 476 | self.encoder_inputs_length, multiplier=self.beam_width) 477 | # 如果使用了 beamsearch 那么输入应该是 beam_width 倍于 batch_size 的 478 | batch_size *= self.beam_width 479 | 480 | # 下面是两种不同的 Attention 机制 481 | if self.attention_type.lower() == 'luong': 482 | # 'Luong' style attention: https://arxiv.org/abs/1508.04025 483 | self.attention_mechanism = LuongAttention( 484 | num_units=self.hidden_units, 485 | memory=encoder_outputs, 486 | memory_sequence_length=encoder_inputs_length 487 | ) 488 | else: # Default Bahdanau 489 | # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 490 | self.attention_mechanism = BahdanauAttention( 491 | num_units=self.hidden_units, 492 | memory=encoder_outputs, 493 | memory_sequence_length=encoder_inputs_length 494 | ) 495 | 496 | # Building decoder_cell 497 | cell = MultiRNNCell([ 498 | self.build_single_cell( 499 | self.hidden_units, 500 | use_residual=self.use_residual 501 | ) 502 | for _ in range(self.depth) 503 | ]) 504 | 505 | # 在非训练(预测)模式,并且没开启 beamsearch 的时候,打开 attention 历史信息 506 | alignment_history = ( 507 | self.mode != 'train' and not self.use_beamsearch_decode 508 | ) 509 | 510 | def cell_input_fn(inputs, attention): 511 | """根据attn_input_feeding属性来判断是否在attention计算前进行一次投影计算 512 | """ 513 | if not self.use_residual: 514 | return array_ops.concat([inputs, attention], -1) 515 | 516 | attn_projection = layers.Dense(self.hidden_units, 517 | dtype=tf.float32, 518 | use_bias=False, 519 | name='attention_cell_input_fn') 520 | return attn_projection(array_ops.concat([inputs, attention], -1)) 521 | 522 | cell = AttentionWrapper( 523 | cell=cell, 524 | attention_mechanism=self.attention_mechanism, 525 | attention_layer_size=self.hidden_units, 526 | alignment_history=alignment_history, 527 | cell_input_fn=cell_input_fn, 528 | name='Attention_Wrapper') 529 | 530 | # 空状态 531 | decoder_initial_state = cell.zero_state( 532 | batch_size, tf.float32) 533 | 534 | # 传递encoder状态 535 | decoder_initial_state = decoder_initial_state.clone( 536 | cell_state=encoder_state) 537 | 538 | # if self.use_beamsearch_decode: 539 | # decoder_initial_state = seq2seq.tile_batch( 540 | # decoder_initial_state, multiplier=self.beam_width) 541 | 542 | return cell, decoder_initial_state 543 | 544 | 545 | def build_decoder(self, encoder_outputs, encoder_state): 546 | """构建解码器 547 | """ 548 | with tf.variable_scope('decoder') as decoder_scope: 549 | # Building decoder_cell and decoder_initial_state 550 | ( 551 | self.decoder_cell, 552 | self.decoder_initial_state 553 | ) = self.build_decoder_cell(encoder_outputs, encoder_state) 554 | 555 | # 解码器embedding 556 | with tf.device(_get_embed_device(self.target_vocab_size)): 557 | if self.share_embedding: 558 | self.decoder_embeddings = self.encoder_embeddings 559 | elif self.pretrained_embedding: 560 | 561 | self.decoder_embeddings = tf.Variable( 562 | tf.constant( 563 | 0.0, 564 | shape=(self.target_vocab_size, 565 | self.embedding_size) 566 | ), 567 | trainable=True, 568 | name='embeddings' 569 | ) 570 | self.decoder_embeddings_placeholder = tf.placeholder( 571 | tf.float32, 572 | (self.target_vocab_size, self.embedding_size) 573 | ) 574 | self.decoder_embeddings_init = \ 575 | self.decoder_embeddings.assign( 576 | self.decoder_embeddings_placeholder) 577 | else: 578 | self.decoder_embeddings = tf.get_variable( 579 | name='embeddings', 580 | shape=(self.target_vocab_size, self.embedding_size), 581 | initializer=self.initializer, 582 | dtype=tf.float32 583 | ) 584 | 585 | self.decoder_output_projection = layers.Dense( 586 | self.target_vocab_size, 587 | dtype=tf.float32, 588 | use_bias=False, 589 | name='decoder_output_projection' 590 | ) 591 | 592 | if self.mode == 'train': 593 | # decoder_inputs_embedded: 594 | # [batch_size, max_time_step + 1, embedding_size] 595 | self.decoder_inputs_embedded = tf.nn.embedding_lookup( 596 | params=self.decoder_embeddings, 597 | ids=self.decoder_inputs_train 598 | ) 599 | 600 | # Helper to feed inputs for training: 601 | # read inputs from dense ground truth vectors 602 | inputs = self.decoder_inputs_embedded 603 | 604 | if self.time_major: 605 | inputs = tf.transpose(inputs, (1, 0, 2)) 606 | 607 | training_helper = seq2seq.TrainingHelper( 608 | inputs=inputs, 609 | sequence_length=self.decoder_inputs_length, 610 | time_major=self.time_major, 611 | name='training_helper' 612 | ) 613 | 614 | # 训练的时候不在这里应用 output_layer 615 | # 因为这里会每个 time_step 的进行 output_layer 的投影计算,比较慢 616 | # 注意这个trick要成功必须设置 dynamic_decode 的 scope 参数 617 | training_decoder = seq2seq.BasicDecoder( 618 | cell=self.decoder_cell, 619 | helper=training_helper, 620 | initial_state=self.decoder_initial_state, 621 | # output_layer=self.decoder_output_projection 622 | ) 623 | 624 | # Maximum decoder time_steps in current batch 625 | max_decoder_length = tf.reduce_max( 626 | self.decoder_inputs_length 627 | ) 628 | 629 | # decoder_outputs_train: BasicDecoderOutput 630 | # namedtuple(rnn_outputs, sample_id) 631 | # decoder_outputs_train.rnn_output: 632 | # if output_time_major=False: 633 | # [batch_size, max_time_step + 1, num_decoder_symbols] 634 | # if output_time_major=True: 635 | # [max_time_step + 1, batch_size, num_decoder_symbols] 636 | # decoder_outputs_train.sample_id: [batch_size], tf.int32 637 | 638 | ( 639 | outputs, 640 | self.final_state, # contain attention 641 | _ # self.final_sequence_lengths 642 | ) = seq2seq.dynamic_decode( 643 | decoder=training_decoder, 644 | output_time_major=self.time_major, 645 | impute_finished=True, 646 | maximum_iterations=max_decoder_length, 647 | parallel_iterations=self.parallel_iterations, 648 | swap_memory=True, 649 | scope=decoder_scope 650 | ) 651 | 652 | # More efficient to do the projection 653 | # on the batch-time-concatenated tensor 654 | # logits_train: 655 | # [batch_size, max_time_step + 1, num_decoder_symbols] 656 | # 训练的时候一次性对所有的结果进行 output_layer 的投影运算 657 | # 官方NMT库说这样能提高10~20%的速度 658 | # 实际上我提高的速度会更大 659 | self.decoder_logits_train = self.decoder_output_projection( 660 | outputs.rnn_output 661 | ) 662 | 663 | # masks: masking for valid and padded time steps, 664 | # [batch_size, max_time_step + 1] 665 | self.masks = tf.sequence_mask( 666 | lengths=self.decoder_inputs_length, 667 | maxlen=max_decoder_length, 668 | dtype=tf.float32, name='masks' 669 | ) 670 | 671 | # Computes per word average cross-entropy over a batch 672 | # Internally calls 673 | # 'nn_ops.sparse_softmax_cross_entropy_with_logits' by default 674 | 675 | decoder_logits_train = self.decoder_logits_train 676 | if self.time_major: 677 | decoder_logits_train = tf.transpose(decoder_logits_train, 678 | (1, 0, 2)) 679 | 680 | self.decoder_pred_train = tf.argmax( 681 | decoder_logits_train, axis=-1, 682 | name='decoder_pred_train') 683 | 684 | # 下面的一些变量用于特殊的学习训练 685 | # 自定义rewards,其实我这里是修改了masks 686 | # train_entropy = cross entropy 687 | self.train_entropy = \ 688 | tf.nn.sparse_softmax_cross_entropy_with_logits( 689 | labels=self.decoder_inputs, 690 | logits=decoder_logits_train) 691 | 692 | self.masks_rewards = self.masks * self.rewards 693 | 694 | self.loss_rewards = seq2seq.sequence_loss( 695 | logits=decoder_logits_train, 696 | targets=self.decoder_inputs, 697 | weights=self.masks_rewards, 698 | average_across_timesteps=True, 699 | average_across_batch=True, 700 | ) 701 | 702 | self.loss = seq2seq.sequence_loss( 703 | logits=decoder_logits_train, 704 | targets=self.decoder_inputs, 705 | weights=self.masks, 706 | average_across_timesteps=True, 707 | average_across_batch=True, 708 | ) 709 | 710 | self.loss_add = self.loss + self.add_loss 711 | 712 | elif self.mode == 'decode': 713 | # 预测模式,非训练 714 | 715 | start_tokens = tf.tile( 716 | [WordSequence.START], 717 | [self.batch_size] 718 | ) 719 | end_token = WordSequence.END 720 | 721 | def embed_and_input_proj(inputs): 722 | """输入层的投影层wrapper 723 | """ 724 | return tf.nn.embedding_lookup( 725 | self.decoder_embeddings, 726 | inputs 727 | ) 728 | 729 | if not self.use_beamsearch_decode: 730 | # Helper to feed inputs for greedy decoding: 731 | # uses the argmax of the output 732 | decoding_helper = seq2seq.GreedyEmbeddingHelper( 733 | start_tokens=start_tokens, 734 | end_token=end_token, 735 | embedding=embed_and_input_proj 736 | ) 737 | # Basic decoder performs greedy decoding at each time step 738 | # print("building greedy decoder..") 739 | inference_decoder = seq2seq.BasicDecoder( 740 | cell=self.decoder_cell, 741 | helper=decoding_helper, 742 | initial_state=self.decoder_initial_state, 743 | output_layer=self.decoder_output_projection 744 | ) 745 | else: 746 | # Beamsearch is used to approximately 747 | # find the most likely translation 748 | # print("building beamsearch decoder..") 749 | inference_decoder = BeamSearchDecoder( 750 | cell=self.decoder_cell, 751 | embedding=embed_and_input_proj, 752 | start_tokens=start_tokens, 753 | end_token=end_token, 754 | initial_state=self.decoder_initial_state, 755 | beam_width=self.beam_width, 756 | output_layer=self.decoder_output_projection, 757 | ) 758 | 759 | # For GreedyDecoder, return 760 | # decoder_outputs_decode: BasicDecoderOutput instance 761 | # namedtuple(rnn_outputs, sample_id) 762 | # decoder_outputs_decode.rnn_output: 763 | # if output_time_major=False: 764 | # [batch_size, max_time_step, num_decoder_symbols] 765 | # if output_time_major=True 766 | # [max_time_step, batch_size, num_decoder_symbols] 767 | # decoder_outputs_decode.sample_id: 768 | # if output_time_major=False 769 | # [batch_size, max_time_step], tf.int32 770 | # if output_time_major=True 771 | # [max_time_step, batch_size], tf.int32 772 | 773 | # For BeamSearchDecoder, return 774 | # decoder_outputs_decode: FinalBeamSearchDecoderOutput instance 775 | # namedtuple(predicted_ids, beam_search_decoder_output) 776 | # decoder_outputs_decode.predicted_ids: 777 | # if output_time_major=False: 778 | # [batch_size, max_time_step, beam_width] 779 | # if output_time_major=True 780 | # [max_time_step, batch_size, beam_width] 781 | # decoder_outputs_decode.beam_search_decoder_output: 782 | # BeamSearchDecoderOutput instance 783 | # namedtuple(scores, predicted_ids, parent_ids) 784 | 785 | # 官方文档提到的一个潜在的最大长度选择 786 | # 我这里改为 * 4 787 | # maximum_iterations = tf.round(tf.reduce_max(source_sequence_length) * 2) 788 | # https://www.tensorflow.org/tutorials/seq2seq 789 | 790 | if self.max_decode_step is not None: 791 | max_decode_step = self.max_decode_step 792 | else: 793 | # 默认 4 倍输入长度的输出解码 794 | max_decode_step = tf.round(tf.reduce_max( 795 | self.encoder_inputs_length) * 4) 796 | 797 | ( 798 | self.decoder_outputs_decode, 799 | self.final_state, 800 | _ # self.decoder_outputs_length_decode 801 | ) = (seq2seq.dynamic_decode( 802 | decoder=inference_decoder, 803 | output_time_major=self.time_major, 804 | # impute_finished=True, # error occurs 805 | maximum_iterations=max_decode_step, 806 | parallel_iterations=self.parallel_iterations, 807 | swap_memory=True, 808 | scope=decoder_scope 809 | )) 810 | 811 | if not self.use_beamsearch_decode: 812 | # decoder_outputs_decode.sample_id: 813 | # [batch_size, max_time_step] 814 | # Or use argmax to find decoder symbols to emit: 815 | # self.decoder_pred_decode = tf.argmax( 816 | # self.decoder_outputs_decode.rnn_output, 817 | # axis=-1, name='decoder_pred_decode') 818 | 819 | # Here, we use expand_dims to be compatible with 820 | # the result of the beamsearch decoder 821 | # decoder_pred_decode: 822 | # [batch_size, max_time_step, 1] (output_major=False) 823 | 824 | # self.decoder_pred_decode = tf.expand_dims( 825 | # self.decoder_outputs_decode.sample_id, 826 | # -1 827 | # ) 828 | 829 | dod = self.decoder_outputs_decode 830 | self.decoder_pred_decode = dod.sample_id 831 | 832 | if self.time_major: 833 | self.decoder_pred_decode = tf.transpose( 834 | self.decoder_pred_decode, (1, 0)) 835 | 836 | else: 837 | # Use beam search to approximately 838 | # find the most likely translation 839 | # decoder_pred_decode: 840 | # [batch_size, max_time_step, beam_width] (output_major=False) 841 | self.decoder_pred_decode = \ 842 | self.decoder_outputs_decode.predicted_ids 843 | 844 | if self.time_major: 845 | self.decoder_pred_decode = tf.transpose( 846 | self.decoder_pred_decode, (1, 0, 2)) 847 | 848 | self.decoder_pred_decode = tf.transpose( 849 | self.decoder_pred_decode, 850 | perm=[0, 2, 1]) 851 | dod = self.decoder_outputs_decode 852 | self.beam_prob = dod.beam_search_decoder_output.scores 853 | 854 | 855 | def save(self, sess, save_path='model.ckpt'): 856 | """保存模型""" 857 | self.saver.save(sess, save_path=save_path) 858 | 859 | 860 | def load(self, sess, save_path='model.ckpt'): 861 | """读取模型""" 862 | print('try load model from', save_path) 863 | self.saver.restore(sess, save_path) 864 | 865 | 866 | def init_optimizer(self): 867 | """初始化优化器 868 | 支持的方法有 sgd, adadelta, adam, rmsprop, momentum 869 | """ 870 | 871 | # 学习率下降算法 872 | learning_rate = tf.train.polynomial_decay( 873 | self.learning_rate, 874 | self.global_step, 875 | self.decay_steps, 876 | self.min_learning_rate, 877 | power=0.5 878 | ) 879 | self.current_learning_rate = learning_rate 880 | 881 | # 设置优化器,合法的优化器如下 882 | # 'adadelta', 'adam', 'rmsprop', 'momentum', 'sgd' 883 | trainable_params = tf.trainable_variables() 884 | if self.optimizer.lower() == 'adadelta': 885 | self.opt = tf.train.AdadeltaOptimizer( 886 | learning_rate=learning_rate) 887 | elif self.optimizer.lower() == 'adam': 888 | self.opt = tf.train.AdamOptimizer( 889 | learning_rate=learning_rate) 890 | elif self.optimizer.lower() == 'rmsprop': 891 | self.opt = tf.train.RMSPropOptimizer( 892 | learning_rate=learning_rate) 893 | elif self.optimizer.lower() == 'momentum': 894 | self.opt = tf.train.MomentumOptimizer( 895 | learning_rate=learning_rate, momentum=0.9) 896 | elif self.optimizer.lower() == 'sgd': 897 | self.opt = tf.train.GradientDescentOptimizer( 898 | learning_rate=learning_rate) 899 | 900 | # Compute gradients of loss w.r.t. all trainable variables 901 | gradients = tf.gradients(self.loss, trainable_params) 902 | # Clip gradients by a given maximum_gradient_norm 903 | clip_gradients, _ = tf.clip_by_global_norm( 904 | gradients, self.max_gradient_norm) 905 | # Update the model 906 | self.updates = self.opt.apply_gradients( 907 | zip(clip_gradients, trainable_params), 908 | global_step=self.global_step) 909 | 910 | # 使用包括rewards的loss进行更新 911 | # 是特殊学习的一部分 912 | gradients = tf.gradients(self.loss_rewards, trainable_params) 913 | clip_gradients, _ = tf.clip_by_global_norm( 914 | gradients, self.max_gradient_norm) 915 | self.updates_rewards = self.opt.apply_gradients( 916 | zip(clip_gradients, trainable_params), 917 | global_step=self.global_step) 918 | 919 | # 添加 self.loss_add 的 update 920 | gradients = tf.gradients(self.loss_add, trainable_params) 921 | clip_gradients, _ = tf.clip_by_global_norm( 922 | gradients, self.max_gradient_norm) 923 | self.updates_add = self.opt.apply_gradients( 924 | zip(clip_gradients, trainable_params), 925 | global_step=self.global_step) 926 | 927 | 928 | def check_feeds(self, encoder_inputs, encoder_inputs_length, 929 | decoder_inputs, decoder_inputs_length, decode): 930 | """检查输入变量,并返回input_feed 931 | 932 | 我们首先会把数据编码,例如把“你好吗”,编码为[0, 1, 2] 933 | 多个句子组成一个batch,共同训练,例如一个batch_size=2,那么训练矩阵就可能是 934 | encoder_inputs = [ 935 | [0, 1, 2, 3], 936 | [4, 5, 6, 7] 937 | ] 938 | 它所代表的可能是:[['我', '是', '帅', '哥'], ['你', '好', '啊', '']] 939 | 注意第一句的真实长度是 4,第二句只有 3(最后的是一个填充数据) 940 | 941 | 那么: 942 | encoder_inputs_length = [4, 3] 943 | 来代表输入整个batch的真实长度 944 | 注意,为了符合算法要求,每个batch的句子必须是长度降序的,也就是说你输入一个 945 | encoder_inputs_length = [1, 10] 这样是错误的,必须在输入前排序到 946 | encoder_inputs_length = [10, 1] 这样才行 947 | 948 | decoder_inputs 和 decoder_inputs_length 所代表的含义差不多 949 | 950 | Args: 951 | encoder_inputs: 952 | 一个整形二维矩阵 [batch_size, max_source_time_steps] 953 | encoder_inputs_length: 954 | 一个整形向量 [batch_size] 955 | 每个维度是encoder句子的真实长度 956 | decoder_inputs: 957 | 一个整形矩阵 [batch_size, max_target_time_steps] 958 | decoder_inputs_length: 959 | 一个整形向量 [batch_size] 960 | 每个维度是decoder句子的真实长度 961 | decode: 用来指示正在训练模式(decode=False)还是预测模式(decode=True) 962 | Returns: 963 | tensorflow所操作需要的input_feed,包括 964 | encoder_inputs, encoder_inputs_length, 965 | decoder_inputs, decoder_inputs_length 966 | """ 967 | 968 | input_batch_size = encoder_inputs.shape[0] 969 | if input_batch_size != encoder_inputs_length.shape[0]: 970 | raise ValueError( 971 | "encoder_inputs和encoder_inputs_length的第一维度必须一致 " 972 | "这一维度是batch_size, %d != %d" % ( 973 | input_batch_size, encoder_inputs_length.shape[0])) 974 | 975 | if not decode: 976 | target_batch_size = decoder_inputs.shape[0] 977 | if target_batch_size != input_batch_size: 978 | raise ValueError( 979 | "encoder_inputs和decoder_inputs的第一维度必须一致 " 980 | "这一维度是batch_size, %d != %d" % ( 981 | input_batch_size, target_batch_size)) 982 | if target_batch_size != decoder_inputs_length.shape[0]: 983 | raise ValueError( 984 | "edeoder_inputs和decoder_inputs_length的第一维度必须一致 " 985 | "这一维度是batch_size, %d != %d" % ( 986 | target_batch_size, decoder_inputs_length.shape[0])) 987 | 988 | input_feed = {} 989 | 990 | input_feed[self.encoder_inputs.name] = encoder_inputs 991 | input_feed[self.encoder_inputs_length.name] = encoder_inputs_length 992 | 993 | if not decode: 994 | input_feed[self.decoder_inputs.name] = decoder_inputs 995 | input_feed[self.decoder_inputs_length.name] = decoder_inputs_length 996 | 997 | return input_feed 998 | 999 | 1000 | def train(self, sess, encoder_inputs, encoder_inputs_length, 1001 | decoder_inputs, decoder_inputs_length, 1002 | rewards=None, return_lr=False, 1003 | loss_only=False, add_loss=None): 1004 | """训练模型""" 1005 | 1006 | # 输入 1007 | input_feed = self.check_feeds( 1008 | encoder_inputs, encoder_inputs_length, 1009 | decoder_inputs, decoder_inputs_length, 1010 | False 1011 | ) 1012 | 1013 | # 设置 dropout 1014 | input_feed[self.keep_prob_placeholder.name] = self.keep_prob 1015 | 1016 | if loss_only: 1017 | # 输出 1018 | return sess.run(self.loss, input_feed) 1019 | 1020 | if add_loss is not None: 1021 | input_feed[self.add_loss.name] = add_loss 1022 | output_feed = [ 1023 | self.updates_add, self.loss_add, 1024 | self.current_learning_rate] 1025 | _, cost, lr = sess.run(output_feed, input_feed) 1026 | 1027 | if return_lr: 1028 | return cost, lr 1029 | 1030 | return cost 1031 | 1032 | if rewards is not None: 1033 | input_feed[self.rewards.name] = rewards 1034 | output_feed = [ 1035 | self.updates_rewards, self.loss_rewards, 1036 | self.current_learning_rate] 1037 | _, cost, lr = sess.run(output_feed, input_feed) 1038 | 1039 | if return_lr: 1040 | return cost, lr 1041 | return cost 1042 | 1043 | output_feed = [ 1044 | self.updates, self.loss, 1045 | self.current_learning_rate] 1046 | _, cost, lr = sess.run(output_feed, input_feed) 1047 | 1048 | if return_lr: 1049 | return cost, lr 1050 | 1051 | return cost 1052 | 1053 | 1054 | def get_encoder_embedding(self, sess, encoder_inputs): 1055 | """获取经过embedding的encoder_inputs""" 1056 | input_feed = { 1057 | self.encoder_inputs.name: encoder_inputs 1058 | } 1059 | emb = sess.run(self.encoder_inputs_embedded, input_feed) 1060 | return emb 1061 | 1062 | 1063 | def entropy(self, sess, encoder_inputs, encoder_inputs_length, 1064 | decoder_inputs, decoder_inputs_length): 1065 | """获取针对一组输入输出的entropy 1066 | 相当于在计算P(target|source) 1067 | """ 1068 | input_feed = self.check_feeds( 1069 | encoder_inputs, encoder_inputs_length, 1070 | decoder_inputs, decoder_inputs_length, 1071 | False 1072 | ) 1073 | input_feed[self.keep_prob_placeholder.name] = 1.0 1074 | output_feed = [self.train_entropy, self.decoder_pred_train] 1075 | entropy, logits = sess.run(output_feed, input_feed) 1076 | return entropy, logits 1077 | 1078 | 1079 | def predict(self, sess, 1080 | encoder_inputs, 1081 | encoder_inputs_length, 1082 | attention=False): 1083 | """预测输出""" 1084 | 1085 | # 输入 1086 | input_feed = self.check_feeds(encoder_inputs, 1087 | encoder_inputs_length, None, None, True) 1088 | 1089 | input_feed[self.keep_prob_placeholder.name] = 1.0 1090 | 1091 | # Attention 输出 1092 | if attention: 1093 | 1094 | assert not self.use_beamsearch_decode, \ 1095 | 'Attention 模式不能打开 BeamSearch' 1096 | 1097 | pred, atten = sess.run([ 1098 | self.decoder_pred_decode, 1099 | self.final_state.alignment_history.stack() 1100 | ], input_feed) 1101 | 1102 | return pred, atten 1103 | 1104 | # BeamSearch 模式输出 1105 | if self.use_beamsearch_decode: 1106 | pred, beam_prob = sess.run([ 1107 | self.decoder_pred_decode, self.beam_prob 1108 | ], input_feed) 1109 | beam_prob = np.mean(beam_prob, axis=1) 1110 | 1111 | pred = pred[0] 1112 | return pred 1113 | 1114 | # ret = [] 1115 | # for i in range(encoder_inputs.shape[0]): 1116 | # ret.append(pred[i * self.beam_width]) 1117 | # return np.array(ret) 1118 | # 1119 | 1120 | # 普通(Greedy)模式输出 1121 | pred, = sess.run([ 1122 | self.decoder_pred_decode 1123 | ], input_feed) 1124 | 1125 | return pred 1126 | --------------------------------------------------------------------------------