├── __init__.py ├── lib ├── __init__.py └── chatbot_model.py ├── old ├── data │ └── .gitignore ├── generated │ └── .gitignore ├── deploy.py ├── config.py ├── README.md ├── find_long_reply.py ├── tweet_listener.py ├── tweet_replyer.py ├── predict.py ├── TODO.md ├── lib │ ├── data_utils.py │ ├── seq2seq_model.py │ └── my_seq2seq.py ├── train.py └── data_processer.py ├── seq2seq.py ├── tweet_bot.py ├── README.md └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /old/data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /old/generated/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /seq2seq.py: -------------------------------------------------------------------------------- 1 | # mkdir -p ~/seq2seq_run; cd ~/seq2seq_run; python ~/Google\ 2 | # Drive/tensorflow_seq2seq_chatbot/seq2seq.py 3 | 4 | import lib.chatbot_model as sq 5 | 6 | sq.test_train_rl() 7 | -------------------------------------------------------------------------------- /tweet_bot.py: -------------------------------------------------------------------------------- 1 | # mkdir -p ~/seq2seq_run; cd ~/seq2seq_run; python ~/Google\ 2 | # Drive/tensorflow_seq2seq_chatbot/tweet_bot.py 3 | 4 | import lib.chatbot_model as sq 5 | 6 | sq.listener(sq.conversations_large_rl_hparams) 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | __Please check [old/](https://github.com/higepon/tensorflow_seq2seq_chatbot/tree/master/old) directory if you're looking for seq2seq chatbot + tweet bot.__ 2 | 3 | # What is this? 4 | This is in progress and experimental seq2seq + reinforcement learning chatbot. 5 | The bot is not functional yet. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Taro Minowa higepon 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /old/deploy.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import argparse 3 | import config 4 | import glob 5 | import subprocess 6 | 7 | # Script to deploy latest model to server 8 | 9 | 10 | def create_checkpoint_file(path): 11 | # rewrite path 12 | path = path.replace('Users', 'home') 13 | f = open('{}/checkpoint'.format(config.GENERATED_DIR), 'w') 14 | f.write('model_checkpoint_path: "{}"\n'.format(path)) 15 | f.write('all_model_checkpoint_paths: "{}"\n'.format(path)) 16 | f.close() 17 | 18 | 19 | def deploy(host): 20 | files_to_deploy = get_files_to_deploy() 21 | subprocess.call("scp {} {}:chatbot_generated/".format(' '.join(files_to_deploy), host), shell=True) 22 | # double check 23 | if files_to_deploy != get_files_to_deploy(): 24 | raise "inconsistent state" 25 | 26 | 27 | def get_files_to_deploy(): 28 | source_dir = config.GENERATED_DIR 29 | all_files = sorted(glob.glob('{}/*ckpt*'.format(source_dir)), reverse=True) 30 | # first is latest 31 | target_checkpoint = all_files[0] 32 | path, _ = os.path.splitext(target_checkpoint) 33 | create_checkpoint_file(path) 34 | files_to_deploy = glob.glob('{}.*'.format(path)) 35 | files_to_deploy.append('{}/checkpoint'.format(config.GENERATED_DIR)) 36 | return files_to_deploy 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('-host', type=str, required=True, help='deploy target host') 42 | args = parser.parse_args() 43 | host = args.host 44 | deploy(host) 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /old/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sys import platform 3 | 4 | if platform == 'linux': 5 | GENERATED_DIR = os.getenv("HOME") + "/chatbot_generated" 6 | LOGS_DIR = os.getenv("HOME") + "/chatbot_train_logs" 7 | else: 8 | GENERATED_DIR = os.getenv("HOME") + "/Dropbox/tensorflow_seq2seq_chatbot/chatbot_generated" 9 | LOGS_DIR = os.getenv("HOME") + "/chatbot_train_logs" 10 | 11 | is_fast_build = False 12 | beam_search = True 13 | beam_size = 20 14 | 15 | DATA_DIR = "data" 16 | if is_fast_build: 17 | TWEETS_TXT = "{0}/tweets_short.txt".format(DATA_DIR) 18 | else: 19 | # TWEETS_TXT = "{0}/tweets1M.txt".format(DATA_DIR) 20 | TWEETS_TXT = "{0}/hoge.txt".format(DATA_DIR) 21 | 22 | if is_fast_build: 23 | MAX_ENC_VOCABULARY = 5 24 | NUM_LAYERS = 2 25 | LAYER_SIZE = 2 26 | BATCH_SIZE = 2 27 | buckets = [(5, 10), (8, 13)] 28 | else: 29 | MAX_ENC_VOCABULARY = 50000 30 | NUM_LAYERS = 3 31 | LAYER_SIZE = 1024 32 | BATCH_SIZE = 128 33 | buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 34 | 35 | MAX_DEC_VOCABULARY = MAX_ENC_VOCABULARY 36 | 37 | LEARNING_RATE = 0.5 38 | LEARNING_RATE_DECAY_FACTOR = 0.99 39 | MAX_GRADIENT_NORM = 5.0 40 | 41 | TWEETS_TRAIN_ENC_IDX_TXT = "{0}/tweets_train_enc_idx.txt".format(GENERATED_DIR) 42 | TWEETS_TRAIN_DEC_IDX_TXT = "{0}/tweets_train_dec_idx.txt".format(GENERATED_DIR) 43 | TWEETS_VAL_ENC_IDX_TXT = "{0}/tweets_val_enc_idx.txt".format(GENERATED_DIR) 44 | TWEETS_VAL_DEC_IDX_TXT = "{0}/tweets_val_dec_idx.txt".format(GENERATED_DIR) 45 | 46 | VOCAB_ENC_TXT = "{0}/vocab_enc.txt".format(GENERATED_DIR) 47 | VOCAB_DEC_TXT = "{0}/vocab_dec.txt".format(GENERATED_DIR) 48 | -------------------------------------------------------------------------------- /old/README.md: -------------------------------------------------------------------------------- 1 | # What is this? 2 | This is seq2seq chatbot implementation. Most credit goes to [1228337123](https://github.com/1228337123/tensorflow-seq2seq-chatbot) and [AvaisP](https://github.com/AvaisP/Neural_Conversation_Models). I'm just reimplmenting their work to have better understandings on seq2seq. This chat bot is optimized for Japanese. You may replace existing tokenizer with one for your language. 3 | 4 | Main differences of my implementation are 5 | - More comments 6 | - Easy to understand input/output format for each processes 7 | 8 | # Requirements 9 | - Python 3.6 10 | - Tensorflow 1.1.0 11 | 12 | # How to run 13 | 1. Prepare train data. 14 | 1. Put your train data as data/tweets.txt, the file consists of pairs of tweet and reply. 15 | 1. Odd lines are tweets and even lines are corresponding replies. 16 | 1. You can get the training data using [github.com/Marsan-Ma/twitter_scraper](https://github.com/Marsan-Ma/twitter_scraper). 17 | 1. Process the training data and generate vocabulary file and some necessary files. Run following command then you'd see the files generated in generated/ directory. 18 | 19 | python data_processer.py 20 | 1. Train! Train may take a few hours to 1 day, and it never stops. Once you think it's ready, just Ctrl-C. Model parameters are saved in generated/ directory. 21 | 22 | python train.py 23 | 24 | 1. Talk to him! 25 | 26 | python predict.py 27 | 28 | # Twitter Bot 29 | By running twitter_listenr.py and twitter_replier.py, you can run this chatbot on twitter. 30 | 31 | 32 | Here are some interesting conversations with him. 33 | ![sample1](http://cdn-ak.f.st-hatena.com/images/fotolife/h/higepon/20170428/20170428211132.jpg?1493381493?changed=1493381493) 34 | ![sample2](http://cdn-ak.f.st-hatena.com/images/fotolife/h/higepon/20170428/20170428211230.jpg?1493381551?changed=1493381551) 35 | -------------------------------------------------------------------------------- /old/find_long_reply.py: -------------------------------------------------------------------------------- 1 | import predict 2 | import tensorflow as tf 3 | import os 4 | import json 5 | import tweepy 6 | import time 7 | import socket 8 | import http.client 9 | from tweepy import OAuthHandler, Stream 10 | from tweepy.streaming import StreamListener 11 | 12 | tcpip_delay = 0.25 13 | MAX_TCPIP_TIMEOUT = 16 14 | 15 | 16 | class QueueListener(StreamListener): 17 | 18 | def __init__(self, sess): 19 | consumer_key = os.getenv("consumer_key") 20 | consumer_secret = os.getenv("consumer_secret") 21 | access_token = os.getenv("access_token") 22 | access_token_secret = os.getenv("access_token_secret") 23 | 24 | self.auth = OAuthHandler(consumer_key, consumer_secret) 25 | self.auth.set_access_token(access_token, access_token_secret) 26 | self.api = tweepy.API(self.auth) 27 | self.predictor = predict.EasyPredictor(sess) 28 | 29 | def on_data(self, data): 30 | """Routes the raw stream data to the appropriate method.""" 31 | raw = json.loads(data) 32 | if 'in_reply_to_status_id' in raw: 33 | if self.on_status(raw) is False: 34 | return False 35 | elif 'limit' in raw: 36 | if self.on_limit(raw['limit']['track']) is False: 37 | return False 38 | return True 39 | 40 | def on_status(self, status): 41 | if 'retweeted_status' in status: 42 | return True 43 | text = status['text'] 44 | replies = self.predictor.predict(text) 45 | if not replies: 46 | return True 47 | reply_body = replies[0] 48 | text = text.replace('\n', ' ') 49 | print(text) 50 | print("reply:{}".format(reply_body)) 51 | return True 52 | 53 | def on_error(self, status): 54 | print('ON ERROR:', status) 55 | 56 | def on_limit(self, track): 57 | print('ON LIMIT:', track) 58 | 59 | 60 | def main(): 61 | with tf.Session() as sess: 62 | listener = QueueListener(sess) 63 | stream = Stream(listener.auth, listener) 64 | stream.filter(languages=["ja"], 65 | track=['「', '」', '私', '俺', 'わたし', 'おれ', 'ぼく', '僕', 'http', 'www', 'co', '@', '#', '。', ',', '!', 66 | '.', '!', ',', ':', ':', '』', ')', '...', 'これ']) 67 | try: 68 | while True: 69 | try: 70 | stream.sample() 71 | except KeyboardInterrupt: 72 | print('KEYBOARD INTERRUPT') 73 | return 74 | except (socket.error, http.client.HTTPException): 75 | global tcpip_delay 76 | print('TCP/IP Error: Restarting after %.2f seconds.' % tcpip_delay) 77 | time.sleep(min(tcpip_delay, MAX_TCPIP_TIMEOUT)) 78 | tcpip_delay += 0.25 79 | finally: 80 | stream.disconnect() 81 | print('Exit successful, corpus dumped in %s' % (listener.dumpfile)) 82 | 83 | 84 | if __name__ == '__main__': 85 | main() -------------------------------------------------------------------------------- /old/tweet_listener.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | import pickle 4 | import tweepy 5 | from datetime import datetime, timedelta 6 | 7 | DB_NAME = 'tweets.db' 8 | SHOULD_TWEET = 1 9 | 10 | 11 | 12 | def create_tables(): 13 | conn = sqlite3.connect(DB_NAME) 14 | sql = 'create table tweets(sid integer primary key, data blob not null, processed integer not null default 0)' 15 | c = conn.cursor() 16 | c.execute(sql) 17 | conn.commit() 18 | conn.close() 19 | # alter table tweets add column bot_flag integer NOT NULL default 0; 20 | 21 | 22 | def insert_tweet(status_id, tweet, bot_flag=0): 23 | conn = sqlite3.connect(DB_NAME) 24 | binary_data = pickle.dumps(tweet, pickle.HIGHEST_PROTOCOL) 25 | c = conn.cursor() 26 | c.execute("insert into tweets (sid, data, bot_flag) values (?, ?, ?)", [status_id, sqlite3.Binary(binary_data), bot_flag]) 27 | conn.commit() 28 | conn.close() 29 | 30 | 31 | class StreamListener(tweepy.StreamListener): 32 | def __init__(self, api): 33 | self.api = api 34 | self.next_tweet_time = self.get_next_tweet_time() 35 | 36 | def on_status(self, status): 37 | print("{0}: {1}".format(status.text, status.author.screen_name)) 38 | 39 | screen_name = status.author.screen_name 40 | # ignore my tweets 41 | if screen_name == self.api.me().screen_name: 42 | print("Ignored my tweet") 43 | return True 44 | elif status.text.startswith("@{0}".format(self.api.me().screen_name)): 45 | # Save mentions 46 | print("Saved mention") 47 | insert_tweet(status.id, status) 48 | return True 49 | else: 50 | if self.next_tweet_time < datetime.today(): 51 | print("Saving normal tweet as seed") 52 | self.next_tweet_time = self.get_next_tweet_time() 53 | insert_tweet(status.id, status, bot_flag=SHOULD_TWEET) 54 | print("Ignored this tweet") 55 | return True 56 | 57 | @staticmethod 58 | def get_next_tweet_time(): 59 | return datetime.today() + timedelta(hours=4) 60 | 61 | @staticmethod 62 | def on_error(status_code): 63 | print(status_code) 64 | return True 65 | 66 | 67 | def tweet_listener(): 68 | consumer_key = os.getenv("consumer_key") 69 | consumer_secret = os.getenv("consumer_secret") 70 | access_token = os.getenv("access_token") 71 | access_token_secret = os.getenv("access_token_secret") 72 | 73 | auth = tweepy.OAuthHandler(consumer_key, consumer_secret) 74 | auth.set_access_token(access_token, access_token_secret) 75 | api = tweepy.API(auth) 76 | 77 | while True: 78 | try: 79 | stream = tweepy.Stream(auth=api.auth, 80 | listener=StreamListener(api)) 81 | print("listener starting...") 82 | stream.userstream() 83 | except Exception as e: 84 | print(e) 85 | print(e.__doc__) 86 | 87 | if __name__ == '__main__': 88 | if False: 89 | create_tables() 90 | tweet_listener() 91 | -------------------------------------------------------------------------------- /old/tweet_replyer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import tweepy 4 | import time 5 | import predict 6 | import sqlite3 7 | import pickle 8 | import tweet_listener 9 | 10 | 11 | def select_next_tweet(): 12 | conn = sqlite3.connect(tweet_listener.DB_NAME) 13 | c = conn.cursor() 14 | c.execute("select sid, data, bot_flag from tweets where processed = 0") 15 | for row in c: 16 | sid = row[0] 17 | data = pickle.loads(row[1]) 18 | bot_flag = row[2] 19 | return sid, data, bot_flag 20 | return None, None, None 21 | 22 | 23 | def mark_tweet_processed(status_id): 24 | conn = sqlite3.connect(tweet_listener.DB_NAME) 25 | c = conn.cursor() 26 | c.execute("update tweets set processed = 1 where sid = ?", [status_id]) 27 | conn.commit() 28 | conn.close() 29 | 30 | 31 | def tweets(): 32 | while True: 33 | status_id, tweet, bot_flag = select_next_tweet() 34 | if status_id is not None: 35 | yield(status_id, tweet, bot_flag) 36 | time.sleep(1) 37 | 38 | 39 | def post_reply(api, bot_flag, reply_body, screen_name, status_id): 40 | unk_count = reply_body.count('_UNK') 41 | reply_body = reply_body.replace('_UNK', '💩') 42 | if bot_flag == tweet_listener.SHOULD_TWEET: 43 | if unk_count > 0: 44 | return 45 | reply_text = reply_body 46 | print("My Tweet:{0}".format(reply_text)) 47 | if not reply_text: 48 | return 49 | api.update_status(status=reply_text) 50 | else: 51 | if not reply_body: 52 | reply_body = "🐶(適切なお返事が生成できませんでした)" 53 | reply_text = "@" + screen_name + " " + reply_body 54 | print("Reply:{0}".format(reply_text)) 55 | api.update_status(status=reply_text, 56 | in_reply_to_status_id=status_id) 57 | 58 | 59 | def twitter_bot(): 60 | # Only allocate part of the gpu memory when predicting. 61 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) 62 | tf_config = tf.ConfigProto(gpu_options=gpu_options) 63 | 64 | consumer_key = os.getenv("consumer_key") 65 | consumer_secret = os.getenv("consumer_secret") 66 | access_token = os.getenv("access_token") 67 | access_token_secret = os.getenv("access_token_secret") 68 | 69 | auth = tweepy.OAuthHandler(consumer_key, consumer_secret) 70 | auth.set_access_token(access_token, access_token_secret) 71 | api = tweepy.API(auth) 72 | with tf.Session(config=tf_config) as sess: 73 | predictor = predict.EasyPredictor(sess) 74 | 75 | for tweet in tweets(): 76 | status_id, status, bot_flag = tweet 77 | print("Processing {0}...".format(status.text)) 78 | screen_name = status.author.screen_name 79 | replies = predictor.predict(status.text) 80 | if not replies: 81 | print("no reply") 82 | continue 83 | reply_body = replies[0] 84 | if reply_body is None: 85 | print("No reply predicted") 86 | else: 87 | try: 88 | post_reply(api, bot_flag, reply_body, screen_name, status_id) 89 | except tweepy.TweepError as e: 90 | # duplicate status 91 | if e.api_code == 187: 92 | pass 93 | else: 94 | raise 95 | mark_tweet_processed(status_id) 96 | 97 | 98 | if __name__ == '__main__': 99 | twitter_bot() 100 | -------------------------------------------------------------------------------- /old/predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | import numpy as np 4 | import train 5 | import config 6 | import data_processer 7 | 8 | 9 | def get_prediction(session, model, enc_vocab, rev_dec_vocab, text): 10 | token_ids = data_processer.sentence_to_token_ids(text, enc_vocab) 11 | bucket_id = min([b for b in range(len(config.buckets)) 12 | if config.buckets[b][0] > len(token_ids)]) 13 | encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id) 14 | 15 | _, _, output_logits = model.step(session, encoder_inputs, decoder_inputs, 16 | target_weights, bucket_id, True, beam_search=False) 17 | 18 | outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] 19 | if data_processer.EOS_ID in outputs: 20 | outputs = outputs[:outputs.index(data_processer.EOS_ID)] 21 | text = "".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs]) 22 | return text 23 | 24 | # normal_prediction(bucket_id, decoder_inputs, encoder_inputs, model, rev_dec_vocab, session, target_weights) 25 | 26 | #if config.beam_search: 27 | # beam_search_prediction(bucket_id, decoder_inputs, encoder_inputs, model, rev_dec_vocab, session, 28 | # target_weights) 29 | 30 | #def normal_prediction(bucket_id, decoder_inputs, encoder_inputs, model, rev_dec_vocab, session, target_weights): 31 | # _, _, output_logits = model.step(session, encoder_inputs, decoder_inputs, 32 | # target_weights, bucket_id, True, beam_search=False) 33 | # outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] 34 | # if data_processer.EOS_ID in outputs: 35 | # outputs = outputs[:outputs.index(data_processer.EOS_ID)] 36 | # text = "".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs]) 37 | # print("Normal Prediction") 38 | # print(text) 39 | 40 | 41 | def get_beam_serch_prediction(session, model, enc_vocab, rev_dec_vocab, text): 42 | max_len = config.buckets[-1][0] 43 | target_text = text 44 | if len(text) > max_len: 45 | target_text = text[:max_len] 46 | token_ids = data_processer.sentence_to_token_ids(target_text, enc_vocab) 47 | target_buckets = [b for b in range(len(config.buckets)) 48 | if config.buckets[b][0] > len(token_ids)] 49 | if not target_buckets: 50 | return [] 51 | 52 | bucket_id = min(target_buckets) 53 | encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id) 54 | 55 | path, symbol, output_logits = model.step(session, encoder_inputs, decoder_inputs, 56 | target_weights, bucket_id, True, beam_search=config.beam_search) 57 | beam_size = config.beam_size 58 | k = output_logits[0] 59 | paths = [] 60 | for kk in range(beam_size): 61 | paths.append([]) 62 | curr = list(range(beam_size)) 63 | num_steps = len(path) 64 | for i in range(num_steps - 1, -1, -1): 65 | for kk in range(beam_size): 66 | paths[kk].append(symbol[i][curr[kk]]) 67 | curr[kk] = path[i][curr[kk]] 68 | recos = set() 69 | ret = [] 70 | i = 0 71 | for kk in range(beam_size): 72 | foutputs = [int(logit) for logit in paths[kk][::-1]] 73 | 74 | # If there is an EOS symbol in outputs, cut them at that point. 75 | if data_processer.EOS_ID in foutputs: 76 | # # print outputs 77 | foutputs = foutputs[:foutputs.index(data_processer.EOS_ID)] 78 | rec = "".join([tf.compat.as_str(rev_dec_vocab[output]) for output in foutputs]) 79 | if rec not in recos: 80 | recos.add(rec) 81 | # print("reply {}".format(i)) 82 | # i = i + 1 83 | ret.append(rec) 84 | return ret 85 | 86 | 87 | class EasyPredictor: 88 | def __init__(self, session): 89 | self.session = session 90 | train.show_progress("Creating model...") 91 | self.model = train.create_or_restore_model(self.session, config.buckets, forward_only=True, beam_search=config.beam_search, beam_size=config.beam_size) 92 | self.model.batch_size = 1 93 | train.show_progress("done\n") 94 | self.enc_vocab, _ = data_processer.initialize_vocabulary(config.VOCAB_ENC_TXT) 95 | _, self.rev_dec_vocab = data_processer.initialize_vocabulary(config.VOCAB_DEC_TXT) 96 | 97 | def predict(self, text): 98 | text = text.replace('\n', ' ') 99 | text = data_processer.sanitize_line(text) 100 | if config.beam_search: 101 | replies = get_beam_serch_prediction(self.session, self.model, self.enc_vocab, self.rev_dec_vocab, text) 102 | return replies 103 | else: 104 | reply = get_prediction(self.session, self.model, self.enc_vocab, self.rev_dec_vocab, text) 105 | return [reply] 106 | 107 | 108 | def predict(): 109 | # Only allocate part of the gpu memory when predicting. 110 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) 111 | tf_config = tf.ConfigProto(gpu_options=gpu_options) 112 | 113 | with tf.Session(config=tf_config) as sess: 114 | 115 | predictor = EasyPredictor(sess) 116 | 117 | sys.stdout.write("> ") 118 | sys.stdout.flush() 119 | line = sys.stdin.readline() 120 | while line: 121 | replies = predictor.predict(line) 122 | for i, text in enumerate(replies): 123 | print(i, text) 124 | print("> ", end="") 125 | sys.stdout.flush() 126 | line = sys.stdin.readline() 127 | 128 | if __name__ == '__main__': 129 | predict() 130 | -------------------------------------------------------------------------------- /old/TODO.md: -------------------------------------------------------------------------------- 1 | # Goal 2 | Make *fun* chatbot like human. 3 | # TODO 4 | 5 | - Revisit GNMT once it's stablized. 6 | - checkout news API [Release TensorFlow 1.2.0 · tensorflow/tensorflow](https://github.com/tensorflow/tensorflow/releases/tag/v1.2.0) 7 | - Get 5M data 8 | - see if the data text above works well 9 | - Stop using bio it affects bucket distirbution and too biased to bio 10 | - Improvement 4: Create a chatbot with personality 11 | - [B] Improvement 5: Make your chatbot remember information from the previous conversation 12 | - can we simpley train with 1 pair? 13 | - Improvements todo 14 | - Train on multiple datasets 15 | - Create a feedback loop that allows users to train your chatbot 16 | - How can I have the bot tweet something interesting. 17 | - [Make your chatbot remember information · higepon/tensorflow_seq2seq_chatbot Wiki](https://github.com/higepon/tensorflow_seq2seq_chatbot/wiki/Make-your-chatbot-remember-information) 18 | - tweet something based on 19 | - someone's tweet 20 | - news? 21 | - maybe 3 times a day 22 | - changed it to 3 layers of 1024 (rather than 3 layers of 256) see https://github.com/tensorflow/tensorflow/issues/550 23 | - [In my seq2seq chatbot, I'm seeing many general replies like Thank you, lol, this is it or yes. Even for train inputs, outputs are generic and not interesting. How can I debug seq2seq output? - Quora](https://www.quora.com/unanswered/In-my-seq2seq-chatbot-Im-seeing-many-general-replies-like-Thank-you-lol-this-is-it-or-yes-Even-for-train-inputs-outputs-are-generic-and-not-interesting-How-can-I-debug-seq2seq-output) 24 | - __here__: Read https://arxiv.org/pdf/1606.01541.pdf 25 | - related papers 26 | - Deep Reinforcement Learning for Dialogue Generation (56) 27 | - __done__ 8/27 The paper above 28 | - __done__ ref impl: [liuyuemaicha/Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow: Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow](https://github.com/liuyuemaicha/Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow) 29 | - Building end-to-end dialogue systems using generative hierar- chical neural network models. 30 | - __done__ 8/27 Useless for our case? 31 |      - A diversity-promoting objective function for neural conversation models. (87) 32 |       - __done__ 8/27 Pernalize generic response 33 | - __done__ 8/27 ref impl [pender/chatbot-rnn: A toy chatbot powered by deep learning and trained on data from Reddit](https://github.com/pender/chatbot-rnn) 34 | - 8/17 I'll implement Deep Reinforcement Learning for Dialogue Generation, because it's is newer version from the same author 35 | 36 | - 8/23 super summary 37 | # random ideas 38 | - Can we use 2ch.net data? 39 | 40 | # Done 41 | - __done__ 6/11 [Understand how CS20SI chatbot works](https://github.com/higepon/tensorflow_seq2seq_chatbot/wiki/Understand-how-CS20SI-chatbot-works) 42 | - __done__ 6/17 [Make old chatbot compatible with tensorflow 1.0](https://github.com/higepon/tensorflow_seq2seq_chatbot/wiki/Make-old-chatbot-compatible-with-tensorflow-1.0) 43 | - __done__ beam search 44 | - __done__ read https://arxiv.org/abs/1609.08144 45 | - __done__ read https://arxiv.org/abs/1611.04558 46 | - __done__ read [this pdf](http://2boy.org/~yuta/publications/neural-dialog-model-kanto-mt-20170714.pdf) 47 | - __done__ read [tensorflow/nmt: TensorFlow Neural Machine Translation Tutorial](https://github.com/tensorflow/nmt) 48 | - __done__ 7/30[Install tensorflow FreeBSD · higepon/tensorflow_seq2seq_chatbot Wiki](https://github.com/higepon/tensorflow_seq2seq_chatbot/wiki/Install-tensorflow-FreeBSD) 49 | - __done__ greedy get tweets 50 | - __done__ train the bot with above 51 | - __done__ set up only bot somehow 52 | - __done__ [A] Improvement 3: Use more than just one utterance as the encoder 53 | - __done__ Check if tweet collector get one more deep conversation 54 | - __done__ collect 55 | - __done__ 7/30 0.0.3 tag with commit 56 | - __done__ 7/30 deploy script 57 | - __done__ Clean up tweets data, we see a lot これw 58 | - __done__ write a script which remove spam tweets 59 | - __done__ add column to the db is_spam. 60 | - __done__ Make deploy process for tweet bot 61 | - __done__ Make tweet bot available in cloud 62 | ## tags 63 | ### 0.0.3 64 | Refactoring and make sure tweet bot is working. 65 | Actually now the bot is working in vps (Ubuntu). 66 | 67 | ### 0.0.2 68 | beam search implemented. 69 | 70 | >おはよう 71 | normal:おはようございます 72 | beam 73 | 0 おはよう ござい ます 74 | 1 お は あり 75 | 2 お は あり です 〜 ♪ 76 | 77 | >こんにちは 78 | normal:はい(˘ω˘) 79 | beam 80 | 0 はい ( ˘ ω ˘ ) 81 | 1 はい ( ˘ ω ˘ ) スヤァ 82 | 2 はい ( ˙ㅿ˙ 。 . 83 | 3 はい ♡ 84 | 4 はい ( ´ ω 。 85 | 5 はい 、 さ www 86 | 6 はい ( 笑 87 | 88 | >ばいばいー 89 | わろきちってんじゃんwww 90 | normal:beam 91 | 0 がち やし ま ー ん 92 | 1 いや ま ー ! 93 | 2 わろ ぶ や ! 94 | 3 ほら 95 | 4 ネタ やし ぶ 96 | 5 ど ま ー 97 | 6 がち やし ま ーー 98 | 7 いつの間に ま ー 99 | 8 す 100 | 9 いつの間に ぶ 101 | 10 いつの間に やし ぶ うち 102 | 11 やらかし た ❤ 103 | 12 現実 やし 104 | 13 ほんま やし ぶ () 105 | 14 や ま ー 106 | 107 | >(月曜日から)逃げちゃ駄目だ……! 108 | normal;えぇこれは、、、 109 | beam 110 | 0 なんで 進捗 は これ じゃ ねぇ ・ ・ ω ! 111 | 1 え ぇ これ は 光 は 、 ! 112 | 2 え ぇ これ は 嫌 ) 113 | 3 なんで 進捗 おっ け ( ω ! 114 | 4 なんで 進捗 は これ じゃ ぞ 〜 115 | 116 | > 子供たちにつられて苦手なミニオンズ…(´・ω・`)w 117 | normal:気をしてねー(˘ω˘) 118 | beam 119 | 0 気 を し て ( ˘ つ ω -(´∀`; ) 120 | 1 気 を すん な ( ˙ ˘ ) 121 | 2 仕事 を すん や ( ˙ ω -(´∀`; ) 122 | 3 気 を し て ねー 。 ( ^ ー ` ・ ) 123 | 4 気 を し てる やろ ( ˙ ˘ ω ˘ ・) ! 124 | 5 気 を し てる やろ ( ˙ ˘ ω ˘ ω ・ ) 125 | 6 気 を し てる の だ よ ) 126 | 127 | > 中華そば醤油をいただきました💕お、おいしい〜😍大盛いけたかも? 128 | normal: 追加ですよねwww 129 | beam 130 | 0 追加 し まし た ☺ 131 | 1 追加 です よ ☺ 132 | 2 追加 です よ ね www 133 | ### 0.0.1 134 | Adam optimizer and summary op work well. 135 | 136 | global step 25000 learning rate 0.4522 perplexity 19.24 137 | eval: bucket 0 perplexity 4.63 138 | eval: bucket 1 perplexity 13.32 139 | eval: bucket 2 perplexity 27.23 140 | eval: bucket 3 perplexity 2.20 141 | > おはようー 142 | おはようございます 143 | > どうなの最近? 144 | これはいいの? 145 | > あほか。 146 | なにしwww 147 | > 君の名は 148 | まじでしょ? 149 | > まじかよ。 150 | とりあえず増えてないんですか 151 | > 最近映画見た? 152 | これ? 153 | > いやちがうよ。それじゃないって。適当なこと言うなよ。 154 | _UNKかwww 155 | > うんこじゃない 156 | どんな撮www 157 | -------------------------------------------------------------------------------- /old/lib/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | import gzip 18 | import os 19 | import re 20 | import tarfile 21 | 22 | from six.moves import urllib 23 | 24 | from tensorflow.python.platform import gfile 25 | 26 | # Special vocabulary symbols - we always put them at the start. 27 | _PAD = "_PAD" 28 | _GO = "_GO" 29 | _EOS = "_EOS" 30 | _UNK = "_UNK" 31 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 32 | 33 | PAD_ID = 0 34 | GO_ID = 1 35 | EOS_ID = 2 36 | UNK_ID = 3 37 | 38 | # Regular expressions used to tokenize. 39 | # _WORD_SPLIT = re.compile("([.,!?\"':;)(])") 40 | _WORD_SPLIT = re.compile("([.,!/?\":;)(])") 41 | _DIGIT_RE = re.compile(r"\d") 42 | 43 | 44 | 45 | def gunzip_file(gz_path, new_path): 46 | """Unzips from gz_path into new_path.""" 47 | print("Unpacking %s to %s" % (gz_path, new_path)) 48 | with gzip.open(gz_path, "rb") as gz_file: 49 | with open(new_path, "w") as new_file: 50 | for line in gz_file: 51 | new_file.write(line) 52 | 53 | 54 | def basic_tokenizer(sentence): 55 | """Very basic tokenizer: split the sentence into a list of tokens.""" 56 | words = [] 57 | for space_separated_fragment in sentence.strip().split(): 58 | words.extend(re.split(_WORD_SPLIT, space_separated_fragment)) 59 | return [w for w in words if w] 60 | 61 | 62 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, 63 | tokenizer=None, normalize_digits=True): 64 | """Create vocabulary file (if it does not exist yet) from data file. 65 | 66 | Data file is assumed to contain one sentence per line. Each sentence is 67 | tokenized and digits are normalized (if normalize_digits is set). 68 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 69 | We write it to vocabulary_path in a one-token-per-line format, so that later 70 | token in the first line gets id=0, second line gets id=1, and so on. 71 | 72 | Args: 73 | vocabulary_path: path where the vocabulary will be created. 74 | data_path: data file that will be used to create vocabulary. 75 | max_vocabulary_size: limit on the size of the created vocabulary. 76 | tokenizer: a function to use to tokenize each data sentence; 77 | if None, basic_tokenizer will be used. 78 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 79 | """ 80 | if not gfile.Exists(vocabulary_path): 81 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 82 | vocab = {} 83 | with gfile.GFile(data_path, mode="r") as f: 84 | counter = 0 85 | for line in f: 86 | counter += 1 87 | if counter % 100000 == 0: 88 | print(" processing line %d" % counter) 89 | text_conversation =line.strip().lower().split("\t") 90 | if len(text_conversation) == 2: 91 | txt = text_conversation[0] + " " + text_conversation[1] 92 | tokens = tokenizer(txt) if tokenizer else basic_tokenizer(txt) 93 | for w in tokens: 94 | # word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w 95 | word = w 96 | if word in vocab: 97 | vocab[word] += 1 98 | else: 99 | vocab[word] = 1 100 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 101 | print(len(vocab_list)) 102 | if len(vocab_list) > max_vocabulary_size: 103 | vocab_list = vocab_list[:max_vocabulary_size] 104 | with gfile.GFile(vocabulary_path, mode="w") as vocab_file: 105 | for w in vocab_list: 106 | vocab_file.write(w + "\n") 107 | 108 | def initialize_vocabulary(vocabulary_path): 109 | """Initialize vocabulary from file. 110 | 111 | We assume the vocabulary is stored one-item-per-line, so a file: 112 | dog 113 | cat 114 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 115 | also return the reversed-vocabulary ["dog", "cat"]. 116 | 117 | Args: 118 | vocabulary_path: path to the file containing the vocabulary. 119 | 120 | Returns: 121 | a pair: the vocabulary (a dictionary mapping string to integers), and 122 | the reversed vocabulary (a list, which reverses the vocabulary mapping). 123 | 124 | Raises: 125 | ValueError: if the provided vocabulary_path does not exist. 126 | """ 127 | if gfile.Exists(vocabulary_path): 128 | rev_vocab = [] 129 | with gfile.GFile(vocabulary_path, mode="r") as f: 130 | rev_vocab.extend(f.readlines()) 131 | rev_vocab = [line.strip() for line in rev_vocab] 132 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 133 | return vocab, rev_vocab 134 | else: 135 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 136 | 137 | 138 | def sentence_to_token_ids(sentence, vocabulary, 139 | tokenizer=None, normalize_digits=True): 140 | """Convert a string to list of integers representing token-ids. 141 | 142 | For example, a sentence "I have a dog" may become tokenized into 143 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 144 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 145 | 146 | Args: 147 | sentence: a string, the sentence to convert to token-ids. 148 | vocabulary: a dictionary mapping tokens to integers. 149 | tokenizer: a function to use to tokenize each sentence; 150 | if None, basic_tokenizer will be used. 151 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 152 | 153 | Returns: 154 | a list of integers, the token-ids for the sentence. 155 | """ 156 | if tokenizer: 157 | words = tokenizer(sentence) 158 | else: 159 | words = basic_tokenizer(sentence) 160 | # if not normalize_digits: 161 | return [vocabulary.get(w, UNK_ID) for w in words] 162 | # Normalize digits by 0 before looking words up in the vocabulary. 163 | # return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words] 164 | 165 | 166 | def data_to_token_ids(data_path, target_path, vocabulary_path, 167 | tokenizer=None, normalize_digits=True): 168 | """Tokenize data file and turn into token-ids using given vocabulary file. 169 | 170 | This function loads data line-by-line from data_path, calls the above 171 | sentence_to_token_ids, and saves the result to target_path. See comment 172 | for sentence_to_token_ids on the details of token-ids format. 173 | 174 | Args: 175 | data_path: path to the data file in one-sentence-per-line format. 176 | target_path: path where the file with token-ids will be created. 177 | vocabulary_path: path to the vocabulary file. 178 | tokenizer: a function to use to tokenize each sentence; 179 | if None, basic_tokenizer will be used. 180 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 181 | """ 182 | if not gfile.Exists(target_path): 183 | print("Tokenizing data in %s" % data_path) 184 | vocab, _ = initialize_vocabulary(vocabulary_path) 185 | with gfile.GFile(data_path, mode="r") as data_file: 186 | with gfile.GFile(target_path, mode="w") as tokens_file: 187 | counter = 0 188 | for line in data_file: 189 | counter += 1 190 | if counter % 100000 == 0: 191 | print(" tokenizing line %d" % counter) 192 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, 193 | normalize_digits) 194 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 195 | -------------------------------------------------------------------------------- /old/train.py: -------------------------------------------------------------------------------- 1 | import config 2 | import os 3 | import sys 4 | import math 5 | import random 6 | import numpy as np 7 | import tensorflow as tf 8 | import data_processer 9 | import lib.seq2seq_model as seq2seq_model 10 | 11 | 12 | def show_progress(text): 13 | sys.stdout.write(text) 14 | sys.stdout.flush() 15 | 16 | 17 | def read_data_into_buckets(enc_path, dec_path, buckets): 18 | """Read tweets and reply and put them into buckets based on their length 19 | 20 | Args: 21 | enc_path: path to indexed tweets 22 | dec_path: path to indexed replies 23 | buckets: list of bucket 24 | 25 | Returns: 26 | data_set: data_set[i] has [tweet, reply] pairs for bucket[i] 27 | """ 28 | # data_set[i] corresponds data for buckets[i] 29 | data_set = [[] for _ in buckets] 30 | with tf.gfile.GFile(enc_path, mode="r") as ef, tf.gfile.GFile(dec_path, mode="r") as df: 31 | tweet, reply = ef.readline(), df.readline() 32 | counter = 0 33 | while tweet and reply: 34 | counter += 1 35 | if counter % 100000 == 0: 36 | print(" reading data line %d" % counter) 37 | sys.stdout.flush() 38 | source_ids = [int(x) for x in tweet.split()] 39 | target_ids = [int(x) for x in reply.split()] 40 | target_ids.append(data_processer.EOS_ID) 41 | for bucket_id, (source_size, target_size) in enumerate(buckets): 42 | # Find bucket to put this conversation based on tweet and reply length 43 | if len(source_ids) < source_size and len(target_ids) < target_size: 44 | data_set[bucket_id].append([source_ids, target_ids]) 45 | break 46 | tweet, reply = ef.readline(), df.readline() 47 | for bucket_id in range(len(buckets)): 48 | print("{}={}=".format(buckets[bucket_id], len(data_set[bucket_id]))) 49 | return data_set 50 | 51 | 52 | # Originally from https://github.com/1228337123/tensorflow-seq2seq-chatbot 53 | def create_or_restore_model(session, buckets, forward_only, beam_search, beam_size): 54 | 55 | # beam search is off for training 56 | """Create model and initialize or load parameters""" 57 | 58 | model = seq2seq_model.Seq2SeqModel(source_vocab_size=config.MAX_ENC_VOCABULARY, 59 | target_vocab_size=config.MAX_DEC_VOCABULARY, 60 | buckets=buckets, 61 | size=config.LAYER_SIZE, 62 | num_layers=config.NUM_LAYERS, 63 | max_gradient_norm=config.MAX_GRADIENT_NORM, 64 | batch_size=config.BATCH_SIZE, 65 | learning_rate=config.LEARNING_RATE, 66 | learning_rate_decay_factor=config.LEARNING_RATE_DECAY_FACTOR, 67 | beam_search=beam_search, 68 | attention=True, 69 | forward_only=forward_only, 70 | beam_size=beam_size) 71 | 72 | print("model initialized") 73 | ckpt = tf.train.get_checkpoint_state(config.GENERATED_DIR) 74 | # the checkpoint filename has changed in recent versions of tensorflow 75 | checkpoint_suffix = ".index" 76 | if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path + checkpoint_suffix): 77 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 78 | model.saver.restore(session, ckpt.model_checkpoint_path) 79 | else: 80 | print("Created model with fresh parameters.") 81 | session.run(tf.global_variables_initializer()) 82 | return model 83 | 84 | 85 | def next_random_bucket_id(buckets_scale): 86 | n = np.random.random_sample() 87 | bucket_id = min([i for i in range(len(buckets_scale)) if buckets_scale[i] > n]) 88 | return bucket_id 89 | 90 | 91 | def train(): 92 | # Only allocate 2/3 of the gpu memory to allow for running gpu-based predictions while training: 93 | # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.666) 94 | # tf_config = tf.ConfigProto(gpu_options=gpu_options) 95 | # tf_config.gpu_options.allocator_type = 'BFC' 96 | 97 | #with tf.Session(config=tf_config) as sess: 98 | with tf.Session() as sess: 99 | 100 | show_progress("Setting up data set for each buckets...") 101 | train_set = read_data_into_buckets(config.TWEETS_TRAIN_ENC_IDX_TXT, config.TWEETS_TRAIN_DEC_IDX_TXT, config.buckets) 102 | valid_set = read_data_into_buckets(config.TWEETS_VAL_ENC_IDX_TXT, config.TWEETS_VAL_DEC_IDX_TXT, config.buckets) 103 | show_progress("done\n") 104 | 105 | show_progress("Creating model...") 106 | # False for train 107 | beam_search = False 108 | model = create_or_restore_model(sess, config.buckets, forward_only=False, beam_search=beam_search, beam_size=config.beam_size) 109 | 110 | show_progress("done\n") 111 | 112 | # list of # of data in ith bucket 113 | train_bucket_sizes = [len(train_set[b]) for b in range(len(config.buckets))] 114 | train_total_size = float(sum(train_bucket_sizes)) 115 | 116 | # Originally from https://github.com/1228337123/tensorflow-seq2seq-chatbot 117 | # This is for choosing randomly bucket based on distribution 118 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 119 | for i in range(len(train_bucket_sizes))] 120 | 121 | show_progress("before train loop") 122 | # Train Loop 123 | steps = 0 124 | previous_perplexities = [] 125 | writer = tf.summary.FileWriter(config.LOGS_DIR, sess.graph) 126 | 127 | while True: 128 | bucket_id = next_random_bucket_id(train_buckets_scale) 129 | # print(bucket_id) 130 | 131 | # Get batch 132 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id) 133 | # show_progress("Training bucket_id={0}...".format(bucket_id)) 134 | 135 | # Train! 136 | # _, average_perplexity, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, 137 | # bucket_id, 138 | # forward_only=False, 139 | # beam_search=beam_search) 140 | _, average_perplexity, summary, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, 141 | bucket_id, 142 | forward_only=False, 143 | beam_search=beam_search) 144 | 145 | # show_progress("done {0}\n".format(average_perplexity)) 146 | 147 | steps = steps + 1 148 | if steps % 2 == 0: 149 | writer.add_summary(summary, steps) 150 | show_progress(".") 151 | if steps % 50 != 0: 152 | continue 153 | 154 | # check point 155 | checkpoint_path = os.path.join(config.GENERATED_DIR, "seq2seq.ckpt") 156 | show_progress("Saving checkpoint...") 157 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 158 | show_progress("done\n") 159 | 160 | perplexity = math.exp(average_perplexity) if average_perplexity < 300 else float('inf') 161 | print ("global step %d learning rate %.4f perplexity " 162 | "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), perplexity)) 163 | 164 | # Decrease learning rate if no improvement was seen over last 3 times. 165 | if len(previous_perplexities) > 2 and perplexity > max(previous_perplexities[-3:]): 166 | sess.run(model.learning_rate_decay_op) 167 | previous_perplexities.append(perplexity) 168 | 169 | for bucket_id in range(len(config.buckets)): 170 | if len(valid_set[bucket_id]) == 0: 171 | print(" eval: empty bucket %d" % bucket_id) 172 | continue 173 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(valid_set, bucket_id) 174 | # _, average_perplexity, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True, beam_search=beam_search) 175 | _, average_perplexity, valid_summary, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True, beam_search=beam_search) 176 | writer.add_summary(valid_summary, steps) 177 | eval_ppx = math.exp(average_perplexity) if average_perplexity < 300 else float('inf') 178 | print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) 179 | 180 | if __name__ == '__main__': 181 | train() 182 | -------------------------------------------------------------------------------- /old/data_processer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tensorflow as tf 3 | import config 4 | # For Japanese tokenizer 5 | import MeCab 6 | from tensorflow.python.platform import gfile 7 | 8 | # The data format 9 | # 10 | # (A) data/tweets.txt 11 | # You have to parepare this file by yourself. 12 | # This file has many raw tweet and reply pairs. Odd lines are tweets and even lines are replies. 13 | # example) 14 | # Line 1: Hey how are you doing? 15 | # Line 2: @higepon doing good. 16 | # 17 | # Following files are generated by data_processer.py for training. 18 | # 19 | # (B) generated/tweets_enc.txt 20 | # Each lines consists of one tweet, @username and URL are removed. 21 | # 22 | # (C) generated/tweets_dec.txt 23 | # Each lines consists of one reply, @username and URL are removed. 24 | # 25 | # (D) generated/tweets_train_[enc|dec].txt 26 | # Tweets or replies train data 27 | # 28 | # (E) generated/tweets_val_[enc|dec].txt 29 | # Tweets or replies validation data 30 | # 31 | # (F) generated/vocab_enc.txt 32 | # Vocabulary for tweets. 33 | # Words in frequency order 34 | # 35 | # (G) generated/vocab_dec.txt 36 | # Vocabulary for replies. 37 | # Words in frequency order 38 | # 39 | # (H) generated/tweets_[train|val]_[dec|enc]_idx.txt 40 | # Generated from tweets_[train|val]_[enc|dec].txt. 41 | # All words in the source file are replaced idx to the word. 42 | # 43 | 44 | import sys 45 | 46 | TWEETS_ENC_TXT = "{0}/tweets_enc.txt".format(config.GENERATED_DIR) 47 | TWEETS_DEC_TXT = "{0}/tweets_dec.txt".format(config.GENERATED_DIR) 48 | 49 | TWEETS_TRAIN_ENC_TXT = "{0}/tweets_train_enc.txt".format(config.GENERATED_DIR) 50 | TWEETS_TRAIN_DEC_TXT = "{0}/tweets_train_dec.txt".format(config.GENERATED_DIR) 51 | 52 | TWEETS_VAL_ENC_TXT = "{0}/tweets_val_enc.txt".format(config.GENERATED_DIR) 53 | TWEETS_VAL_DEC_TXT = "{0}/tweets_val_dec.txt".format(config.GENERATED_DIR) 54 | TWEETS_VAL_ENC_IDX_TXT = "{0}/tweets_val_enc_idx.txt".format(config.GENERATED_DIR) 55 | TWEETS_VAL_DEC_IDX_TXT = "{0}/tweets_val_dec_idx.txt".format(config.GENERATED_DIR) 56 | 57 | DIGIT_RE = re.compile(r"\d") 58 | 59 | _PAD = "_PAD" 60 | _GO = "_GO" 61 | _EOS = "_EOS" 62 | _UNK = "_UNK" 63 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 64 | 65 | PAD_ID = 0 66 | GO_ID = 1 67 | EOS_ID = 2 68 | UNK_ID = 3 69 | 70 | tagger = MeCab.Tagger("-Owakati") 71 | 72 | 73 | def japanese_tokenizer(sentence): 74 | assert type(sentence) is str 75 | # Mecab doesn't accept binary, but Python string (utf-8). 76 | result = tagger.parse(sentence) 77 | return result.split() 78 | 79 | 80 | def split_tweets_replies(tweets_path, enc_path, dec_path): 81 | """Read data from tweets_paths and split it to tweets and replies. 82 | 83 | Args: 84 | tweets_path: original tweets data 85 | enc_path: path to write tweets 86 | dec_path: path to write replies 87 | 88 | Returns: 89 | None 90 | """ 91 | i = 1 92 | with gfile.GFile(tweets_path, mode="rb") as f, gfile.GFile(enc_path, mode="w+") as ef, gfile.GFile(dec_path, 93 | mode="w+") as df: 94 | for line in f: 95 | if not isinstance(line, str): 96 | line = line.decode('utf-8') 97 | line = sanitize_line(line) 98 | 99 | # Odd lines are tweets 100 | if i % 2 == 1: 101 | ef.write(line) 102 | # Even lines are replies 103 | else: 104 | df.write(line) 105 | i = i + 1 106 | 107 | 108 | def sanitize_line(line): 109 | # Remove @username 110 | line = re.sub(r"@([A-Za-z0-9_]+)", "", line) 111 | # Remove URL 112 | line = re.sub(r'https?:\/\/.*', "", line) 113 | line = re.sub(DIGIT_RE, "0", line) 114 | return line 115 | 116 | 117 | def num_lines(file): 118 | """Return # of lines in file 119 | 120 | Args: 121 | file: Target file. 122 | 123 | Returns: 124 | # of lines in file 125 | """ 126 | return sum(1 for _ in open(file)) 127 | 128 | 129 | def create_train_validation(source_path, train_path, validation_path, train_ratio=0.75): 130 | """Split source file into train and validation data 131 | 132 | Args: 133 | source_path: source file path 134 | train_path: Path to write train data 135 | validation_path: Path to write validatio data 136 | train_ratio: Train data ratio 137 | 138 | Returns: 139 | None 140 | """ 141 | nb_lines = num_lines(source_path) 142 | nb_train = int(nb_lines * train_ratio) 143 | counter = 0 144 | with gfile.GFile(source_path, "r") as f, gfile.GFile(train_path, "w") as tf, gfile.GFile(validation_path, 145 | "w") as vf: 146 | for line in f: 147 | if counter < nb_train: 148 | tf.write(line) 149 | else: 150 | vf.write(line) 151 | counter = counter + 1 152 | 153 | 154 | # Originally from https://github.com/1228337123/tensorflow-seq2seq-chatbot 155 | def sentence_to_token_ids(sentence, vocabulary, tokenizer=japanese_tokenizer, normalize_digits=True): 156 | if tokenizer: 157 | words = tokenizer(sentence) 158 | else: 159 | words = basic_tokenizer(sentence) 160 | if not normalize_digits: 161 | return [vocabulary.get(w, UNK_ID) for w in words] 162 | # Normalize digits by 0 before looking words up in the vocabulary. 163 | # return [vocabulary.get(re.sub(_DIGIT_RE, b"0", w), UNK_ID) for w in words] #mark added .decode by Ken 164 | return [vocabulary.get(w, UNK_ID) for w in words] # added by Ken 165 | 166 | 167 | # Originally from https://github.com/1228337123/tensorflow-seq2seq-chatbot 168 | def data_to_token_ids(data_path, target_path, vocabulary_path, 169 | tokenizer=japanese_tokenizer, normalize_digits=True): 170 | if not gfile.Exists(target_path): 171 | print("Tokenizing data in %s" % data_path) 172 | vocab, _ = initialize_vocabulary(vocabulary_path) 173 | with gfile.GFile(data_path, mode="rb") as data_file: 174 | with gfile.GFile(target_path, mode="wb") as tokens_file: # edit w to wb 175 | counter = 0 176 | for line in data_file: 177 | # line = tf.compat.as_bytes(line) # added by Ken 178 | counter += 1 179 | if counter % 100000 == 0: 180 | print(" tokenizing line %d" % counter) 181 | # line is binary here 182 | line = line.decode('utf-8') 183 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, 184 | normalize_digits) 185 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 186 | 187 | 188 | # Originally from https://github.com/1228337123/tensorflow-seq2seq-chatbot 189 | def initialize_vocabulary(vocabulary_path): 190 | if gfile.Exists(vocabulary_path): 191 | rev_vocab = [] 192 | with gfile.GFile(vocabulary_path, mode="r") as f: 193 | rev_vocab.extend(f.readlines()) 194 | rev_vocab = [line.strip() for line in rev_vocab] 195 | # Dictionary of (word, idx) 196 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 197 | return vocab, rev_vocab 198 | else: 199 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 200 | 201 | 202 | # From https://github.com/1228337123/tensorflow-seq2seq-chatbot 203 | def create_vocabulary(source_path, vocabulary_path, max_vocabulary_size, tokenizer=japanese_tokenizer): 204 | """Create vocabulary file. Please see comments in head for file format 205 | 206 | Args: 207 | source_path: source file path 208 | vocabulary_path: Path to write vocabulary 209 | max_vocabulary_size: Max vocabulary size 210 | tokenizer: tokenizer used for tokenize each lines 211 | 212 | Returns: 213 | None 214 | """ 215 | if gfile.Exists(vocabulary_path): 216 | print("Found vocabulary file") 217 | return 218 | with gfile.GFile(source_path, mode="r") as f: 219 | counter = 0 220 | vocab = {} # (word, word_freq) 221 | for line in f: 222 | counter += 1 223 | words = tokenizer(line) 224 | if counter % 5000 == 0: 225 | sys.stdout.write(".") 226 | sys.stdout.flush() 227 | for word in words: 228 | # Normalize numbers. Not sure if it's necessary. 229 | word = re.sub(DIGIT_RE, "0", word) 230 | if word in vocab: 231 | vocab[word] += 1 232 | else: 233 | vocab[word] = 1 234 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 235 | if len(vocab_list) > max_vocabulary_size: 236 | vocab_list = vocab_list[:max_vocabulary_size] 237 | with gfile.GFile(vocabulary_path, mode="w") as vocab_file: 238 | for w in vocab_list: 239 | vocab_file.write(w + "\n") 240 | print("\n") 241 | 242 | 243 | if __name__ == '__main__': 244 | print("Splitting into tweets and replies...") 245 | split_tweets_replies(config.TWEETS_TXT, TWEETS_ENC_TXT, TWEETS_DEC_TXT) 246 | print("Done") 247 | 248 | print("Splitting into train and validation data...") 249 | create_train_validation(TWEETS_ENC_TXT, TWEETS_TRAIN_ENC_TXT, TWEETS_VAL_ENC_TXT) 250 | create_train_validation(TWEETS_DEC_TXT, TWEETS_TRAIN_DEC_TXT, TWEETS_VAL_DEC_TXT) 251 | print("Done") 252 | 253 | print("Creating vocabulary files...") 254 | create_vocabulary(TWEETS_ENC_TXT, config.VOCAB_ENC_TXT, config.MAX_ENC_VOCABULARY) 255 | create_vocabulary(TWEETS_DEC_TXT, config.VOCAB_DEC_TXT, config.MAX_DEC_VOCABULARY) 256 | print("Done") 257 | 258 | print("Creating sentence idx files...") 259 | data_to_token_ids(TWEETS_TRAIN_ENC_TXT, config.TWEETS_TRAIN_ENC_IDX_TXT, config.VOCAB_ENC_TXT) 260 | data_to_token_ids(TWEETS_TRAIN_DEC_TXT, config.TWEETS_TRAIN_DEC_IDX_TXT, config.VOCAB_DEC_TXT) 261 | data_to_token_ids(TWEETS_VAL_ENC_TXT, TWEETS_VAL_ENC_IDX_TXT, config.VOCAB_ENC_TXT) 262 | data_to_token_ids(TWEETS_VAL_DEC_TXT, TWEETS_VAL_DEC_IDX_TXT, config.VOCAB_DEC_TXT) 263 | print("Done") 264 | -------------------------------------------------------------------------------- /old/lib/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | 2 | """Sequence-to-sequence model with an attention mechanism.""" 3 | 4 | 5 | import random 6 | 7 | import numpy as np 8 | from six.moves import xrange # pylint: disable=redefined-builtin 9 | import tensorflow as tf 10 | 11 | from lib.data_utils import * 12 | from lib.my_seq2seq import * 13 | 14 | class Seq2SeqModel(object): 15 | """Sequence-to-sequence model with attention and for multiple buckets. 16 | 17 | This class implements a multi-layer recurrent neural network as encoder, 18 | and an attention-based decoder. This is the same as the model described in 19 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details, 20 | or into the seq2seq library for complete model implementation. 21 | This class also allows to use GRU cells in addition to LSTM cells, and 22 | sampled softmax to handle large output vocabulary size. A single-layer 23 | version of this model, but with bi-directional encoder, was presented in 24 | http://arxiv.org/abs/1409.0473 25 | and sampled softmax is described in Section 3 of the following paper. 26 | http://arxiv.org/abs/1412.2007 27 | """ 28 | 29 | def __init__(self, source_vocab_size, target_vocab_size, buckets, size, 30 | num_layers, max_gradient_norm, batch_size, learning_rate, 31 | learning_rate_decay_factor, use_lstm=False, 32 | num_samples=1024, forward_only=False, beam_search = True, beam_size=10, attention=True): 33 | """Create the model. 34 | 35 | Args: 36 | source_vocab_size: size of the source vocabulary. 37 | target_vocab_size: size of the target vocabulary. 38 | buckets: a list of pairs (I, O), where I specifies maximum input length 39 | that will be processed in that bucket, and O specifies maximum output 40 | length. Training instances that have inputs longer than I or outputs 41 | longer than O will be pushed to the next bucket and padded accordingly. 42 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. 43 | size: number of units in each layer of the model. 44 | num_layers: number of layers in the model. 45 | max_gradient_norm: gradients will be clipped to maximally this norm. 46 | batch_size: the size of the batches used during training; 47 | the model construction is independent of batch_size, so it can be 48 | changed after initialization if this is convenient, e.g., for decoding. 49 | learning_rate: learning rate to start with. 50 | learning_rate_decay_factor: decay learning rate by this much when needed. 51 | use_lstm: if true, we use LSTM cells instead of GRU cells. 52 | num_samples: number of samples for sampled softmax. 53 | forward_only: if set, we do not construct the backward pass in the model. 54 | """ 55 | self.source_vocab_size = source_vocab_size 56 | self.target_vocab_size = target_vocab_size 57 | self.buckets = buckets 58 | self.batch_size = batch_size 59 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False) 60 | self.learning_rate_decay_op = self.learning_rate.assign( 61 | self.learning_rate * learning_rate_decay_factor) 62 | self.global_step = tf.Variable(0, trainable=False) 63 | 64 | # If we use sampled softmax, we need an output projection. 65 | output_projection = None 66 | softmax_loss_function = None 67 | # Sampled softmax only makes sense if we sample less than vocabulary size. 68 | if num_samples > 0 and num_samples < self.target_vocab_size: 69 | with tf.device("/cpu:0"): 70 | w = tf.get_variable("proj_w", [size, self.target_vocab_size]) 71 | w_t = tf.transpose(w) 72 | b = tf.get_variable("proj_b", [self.target_vocab_size]) 73 | output_projection = (w, b) 74 | 75 | def sampled_loss(inputs, labels): 76 | with tf.device("/cpu:0"): 77 | labels = tf.reshape(labels, [-1, 1]) 78 | return tf.nn.sampled_softmax_loss(w_t, b, labels, inputs, num_samples, 79 | self.target_vocab_size) 80 | softmax_loss_function = sampled_loss 81 | # Create the internal multi-layer cell for our RNN. 82 | print('###### tf.get_variable_scope().reuse : {}'.format(tf.get_variable_scope().reuse)) 83 | def gru_cell(): 84 | return tf.contrib.rnn.core_rnn_cell.GRUCell(size, reuse=tf.get_variable_scope().reuse)#tf.get_variable_scope().reuse 85 | def lstm_cell(): 86 | return tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(size, reuse=tf.get_variable_scope().reuse)#tf.get_variable_scope().reuse 87 | single_cell = gru_cell 88 | if use_lstm: 89 | single_cell = lstm_cell 90 | cell = single_cell() 91 | if num_layers > 1: 92 | cell_1 = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([single_cell() for _ in range(num_layers)], state_is_tuple=False) 93 | cell_2 = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([single_cell() for _ in range(num_layers)], state_is_tuple=False) 94 | 95 | # The seq2seq function: we use embedding for the input and attention. 96 | print('##### num_layers: {} #####'.format(num_layers)) 97 | print('##### {} #####'.format(output_projection)) 98 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 99 | if attention: 100 | print("Attention Model") 101 | return embedding_attention_seq2seq( 102 | encoder_inputs, decoder_inputs, cell_1, cell_2, 103 | num_encoder_symbols=source_vocab_size, 104 | num_decoder_symbols=target_vocab_size, 105 | embedding_size=size, 106 | output_projection=output_projection, 107 | feed_previous=do_decode, 108 | beam_search=beam_search, 109 | beam_size=beam_size ) 110 | else: 111 | print("Simple Model") 112 | return embedding_rnn_seq2seq( 113 | encoder_inputs, decoder_inputs, cell, 114 | num_encoder_symbols=source_vocab_size, 115 | num_decoder_symbols=target_vocab_size, 116 | embedding_size=size, 117 | output_projection=output_projection, 118 | feed_previous=do_decode, 119 | beam_search=beam_search, 120 | beam_size=beam_size ) 121 | 122 | 123 | # Feeds for inputs. 124 | self.encoder_inputs = [] 125 | self.decoder_inputs = [] 126 | self.target_weights = [] 127 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. 128 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 129 | name="encoder{0}".format(i))) 130 | for i in xrange(buckets[-1][1] + 1): 131 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 132 | name="decoder{0}".format(i))) 133 | self.target_weights.append(tf.placeholder(tf.float32, shape=[None], 134 | name="weight{0}".format(i))) 135 | 136 | # Our targets are decoder inputs shifted by one. 137 | targets = [self.decoder_inputs[i + 1] 138 | for i in xrange(len(self.decoder_inputs) - 1)] 139 | 140 | # Training outputs and losses. 141 | if forward_only: 142 | if beam_search: 143 | self.losses = [] 144 | self.outputs, self.beam_path, self.beam_symbol = decode_model_with_buckets( 145 | self.encoder_inputs, self.decoder_inputs, targets, 146 | self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True), 147 | softmax_loss_function=softmax_loss_function) 148 | else: 149 | # print self.decoder_inputs 150 | self.outputs, self.losses = model_with_buckets( 151 | self.encoder_inputs, self.decoder_inputs, targets, 152 | self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True), 153 | softmax_loss_function=softmax_loss_function) 154 | # If we use output projection, we need to project outputs for decoding. 155 | if output_projection is not None: 156 | for b in xrange(len(buckets)): 157 | self.outputs[b] = [ 158 | tf.matmul(output, output_projection[0]) + output_projection[1] 159 | for output in self.outputs[b] 160 | ] 161 | 162 | 163 | else: 164 | self.outputs, self.losses = model_with_buckets( 165 | self.encoder_inputs, self.decoder_inputs, targets, 166 | self.target_weights, buckets, 167 | lambda x, y: seq2seq_f(x, y, False), 168 | softmax_loss_function=softmax_loss_function) 169 | 170 | self.train_loss_summaries = [] 171 | self.forward_only_loss_summaries = [] 172 | for i in range(len(self.losses)): 173 | self.train_loss_summaries.append(tf.summary.scalar("train_loss_bucket_{}".format(i), self.losses[i])) 174 | self.forward_only_loss_summaries.append(tf.summary.scalar("forward_only_loss_bucket_{}".format(i), 175 | self.losses[i])) 176 | 177 | # Gradients and SGD update operation for training the model. 178 | params = tf.trainable_variables() 179 | if not forward_only: 180 | self.gradient_norms = [] 181 | self.updates = [] 182 | opt = tf.train.GradientDescentOptimizer(self.learning_rate) 183 | for b in xrange(len(buckets)): 184 | gradients = tf.gradients(self.losses[b], params) 185 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 186 | max_gradient_norm) 187 | self.gradient_norms.append(norm) 188 | self.updates.append(opt.apply_gradients( 189 | zip(clipped_gradients, params), global_step=self.global_step)) 190 | 191 | self.saver = tf.train.Saver(tf.global_variables()) 192 | 193 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 194 | bucket_id, forward_only, beam_search): 195 | """Run a step of the model feeding the given inputs. 196 | 197 | Args: 198 | session: tensorflow session to use. 199 | encoder_inputs: list of numpy int vectors to feed as encoder inputs. 200 | decoder_inputs: list of numpy int vectors to feed as decoder inputs. 201 | target_weights: list of numpy float vectors to feed as target weights. 202 | bucket_id: which bucket of the model to use. 203 | forward_only: whether to do the backward step or only forward. 204 | 205 | Returns: 206 | A triple consisting of gradient norm (or None if we did not do backward), 207 | average perplexity, and the outputs. 208 | 209 | Raises: 210 | ValueError: if length of encoder_inputs, decoder_inputs, or 211 | target_weights disagrees with bucket size for the specified bucket_id. 212 | """ 213 | # Check if the sizes match. 214 | encoder_size, decoder_size = self.buckets[bucket_id] 215 | if len(encoder_inputs) != encoder_size: 216 | raise ValueError("Encoder length must be equal to the one in bucket," 217 | " %d != %d." % (len(encoder_inputs), encoder_size)) 218 | if len(decoder_inputs) != decoder_size: 219 | raise ValueError("Decoder length must be equal to the one in bucket," 220 | " %d != %d." % (len(decoder_inputs), decoder_size)) 221 | if len(target_weights) != decoder_size: 222 | raise ValueError("Weights length must be equal to the one in bucket," 223 | " %d != %d." % (len(target_weights), decoder_size)) 224 | 225 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 226 | input_feed = {} 227 | for l in xrange(encoder_size): 228 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 229 | for l in xrange(decoder_size): 230 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 231 | input_feed[self.target_weights[l].name] = target_weights[l] 232 | 233 | # Since our targets are decoder inputs shifted by one, we need one more. 234 | last_target = self.decoder_inputs[decoder_size].name 235 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 236 | 237 | # Output feed: depends on whether we do a backward step or not. 238 | if not forward_only: 239 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 240 | self.gradient_norms[bucket_id], # Gradient norm. 241 | self.losses[bucket_id], # Loss for this batch. 242 | self.train_loss_summaries[bucket_id]] # Summary op for Train Loss 243 | else: 244 | if beam_search: 245 | output_feed = [self.beam_path[bucket_id]] # Loss for this batch. 246 | output_feed.append(self.beam_symbol[bucket_id]) 247 | else: 248 | output_feed = [self.losses[bucket_id], 249 | self.forward_only_loss_summaries[bucket_id]] # Summary op for forward only loss 250 | 251 | for l in xrange(decoder_size): # Output logits. 252 | output_feed.append(self.outputs[bucket_id][l]) 253 | # print bucket_id 254 | outputs = session.run(output_feed, input_feed) 255 | if not forward_only: 256 | return outputs[1], outputs[2], outputs[3], None # Gradient norm, loss, no outputs. 257 | else: 258 | if beam_search: 259 | return outputs[0], outputs[1], outputs[2:] # No gradient norm, loss, outputs. 260 | else: 261 | return None, outputs[0], outputs[1], outputs[2:] # No gradient norm, loss, outputs. 262 | 263 | def get_batch(self, data, bucket_id): 264 | """Get a random batch of data from the specified bucket, prepare for step. 265 | 266 | To feed data in step(..) it must be a list of batch-major vectors, while 267 | data here contains single length-major cases. So the main logic of this 268 | function is to re-index data cases to be in the proper format for feeding. 269 | 270 | Args: 271 | data: a tuple of size len(self.buckets) in which each element contains 272 | lists of pairs of input and output data that we use to create a batch. 273 | bucket_id: integer, which bucket to get the batch for. 274 | 275 | Returns: 276 | The triple (encoder_inputs, decoder_inputs, target_weights) for 277 | the constructed batch that has the proper format to call step(...) later. 278 | """ 279 | encoder_size, decoder_size = self.buckets[bucket_id] 280 | encoder_inputs, decoder_inputs = [], [] 281 | 282 | # Get a random batch of encoder and decoder inputs from data, 283 | # pad them if needed, reverse encoder inputs and add GO to decoder. 284 | for _ in xrange(self.batch_size): 285 | encoder_input, decoder_input = random.choice(data[bucket_id]) 286 | 287 | # Encoder inputs are padded and then reversed. 288 | encoder_pad = [PAD_ID] * (encoder_size - len(encoder_input)) 289 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) 290 | 291 | # Decoder inputs get an extra "GO" symbol, and are padded then. 292 | decoder_pad_size = decoder_size - len(decoder_input) - 1 293 | decoder_inputs.append([GO_ID] + decoder_input + 294 | [PAD_ID] * decoder_pad_size) 295 | 296 | # Now we create batch-major vectors from the data selected above. 297 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 298 | 299 | # Batch encoder inputs are just re-indexed encoder_inputs. 300 | for length_idx in xrange(encoder_size): 301 | batch_encoder_inputs.append( 302 | np.array([encoder_inputs[batch_idx][length_idx] 303 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 304 | 305 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 306 | for length_idx in xrange(decoder_size): 307 | batch_decoder_inputs.append( 308 | np.array([decoder_inputs[batch_idx][length_idx] 309 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 310 | 311 | # Create target_weights to be 0 for targets that are padding. 312 | batch_weight = np.ones(self.batch_size, dtype=np.float32) 313 | for batch_idx in xrange(self.batch_size): 314 | # We set weight to 0 if the corresponding target is a PAD symbol. 315 | # The corresponding target is decoder_input shifted by 1 forward. 316 | if length_idx < decoder_size - 1: 317 | target = decoder_inputs[batch_idx][length_idx + 1] 318 | if length_idx == decoder_size - 1 or target == PAD_ID: 319 | batch_weight[batch_idx] = 0.0 320 | batch_weights.append(batch_weight) 321 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights -------------------------------------------------------------------------------- /old/lib/my_seq2seq.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | """Library for creating sequence-to-sequence models in TensorFlow. 4 | 5 | Sequence-to-sequence recurrent neural networks can learn complex functions 6 | that map input sequences to output sequences. These models yield very good 7 | results on a number of tasks, such as speech recognition, parsing, machine 8 | translation, or even constructing automated replies to emails. 9 | 10 | 11 | * Full sequence-to-sequence models. 12 | 13 | - embedding_rnn_seq2seq: The basic model with input embedding. 14 | - embedding_attention_seq2seq: Advanced model with input embedding and 15 | the neural attention mechanism; recommended for complex tasks. 16 | 17 | 18 | * Decoders 19 | - rnn_decoder: The basic decoder based on a pure RNN. 20 | - attention_decoder: A decoder that uses the attention mechanism. 21 | 22 | * Losses. 23 | - sequence_loss: Loss for a sequence model returning average log-perplexity. 24 | - sequence_loss_by_example: As above, but not averaging over all examples. 25 | 26 | * model_with_buckets: A convenience function to create models with bucketing 27 | (see the tutorial above for an explanation of why and how to use it). 28 | """ 29 | 30 | from six.moves import xrange # pylint: disable=redefined-builtin 31 | from six.moves import zip # pylint: disable=redefined-builtin 32 | 33 | from tensorflow.python.framework import dtypes 34 | from tensorflow.python.framework import ops 35 | from tensorflow.python.ops import array_ops 36 | from tensorflow.python.ops import control_flow_ops 37 | from tensorflow.python.ops import embedding_ops 38 | from tensorflow.python.ops import math_ops 39 | from tensorflow.python.ops import nn_ops 40 | # from tensorflow.python.ops import rnn 41 | # from tensorflow.contrib.rnn.python.ops import rnn 42 | # from tensorflow.contrib import rnn 43 | from tensorflow.contrib.rnn.python.ops import core_rnn 44 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell as rnn_cell 45 | from tensorflow.python.ops import variable_scope 46 | import tensorflow as tf 47 | 48 | try: 49 | # linear = tf.nn.rnn_cell.linear 50 | linear = tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.linear 51 | except: 52 | from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear 53 | 54 | 55 | def _extract_argmax_and_embed(embedding, output_projection=None, 56 | update_embedding=True): 57 | """Get a loop_function that extracts the previous symbol and embeds it. 58 | Args: 59 | embedding: embedding tensor for symbols. 60 | output_projection: None or a pair (W, B). If provided, each fed previous 61 | output will first be multiplied by W and added B. 62 | update_embedding: Boolean; if False, the gradients will not propagate 63 | through the embeddings. 64 | Returns: 65 | A loop function. 66 | """ 67 | def loop_function(prev, _): 68 | if output_projection is not None: 69 | prev = nn_ops.xw_plus_b( 70 | prev, output_projection[0], output_projection[1]) 71 | prev_symbol = math_ops.argmax(prev, 1) 72 | # Note that gradients will not propagate through the second parameter of 73 | # embedding_lookup. 74 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 75 | if not update_embedding: 76 | emb_prev = array_ops.stop_gradient(emb_prev) 77 | return emb_prev 78 | return loop_function 79 | 80 | def _extract_beam_search(embedding, beam_size, num_symbols, embedding_size, output_projection=None, 81 | update_embedding=True): 82 | """Get a loop_function that extracts the previous symbol and embeds it. 83 | 84 | Args: 85 | embedding: embedding tensor for symbols. 86 | output_projection: None or a pair (W, B). If provided, each fed previous 87 | output will first be multiplied by W and added B. 88 | update_embedding: Boolean; if False, the gradients will not propagate 89 | through the embeddings. 90 | 91 | Returns: 92 | A loop function. 93 | """ 94 | def loop_function(prev, i, log_beam_probs, beam_path, beam_symbols): 95 | if output_projection is not None: 96 | prev = nn_ops.xw_plus_b( 97 | prev, output_projection[0], output_projection[1]) 98 | # prev= prev.get_shape().with_rank(2)[1] 99 | 100 | probs = tf.log(tf.nn.softmax(prev)) 101 | 102 | if i > 1: 103 | 104 | probs = tf.reshape(probs + log_beam_probs[-1], 105 | [-1, beam_size * num_symbols]) 106 | 107 | best_probs, indices = tf.nn.top_k(probs, beam_size) 108 | indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1]))) 109 | best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1])) 110 | 111 | symbols = indices % num_symbols # Which word in vocabulary. 112 | beam_parent = indices // num_symbols # Which hypothesis it came from. 113 | 114 | 115 | beam_symbols.append(symbols) 116 | beam_path.append(beam_parent) 117 | log_beam_probs.append(best_probs) 118 | 119 | # Note that gradients will not propagate through the second parameter of 120 | # embedding_lookup. 121 | 122 | emb_prev = embedding_ops.embedding_lookup(embedding, symbols) 123 | emb_prev = tf.reshape(emb_prev,[beam_size,embedding_size]) 124 | # emb_prev = embedding_ops.embedding_lookup(embedding, symbols) 125 | if not update_embedding: 126 | emb_prev = array_ops.stop_gradient(emb_prev) 127 | return emb_prev 128 | return loop_function 129 | 130 | 131 | def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, 132 | scope=None): 133 | """RNN decoder for the sequence-to-sequence model. 134 | 135 | Args: 136 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 137 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 138 | cell: rnn_cell.RNNCell defining the cell function and size. 139 | loop_function: If not None, this function will be applied to the i-th output 140 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 141 | except for the first element ("GO" symbol). This can be used for decoding, 142 | but also for training to emulate http://arxiv.org/abs/1506.03099. 143 | Signature -- loop_function(prev, i) = next 144 | * prev is a 2D Tensor of shape [batch_size x output_size], 145 | * i is an integer, the step number (when advanced control is needed), 146 | * next is a 2D Tensor of shape [batch_size x input_size]. 147 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 148 | 149 | Returns: 150 | A tuple of the form (outputs, state), where: 151 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 152 | shape [batch_size x output_size] containing generated outputs. 153 | state: The state of each cell at the final time-step. 154 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 155 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 156 | states can be the same. They are different for LSTM cells though.) 157 | """ 158 | with variable_scope.variable_scope(scope or "rnn_decoder"): 159 | state = initial_state 160 | outputs = [] 161 | prev = None 162 | for i, inp in enumerate(decoder_inputs): 163 | if loop_function is not None and prev is not None: 164 | with variable_scope.variable_scope("loop_function", reuse=True): 165 | inp = loop_function(prev, i) 166 | if i > 0: 167 | variable_scope.get_variable_scope().reuse_variables() 168 | output, state = cell(inp, state) 169 | 170 | outputs.append(output) 171 | if loop_function is not None: 172 | prev = output 173 | return outputs, state 174 | 175 | def beam_rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, 176 | scope=None,output_projection=None, beam_size=10): 177 | """RNN decoder for the sequence-to-sequence model. 178 | 179 | Args: 180 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 181 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 182 | cell: rnn_cell.RNNCell defining the cell function and size. 183 | loop_function: If not None, this function will be applied to the i-th output 184 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 185 | except for the first element ("GO" symbol). This can be used for decoding, 186 | but also for training to emulate http://arxiv.org/abs/1506.03099. 187 | Signature -- loop_function(prev, i) = next 188 | * prev is a 2D Tensor of shape [batch_size x output_size], 189 | * i is an integer, the step number (when advanced control is needed), 190 | * next is a 2D Tensor of shape [batch_size x input_size]. 191 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 192 | 193 | Returns: 194 | A tuple of the form (outputs, state), where: 195 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 196 | shape [batch_size x output_size] containing generated outputs. 197 | state: The state of each cell at the final time-step. 198 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 199 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 200 | states can be the same. They are different for LSTM cells though.) 201 | """ 202 | with variable_scope.variable_scope(scope or "rnn_decoder"): 203 | state = initial_state 204 | outputs = [] 205 | prev = None 206 | log_beam_probs, beam_path, beam_symbols = [],[],[] 207 | state_size = int(initial_state.get_shape().with_rank(2)[1]) 208 | 209 | for i, inp in enumerate(decoder_inputs): 210 | if loop_function is not None and prev is not None: 211 | with variable_scope.variable_scope("loop_function", reuse=True): 212 | inp = loop_function(prev, i,log_beam_probs, beam_path, beam_symbols) 213 | if i > 0: 214 | variable_scope.get_variable_scope().reuse_variables() 215 | 216 | input_size = inp.get_shape().with_rank(2)[1] 217 | print(input_size) 218 | x = inp 219 | output, state = cell(x, state) 220 | 221 | if loop_function is not None: 222 | prev = output 223 | if i ==0: 224 | states =[] 225 | for kk in range(beam_size): 226 | states.append(state) 227 | state = tf.reshape(tf.concat(axis=0, values=states), [-1, state_size]) 228 | 229 | outputs.append(tf.argmax(nn_ops.xw_plus_b( 230 | output, output_projection[0], output_projection[1]), axis=1)) 231 | return outputs, state, tf.reshape(tf.concat(axis=0, values=beam_path),[-1,beam_size]), tf.reshape(tf.concat(axis=0, values=beam_symbols),[-1,beam_size]) 232 | 233 | 234 | def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, 235 | embedding_size, output_projection=None, 236 | feed_previous=False, 237 | update_embedding_for_previous=True, scope=None, beam_search=True, beam_size=10 ): 238 | """RNN decoder with embedding and a pure-decoding option. 239 | 240 | Args: 241 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 242 | initial_state: 2D Tensor [batch_size x cell.state_size]. 243 | cell: rnn_cell.RNNCell defining the cell function. 244 | num_symbols: Integer, how many symbols come into the embedding. 245 | embedding_size: Integer, the length of the embedding vector for each symbol. 246 | output_projection: None or a pair (W, B) of output projection weights and 247 | biases; W has shape [output_size x num_symbols] and B has 248 | shape [num_symbols]; if provided and feed_previous=True, each fed 249 | previous output will first be multiplied by W and added B. 250 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 251 | used (the "GO" symbol), and all other decoder inputs will be generated by: 252 | next = embedding_lookup(embedding, argmax(previous_output)), 253 | In effect, this implements a greedy decoder. It can also be used 254 | during training to emulate http://arxiv.org/abs/1506.03099. 255 | If False, decoder_inputs are used as given (the standard decoder case). 256 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 257 | only the embedding for the first symbol of decoder_inputs (the "GO" 258 | symbol) will be updated by back propagation. Embeddings for the symbols 259 | generated from the decoder itself remain unchanged. This parameter has 260 | no effect if feed_previous=False. 261 | scope: VariableScope for the created subgraph; defaults to 262 | "embedding_rnn_decoder". 263 | 264 | Returns: 265 | A tuple of the form (outputs, state), where: 266 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 267 | shape [batch_size x output_size] containing the generated outputs. 268 | state: The state of each decoder cell in each time-step. This is a list 269 | with length len(decoder_inputs) -- one item for each time-step. 270 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 271 | 272 | Raises: 273 | ValueError: When output_projection has the wrong shape. 274 | """ 275 | if output_projection is not None: 276 | proj_weights = ops.convert_to_tensor(output_projection[0], 277 | dtype=dtypes.float32) 278 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 279 | proj_biases = ops.convert_to_tensor( 280 | output_projection[1], dtype=dtypes.float32) 281 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 282 | 283 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder"): 284 | with ops.device("/cpu:0"): 285 | embedding = variable_scope.get_variable("embedding", 286 | [num_symbols, embedding_size]) 287 | 288 | if beam_search: 289 | loop_function = _extract_beam_search( 290 | embedding, beam_size,num_symbols,embedding_size, output_projection, 291 | update_embedding_for_previous) 292 | else: 293 | loop_function = _extract_argmax_and_embed( 294 | embedding, output_projection, 295 | update_embedding_for_previous) if feed_previous else None 296 | 297 | emb_inp = [ 298 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 299 | 300 | 301 | if beam_search: 302 | return beam_rnn_decoder(emb_inp, initial_state, cell, 303 | loop_function=loop_function,output_projection=output_projection, beam_size=beam_size) 304 | 305 | else: 306 | return rnn_decoder(emb_inp, initial_state, cell, 307 | loop_function=loop_function) 308 | 309 | 310 | 311 | def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, 312 | num_encoder_symbols, num_decoder_symbols, 313 | embedding_size, output_projection=None, 314 | feed_previous=False, dtype=dtypes.float32, 315 | scope=None, beam_search=True, beam_size=10): 316 | """Embedding RNN sequence-to-sequence model. 317 | 318 | This model first embeds encoder_inputs by a newly created embedding (of shape 319 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 320 | embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs 321 | by another newly created embedding (of shape [num_decoder_symbols x 322 | input_size]). Then it runs RNN decoder, initialized with the last 323 | encoder state, on embedded decoder_inputs. 324 | 325 | Args: 326 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 327 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 328 | cell: rnn_cell.RNNCell defining the cell function and size. 329 | num_encoder_symbols: Integer; number of symbols on the encoder side. 330 | num_decoder_symbols: Integer; number of symbols on the decoder side. 331 | embedding_size: Integer, the length of the embedding vector for each symbol. 332 | output_projection: None or a pair (W, B) of output projection weights and 333 | biases; W has shape [output_size x num_decoder_symbols] and B has 334 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 335 | fed previous output will first be multiplied by W and added B. 336 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 337 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 338 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 339 | If False, decoder_inputs are used as given (the standard decoder case). 340 | dtype: The dtype of the initial state for both the encoder and encoder 341 | rnn cells (default: tf.float32). 342 | scope: VariableScope for the created subgraph; defaults to 343 | "embedding_rnn_seq2seq" 344 | 345 | Returns: 346 | A tuple of the form (outputs, state), where: 347 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 348 | shape [batch_size x num_decoder_symbols] containing the generated 349 | outputs. 350 | state: The state of each decoder cell in each time-step. This is a list 351 | with length len(decoder_inputs) -- one item for each time-step. 352 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 353 | """ 354 | with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq"): 355 | # Encoder. 356 | encoder_cell = rnn_cell.EmbeddingWrapper( 357 | cell, embedding_classes=num_encoder_symbols, 358 | embedding_size=embedding_size) 359 | _, encoder_state = core_rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) 360 | 361 | # Decoder. 362 | if output_projection is None: 363 | cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 364 | 365 | 366 | return embedding_rnn_decoder( 367 | decoder_inputs, encoder_state, cell, num_decoder_symbols, 368 | embedding_size, output_projection=output_projection, 369 | feed_previous=feed_previous, beam_search=beam_search, beam_size=beam_size) 370 | 371 | 372 | 373 | 374 | 375 | def attention_decoder(decoder_inputs, initial_state, attention_states, cell, 376 | output_size=None, num_heads=1, loop_function=None, 377 | dtype=dtypes.float32, scope=None, 378 | initial_state_attention=False): 379 | """RNN decoder with attention for the sequence-to-sequence model. 380 | 381 | In this context "attention" means that, during decoding, the RNN can look up 382 | information in the additional tensor attention_states, and it does this by 383 | focusing on a few entries from the tensor. This model has proven to yield 384 | especially good results in a number of sequence-to-sequence tasks. This 385 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for 386 | details). It is recommended for complex sequence-to-sequence tasks. 387 | 388 | Args: 389 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 390 | initial_state: 2D Tensor [batch_size x cell.state_size]. 391 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 392 | cell: rnn_cell.RNNCell defining the cell function and size. 393 | output_size: Size of the output vectors; if None, we use cell.output_size. 394 | num_heads: Number of attention heads that read from attention_states. 395 | loop_function: If not None, this function will be applied to i-th output 396 | in order to generate i+1-th input, and decoder_inputs will be ignored, 397 | except for the first element ("GO" symbol). This can be used for decoding, 398 | but also for training to emulate http://arxiv.org/abs/1506.03099. 399 | Signature -- loop_function(prev, i) = next 400 | * prev is a 2D Tensor of shape [batch_size x output_size], 401 | * i is an integer, the step number (when advanced control is needed), 402 | * next is a 2D Tensor of shape [batch_size x input_size]. 403 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 404 | scope: VariableScope for the created subgraph; default: "attention_decoder". 405 | initial_state_attention: If False (default), initial attentions are zero. 406 | If True, initialize the attentions from the initial state and attention 407 | states -- useful when we wish to resume decoding from a previously 408 | stored decoder state and attention states. 409 | 410 | Returns: 411 | A tuple of the form (outputs, state), where: 412 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 413 | shape [batch_size x output_size]. These represent the generated outputs. 414 | Output i is computed from input i (which is either the i-th element 415 | of decoder_inputs or loop_function(output {i-1}, i)) as follows. 416 | First, we run the cell on a combination of the input and previous 417 | attention masks: 418 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 419 | Then, we calculate new attention masks: 420 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 421 | and then we calculate the output: 422 | output = linear(cell_output, new_attn). 423 | state: The state of each decoder cell the final time-step. 424 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 425 | 426 | Raises: 427 | ValueError: when num_heads is not positive, there are no inputs, shapes 428 | of attention_states are not set, or input size cannot be inferred 429 | from the input. 430 | """ 431 | if not decoder_inputs: 432 | raise ValueError("Must provide at least 1 input to attention decoder.") 433 | if num_heads < 1: 434 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 435 | if not attention_states.get_shape()[1:2].is_fully_defined(): 436 | raise ValueError("Shape[1] and [2] of attention_states must be known: %s" 437 | % attention_states.get_shape()) 438 | if output_size is None: 439 | output_size = cell.output_size 440 | 441 | with variable_scope.variable_scope(scope or "attention_decoder"): 442 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 443 | attn_length = attention_states.get_shape()[1].value 444 | attn_size = attention_states.get_shape()[2].value 445 | 446 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 447 | hidden = array_ops.reshape( 448 | attention_states, [-1, attn_length, 1, attn_size]) 449 | hidden_features = [] 450 | v = [] 451 | attention_vec_size = attn_size # Size of query vectors for attention. 452 | for a in xrange(num_heads): 453 | k = variable_scope.get_variable("AttnW_%d" % a, 454 | [1, 1, attn_size, attention_vec_size]) 455 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 456 | v.append(variable_scope.get_variable("AttnV_%d" % a, 457 | [attention_vec_size])) 458 | 459 | state = initial_state 460 | def attention(query): 461 | """Put attention masks on hidden using hidden_features and query.""" 462 | ds = [] # Results of attention reads will be stored here. 463 | for a in xrange(num_heads): 464 | with variable_scope.variable_scope("Attention_%d" % a): 465 | y = linear(query, attention_vec_size, True) 466 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 467 | # Attention mask is a softmax of v^T * tanh(...). 468 | s = math_ops.reduce_sum( 469 | v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 470 | a = nn_ops.softmax(s) 471 | # Now calculate the attention-weighted vector d. 472 | d = math_ops.reduce_sum( 473 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, 474 | [1, 2]) 475 | ds.append(array_ops.reshape(d, [-1, attn_size])) 476 | return ds 477 | 478 | outputs = [] 479 | prev = None 480 | batch_attn_size = array_ops.stack([batch_size, attn_size]) 481 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) 482 | for _ in xrange(num_heads)] 483 | for a in attns: # Ensure the second shape of attention vectors is set. 484 | a.set_shape([None, attn_size]) 485 | if initial_state_attention: 486 | attns = attention(initial_state) 487 | for i, inp in enumerate(decoder_inputs): 488 | if i > 0: 489 | variable_scope.get_variable_scope().reuse_variables() 490 | # If loop_function is set, we use it instead of decoder_inputs. 491 | if loop_function is not None : 492 | with variable_scope.variable_scope("loop_function", reuse=True): 493 | if prev is not None: 494 | inp = loop_function(prev, i) 495 | 496 | input_size = inp.get_shape().with_rank(2)[1] 497 | 498 | x = linear([inp] + attns, input_size, True) 499 | # Run the RNN. 500 | cell_output, state = cell(x, state) 501 | # Run the attention mechanism. 502 | if i == 0 and initial_state_attention: 503 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 504 | reuse=True): 505 | attns = attention(state) 506 | else: 507 | attns = attention(state) 508 | 509 | with variable_scope.variable_scope("AttnOutputProjection"): 510 | output = linear([cell_output] + attns, output_size, True) 511 | if loop_function is not None: 512 | prev = output 513 | outputs.append(output) 514 | 515 | return outputs, state 516 | 517 | 518 | def beam_attention_decoder(decoder_inputs, initial_state, attention_states, cell, 519 | output_size=None, num_heads=1, loop_function=None, 520 | dtype=dtypes.float32, scope=None, 521 | initial_state_attention=False, output_projection=None, beam_size=10): 522 | """RNN decoder with attention for the sequence-to-sequence model. 523 | 524 | In this context "attention" means that, during decoding, the RNN can look up 525 | information in the additional tensor attention_states, and it does this by 526 | focusing on a few entries from the tensor. This model has proven to yield 527 | especially good results in a number of sequence-to-sequence tasks. This 528 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for 529 | details). It is recommended for complex sequence-to-sequence tasks. 530 | 531 | Args: 532 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 533 | initial_state: 2D Tensor [batch_size x cell.state_size]. 534 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 535 | cell: rnn_cell.RNNCell defining the cell function and size. 536 | output_size: Size of the output vectors; if None, we use cell.output_size. 537 | num_heads: Number of attention heads that read from attention_states. 538 | loop_function: If not None, this function will be applied to i-th output 539 | in order to generate i+1-th input, and decoder_inputs will be ignored, 540 | except for the first element ("GO" symbol). This can be used for decoding, 541 | but also for training to emulate http://arxiv.org/abs/1506.03099. 542 | Signature -- loop_function(prev, i) = next 543 | * prev is a 2D Tensor of shape [batch_size x output_size], 544 | * i is an integer, the step number (when advanced control is needed), 545 | * next is a 2D Tensor of shape [batch_size x input_size]. 546 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 547 | scope: VariableScope for the created subgraph; default: "attention_decoder". 548 | initial_state_attention: If False (default), initial attentions are zero. 549 | If True, initialize the attentions from the initial state and attention 550 | states -- useful when we wish to resume decoding from a previously 551 | stored decoder state and attention states. 552 | 553 | Returns: 554 | A tuple of the form (outputs, state), where: 555 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 556 | shape [batch_size x output_size]. These represent the generated outputs. 557 | Output i is computed from input i (which is either the i-th element 558 | of decoder_inputs or loop_function(output {i-1}, i)) as follows. 559 | First, we run the cell on a combination of the input and previous 560 | attention masks: 561 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 562 | Then, we calculate new attention masks: 563 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 564 | and then we calculate the output: 565 | output = linear(cell_output, new_attn). 566 | state: The state of each decoder cell the final time-step. 567 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 568 | 569 | Raises: 570 | ValueError: when num_heads is not positive, there are no inputs, shapes 571 | of attention_states are not set, or input size cannot be inferred 572 | from the input. 573 | """ 574 | if not decoder_inputs: 575 | raise ValueError("Must provide at least 1 input to attention decoder.") 576 | if num_heads < 1: 577 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 578 | if not attention_states.get_shape()[1:2].is_fully_defined(): 579 | raise ValueError("Shape[1] and [2] of attention_states must be known: %s" 580 | % attention_states.get_shape()) 581 | if output_size is None: 582 | output_size = cell.output_size 583 | 584 | with variable_scope.variable_scope(scope or "attention_decoder"): 585 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 586 | attn_length = attention_states.get_shape()[1].value 587 | attn_size = attention_states.get_shape()[2].value 588 | 589 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 590 | hidden = array_ops.reshape( 591 | attention_states, [-1, attn_length, 1, attn_size]) 592 | hidden_features = [] 593 | v = [] 594 | attention_vec_size = attn_size # Size of query vectors for attention. 595 | for a in xrange(num_heads): 596 | k = variable_scope.get_variable("AttnW_%d" % a, 597 | [1, 1, attn_size, attention_vec_size]) 598 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 599 | v.append(variable_scope.get_variable("AttnV_%d" % a, 600 | [attention_vec_size])) 601 | 602 | print("Initial_state") 603 | 604 | state_size = int(initial_state.get_shape().with_rank(2)[1]) 605 | states =[] 606 | for kk in range(1): 607 | states.append(initial_state) 608 | state = tf.reshape(tf.concat(axis=0, values=states), [-1, state_size]) 609 | def attention(query): 610 | """Put attention masks on hidden using hidden_features and query.""" 611 | ds = [] # Results of attention reads will be stored here. 612 | for a in xrange(num_heads): 613 | with variable_scope.variable_scope("Attention_%d" % a): 614 | y = linear(query, attention_vec_size, True) 615 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 616 | # Attention mask is a softmax of v^T * tanh(...). 617 | s = math_ops.reduce_sum( 618 | v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 619 | a = nn_ops.softmax(s) 620 | # Now calculate the attention-weighted vector d. 621 | d = math_ops.reduce_sum( 622 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, 623 | [1, 2]) 624 | # for c in range(ct): 625 | ds.append(array_ops.reshape(d, [-1, attn_size])) 626 | return ds 627 | 628 | outputs = [] 629 | prev = None 630 | batch_attn_size = array_ops.stack([batch_size, attn_size]) 631 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) 632 | for _ in xrange(num_heads)] 633 | for a in attns: # Ensure the second shape of attention vectors is set. 634 | a.set_shape([None, attn_size]) 635 | 636 | if initial_state_attention: 637 | attns = [] 638 | attns.append(attention(initial_state)) 639 | tmp = tf.reshape(tf.concat(axis=0, values=attns), [-1, attn_size]) 640 | attns = [] 641 | attns.append(tmp) 642 | 643 | log_beam_probs, beam_path, beam_symbols = [],[],[] 644 | for i, inp in enumerate(decoder_inputs): 645 | 646 | if i > 0: 647 | variable_scope.get_variable_scope().reuse_variables() 648 | # If loop_function is set, we use it instead of decoder_inputs. 649 | if loop_function is not None : 650 | with variable_scope.variable_scope("loop_function", reuse=True): 651 | if prev is not None: 652 | inp = loop_function(prev, i,log_beam_probs, beam_path, beam_symbols) 653 | 654 | input_size = inp.get_shape().with_rank(2)[1] 655 | x = linear([inp] + attns, input_size, True) 656 | cell_output, state = cell(x, state) 657 | 658 | # Run the attention mechanism. 659 | if i == 0 and initial_state_attention: 660 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 661 | reuse=True): 662 | attns = attention(state) 663 | else: 664 | attns = attention(state) 665 | 666 | with variable_scope.variable_scope("AttnOutputProjection"): 667 | output = linear([cell_output] + attns, output_size, True) 668 | if loop_function is not None: 669 | prev = output 670 | if i ==0: 671 | states =[] 672 | for kk in range(beam_size): 673 | states.append(state) 674 | state = tf.reshape(tf.concat(axis=0, values=states), [-1, state_size]) 675 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): 676 | attns = attention(state) 677 | 678 | outputs.append(tf.argmax(nn_ops.xw_plus_b( 679 | output, output_projection[0], output_projection[1]), axis=1)) 680 | 681 | return outputs, state, tf.reshape(tf.concat(axis=0, values=beam_path),[-1,beam_size]), tf.reshape(tf.concat(axis=0, values=beam_symbols),[-1,beam_size]) 682 | 683 | def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, 684 | cell, num_symbols, embedding_size, num_heads=1, 685 | output_size=None, output_projection=None, 686 | feed_previous=False, 687 | update_embedding_for_previous=True, 688 | dtype=dtypes.float32, scope=None, 689 | initial_state_attention=False, beam_search=True, beam_size=10): 690 | """RNN decoder with embedding and attention and a pure-decoding option. 691 | 692 | Args: 693 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 694 | initial_state: 2D Tensor [batch_size x cell.state_size]. 695 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 696 | cell: rnn_cell.RNNCell defining the cell function. 697 | num_symbols: Integer, how many symbols come into the embedding. 698 | embedding_size: Integer, the length of the embedding vector for each symbol. 699 | num_heads: Number of attention heads that read from attention_states. 700 | output_size: Size of the output vectors; if None, use output_size. 701 | output_projection: None or a pair (W, B) of output projection weights and 702 | biases; W has shape [output_size x num_symbols] and B has shape 703 | [num_symbols]; if provided and feed_previous=True, each fed previous 704 | output will first be multiplied by W and added B. 705 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 706 | used (the "GO" symbol), and all other decoder inputs will be generated by: 707 | next = embedding_lookup(embedding, argmax(previous_output)), 708 | In effect, this implements a greedy decoder. It can also be used 709 | during training to emulate http://arxiv.org/abs/1506.03099. 710 | If False, decoder_inputs are used as given (the standard decoder case). 711 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 712 | only the embedding for the first symbol of decoder_inputs (the "GO" 713 | symbol) will be updated by back propagation. Embeddings for the symbols 714 | generated from the decoder itself remain unchanged. This parameter has 715 | no effect if feed_previous=False. 716 | dtype: The dtype to use for the RNN initial states (default: tf.float32). 717 | scope: VariableScope for the created subgraph; defaults to 718 | "embedding_attention_decoder". 719 | initial_state_attention: If False (default), initial attentions are zero. 720 | If True, initialize the attentions from the initial state and attention 721 | states -- useful when we wish to resume decoding from a previously 722 | stored decoder state and attention states. 723 | 724 | Returns: 725 | A tuple of the form (outputs, state), where: 726 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 727 | shape [batch_size x output_size] containing the generated outputs. 728 | state: The state of each decoder cell at the final time-step. 729 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 730 | 731 | Raises: 732 | ValueError: When output_projection has the wrong shape. 733 | """ 734 | if output_size is None: 735 | output_size = cell.output_size 736 | if output_projection is not None: 737 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 738 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 739 | 740 | with variable_scope.variable_scope(scope or "embedding_attention_decoder"): 741 | with ops.device("/cpu:0"): 742 | embedding = variable_scope.get_variable("embedding", 743 | [num_symbols, embedding_size]) 744 | print("Check number of symbols") 745 | print(num_symbols) 746 | if beam_search: 747 | loop_function = _extract_beam_search( 748 | embedding, beam_size,num_symbols, embedding_size, output_projection, 749 | update_embedding_for_previous) 750 | else: 751 | loop_function = _extract_argmax_and_embed( 752 | embedding, output_projection, 753 | update_embedding_for_previous) if feed_previous else None 754 | emb_inp = [ 755 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 756 | if beam_search: 757 | return beam_attention_decoder( 758 | emb_inp, initial_state, attention_states, cell, output_size=output_size, 759 | num_heads=num_heads, loop_function=loop_function, 760 | initial_state_attention=initial_state_attention, output_projection=output_projection, beam_size=beam_size) 761 | else: 762 | 763 | return attention_decoder( 764 | emb_inp, initial_state, attention_states, cell, output_size=output_size, 765 | num_heads=num_heads, loop_function=loop_function, 766 | initial_state_attention=initial_state_attention) 767 | 768 | 769 | def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell_1,cell_2, 770 | num_encoder_symbols, num_decoder_symbols, 771 | embedding_size, 772 | num_heads=1, output_projection=None, 773 | feed_previous=False, dtype=dtypes.float32, 774 | scope=None, initial_state_attention=False, beam_search =True, beam_size = 10 ): 775 | """Embedding sequence-to-sequence model with attention. 776 | 777 | This model first embeds encoder_inputs by a newly created embedding (of shape 778 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 779 | embedded encoder_inputs into a state vector. It keeps the outputs of this 780 | RNN at every step to use for attention later. Next, it embeds decoder_inputs 781 | by another newly created embedding (of shape [num_decoder_symbols x 782 | input_size]). Then it runs attention decoder, initialized with the last 783 | encoder state, on embedded decoder_inputs and attending to encoder outputs. 784 | 785 | Args: 786 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 787 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 788 | cell: rnn_cell.RNNCell defining the cell function and size. 789 | num_encoder_symbols: Integer; number of symbols on the encoder side. 790 | num_decoder_symbols: Integer; number of symbols on the decoder side. 791 | embedding_size: Integer, the length of the embedding vector for each symbol. 792 | num_heads: Number of attention heads that read from attention_states. 793 | output_projection: None or a pair (W, B) of output projection weights and 794 | biases; W has shape [output_size x num_decoder_symbols] and B has 795 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 796 | fed previous output will first be multiplied by W and added B. 797 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 798 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 799 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 800 | If False, decoder_inputs are used as given (the standard decoder case). 801 | dtype: The dtype of the initial RNN state (default: tf.float32). 802 | scope: VariableScope for the created subgraph; defaults to 803 | "embedding_attention_seq2seq". 804 | initial_state_attention: If False (default), initial attentions are zero. 805 | If True, initialize the attentions from the initial state and attention 806 | states. 807 | 808 | Returns: 809 | A tuple of the form (outputs, state), where: 810 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 811 | shape [batch_size x num_decoder_symbols] containing the generated 812 | outputs. 813 | state: The state of each decoder cell at the final time-step. 814 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 815 | """ 816 | with variable_scope.variable_scope(scope or "embedding_attention_seq2seq"): 817 | # Encoder. 818 | encoder_cell = rnn_cell.EmbeddingWrapper( 819 | cell_1, embedding_classes=num_encoder_symbols, 820 | embedding_size=embedding_size)#reuse=tf.get_variable_scope().reuse 821 | encoder_outputs, encoder_state = core_rnn.static_rnn( 822 | encoder_cell, encoder_inputs, 823 | #scope='embedding_attention_decoder/attention_decoder', 824 | dtype=dtype) 825 | print('####### embedding_attention_seq2seq scope: {}'.format(encoder_cell)) 826 | print("Symbols") 827 | print(num_encoder_symbols) 828 | print(num_decoder_symbols) 829 | # First calculate a concatenation of encoder outputs to put attention on. 830 | top_states = [array_ops.reshape(e, [-1, 1, cell_1.output_size]) 831 | for e in encoder_outputs] 832 | attention_states = array_ops.concat(axis=1, values=top_states) 833 | print(attention_states) 834 | 835 | # Decoder. 836 | output_size = None 837 | if output_projection is None: 838 | cell_2 = rnn_cell.OutputProjectionWrapper(cell_2, num_decoder_symbols) 839 | output_size = num_decoder_symbols 840 | return embedding_attention_decoder( 841 | decoder_inputs, encoder_state, attention_states, cell_2, 842 | num_decoder_symbols, embedding_size, num_heads=num_heads, 843 | output_size=output_size, output_projection=output_projection, 844 | feed_previous=feed_previous, 845 | initial_state_attention=initial_state_attention, beam_search=beam_search, beam_size=beam_size) 846 | 847 | 848 | 849 | 850 | def sequence_loss_by_example(logits, targets, weights, 851 | average_across_timesteps=True, 852 | softmax_loss_function=None, name=None): 853 | """Weighted cross-entropy loss for a sequence of logits (per example). 854 | 855 | Args: 856 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 857 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 858 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 859 | average_across_timesteps: If set, divide the returned cost by the total 860 | label weight. 861 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 862 | to be used instead of the standard softmax (the default if this is None). 863 | name: Optional name for this operation, default: "sequence_loss_by_example". 864 | 865 | Returns: 866 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 867 | 868 | Raises: 869 | ValueError: If len(logits) is different from len(targets) or len(weights). 870 | """ 871 | if len(targets) != len(logits) or len(weights) != len(logits): 872 | raise ValueError("Lengths of logits, weights, and targets must be the same " 873 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 874 | with ops.name_scope( name, 875 | "sequence_loss_by_example",logits + targets + weights): 876 | log_perp_list = [] 877 | for logit, target, weight in zip(logits, targets, weights): 878 | if softmax_loss_function is None: 879 | target = array_ops.reshape(target, [-1]) 880 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 881 | logit, target) 882 | else: 883 | crossent = softmax_loss_function(logit, target) 884 | log_perp_list.append(crossent * weight) 885 | log_perps = math_ops.add_n(log_perp_list) 886 | if average_across_timesteps: 887 | total_size = math_ops.add_n(weights) 888 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 889 | log_perps /= total_size 890 | return log_perps 891 | 892 | 893 | def sequence_loss(logits, targets, weights, 894 | average_across_timesteps=True, average_across_batch=True, 895 | softmax_loss_function=None, name=None): 896 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 897 | 898 | Args: 899 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 900 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 901 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 902 | average_across_timesteps: If set, divide the returned cost by the total 903 | label weight. 904 | average_across_batch: If set, divide the returned cost by the batch size. 905 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 906 | to be used instead of the standard softmax (the default if this is None). 907 | name: Optional name for this operation, defaults to "sequence_loss". 908 | 909 | Returns: 910 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 911 | 912 | Raises: 913 | ValueError: If len(logits) is different from len(targets) or len(weights). 914 | """ 915 | with ops.name_scope( name, "sequence_loss",logits + targets + weights): 916 | cost = math_ops.reduce_sum(sequence_loss_by_example( 917 | logits, targets, weights, 918 | average_across_timesteps=average_across_timesteps, 919 | softmax_loss_function=softmax_loss_function)) 920 | if average_across_batch: 921 | batch_size = array_ops.shape(targets[0])[0] 922 | return cost / math_ops.cast(batch_size, dtypes.float32) 923 | else: 924 | return cost 925 | 926 | 927 | def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 928 | buckets, seq2seq, softmax_loss_function=None, 929 | per_example_loss=False, name=None): 930 | """Create a sequence-to-sequence model with support for bucketing. 931 | 932 | The seq2seq argument is a function that defines a sequence-to-sequence model, 933 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 934 | 935 | Args: 936 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 937 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 938 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 939 | weights: List of 1D batch-sized float-Tensors to weight the targets. 940 | buckets: A list of pairs of (input size, output size) for each bucket. 941 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 942 | agree with encoder_inputs and decoder_inputs, and returns a pair 943 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 944 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 945 | to be used instead of the standard softmax (the default if this is None). 946 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 947 | tensor of losses for each sequence in the batch. If unset, it will be 948 | a scalar with the averaged loss from all examples. 949 | name: Optional name for this operation, defaults to "model_with_buckets". 950 | 951 | Returns: 952 | A tuple of the form (outputs, losses), where: 953 | outputs: The outputs for each bucket. Its j'th element consists of a list 954 | of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs). 955 | losses: List of scalar Tensors, representing losses for each bucket, or, 956 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 957 | 958 | Raises: 959 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 960 | than the largest (last) bucket. 961 | """ 962 | if len(encoder_inputs) < buckets[-1][0]: 963 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 964 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 965 | if len(targets) < buckets[-1][1]: 966 | raise ValueError("Length of targets (%d) must be at least that of last" 967 | "bucket (%d)." % (len(targets), buckets[-1][1])) 968 | if len(weights) < buckets[-1][1]: 969 | raise ValueError("Length of weights (%d) must be at least that of last" 970 | "bucket (%d)." % (len(weights), buckets[-1][1])) 971 | 972 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 973 | losses = [] 974 | outputs = [] 975 | with ops.name_scope(name, "model_with_buckets", all_inputs): 976 | for j, bucket in enumerate(buckets): 977 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 978 | reuse=True if j > 0 else None): 979 | 980 | bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], 981 | decoder_inputs[:bucket[1]]) 982 | 983 | outputs.append(bucket_outputs) 984 | if per_example_loss: 985 | losses.append(sequence_loss_by_example( 986 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 987 | softmax_loss_function=softmax_loss_function)) 988 | else: 989 | losses.append(sequence_loss( 990 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 991 | softmax_loss_function=softmax_loss_function)) 992 | 993 | return outputs, losses 994 | 995 | def decode_model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 996 | buckets, seq2seq, softmax_loss_function=None, 997 | per_example_loss=False, name=None): 998 | """Create a sequence-to-sequence model with support for bucketing. 999 | 1000 | The seq2seq argument is a function that defines a sequence-to-sequence model, 1001 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 1002 | 1003 | Args: 1004 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 1005 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 1006 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 1007 | weights: List of 1D batch-sized float-Tensors to weight the targets. 1008 | buckets: A list of pairs of (input size, output size) for each bucket. 1009 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 1010 | agree with encoder_inputs and decoder_inputs, and returns a pair 1011 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 1012 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 1013 | to be used instead of the standard softmax (the default if this is None). 1014 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 1015 | tensor of losses for each sequence in the batch. If unset, it will be 1016 | a scalar with the averaged loss from all examples. 1017 | name: Optional name for this operation, defaults to "model_with_buckets". 1018 | 1019 | Returns: 1020 | A tuple of the form (outputs, losses), where: 1021 | outputs: The outputs for each bucket. Its j'th element consists of a list 1022 | of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs). 1023 | losses: List of scalar Tensors, representing losses for each bucket, or, 1024 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 1025 | 1026 | Raises: 1027 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 1028 | than the largest (last) bucket. 1029 | """ 1030 | if len(encoder_inputs) < buckets[-1][0]: 1031 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 1032 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 1033 | if len(targets) < buckets[-1][1]: 1034 | raise ValueError("Length of targets (%d) must be at least that of last" 1035 | "bucket (%d)." % (len(targets), buckets[-1][1])) 1036 | if len(weights) < buckets[-1][1]: 1037 | raise ValueError("Length of weights (%d) must be at least that of last" 1038 | "bucket (%d)." % (len(weights), buckets[-1][1])) 1039 | 1040 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 1041 | losses = [] 1042 | outputs = [] 1043 | beam_paths = [] 1044 | beam_symbols = [] 1045 | with ops.name_scope(name, "model_with_buckets", all_inputs): 1046 | for j, bucket in enumerate(buckets): 1047 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 1048 | reuse=True if j > 0 else None): 1049 | bucket_outputs, _, beam_path, beam_symbol = seq2seq(encoder_inputs[:bucket[0]], 1050 | decoder_inputs[:bucket[1]]) 1051 | outputs.append(bucket_outputs) 1052 | beam_paths.append(beam_path) 1053 | beam_symbols.append(beam_symbol) 1054 | print("End**********") 1055 | 1056 | return outputs, beam_paths, beam_symbols -------------------------------------------------------------------------------- /lib/chatbot_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import copy as copy 4 | import datetime 5 | import filecmp 6 | import hashlib 7 | import json 8 | import os 9 | import os.path 10 | import re 11 | import shutil 12 | import platform 13 | import random 14 | import tweepy 15 | 16 | import MeCab 17 | # noinspection PyUnresolvedReferences 18 | import easy_tf_log 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import tensorflow as tf 22 | import yaml 23 | # noinspection PyUnresolvedReferences 24 | from easy_tf_log import tflog 25 | # noinspection PyUnresolvedReferences 26 | from google.colab import files 27 | # noinspection PyUnresolvedReferences 28 | from pushbullet import Pushbullet 29 | from tensorflow.python.layers import core as layers_core 30 | from tensorflow.python.platform import gfile 31 | from enum import Enum, auto 32 | 33 | 34 | class Mode(Enum): 35 | Test = auto() 36 | TrainSeq2Seq = auto() 37 | TrainSeq2SeqSwapped = auto() 38 | TrainRL = auto() 39 | TweetBot = auto() 40 | 41 | 42 | def pp(*arguments): 43 | print(*arguments) 44 | with open("stdout.txt", "a") as fout: 45 | print(*arguments, file=fout) 46 | 47 | 48 | def is_local(): 49 | return platform.system() == 'Darwin' 50 | 51 | 52 | def client_id(): 53 | if is_local(): 54 | return "local" 55 | # noinspection SpellCheckingInspection 56 | clients = {'dfc1d5b22ba03430800179d23e522f6f': 'client1', 57 | 'f8e857a2d792038820ebb2ae8d803f7c': 'client2', 58 | '7628f983785173edabbde501ef8f781d': 'client3'} 59 | with open('/content/datalab/adc.json') as json_data: 60 | d = json.load(json_data) 61 | email = d['id_token']['email'].encode('utf-8') 62 | return clients[hashlib.md5(email).hexdigest()] 63 | 64 | 65 | if is_local(): 66 | drive_path = '/Users/higepon/Google Drive/seq2seq_data' 67 | else: 68 | drive_path = 'drive/seq2seq_data' 69 | 70 | pp(client_id()) 71 | current_client_id = client_id() 72 | 73 | mode = Mode.Test 74 | 75 | 76 | # mode = Mode.TrainSeq2Seq 77 | # mode = Mode.TrainSeq2SeqSwapped 78 | # mode = Mode.TrainRL 79 | # mode = Mode.TweetBot 80 | 81 | class DeltaLogger: 82 | def __init__(self, key, step, stdout=None): 83 | self.key = "{}_time_sec".format(key) 84 | self.step = step 85 | self.stdout = stdout 86 | 87 | def __enter__(self): 88 | self.start_time = datetime.datetime.now() 89 | return self 90 | 91 | def __exit__(self, exc_type, exc_value, traceback): 92 | end_time = datetime.datetime.now() 93 | delta_sec = (end_time - self.start_time).total_seconds() 94 | 95 | if self.step is not None: 96 | tflog("{}_{}".format(self.key, current_client_id), delta_sec, 97 | step=self.step) 98 | if self.stdout is not None: 99 | pp("{}={}".format(self.key, round(delta_sec, 1))) 100 | if exc_type is None: 101 | return False 102 | 103 | 104 | def delta(key, step, stdout=False): 105 | return DeltaLogger(key, step, stdout) 106 | 107 | 108 | class Shell: 109 | @staticmethod 110 | def download_file_if_necessary(file_name): 111 | if os.path.exists(file_name): 112 | return 113 | pp("downloading {}...".format(file_name)) 114 | shutil.copy2(os.path.join(drive_path, file_name), file_name) 115 | pp("downloaded") 116 | 117 | @staticmethod 118 | def download_model_data_if_necessary(model_path): 119 | if not os.path.exists(model_path): 120 | os.makedirs(model_path) 121 | pp("Downloading model files...") 122 | src_dir = os.path.join(drive_path, model_path) 123 | Shell.copy_all_files(src_dir, model_path) 124 | pp("done") 125 | 126 | @staticmethod 127 | def copy_all_files(src_dir, dst_dir): 128 | if os.path.exists(src_dir): 129 | for file in os.listdir(src_dir): 130 | src = os.path.join(src_dir, file) 131 | dst = os.path.join(dst_dir, file) 132 | if os.path.exists(dst) and filecmp.cmp(src, dst): 133 | pp("Skip copying ", src) 134 | continue 135 | else: 136 | pp("Copying ", src) 137 | shutil.copy2(src, dst) 138 | 139 | @staticmethod 140 | def remove_all_files(target_dir): 141 | for file in Shell.listdir(target_dir): 142 | os.remove(file) 143 | 144 | @staticmethod 145 | def remove_matched_files(target_dir, pattern): 146 | for file in Shell.listdir(target_dir): 147 | if re.match(pattern, file): 148 | os.remove(file) 149 | 150 | @staticmethod 151 | def download_logs(path): 152 | for file in Shell.listdir(path): 153 | if re.match('.*events', file): 154 | files.download(file) 155 | 156 | @staticmethod 157 | def download(path): 158 | files.download(path) 159 | 160 | @staticmethod 161 | def remove_saved_model(hparams): 162 | os.makedirs(hparams.model_path, exist_ok=True) 163 | Shell.remove_all_files(hparams.model_path) 164 | os.makedirs(os.path.join(drive_path, hparams.model_path), exist_ok=True) 165 | Shell.remove_all_files(os.path.join(drive_path, hparams.model_path)) 166 | 167 | @staticmethod 168 | def copy_saved_model(src_hparams, dst_hparams): 169 | Shell.copy_all_files(src_hparams.model_path, dst_hparams.model_path) 170 | # rm tf.logs from source so that it wouldn't be mixed in dest tf.logs. 171 | Shell.remove_matched_files(dst_hparams.model_path, ".*events.*") 172 | 173 | @staticmethod 174 | def listdir(target_dir): 175 | for dir_path, _, file_names in os.walk(target_dir): 176 | for file in file_names: 177 | yield os.path.abspath(os.path.join(dir_path, file)) 178 | 179 | @staticmethod 180 | def list_model_file(path): 181 | file = open('{}/checkpoint'.format(path)) 182 | text = file.read() 183 | file.close() 184 | pp("model_file", text) 185 | m = re.match(r".*ChatbotModel-(\d+)", text) 186 | model_name = m.group(1) 187 | files = ["checkpoint"] 188 | files.extend([x for x in os.listdir(path) if 189 | re.search(model_name, x) or re.search('events.out', x)]) 190 | return files 191 | 192 | @staticmethod 193 | def save_model_in_drive(model_path): 194 | path = os.path.join(drive_path, model_path) 195 | os.makedirs(path, exist_ok=True) 196 | Shell.remove_all_files(os.path.join(drive_path, model_path)) 197 | pp("Saving model in Google Drive...") 198 | for file in Shell.list_model_file(model_path): 199 | pp("Saving ", file) 200 | shutil.copy2(os.path.join(model_path, file), 201 | os.path.join(drive_path, model_path, file)) 202 | pp("done") 203 | 204 | 205 | config_path = 'config.yml' 206 | Shell.download_file_if_necessary(config_path) 207 | f = open(config_path, 'rt') 208 | push_key = yaml.load(f)['pushbullet']['api_key'] 209 | 210 | pb = Pushbullet(push_key) 211 | 212 | # Note for myself. 213 | # You've summarized Seq2Seq 214 | # at http://d.hatena.ne.jp/higepon/20171210/1512887715. 215 | 216 | # If you see following error, it means your max(len(tweets of training set)) 217 | # < decoder_length. 218 | # This should be a bug somewhere in build_decoder, but couldn't find one yet. 219 | # You can workaround by setting hparams.decoder_length=max len of tweet in 220 | # training set. 221 | # InvalidArgumentError: logits and labels must have the same first dimension, 222 | # got logits shape [48,50] and labels shape [54] 223 | # [[Node: root/SparseSoftmaxCrossEntropyWithLogits 224 | # /SparseSoftmaxCrossEntropyWithLogits = SparseSoftmaxCrossEntropyWithLogits[ 225 | # T=DT_FLOAT, Tlabels=DT_INT32, 226 | 227 | pp(tf.__version__) 228 | 229 | 230 | def has_gpu0(): 231 | return tf.test.gpu_device_name() == "/device:GPU:0" 232 | 233 | 234 | class ModelDirectory(Enum): 235 | tweet_large = 'model/tweet_large' 236 | tweet_large_rl = 'model/tweet_large_rl' 237 | tweet_large_swapped = 'model/tweet_large_swapped' 238 | tweet_small = 'model/tweet_small' 239 | tweet_small_swapped = 'model/tweet_small_swapped' 240 | tweet_small_rl = 'model/tweet_small_rl' 241 | conversations_small = 'model/conversations_small' 242 | conversations_small_backward = 'model/conversations_small_backward' 243 | conversations_small_rl = 'model/conversations_small_rl' 244 | conversations_large = 'model/conversations_large' 245 | conversations_large_backward = 'model/conversations_large_backward' 246 | conversations_large_rl = 'model/conversations_large_rl' 247 | test_multiple1 = 'model/test_multiple1' 248 | test_multiple2 = 'model/test_multiple2' 249 | test_multiple3 = 'model/test_multiple3' 250 | test_distributed = 'model/test_distributed' 251 | 252 | @staticmethod 253 | def create_all_directories(): 254 | for d in ModelDirectory: 255 | os.makedirs(d.value, exist_ok=True) 256 | 257 | 258 | ModelDirectory.create_all_directories() 259 | 260 | base_hparams = tf.contrib.training.HParams( 261 | machine=current_client_id, 262 | batch_size=3, 263 | num_units=6, 264 | num_layers=2, 265 | vocab_size=9, 266 | embedding_size=8, 267 | learning_rate=0.01, 268 | learning_rate_decay=0.99, 269 | use_attention=False, 270 | encoder_length=5, 271 | decoder_length=5, 272 | max_gradient_norm=5.0, 273 | beam_width=2, 274 | num_train_steps=100, 275 | debug_verbose=False, 276 | model_path='Please override model_directory', 277 | sos_id=0, 278 | eos_id=1, 279 | pad_id=2, 280 | unk_id=3, 281 | sos_token="[SOS]", 282 | eos_token="[EOS]", 283 | pad_token="[PAD]", 284 | unk_token="[UNK]", 285 | ) 286 | 287 | # For debug purpose. 288 | tf.reset_default_graph() 289 | 290 | 291 | class ChatbotModel: 292 | def __init__(self, sess, hparams, model_path, scope='ChatbotModel'): 293 | self.sess = sess 294 | # todo remove 295 | self.hparams = hparams 296 | 297 | # todo 298 | self.model_path = model_path 299 | self.scope = scope 300 | # Sampled replies in previous session, 301 | # this is necessary to back propagation. 302 | self.sampled = tf.placeholder(tf.int32, name="sampled") 303 | 304 | # Used to store previously inferred by beam_search. 305 | # self.beam_predicted_ids = tf.placeholder(tf.int32, 306 | # 307 | # name="beam_predicted_ids") 308 | self.enc_inputs, self.enc_inputs_lengths, enc_outputs, enc_state, \ 309 | emb_encoder = self._build_encoder( 310 | hparams, scope) 311 | 312 | self.dec_inputs, self.dec_tgt_lengths, self._logits, \ 313 | self.sample_logits, self.sample_replies, \ 314 | self.log_probs_selected, self.infer_logits, self.replies, \ 315 | self.beam_replies = self._build_decoder( 316 | hparams, self.enc_inputs_lengths, emb_encoder, 317 | enc_state, enc_outputs) 318 | 319 | self._probs = tf.nn.softmax(self.infer_logits) 320 | self._log_probs = tf.nn.log_softmax(self.infer_logits) 321 | 322 | self.reward = tf.placeholder(tf.float32, name="reward") 323 | self.tgt_labels, self.global_step, self.loss, self.train_op = \ 324 | self._build_seq2seq_optimizer( 325 | hparams, self._logits) 326 | self.rl_loss, self.rl_train_op = self._build_rl_optimizer(hparams) 327 | 328 | self.train_loss_summary = tf.summary.scalar("loss", self.loss) 329 | self.val_loss_summary = tf.summary.scalar("validation_loss", 330 | self.loss) 331 | self.merged_summary = tf.summary.merge_all() 332 | 333 | # Initialize saver after model created 334 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) 335 | 336 | def train(self, enc_inputs, enc_inputs_lengths, target_labels, 337 | dec_inputs, dec_target_lengths): 338 | 339 | feed_dict = { 340 | self.enc_inputs: enc_inputs, 341 | self.enc_inputs_lengths: enc_inputs_lengths, 342 | self.tgt_labels: target_labels, 343 | self.dec_inputs: dec_inputs, 344 | self.dec_tgt_lengths: dec_target_lengths, 345 | } 346 | _, global_step, summary = self.sess.run( 347 | [self.train_op, self.global_step, self.train_loss_summary], 348 | feed_dict=feed_dict) 349 | 350 | return global_step, summary 351 | 352 | def infer(self, enc_inputs, enc_inputs_lengths): 353 | infer_feed_dic = { 354 | self.enc_inputs: enc_inputs, 355 | self.enc_inputs_lengths: enc_inputs_lengths, 356 | } 357 | return self.sess.run([self.replies, self.infer_logits], 358 | feed_dict=infer_feed_dic) 359 | 360 | def log_probs(self, enc_inputs, enc_inputs_lengths): 361 | infer_feed_dic = { 362 | self.enc_inputs: enc_inputs, 363 | self.enc_inputs_lengths: enc_inputs_lengths, 364 | } 365 | return self.sess.run([self._log_probs, self._probs], 366 | feed_dict=infer_feed_dic) 367 | 368 | def log_probs_sampled(self, enc_inputs, enc_inputs_lengths, sampled): 369 | infer_feed_dic = { 370 | self.enc_inputs: enc_inputs, 371 | self.enc_inputs_lengths: enc_inputs_lengths, 372 | self.sampled: sampled 373 | } 374 | return self.sess.run( 375 | [self.log_probs_selected, self.sample_logits], 376 | feed_dict=infer_feed_dic) 377 | 378 | def infer_beam_search(self, enc_inputs, enc_inputs_lengths): 379 | """ 380 | :return: (replies: [batch_size, decoder_length, beam_size], 381 | logits: [batch_size, decoder_length, vocab_size])) 382 | """ 383 | infer_feed_dic = { 384 | self.enc_inputs: enc_inputs, 385 | self.enc_inputs_lengths: enc_inputs_lengths, 386 | } 387 | return self.sess.run([self.beam_replies, self.infer_logits, self._probs, 388 | self._log_probs], 389 | feed_dict=infer_feed_dic) 390 | 391 | def sample(self, enc_inputs, enc_inputs_lengths): 392 | infer_feed_dic = { 393 | self.enc_inputs: enc_inputs, 394 | self.enc_inputs_lengths: enc_inputs_lengths, 395 | } 396 | 397 | replies, logits = self.sess.run( 398 | [self.sample_replies, self.sample_logits], 399 | feed_dict=infer_feed_dic) 400 | return replies, logits 401 | 402 | def batch_loss(self, enc_inputs, enc_inputs_lengths, tgt_labels, 403 | dec_inputs, dec_tgt_lengths): 404 | feed_dict = { 405 | self.enc_inputs: enc_inputs, 406 | self.enc_inputs_lengths: enc_inputs_lengths, 407 | self.tgt_labels: tgt_labels, 408 | self.dec_inputs: dec_inputs, 409 | self.dec_tgt_lengths: dec_tgt_lengths, 410 | } 411 | return self.sess.run([self.loss, self.val_loss_summary], 412 | feed_dict=feed_dict) 413 | 414 | def seq_len(self, seq): 415 | try: 416 | # length includes the first eos_id. 417 | return seq.index(self.hparams.eos_id) + 1 418 | except ValueError: 419 | return self.hparams.encoder_length 420 | 421 | def train_with_reward(self, enc_inputs, enc_inputs_lengths, sampled, 422 | reward): 423 | feed_dict = { 424 | self.enc_inputs: enc_inputs, 425 | self.enc_inputs_lengths: enc_inputs_lengths, 426 | self.sampled: sampled, 427 | self.reward: reward 428 | } 429 | 430 | _, global_step, loss = self.sess.run( 431 | [self.rl_train_op, self.global_step, self.rl_loss], 432 | feed_dict=feed_dict) 433 | return global_step, loss 434 | 435 | def save(self, model_path=None): 436 | if model_path is None: 437 | model_path = self.model_path 438 | model_dir = "{}/{}".format(model_path, self.scope) 439 | self.saver.save(self.sess, model_dir, global_step=self.global_step) 440 | 441 | def restore(self): 442 | ckpt = tf.train.get_checkpoint_state(self.model_path) 443 | if ckpt: 444 | last_model = ckpt.model_checkpoint_path 445 | self.saver.restore(self.sess, last_model) 446 | return True 447 | else: 448 | pp("Created fresh model.") 449 | return False 450 | 451 | @staticmethod 452 | def _softmax(x): 453 | return np.exp(x) / np.sum(np.exp(x), axis=0) 454 | 455 | def _build_rl_optimizer(self, hparams): 456 | # todo mask the sampling results 457 | sample_log_prob_shape = tf.shape(self.log_probs_selected) 458 | reward_shape = tf.shape(self.reward) 459 | reward_shape_print = tf.Print(reward_shape, 460 | [reward_shape], 461 | message="reward_shape") 462 | reward_print = tf.Print(self.reward, 463 | [self.reward], 464 | message="reward") 465 | 466 | asserts = [tf.assert_equal(sample_log_prob_shape[0], 467 | reward_shape_print[0], 468 | [self.log_probs_selected, 469 | self.reward]), 470 | tf.assert_equal(sample_log_prob_shape[1], 471 | reward_shape_print[1], 472 | [self.log_probs_selected, 473 | self.reward]), reward_print 474 | ] 475 | with tf.control_dependencies(asserts): 476 | loss = -tf.reduce_sum( 477 | self.log_probs_selected * self.reward) / tf.to_float( 478 | hparams.batch_size) 479 | train_op = self._build_optimizer_with_loss(self.global_step, hparams, 480 | loss) 481 | return loss, train_op 482 | 483 | def _build_optimizer_with_loss(self, global_step, hparams, loss): 484 | params = tf.trainable_variables() 485 | optimizer = tf.train.GradientDescentOptimizer(hparams.learning_rate) 486 | gradients = tf.gradients(loss, params) 487 | clipped_gradients, _ = tf.clip_by_global_norm( 488 | gradients, hparams.max_gradient_norm) 489 | with tf.device(self.available_device()): 490 | train_op = optimizer.apply_gradients( 491 | zip(clipped_gradients, params), global_step=global_step) 492 | return train_op 493 | 494 | def _build_seq2seq_optimizer(self, hparams, logits): 495 | # Target labels 496 | # As described in doc for sparse_softmax_cross_entropy_with_logits, 497 | # labels should be [batch_size, decoder_target_lengths] 498 | # instead of [batch_size, decoder_target_lengths, vocab_size]. 499 | # So labels should have indices instead of vocab_size classes. 500 | tgt_labels = tf.placeholder(tf.int32, shape=( 501 | hparams.batch_size, hparams.decoder_length), name="tgt_labels") 502 | # Loss 503 | # tgt_labels: [batch_size, decoder_length] 504 | # _logits: [batch_size, decoder_length, vocab_size] 505 | # crossent: [batch_size, decoder_length] 506 | crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( 507 | labels=tgt_labels, logits=logits) 508 | tgt_weights = tf.sequence_mask(self.dec_tgt_lengths, 509 | hparams.decoder_length, 510 | dtype=logits.dtype) 511 | crossent = crossent * tgt_weights 512 | crossent_by_batch = tf.reduce_sum(crossent, axis=1) 513 | loss = tf.reduce_sum(crossent_by_batch) / tf.to_float( 514 | hparams.batch_size) 515 | # Train 516 | global_step = tf.get_variable(name="global_step", shape=[], 517 | dtype=tf.int32, 518 | initializer=tf.constant_initializer(0), 519 | trainable=False) 520 | train_op = self._build_optimizer_with_loss(global_step, hparams, loss) 521 | return tgt_labels, global_step, loss, train_op 522 | 523 | @staticmethod 524 | def available_device(): 525 | device = '/cpu:0' 526 | if has_gpu0(): 527 | device = '/gpu:0' 528 | pp("$$$ GPU ENABLED $$$") 529 | return device 530 | 531 | @staticmethod 532 | def _build_encoder(hparams, scope): 533 | # Encoder 534 | # enc_inputs: [encoder_length, batch_size] 535 | # This is time major where encoder_length comes 536 | # first instead of batch_size. 537 | # enc_inputs_lengths: [batch_size] 538 | enc_inputs = tf.placeholder(tf.int32, shape=( 539 | hparams.encoder_length, hparams.batch_size), name="enc_inputs") 540 | enc_inputs_lengths = tf.placeholder(tf.int32, 541 | shape=hparams.batch_size, 542 | name="enc_inputs_lengths") 543 | 544 | # Embedding 545 | # We originally didn't share embedding between encoder and decoder. 546 | # But now we share it. It makes much easier to calculate rewards. 547 | # Matrix for embedding: [vocab_size, embedding_size] 548 | # Should be shared between training and inference. 549 | with tf.variable_scope(scope): 550 | emb_encoder = tf.get_variable("emb_encoder", 551 | [hparams.vocab_size, 552 | hparams.embedding_size]) 553 | 554 | # Look up embedding: 555 | # enc_inputs: [encoder_length, batch_size] 556 | # enc_emb_inputs: [encoder_length, batch_size, embedding_size] 557 | enc_emb_inputs = tf.nn.embedding_lookup(emb_encoder, enc_inputs) 558 | 559 | # LSTM cell. 560 | with tf.variable_scope(scope): 561 | # Should be shared between training and inference. 562 | cells = [] 563 | for _ in range(hparams.num_layers): 564 | cells.append( 565 | tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)) 566 | encoder_cell = tf.contrib.rnn.MultiRNNCell(cells) 567 | 568 | # Run Dynamic RNN 569 | # enc_outputs: [encoder_length, batch_size, num_units] 570 | # enc_state: [batch_size, num_units], 571 | # this is final state of the cell for each batch. 572 | enc_outputs, enc_state = tf.nn.dynamic_rnn(encoder_cell, 573 | enc_emb_inputs, 574 | time_major=True, 575 | dtype=tf.float32, 576 | sequence_length=enc_inputs_lengths) 577 | 578 | return enc_inputs, enc_inputs_lengths, enc_outputs, enc_state, \ 579 | emb_encoder 580 | 581 | @staticmethod 582 | def _build_training_decoder(hparams, enc_inputs_lengths, 583 | enc_state, enc_outputs, dec_cell, 584 | dec_emb_inputs, dec_tgt_lengths, 585 | projection_layer, scope): 586 | 587 | dynamic_batch_size = tf.shape(enc_inputs_lengths)[0] 588 | initial_state, wrapped_dec_cell = ChatbotModel._attention_wrapper( 589 | dec_cell, dynamic_batch_size, enc_inputs_lengths, enc_outputs, 590 | enc_state, hparams, scope, reuse=False) 591 | 592 | # Decoder with helper: 593 | # dec_emb_inputs: [decoder_length, batch_size, embedding_size] 594 | # dec_tgt_lengths: [batch_size] vector, 595 | # which represents each target sequence length. 596 | with tf.variable_scope(scope): 597 | training_helper = tf.contrib.seq2seq.TrainingHelper( 598 | dec_emb_inputs, 599 | dec_tgt_lengths, 600 | time_major=True) 601 | 602 | # Decoder and decode 603 | with tf.variable_scope(scope): 604 | with tf.variable_scope("training"): 605 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 606 | wrapped_dec_cell, training_helper, initial_state, 607 | output_layer=projection_layer) 608 | 609 | # Dynamic decoding 610 | # final_outputs.rnn_output: [batch_size, decoder_length, 611 | # vocab_size], list of RNN state. 612 | # final_outputs.sample_id: [batch_size, decoder_length], 613 | # list of argmax of rnn_output. 614 | # final_state: [batch_size, num_units], 615 | # list of final state of RNN on decode process. 616 | # final_sequence_lengths: [batch_size], list of each decoded sequence. 617 | with tf.variable_scope(scope): 618 | final_outputs, _final_state, _final_sequence_lengths = \ 619 | tf.contrib.seq2seq.dynamic_decode( 620 | training_decoder) 621 | 622 | if hparams.debug_verbose: 623 | pp("rnn_output.shape=", final_outputs.rnn_output.shape) 624 | pp("sample_id.shape=", final_outputs.sample_id.shape) 625 | pp("final_state=", _final_state) 626 | pp("final_sequence_lengths.shape=", 627 | _final_sequence_lengths.shape) 628 | 629 | logits = final_outputs.rnn_output 630 | return logits, wrapped_dec_cell, initial_state 631 | 632 | def _build_decoder(self, hparams, enc_inputs_lengths, embedding_encoder, 633 | enc_state, enc_outputs): 634 | # Decoder input 635 | # dec_inputs: [decoder_length, batch_size] 636 | # dec_tgt_lengths: [batch_size] 637 | # This is grand truth target inputs for training. 638 | dec_inputs = tf.placeholder(tf.int32, shape=( 639 | hparams.decoder_length, hparams.batch_size), name="dec_inputs") 640 | dec_tgt_lengths = tf.placeholder(tf.int32, 641 | shape=hparams.batch_size, 642 | name="dec_tgt_lengths") 643 | 644 | # Look up embedding: 645 | # dec_inputs: [decoder_length, batch_size] 646 | # decoder_emb_inp: [decoder_length, batch_size, embedding_size] 647 | dec_emb_inputs = tf.nn.embedding_lookup(embedding_encoder, 648 | dec_inputs) 649 | 650 | # https://stackoverflow.com/questions/39573188/output-projection-in 651 | # -seq2seq-model-tensorflow 652 | # Internally, a neural network operates on dense vectors of some size, 653 | # often 256, 512 or 1024 floats (let's say 512 for here). 654 | # But at the end it needs to predict a word 655 | # from the vocabulary which is often much larger, 656 | # e.g., 40000 words. Output projection is the final linear layer 657 | # that converts (projects) from the internal representation 658 | # to the larger one. 659 | # So, for example, it can consist of a 512 x 40000 parameter matrix 660 | # and a 40000 parameter for the bias vector. 661 | projection_layer = layers_core.Dense(hparams.vocab_size, use_bias=False) 662 | 663 | # We share this between training and inference. 664 | cells = [] 665 | for _ in range(hparams.num_layers): 666 | cells.append(tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)) 667 | dec_cell = tf.contrib.rnn.MultiRNNCell(cells) 668 | 669 | # Training graph 670 | logits, wrapped_dec_cell, initial_state = self._build_training_decoder( 671 | hparams, enc_inputs_lengths, enc_state, enc_outputs, 672 | dec_cell, dec_emb_inputs, dec_tgt_lengths, 673 | projection_layer, self.scope) 674 | 675 | infer_logits, replies = self._build_greedy_inference(hparams, 676 | embedding_encoder, 677 | enc_state, 678 | enc_inputs_lengths, 679 | enc_outputs, 680 | dec_cell, 681 | projection_layer, 682 | self.scope) 683 | 684 | # Beam Search Inference graph 685 | beam_replies = self._build_beam_search_inference(hparams, 686 | enc_inputs_lengths, 687 | embedding_encoder, 688 | enc_state, 689 | enc_outputs, 690 | dec_cell, 691 | projection_layer, 692 | self.scope) 693 | 694 | # beam_log_probs = self._log_probs_beam(infer_logits, 695 | # self.beam_predicted_ids) 696 | 697 | # Sample Inference graph 698 | _, sample_replies = self._build_sample_inference(hparams, 699 | embedding_encoder, 700 | enc_state, 701 | enc_inputs_lengths, 702 | enc_outputs, 703 | dec_cell, 704 | projection_layer, 705 | self.scope) 706 | 707 | # Here we use infer_logits which is generated from argmax. 708 | # We don't use sample_logits for RL, because infer_logts and 709 | # sample_logits are different. 710 | # And eventually infer_logits should become our desiered inference 711 | # with our RL training. 712 | logits_print = tf.Print(infer_logits, [infer_logits], 713 | message="infer_logits") 714 | indices = self._convert_indices(self.sampled) 715 | log_probs = tf.nn.log_softmax(logits_print) 716 | log_probs_selected = tf.gather_nd(log_probs, indices) 717 | return dec_inputs, dec_tgt_lengths, logits, logits_print, \ 718 | sample_replies, log_probs_selected, infer_logits, \ 719 | replies, beam_replies 720 | 721 | @staticmethod 722 | def _build_greedy_inference(hparams, embedding_encoder, enc_state, 723 | encoder_inputs_lengths, encoder_outputs, 724 | dec_cell, projection_layer, scope): 725 | # Greedy decoder 726 | dynamic_batch_size = tf.shape(encoder_inputs_lengths)[0] 727 | inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( 728 | embedding_encoder, 729 | tf.fill([dynamic_batch_size], hparams.sos_id), hparams.eos_id) 730 | 731 | infer_logits, replies = ChatbotModel._dynamic_decode(dec_cell, 732 | dynamic_batch_size, 733 | encoder_inputs_lengths, 734 | encoder_outputs, 735 | enc_state, 736 | hparams, 737 | inference_helper, 738 | projection_layer, 739 | scope) 740 | return infer_logits, replies 741 | 742 | @staticmethod 743 | def _build_beam_search_inference(hparams, encoder_inputs_lengths, 744 | embedding_encoder, enc_state, 745 | encoder_outputs, dec_cell, 746 | projection_layer, scope): 747 | 748 | assert (hparams.beam_width != 0) 749 | 750 | dynamic_batch_size = tf.shape(encoder_inputs_lengths)[0] 751 | # https://github.com/tensorflow/tensorflow/issues/11904 752 | if hparams.use_attention: 753 | with tf.variable_scope(scope, reuse=True): 754 | # Attention 755 | # encoder_outputs is time major, so transopse it to batch major. 756 | # attention_encoder_outputs: [batch_size, encoder_length, 757 | # num_units] 758 | attention_encoder_outputs = tf.transpose(encoder_outputs, 759 | [1, 0, 2]) 760 | 761 | tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( 762 | attention_encoder_outputs, multiplier=hparams.beam_width) 763 | tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( 764 | enc_state, multiplier=hparams.beam_width) 765 | tiled_encoder_inputs_lengths = tf.contrib.seq2seq.tile_batch( 766 | encoder_inputs_lengths, multiplier=hparams.beam_width) 767 | 768 | # Create an attention mechanism 769 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 770 | hparams.num_units, tiled_encoder_outputs, 771 | memory_sequence_length=tiled_encoder_inputs_lengths) 772 | 773 | wrapped_de_cell = tf.contrib.seq2seq.AttentionWrapper( 774 | dec_cell, attention_mechanism, 775 | attention_layer_size=hparams.num_units) 776 | 777 | dec_initial_state = wrapped_de_cell.zero_state( 778 | dtype=tf.float32, 779 | batch_size=dynamic_batch_size * hparams.beam_width) 780 | dec_initial_state = dec_initial_state.clone( 781 | cell_state=tiled_encoder_final_state) 782 | else: 783 | with tf.variable_scope(scope, reuse=True): 784 | wrapped_de_cell = dec_cell 785 | dec_initial_state = tf.contrib.seq2seq.tile_batch( 786 | enc_state, 787 | multiplier=hparams.beam_width) 788 | 789 | # len(inferred_reply) is lte encoder_length, 790 | # because we are targeting tweet (140 for each tweet) 791 | # Also by doing this, 792 | # we can pass the reply to other seq2seq w/o shorten it. 793 | maximum_iterations = hparams.encoder_length 794 | 795 | inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 796 | cell=wrapped_de_cell, 797 | embedding=embedding_encoder, 798 | start_tokens=tf.fill([dynamic_batch_size], hparams.sos_id), 799 | end_token=hparams.eos_id, 800 | initial_state=dec_initial_state, 801 | beam_width=hparams.beam_width, 802 | output_layer=projection_layer, 803 | length_penalty_weight=0.0) 804 | 805 | # Dynamic decoding 806 | with tf.variable_scope(scope, reuse=True): 807 | beam_outputs, final_state, _ = tf.contrib.seq2seq.dynamic_decode( 808 | inference_decoder, maximum_iterations=maximum_iterations) 809 | beam_replies = beam_outputs.predicted_ids 810 | return beam_replies 811 | 812 | @staticmethod 813 | def _build_sample_inference(hparams, embedding_encoder, enc_state, 814 | enc_inputs_lengths, enc_outputs, 815 | dec_cell, projection_layer, scope): 816 | # Sample decoder 817 | dynamic_batch_size = tf.shape(enc_inputs_lengths)[0] 818 | inference_helper = tf.contrib.seq2seq.SampleEmbeddingHelper( 819 | embedding_encoder, 820 | tf.fill([dynamic_batch_size], hparams.sos_id), hparams.eos_id, 821 | softmax_temperature=0.1) # 1.0 is default 822 | 823 | infer_logits, replies = ChatbotModel._dynamic_decode(dec_cell, 824 | dynamic_batch_size, 825 | enc_inputs_lengths, 826 | enc_outputs, 827 | enc_state, 828 | hparams, 829 | inference_helper, 830 | projection_layer, 831 | scope) 832 | return infer_logits, replies 833 | 834 | @staticmethod 835 | def _dynamic_decode(dec_cell, dynamic_batch_size, 836 | enc_inputs_lengths, enc_outputs, enc_state, 837 | hparams, dec_helper, projection_layer, scope): 838 | initial_state, wrapped_dec_cell = ChatbotModel._attention_wrapper( 839 | dec_cell, dynamic_batch_size, enc_inputs_lengths, enc_outputs, 840 | enc_state, hparams, scope) 841 | with tf.variable_scope(scope): 842 | with tf.variable_scope("infer"): 843 | inference_decoder = tf.contrib.seq2seq.BasicDecoder( 844 | wrapped_dec_cell, dec_helper, initial_state, 845 | output_layer=projection_layer) 846 | # len(inferred_reply) is lte encoder_length, 847 | # because we are targeting tweet (140 for each tweet) 848 | # Also by doing this, 849 | # we can pass the reply to other seq2seq w/o shorten it. 850 | maximum_iterations = hparams.encoder_length 851 | # Dynamic decoding 852 | # Here we reuse Attention Wrapper 853 | with tf.variable_scope(scope, reuse=True): 854 | outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( 855 | inference_decoder, maximum_iterations=maximum_iterations) 856 | replies = outputs.sample_id 857 | # We use infer_logits instead of _logits when calculating log_prob, 858 | # because infer_logits doesn't require decoder_target_lengths input. 859 | infer_logits = outputs.rnn_output 860 | return infer_logits, replies 861 | 862 | @staticmethod 863 | def _attention_wrapper(dec_cell, dynamic_batch_size, enc_inputs_lengths, 864 | enc_outputs, enc_state, hparams, scope, reuse=True): 865 | # See https://github.com/tensorflow/tensorflow/issues/11904 866 | if hparams.use_attention: 867 | with tf.variable_scope(scope, reuse=reuse): 868 | # Attention 869 | # encoder_outputs is time major, so transopse it to batch major. 870 | # attention_encoder_outputs: [batch_size, encoder_length, 871 | # num_units] 872 | attention_encoder_outputs = tf.transpose(enc_outputs, 873 | [1, 0, 2]) 874 | 875 | # Create an attention mechanism 876 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 877 | hparams.num_units, 878 | attention_encoder_outputs, 879 | memory_sequence_length=enc_inputs_lengths) 880 | 881 | wrapped_dec_cell = tf.contrib.seq2seq.AttentionWrapper( 882 | dec_cell, attention_mechanism, 883 | attention_layer_size=hparams.num_units) 884 | 885 | initial_state = wrapped_dec_cell.zero_state( 886 | dynamic_batch_size, 887 | tf.float32).clone( 888 | cell_state=enc_state) 889 | else: 890 | with tf.variable_scope(scope, reuse=reuse): 891 | wrapped_dec_cell = dec_cell 892 | initial_state = enc_state 893 | return initial_state, wrapped_dec_cell 894 | 895 | # convert sampled_indices to indices for tf.gather_nd. 896 | @staticmethod 897 | def _convert_indices(sampled_indices): 898 | print_sampled_indices = tf.Print(sampled_indices, 899 | [tf.shape(sampled_indices)], 900 | message="sampled_indices") 901 | batch_size = tf.shape(print_sampled_indices)[0] 902 | dec_length = tf.shape(print_sampled_indices)[1] 903 | print_batch_size = tf.Print(batch_size, [batch_size, dec_length], 904 | message="(batch_size, dec_length)") 905 | first_indices = tf.tile( 906 | tf.expand_dims(tf.range(print_batch_size), dim=1), 907 | [1, dec_length]) 908 | second_indices = tf.reshape( 909 | tf.tile(tf.range(dec_length), [print_batch_size]), 910 | [print_batch_size, dec_length]) 911 | print_first_indices = tf.Print(first_indices, [tf.shape(first_indices), 912 | tf.shape( 913 | second_indices)], 914 | message="(first_indices, " 915 | "second_indices)") 916 | return tf.stack([print_first_indices, second_indices, sampled_indices], 917 | axis=2) 918 | 919 | 920 | class TrainDataSource: 921 | def __init__(self, source_path, hparams, vocab_path=None): 922 | Shell.download_file_if_necessary(source_path) 923 | generator = TrainDataGenerator(source_path=source_path, 924 | hparams=hparams) 925 | # generator.remove_generated() 926 | train_dataset, vocab, rev_vocab = generator.generate(vocab_path) 927 | # We don't use shuffle here, because we want to align two data source 928 | # here. 929 | self.train_dataset = train_dataset.repeat() 930 | self.vocab_path = generator.vocab_path 931 | # todo(higepon): Use actual validation dataset. 932 | self.valid_dataset = train_dataset.repeat() 933 | self.vocab = vocab 934 | self.rev_vocab = rev_vocab 935 | 936 | 937 | class Trainer: 938 | def __init__(self): 939 | self.loss_step = [] 940 | self.val_losses = [] 941 | self.reward_step = [] 942 | self.reward_average = [] 943 | self.last_saved_time = datetime.datetime.now() 944 | self.last_stats_time = datetime.datetime.now() 945 | self.num_stats_per = 20 946 | self._valid_tweets = ["おはようございます。寒いですね。", "さて帰ろう。明日は早い。", "今回もよろしくです。"] 947 | 948 | def train_rl(self, rl_hparams, seq2seq_hparams, backward_hparams, 949 | seq2seq_source_path, rl_source_path, tweets=None, 950 | should_clean_saved_model=False): 951 | if tweets is None: 952 | tweets = [] 953 | pp("===== Train RL {} ====".format(seq2seq_source_path)) 954 | now = datetime.datetime.today().strftime("%Y%m%d%H%M%S") 955 | pp("{}_rl_test".format(now)) 956 | pp("rl_hparams") 957 | print_hparams(rl_hparams) 958 | 959 | seq2seq_data_source = TrainDataSource(seq2seq_source_path, rl_hparams) 960 | rl_data_source = TrainDataSource(rl_source_path, rl_hparams) 961 | easy_tf_log.set_dir(rl_hparams.model_path) 962 | Shell.download_model_data_if_necessary(rl_hparams.model_path) 963 | device = self._available_device() 964 | 965 | if should_clean_saved_model: 966 | clean_model_path(rl_hparams.model_path) 967 | with tf.device(device): 968 | rl_model = self.create_model(rl_hparams) 969 | seq2seq_model = self.create_model(seq2seq_hparams) 970 | backward_model = self.create_model(backward_hparams) 971 | 972 | vocab = seq2seq_data_source.vocab 973 | rev_vocab = seq2seq_data_source.rev_vocab 974 | infer_helper_rl = InferenceHelper(rl_model, vocab, rev_vocab) 975 | 976 | graph = rl_model.sess.graph 977 | _ = tf.summary.FileWriter(rl_hparams.model_path, graph) 978 | last_saved_time = datetime.datetime.now() 979 | 980 | with graph.as_default(): 981 | seq2seq_train_data_next = \ 982 | seq2seq_data_source.train_dataset.make_one_shot_iterator( 983 | 984 | ).get_next() 985 | rl_train_data_next = \ 986 | rl_data_source.train_dataset.make_one_shot_iterator().get_next() 987 | 988 | global_step = None 989 | for step in range(rl_hparams.num_train_steps): 990 | with delta("one_train_step", global_step) as _: 991 | with delta("data_fetch_time", global_step) as _: 992 | seq2seq_train_data = rl_model.sess.run( 993 | seq2seq_train_data_next) 994 | rl_train_data = rl_model.sess.run(rl_train_data_next) 995 | 996 | batch_size = rl_hparams.batch_size 997 | 998 | # Sample! 999 | with delta("sample_time", global_step) as _: 1000 | samples, _ = rl_model.sample(seq2seq_train_data[0], 1001 | seq2seq_train_data[1]) 1002 | 1003 | # Calc 1/N_a * logP_seq2seq(a|p_i, q_i) for each sampled. 1004 | with delta("calc_reward_s", global_step) as _: 1005 | reward_s = self.calc_reward_s(seq2seq_model, 1006 | seq2seq_train_data, 1007 | samples) 1008 | 1009 | # Calc 1/N_qi * logP_backward(qi|a) 1010 | # TODO: Vectorized implementation here. 1011 | with delta("calc_reward_qi", global_step) as _: 1012 | reward_qi = self.calc_reward_qi(backward_model, 1013 | rl_train_data, samples) 1014 | 1015 | reward = reward_s + reward_qi 1016 | max_len = len(samples[0]) 1017 | reward_avg = np.sum(reward) / max_len / batch_size 1018 | 1019 | # standardize reward 1020 | # don't shift mean (by RL tips) 1021 | # reward -= np.mean(reward) 1022 | reward /= (np.std(reward)) 1023 | 1024 | rl_hparams = rl_model.hparams 1025 | with delta("train_with_reward", global_step) as _: 1026 | global_step, loss = rl_model.train_with_reward( 1027 | seq2seq_train_data[0], 1028 | seq2seq_train_data[1], 1029 | samples, 1030 | reward) 1031 | self._print_log("rl_loss", loss, global_step) 1032 | 1033 | self._print_log("reward_avg", reward_avg, step=global_step) 1034 | if global_step is not None and global_step % 20 == 0: 1035 | # This takes about 100sec. 1036 | with delta("calc_entropy", global_step) as _: 1037 | self._print_log("entropy", 1038 | self.calc_policy_entropy(infer_helper_rl), 1039 | global_step) 1040 | 1041 | validation_tweets = [ 1042 | "危うく子供を引きかけた……駐車場でバックしようとしてたら子供が走って来てた:(", 1043 | "鏡に写る自分の顔を見て思ったヤバい、痩せすぎて頰が…そこで一大決心!今夜からちゃんと食べる", 1044 | "エスカレーター乗る位置で関西帰ってきたな〜〜って実感します🤔"] 1045 | with delta("valid_infer", global_step) as _: 1046 | for t in validation_tweets: 1047 | infer_helper_rl.print_inferences(t) 1048 | 1049 | # greedy results from RL rl_model 1050 | with delta("rl_infer", global_step) as _: 1051 | replies, _ = rl_model.infer(seq2seq_train_data[0], 1052 | seq2seq_train_data[1]) 1053 | 1054 | # This is for debug to see if probability of RL looks 1055 | # reasonable. 1056 | with delta("calc_rl_reward", global_step) as _: 1057 | reward_s_rl = self.calc_reward_s( 1058 | rl_model, 1059 | seq2seq_train_data, 1060 | replies) 1061 | 1062 | with delta("calc_rl_reward_qi", global_step) as _: 1063 | reward_qi_rl = self.calc_reward_qi(backward_model, 1064 | rl_train_data, 1065 | replies) 1066 | 1067 | with delta("seq2seq_infer", global_step) as _: 1068 | seq2seq_replies, _ = seq2seq_model.infer( 1069 | seq2seq_train_data[0], 1070 | seq2seq_train_data[1]) 1071 | 1072 | # This is for debug to see if reward_s looks reasonable. 1073 | with delta("calc_seq2seq_reward_s", global_step) as _: 1074 | reward_s_seq2seq = self.calc_reward_s( 1075 | seq2seq_model, 1076 | seq2seq_train_data, 1077 | seq2seq_replies) 1078 | with delta("calc_seq2seq_reward_qi", global_step) as _: 1079 | reward_qi_seq2seq = self.calc_reward_qi(backward_model, 1080 | rl_train_data, 1081 | seq2seq_replies) 1082 | 1083 | with delta("debug_print", global_step) as _: 1084 | for batch in range(2): 1085 | pp( 1086 | infer_helper_rl.ids_to_string( 1087 | seq2seq_train_data[0][:, batch])) 1088 | pp( 1089 | " [seq2] : {} {:.2f} => ({:.2f}) <= {" 1090 | ":.2f}".format( 1091 | infer_helper_rl.ids_to_string( 1092 | seq2seq_replies[batch]), 1093 | reward_s_seq2seq[batch][0].item(), 1094 | reward_s_seq2seq[batch][0].item() + 1095 | reward_qi_seq2seq[batch][ 1096 | 0].item(), 1097 | reward_qi_seq2seq[batch][0].item())) 1098 | pp( 1099 | " [RL greedy] : {} {:.2f} => ({:.2f}) <= {" 1100 | ":.2f}".format( 1101 | infer_helper_rl.ids_to_string(replies[batch]), 1102 | reward_s_rl[batch][0].item(), 1103 | reward_s_rl[batch][0].item() + 1104 | reward_qi_rl[batch][0].item(), 1105 | reward_qi_rl[batch][0].item())) 1106 | pp( 1107 | " [RL sample]: {} {:.2f} => ({:.2f}) <= {" 1108 | ":.2f}".format( 1109 | infer_helper_rl.ids_to_string(samples[batch]), 1110 | reward_s[batch][0].item(), 1111 | reward_s[batch][0].item() + reward_qi[batch][ 1112 | 0].item(), 1113 | reward_qi[batch][0].item())) 1114 | 1115 | if is_local() or (step != 0 and step % 50 == 0): 1116 | pp("save and restore") 1117 | with delta("save", global_step) as _: 1118 | rl_model.save() 1119 | with delta("restore", global_step) as _: 1120 | is_restored = rl_model.restore() 1121 | assert is_restored 1122 | with delta("save_drive", global_step) as _: 1123 | self._save_model_in_drive(rl_hparams) 1124 | 1125 | if (step != 0 and step % 100 == 0): 1126 | with delta("print_inferences", global_step) as _: 1127 | self._print_inferences(step, tweets, infer_helper_rl) 1128 | 1129 | # 1130 | # Calculate action entropy. 1131 | # 1132 | # In general entropy is defined as -E[logP(X)] which is 1133 | # -Sum(p(X)logP(X)). But we can't calculate it, because we can't 1134 | # enumerate all the 1135 | # possible actions (= replies). Because (A) it's gonna be dec_len^( 1136 | # vocab_size) pattern. (B) We can't list all the possible input to the 1137 | # model. Here we calculate the entropy by limiting target beam_width = 3 and 1138 | # limiting # of input to 1. 1139 | @staticmethod 1140 | def calc_policy_entropy(infer_helper): 1141 | tweets = ["おはようございます。今日も暑いですね", 1142 | "鏡に写る自分の顔を見て思ったヤバい、痩せすぎて頰が…そこで一大決心!今夜からちゃんと食べる", 1143 | "同じく寒かったので*はだいぶ楽になりました🙇💦᷆"] 1144 | entropy = 0.0 1145 | for tweet in tweets: 1146 | encoder_inputs, encoder_inputs_lengths = \ 1147 | infer_helper.create_inference_input( 1148 | tweet) 1149 | beam_replies, _, probs, log_probs = \ 1150 | infer_helper.model.infer_beam_search( 1151 | encoder_inputs, encoder_inputs_lengths) 1152 | 1153 | # [dec_len, vocab_size] 1154 | log_prob = log_probs[0] 1155 | prob = probs[0] 1156 | # [dec_len, beam_size] 1157 | replies = beam_replies[0] 1158 | 1159 | for i in range(infer_helper.model.hparams.beam_width): 1160 | reply = replies[:, i] 1161 | for idx, word_id in enumerate(reply): 1162 | if idx < len(log_prob): 1163 | entropy = entropy + log_prob[idx][word_id] * prob[idx][ 1164 | word_id] 1165 | return -entropy 1166 | 1167 | def calc_reward_qi(self, backward_model, train_data, samples): 1168 | hparams = backward_model.hparams 1169 | batch_size = hparams.batch_size 1170 | max_len = len(samples[0]) 1171 | pp("reward_qi size=", batch_size, max_len) 1172 | reward_qi = np.zeros((batch_size, max_len)) 1173 | # target label with eos. 1174 | # [batch_size, dec_length] 1175 | qi = train_data[2] 1176 | a_enc_inputs, a_enc_inputs_lengths = self.format_enc_inputs( 1177 | hparams, backward_model, samples) 1178 | # [batch_size, dec_len, vocab_size] 1179 | log_probs, _ = backward_model.log_probs(a_enc_inputs, 1180 | a_enc_inputs_lengths) 1181 | for batch in range(batch_size): 1182 | tweet = qi[batch] 1183 | tweet_len = 0 1184 | p = 0 1185 | for i, word_id in enumerate(tweet): 1186 | # log_probs shape is supposed to be [batch_size, 1187 | # dec_length, vocab_size], 1188 | # but it sometimes becomes [batch_size, 1189 | # smaller_value, vocab_size]. 1190 | # This is because we're using GreedyDecoder, 1191 | # dynamic_decode finishes the decoder process when it 1192 | # sees eos_id. 1193 | # If all enc_inputs ends up shorter dec_output, 1194 | # we can have smaller_value here. 1195 | if i < len(log_probs[batch]): 1196 | p += log_probs[batch][i][word_id] 1197 | tweet_len = tweet_len + 1 1198 | if word_id == hparams.eos_id: 1199 | break 1200 | assert (tweet_len != 0) 1201 | p /= tweet_len 1202 | # reward is zero, after eos. So that we can ignore them. 1203 | for i in range(min([tweet_len, max_len])): 1204 | reward_qi[batch][i] = p 1205 | return reward_qi 1206 | 1207 | @staticmethod 1208 | def calc_reward_s(seq2seq_model, train_data, samples): 1209 | max_len = len(samples[0]) 1210 | # [batch_size, dec_len] 1211 | log_probs_sampled, logits1 = seq2seq_model.log_probs_sampled( 1212 | train_data[0], 1213 | train_data[1], 1214 | samples) 1215 | # log_probs_sampled2, logits2 = seq2seq_model.log_probs_sampled( 1216 | # train_data[0], 1217 | # train_data[1], 1218 | # samples) 1219 | # 1220 | # for b in range(seq2seq_model.hparams.batch_size): 1221 | # for i in range(max_len): 1222 | # for v in range(seq2seq_model.hparams.vocab_size): 1223 | # if logits1[b][i][v] != logits2[b][i][v]: 1224 | # pp("Unmatch b={} i={} v={} {} vs {}".format(b, 1225 | # i, v, 1226 | # 1227 | # logits1[ 1228 | # b][ 1229 | # i][ 1230 | # v], 1231 | # 1232 | # logits2[ 1233 | # b][ 1234 | # i][ 1235 | # 1236 | # v])) 1237 | 1238 | # if np.array_equal(log_probs_sampled, log_probs_sampled2): 1239 | # pp("log probs equl") 1240 | # else: 1241 | # pp("noooo") 1242 | 1243 | # [batch_size, dec_len, vocab_size] 1244 | # log_probs, _ = seq2seq_model.log_probs(train_data[0], train_data[1]) 1245 | # for batch in range(seq2seq_model.hparams.batch_size): 1246 | # log_probs_sampled_batch = log_probs_sampled[batch] 1247 | # for i in range(max_len): 1248 | # pp("debugging[{}][{}] {} {} = {}".format(batch, i, 1249 | # 1250 | # log_probs_sampled_batch[ 1251 | # i] == 1252 | # log_probs[ 1253 | # batch][i][ 1254 | # samples[ 1255 | # batch][ 1256 | # i]], 1257 | # 1258 | # log_probs_sampled_batch[ 1259 | # i], 1260 | # log_probs[ 1261 | # batch][i][ 1262 | # samples[ 1263 | # batch][ 1264 | # i]])) 1265 | 1266 | batch_size = seq2seq_model.hparams.batch_size 1267 | reward_s = np.zeros((batch_size, max_len)) 1268 | for batch in range(batch_size): 1269 | tweet = samples[batch] 1270 | tweet_len = 0 1271 | p = 0 1272 | for i in range(len(tweet)): 1273 | p += log_probs_sampled[batch][i] 1274 | tweet_len = tweet_len + 1 1275 | if tweet[i] == seq2seq_model.hparams.eos_id: 1276 | break 1277 | assert (tweet_len != 0) 1278 | p /= tweet_len 1279 | # reward is zero, after eos. So that we can ignore them. 1280 | for i in range(tweet_len): 1281 | reward_s[batch][i] = p 1282 | return reward_s 1283 | 1284 | @staticmethod 1285 | def format_enc_inputs(hparams, model, replies): 1286 | enc_inputs = [] 1287 | enc_inputs_lengths = [] 1288 | 1289 | # replies: [batch_size, dec_length] 1290 | for reply in replies: 1291 | reply_len = model.seq_len(reply.tolist()) 1292 | # Safe guard: sampled reply has sometimes 0 len. 1293 | # adjusted_len = hparams.encoder_length if reply_len 1294 | # == 0 else reply_len 1295 | enc_inputs_lengths.append(reply_len) 1296 | if reply_len <= hparams.encoder_length: 1297 | padded_reply = np.append(reply, ([hparams.pad_id] * ( 1298 | hparams.encoder_length - len(reply)))) 1299 | enc_inputs.append(padded_reply) 1300 | else: 1301 | raise Exception( 1302 | "Inferred" 1303 | " reply shouldn't be longer than encoder_input") 1304 | 1305 | # Expected enc_inputs param is time major. 1306 | enc_inputs = np.transpose(np.array(enc_inputs)) 1307 | return enc_inputs, enc_inputs_lengths 1308 | 1309 | @staticmethod 1310 | def _reward_for_test(model, sampled_replies): 1311 | max_len = len(sampled_replies[0]) 1312 | # default negative reward 1313 | reward = np.ones((model.hparams.batch_size, max_len)) * -1.0 1314 | good_value = 0 1315 | for i, reply in enumerate(sampled_replies): 1316 | reply_len = model.input_length(reply.tolist()) 1317 | if reply_len == 8 or reply_len == 0 or reply_len == 1: 1318 | for r in range(max_len): 1319 | reward[i][r] = -1.0 1320 | else: 1321 | good_value += 1 1322 | for r in range(max_len): 1323 | reward[i][r] = 1.0 1324 | return good_value, reward 1325 | 1326 | def train_seq2seq_swapped(self, hparams, tweets_path, validation_tweets, 1327 | should_clean_saved_model=True, vocab_path=None): 1328 | Shell.download_file_if_necessary(tweets_path) 1329 | swapped_path = TrainDataGenerator.generate_source_target_swapped( 1330 | tweets_path) 1331 | return self.train_seq2seq(hparams, swapped_path, validation_tweets, 1332 | should_clean_saved_model, vocab_path) 1333 | 1334 | def train_seq2seq(self, hparams, tweets_path, val_tweets, 1335 | should_clean_saved_model=True, vocab_path=None): 1336 | pp("===== Train Seq2Seq {} ====".format(tweets_path)) 1337 | print_hparams(hparams) 1338 | 1339 | if should_clean_saved_model: 1340 | clean_model_path(hparams.model_path) 1341 | data_source = TrainDataSource(tweets_path, hparams, vocab_path) 1342 | return self._train_loop(data_source, hparams, val_tweets) 1343 | 1344 | def _print_inferences(self, global_step, tweets, helper, ): 1345 | pp("==== {} ====".format(global_step)) 1346 | len_array = [] 1347 | for tweet in tweets: 1348 | len_array.append(len(helper.inferences(tweet)[0])) 1349 | helper.print_inferences(tweet) 1350 | self._print_log('average reply len', np.mean(len_array)) 1351 | 1352 | @staticmethod 1353 | def create_model(hparams): 1354 | 1355 | # See https://www.tensorflow.org/tutorials/using_gpu 1356 | # #allowing_gpu_memory_growth 1357 | config = tf.ConfigProto(log_device_placement=False) 1358 | config.gpu_options.allow_growth = True 1359 | 1360 | train_graph = tf.Graph() 1361 | train_sess = tf.Session(graph=train_graph, config=config) 1362 | with train_graph.as_default(): 1363 | with tf.variable_scope('root'): 1364 | model = ChatbotModel(train_sess, hparams, 1365 | model_path=hparams.model_path) 1366 | if not model.restore(): 1367 | train_sess.run(tf.global_variables_initializer()) 1368 | 1369 | return model 1370 | 1371 | def _train_loop(self, data_source, 1372 | hparams, tweets): 1373 | Shell.download_model_data_if_necessary(hparams.model_path) 1374 | 1375 | device = self._available_device() 1376 | with tf.device(device): 1377 | model = self.create_model(hparams) 1378 | 1379 | def my_train(**kwargs): 1380 | data = kwargs['train_data'] 1381 | return model.train(data[0], data[1], data[2], data[3], data[4]) 1382 | 1383 | return self._generic_train_loop(data_source, hparams, 1384 | model, 1385 | tweets, my_train) 1386 | 1387 | @staticmethod 1388 | def _available_device(): 1389 | device = '/cpu:0' 1390 | if has_gpu0(): 1391 | device = '/gpu:0' 1392 | pp("$$$ GPU ENABLED $$$") 1393 | return device 1394 | 1395 | @staticmethod 1396 | def tokenize(infer_helper, text): 1397 | tagger = MeCab.Tagger("-Owakati") 1398 | words = tagger.parse(text).split() 1399 | return infer_helper.words_to_ids(words) 1400 | 1401 | def _generic_train_loop(self, data_source, hparams, 1402 | model, 1403 | tweets, train_func): 1404 | try: 1405 | return self._raw_train_loop(data_source, hparams, model, train_func, 1406 | tweets) 1407 | except KeyboardInterrupt as ke: 1408 | raise ke 1409 | except Exception as e: 1410 | pb.push_note("Train error", str(e)) 1411 | raise e 1412 | 1413 | def _raw_train_loop(self, data_source, hparams, 1414 | model, train_func, 1415 | tweets): 1416 | vocab = data_source.vocab 1417 | rev_vocab = data_source.rev_vocab 1418 | infer_helper = InferenceHelper(model, vocab, rev_vocab) 1419 | graph = model.sess.graph 1420 | with graph.as_default(): 1421 | train_data_next = \ 1422 | data_source.train_dataset.make_one_shot_iterator().get_next() 1423 | val_data_next = data_source.valid_dataset.make_one_shot_iterator( 1424 | 1425 | ).get_next() 1426 | easy_tf_log.set_dir(hparams.model_path) 1427 | writer = tf.summary.FileWriter(hparams.model_path, graph) 1428 | self.last_saved_time = datetime.datetime.now() 1429 | for i in range(hparams.num_train_steps): 1430 | train_data = model.sess.run(train_data_next) 1431 | 1432 | step, summary = train_func( 1433 | train_data=train_data, 1434 | ) 1435 | writer.add_summary(summary, step) 1436 | 1437 | if i != 0 and i % self.num_stats_per == 0: 1438 | model.save(hparams.model_path) 1439 | is_restored = model.restore() 1440 | assert is_restored 1441 | self._print_inferences(step, tweets, infer_helper) 1442 | self._compute_val_loss(step, model, val_data_next, writer) 1443 | # self._print_stats(hparams, 1444 | # learning_rate) 1445 | self._plot_if_necessary() 1446 | self._save_model_in_drive(hparams) 1447 | else: 1448 | print('.', end='') 1449 | return model, infer_helper 1450 | 1451 | def _plot_if_necessary(self): 1452 | if len(self.reward_average) > 0 and len(self.reward_average) % 30 == 0: 1453 | self._plot(self.reward_step, self.reward_average, 1454 | y_label='reward average') 1455 | self._plot(self.loss_step, self.val_losses, 1456 | y_label='validation_loss') 1457 | 1458 | def _print_stats(self, hparams, learning_rate): 1459 | pp("learning rate", learning_rate) 1460 | delta = ( 1461 | datetime.datetime.now() - 1462 | self.last_stats_time).total_seconds() * 1000 1463 | self._print_log("msec/data", 1464 | delta / hparams.batch_size / self.num_stats_per) 1465 | self.last_stats_time = datetime.datetime.now() 1466 | 1467 | def _save_model_in_drive(self, hparams): 1468 | now = datetime.datetime.now() 1469 | delta_in_min = (now - self.last_saved_time).total_seconds() / 60 1470 | 1471 | if delta_in_min >= 60: 1472 | self.last_saved_time = datetime.datetime.now() 1473 | Shell.save_model_in_drive(hparams.model_path) 1474 | 1475 | @staticmethod 1476 | def _log(key, value, step=None): 1477 | tflog("{}[{}]".format(key, current_client_id), value, step) 1478 | 1479 | @staticmethod 1480 | def _print_log(key, value, step=None): 1481 | if step is None: 1482 | return 1483 | tflog("{}_{}".format(key, current_client_id), value, step) 1484 | pp("{}={}".format(key, round(value, 1))) 1485 | 1486 | @staticmethod 1487 | def _plot(x, y, x_label="step", y_label='y'): 1488 | title = "{}_{}".format(current_client_id, y_label) 1489 | plt.plot(x, y, label=title) 1490 | plt.plot() 1491 | plt.ylabel(title) 1492 | plt.xlabel(x_label) 1493 | plt.legend() 1494 | plt.show() 1495 | 1496 | def _compute_val_loss(self, global_step, model, val_data_next, 1497 | writer): 1498 | val_data = model.sess.run(val_data_next) 1499 | val_loss, val_loss_log = model.batch_loss(val_data[0], 1500 | val_data[1], 1501 | val_data[2], 1502 | val_data[3], 1503 | val_data[4]) 1504 | # np.float64 to native float 1505 | val_loss = val_loss.item() 1506 | writer.add_summary(val_loss_log, global_step) 1507 | self._print_log("validation loss", val_loss) 1508 | self.loss_step.append(global_step) 1509 | self.val_losses.append(val_loss) 1510 | return val_loss 1511 | 1512 | 1513 | class InferenceHelper: 1514 | def __init__(self, model, vocab, rev_vocab): 1515 | self.model = model 1516 | self.vocab = vocab 1517 | self.rev_vocab = rev_vocab 1518 | 1519 | def inferences(self, tweet): 1520 | encoder_inputs, encoder_inputs_lengths = self.create_inference_input( 1521 | tweet) 1522 | replies, _ = self.model.infer(encoder_inputs, encoder_inputs_lengths) 1523 | ids = replies[0].tolist() 1524 | all_infer = [self.sanitize_text(self.ids_to_words(ids))] 1525 | beam_replies, logits, _, _ = self.model.infer_beam_search( 1526 | encoder_inputs, 1527 | encoder_inputs_lengths) 1528 | beam_infer = [ 1529 | self.sanitize_text(self.ids_to_words(beam_replies[0][:, i])) for i 1530 | in range(self.model.hparams.beam_width)] 1531 | all_infer.extend(beam_infer) 1532 | return all_infer 1533 | 1534 | @staticmethod 1535 | def sanitize_text(line): 1536 | line = re.sub(r"\[EOS\]", " ", line) 1537 | line = re.sub(r"\[UNK\]", "💩", line) 1538 | return line 1539 | 1540 | def print_inferences(self, tweet): 1541 | pp(tweet) 1542 | for i, reply in enumerate(self.inferences(tweet)): 1543 | pp(" [{}]{}".format(i, reply)) 1544 | 1545 | def words_to_ids(self, words): 1546 | ids = [] 1547 | for word in words: 1548 | if word in self.vocab: 1549 | ids.append(self.vocab[word]) 1550 | else: 1551 | ids.append(self.model.hparams.unk_id) 1552 | return ids 1553 | 1554 | def ids_to_string(self, ids): 1555 | return self.sanitize_text(self.ids_to_words(ids)) 1556 | 1557 | def ids_to_words(self, ids): 1558 | words = "" 1559 | for word_id in ids: 1560 | words += self.rev_vocab[word_id] 1561 | return words 1562 | 1563 | def create_inference_input(self, text): 1564 | inference_encoder_inputs = np.empty( 1565 | (self.model.hparams.encoder_length, self.model.hparams.batch_size), 1566 | dtype=np.int) 1567 | inference_encoder_inputs_lengths = np.empty( 1568 | self.model.hparams.batch_size, dtype=np.int) 1569 | text = TrainDataGenerator.sanitize_line(text) 1570 | tagger = MeCab.Tagger("-Owakati") 1571 | words = tagger.parse(text).split() 1572 | ids = self.words_to_ids(words) 1573 | ids = ids[:self.model.hparams.encoder_length] 1574 | len_ids = len(ids) 1575 | ids.extend([self.model.hparams.pad_id] * ( 1576 | self.model.hparams.encoder_length - len(ids))) 1577 | for i in range(self.model.hparams.batch_size): 1578 | inference_encoder_inputs[:, i] = np.array(ids, dtype=np.int) 1579 | inference_encoder_inputs_lengths[i] = len_ids 1580 | return inference_encoder_inputs, inference_encoder_inputs_lengths 1581 | 1582 | 1583 | class ConversationTrainDataGenerator: 1584 | def __init__(self): 1585 | return 1586 | 1587 | # Generate the following file from conversations_txt file. 1588 | # Let p_i: line 3i in the txt file, which is original tweet. 1589 | # q_i: line 3i + 1 in the txt file, which is reply to the tweet. 1590 | # p_i+1: line 3i + 2 in the txt file, which is reply to the reply above. 1591 | # (A) conversation_seq2seq.txt for train p_seq2seq and p_seq2seq_backward 1592 | # line 2i: p_i + q_i 1593 | # line 2i+1: p_i+1 1594 | # 1595 | # (B) conversation_rl.txt for train p_rl. 1596 | # line 2i: p_i + q_i 1597 | # line 2i+1: q_i 1598 | # 1599 | # (A) and (B) should share the vocabulary. 1600 | # noinspection PyUnusedLocal 1601 | def generate(self, conversations_txt): 1602 | basename, extension = os.path.splitext(conversations_txt) 1603 | seq2seq_path = "{}_seq2seq{}".format(basename, extension) 1604 | rl_path = "{}_rl{}".format(basename, extension) 1605 | with open(seq2seq_path, "w") as s_out, open(rl_path, 1606 | "w") as r_out, gfile.GFile( 1607 | conversations_txt, 1608 | mode="rb") as fin: 1609 | tweet = None 1610 | reply = None 1611 | reply2 = None 1612 | for i, line in enumerate(fin): 1613 | line = line.decode('utf-8') 1614 | line = line.rstrip() 1615 | if i % 3 == 0: 1616 | tweet = line 1617 | elif i % 3 == 1: 1618 | reply = line 1619 | else: 1620 | reply2 = line 1621 | self._write(s_out, tweet, reply, reply2) 1622 | self._write(r_out, tweet, reply, reply) 1623 | 1624 | @staticmethod 1625 | def _write(s_out, tweet, reply, reply2): 1626 | s_out.write(tweet) 1627 | s_out.write(' ') 1628 | s_out.write(reply) 1629 | s_out.write('\n') 1630 | s_out.write(reply2) 1631 | s_out.write('\n') 1632 | 1633 | 1634 | class TrainDataGenerator: 1635 | def __init__(self, source_path, hparams): 1636 | self.source_path = source_path 1637 | self.hparams = hparams 1638 | basename, extension = os.path.splitext(self.source_path) 1639 | self.enc_path = "{}_enc{}".format(basename, extension) 1640 | self.dec_path = "{}_dec{}".format(basename, extension) 1641 | self.enc_idx_path = "{}_enc_idx{}".format(basename, extension) 1642 | self.dec_idx_path = "{}_dec_idx{}".format(basename, extension) 1643 | self.dec_idx_eos_path = "{}_dec_idx_eos{}".format(basename, extension) 1644 | self.dec_idx_sos_path = "{}_dec_idx_sos{}".format(basename, extension) 1645 | self.dec_idx_len_path = "{}_dec_idx_len{}".format(basename, extension) 1646 | 1647 | self.enc_idx_padded_path = "{}_enc_idx_padded{}".format(basename, 1648 | extension) 1649 | self.enc_idx_len_path = "{}_enc_idx_len{}".format(basename, extension) 1650 | 1651 | self.vocab_path = "{}_vocab{}".format(basename, extension) 1652 | 1653 | self.generated_files = [self.enc_path, self.dec_path, self.enc_idx_path, 1654 | self.dec_idx_path, self.dec_idx_eos_path, 1655 | self.dec_idx_sos_path, self.dec_idx_len_path, 1656 | self.enc_idx_padded_path, self.vocab_path, 1657 | self.enc_idx_len_path] 1658 | self.max_vocab_size = hparams.vocab_size 1659 | self.start_vocabs = [hparams.sos_token, hparams.eos_token, 1660 | hparams.pad_token, hparams.unk_token] 1661 | self.tagger = MeCab.Tagger("-Owakati") 1662 | 1663 | def remove_generated(self): 1664 | for file in self.generated_files: 1665 | if os.path.exists(file): 1666 | os.remove(file) 1667 | 1668 | def generate(self, vocab_path=None): 1669 | pp("generating enc and dec files...") 1670 | self._generate_enc_dec() 1671 | pp("generating vocab file...") 1672 | if vocab_path is None: 1673 | self._generate_vocab() 1674 | else: 1675 | shutil.copyfile(vocab_path, self.vocab_path) 1676 | pp("loading vocab...") 1677 | vocab, _ = self._load_vocab() 1678 | pp("generating id files...") 1679 | self._generate_id_file(self.enc_path, self.enc_idx_path, vocab) 1680 | self._generate_id_file(self.dec_path, self.dec_idx_path, vocab) 1681 | pp("generating padded input file...") 1682 | self._generate_enc_idx_padded(self.enc_idx_path, 1683 | self.enc_idx_padded_path, 1684 | self.enc_idx_len_path, 1685 | self.hparams.encoder_length) 1686 | pp("generating dec eos/sos files...") 1687 | self._generate_dec_idx_eos(self.dec_idx_path, self.dec_idx_eos_path, 1688 | self.hparams.decoder_length) 1689 | self._generate_dec_idx_sos(self.dec_idx_path, self.dec_idx_sos_path, 1690 | self.dec_idx_len_path, 1691 | self.hparams.decoder_length) 1692 | pp("done") 1693 | return self._create_dataset() 1694 | 1695 | def _generate_id_file(self, source_path, dest_path, vocab): 1696 | if gfile.Exists(dest_path): 1697 | return 1698 | with gfile.GFile(source_path, mode="rb") as file, gfile.GFile(dest_path, 1699 | mode="wb") as of: 1700 | for line in file: 1701 | line = line.decode('utf-8') 1702 | words = self.tagger.parse(line).split() 1703 | ids = [vocab.get(w, self.hparams.unk_id) for w in words] 1704 | of.write(" ".join([str(word_id) for word_id in ids]) + "\n") 1705 | 1706 | def _load_vocab(self): 1707 | rev_vocab = [] 1708 | with gfile.GFile(self.vocab_path, mode="r") as file: 1709 | rev_vocab.extend(file.readlines()) 1710 | rev_vocab = [line.strip() for line in rev_vocab] 1711 | # Dictionary of (word, idx) 1712 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 1713 | return vocab, rev_vocab 1714 | 1715 | def _generate_vocab(self): 1716 | if gfile.Exists(self.vocab_path): 1717 | return 1718 | vocab_dic = self._build_vocab_dic(self.enc_path) 1719 | vocab_dic = self._build_vocab_dic(self.dec_path, vocab_dic) 1720 | vocab_list = self.start_vocabs + sorted(vocab_dic, key=vocab_dic.get, 1721 | reverse=True) 1722 | if len(vocab_list) > self.max_vocab_size: 1723 | pp("vocab_len=", len(vocab_list)) 1724 | vocab_list = vocab_list[:self.max_vocab_size] 1725 | with gfile.GFile(self.vocab_path, mode="w") as vocab_file: 1726 | for w in vocab_list: 1727 | vocab_file.write(w + "\n") 1728 | 1729 | # noinspection PyUnusedLocal,PyUnusedLocal 1730 | def _generate_enc_dec(self): 1731 | if gfile.Exists(self.enc_path) and gfile.Exists(self.dec_path): 1732 | return 1733 | with gfile.GFile(self.source_path, mode="rb") as file, gfile.GFile( 1734 | self.enc_path, mode="w+") as ef, gfile.GFile(self.dec_path, 1735 | mode="w+") as df: 1736 | tweet = None 1737 | reply = None 1738 | for i, line in enumerate(file): 1739 | line = line.decode('utf-8') 1740 | line = self.sanitize_line(line) 1741 | if i % 2 == 0: 1742 | tweet = line 1743 | else: 1744 | reply = line 1745 | if tweet and reply: 1746 | ef.write(tweet) 1747 | df.write(reply) 1748 | tweet = None 1749 | reply = None 1750 | 1751 | def _generate_enc_idx_padded(self, source_path, dest_path, dest_len_path, 1752 | max_line_len): 1753 | if gfile.Exists(dest_path): 1754 | return 1755 | with open(source_path) as fin, open(dest_path, 1756 | "w") as fout, open(dest_len_path, 1757 | "w") as flen: 1758 | line = fin.readline() 1759 | while line: 1760 | ids = [int(x) for x in line.split()] 1761 | if len(ids) > max_line_len: 1762 | ids = ids[:max_line_len] 1763 | # i don't remember why we did this 1764 | # ids = ids[-max_line_len:] 1765 | flen.write(str(len(ids))) 1766 | flen.write("\n") 1767 | if len(ids) < max_line_len: 1768 | ids.extend( 1769 | [self.hparams.pad_id] * (max_line_len - len(ids))) 1770 | ids = [str(x) for x in ids] 1771 | fout.write(" ".join(ids)) 1772 | fout.write("\n") 1773 | line = fin.readline() 1774 | 1775 | # read decoder_idx file and append eos at the end of idx list. 1776 | def _generate_dec_idx_eos(self, source_path, dest_path, max_line_len): 1777 | if gfile.Exists(dest_path): 1778 | return 1779 | with open(source_path) as fin, open(dest_path, "w") as fout: 1780 | line = fin.readline() 1781 | while line: 1782 | ids = [int(x) for x in line.split()] 1783 | if len(ids) > max_line_len - 1: 1784 | ids = ids[:max_line_len - 1] 1785 | # ids = ids[-(max_line_len - 1):] 1786 | ids.append(self.hparams.eos_id) 1787 | if len(ids) < max_line_len: 1788 | ids.extend( 1789 | [self.hparams.pad_id] * (max_line_len - len(ids))) 1790 | ids = [str(x) for x in ids] 1791 | fout.write(" ".join(ids)) 1792 | fout.write("\n") 1793 | line = fin.readline() 1794 | 1795 | # read decoder_idx file and put sos at the beginning of the idx list. 1796 | # also write out length of index list. 1797 | def _generate_dec_idx_sos(self, source_path, dest_path, dest_len_path, 1798 | max_line_len): 1799 | if gfile.Exists(dest_path): 1800 | return 1801 | with open(source_path) as fin, open(dest_path, "w") as fout, open( 1802 | dest_len_path, "w") as flen: 1803 | line = fin.readline() 1804 | while line: 1805 | ids = [self.hparams.sos_id] 1806 | ids.extend([int(x) for x in line.split()]) 1807 | if len(ids) > max_line_len: 1808 | ids = ids[:max_line_len] 1809 | flen.write(str(len(ids))) 1810 | flen.write("\n") 1811 | if len(ids) < max_line_len: 1812 | ids.extend( 1813 | [self.hparams.pad_id] * (max_line_len - len(ids))) 1814 | ids = [str(x) for x in ids] 1815 | fout.write(" ".join(ids)) 1816 | fout.write("\n") 1817 | line = fin.readline() 1818 | 1819 | @staticmethod 1820 | def sanitize_line(line): 1821 | # replace @username 1822 | # replacing @username had bad impact where USERNAME token shows up 1823 | # everywhere. 1824 | # line = re.sub(r"@([A-Za-z0-9_]+)", "USERNAME", line) 1825 | line = re.sub(r"@([A-Za-z0-9_]+)", "", line) 1826 | # Remove URL 1827 | line = re.sub(r'https?://.*', "", line) 1828 | line = line.lstrip() 1829 | return line 1830 | 1831 | @staticmethod 1832 | def generate_source_target_swapped(source_path): 1833 | basename, extension = os.path.splitext(source_path) 1834 | dest_path = "{}_swapped{}".format(basename, extension) 1835 | with gfile.GFile(source_path, mode="rb") as fin, gfile.GFile(dest_path, 1836 | mode="w+") as fout: 1837 | temp = None 1838 | for i, line in enumerate(fin): 1839 | if i % 2 == 0: 1840 | temp = line 1841 | else: 1842 | fout.write(line) 1843 | fout.write(temp) 1844 | temp = None 1845 | return dest_path 1846 | 1847 | def _build_vocab_dic(self, source_path, vocab_dic=None): 1848 | if vocab_dic is None: 1849 | vocab_dic = {} 1850 | with gfile.GFile(source_path, mode="r") as file: 1851 | for line in file: 1852 | words = self.tagger.parse(line).split() 1853 | for word in words: 1854 | if word in vocab_dic: 1855 | vocab_dic[word] += 1 1856 | else: 1857 | vocab_dic[word] = 1 1858 | return vocab_dic 1859 | 1860 | @staticmethod 1861 | def _read_file(source_path): 1862 | file = open(source_path) 1863 | data = file.read() 1864 | file.close() 1865 | return data 1866 | 1867 | def _read_vocab(self, source_path): 1868 | rev_vocab = [] 1869 | rev_vocab.extend(self._read_file(source_path).splitlines()) 1870 | rev_vocab = [line.strip() for line in rev_vocab] 1871 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 1872 | return vocab, rev_vocab 1873 | 1874 | def text_line_split_dataset(self, filename): 1875 | return tf.data.TextLineDataset(filename).map(self.split_to_int_values) 1876 | 1877 | @staticmethod 1878 | def split_to_int_values(x): 1879 | return tf.string_to_number(tf.string_split([x]).values, tf.int32) 1880 | 1881 | def _create_dataset(self): 1882 | 1883 | tweets_dataset = self.text_line_split_dataset(self.enc_idx_padded_path) 1884 | tweets_lengths_dataset = tf.data.TextLineDataset( 1885 | self.enc_idx_len_path) 1886 | 1887 | replies_sos_dataset = self.text_line_split_dataset( 1888 | self.dec_idx_sos_path) 1889 | replies_eos_dataset = self.text_line_split_dataset( 1890 | self.dec_idx_eos_path) 1891 | replies_sos_lengths_dataset = tf.data.TextLineDataset( 1892 | self.dec_idx_len_path) 1893 | 1894 | tweets_transposed = tweets_dataset.apply( 1895 | tf.contrib.data.batch_and_drop_remainder( 1896 | self.hparams.batch_size)).map( 1897 | lambda x: tf.transpose(x)) 1898 | tweets_lengths = tweets_lengths_dataset.apply( 1899 | tf.contrib.data.batch_and_drop_remainder(self.hparams.batch_size)) 1900 | 1901 | replies_with_eos_suffix = replies_eos_dataset.apply( 1902 | tf.contrib.data.batch_and_drop_remainder(self.hparams.batch_size)) 1903 | replies_with_sos_prefix = replies_sos_dataset.apply( 1904 | tf.contrib.data.batch_and_drop_remainder( 1905 | self.hparams.batch_size)).map( 1906 | lambda x: tf.transpose(x)) 1907 | replies_with_sos_suffix_lengths = replies_sos_lengths_dataset.apply( 1908 | tf.contrib.data.batch_and_drop_remainder( 1909 | self.hparams.batch_size)) 1910 | vocab, rev_vocab = self._read_vocab(self.vocab_path) 1911 | return tf.data.Dataset.zip((tweets_transposed, tweets_lengths, 1912 | replies_with_eos_suffix, 1913 | replies_with_sos_prefix, 1914 | replies_with_sos_suffix_lengths)), vocab, \ 1915 | rev_vocab 1916 | 1917 | 1918 | def print_hparams(hparams): 1919 | result = {} 1920 | for key in ['machine', 'batch_size', 'num_units', 'num_layers', 1921 | 'vocab_size', 1922 | 'embedding_size', 'learning_rate', 'learning_rate_decay', 1923 | 'use_attention', 'encoder_length', 'decoder_length', 1924 | 'max_gradient_norm', 'beam_width', 'num_train_steps', 1925 | 'model_path']: 1926 | result[key] = hparams.get(key) 1927 | pp("hparams=", result) 1928 | 1929 | 1930 | # Helper functions to test 1931 | def make_test_training_data(hparams): 1932 | train_encoder_inputs = np.empty( 1933 | (hparams.encoder_length, hparams.batch_size), dtype=np.int) 1934 | train_encoder_inputs_lengths = np.empty(hparams.batch_size, dtype=np.int) 1935 | training_target_labels = np.empty( 1936 | (hparams.batch_size, hparams.decoder_length), dtype=np.int) 1937 | training_decoder_inputs = np.empty( 1938 | (hparams.decoder_length, hparams.batch_size), dtype=np.int) 1939 | 1940 | # We keep first tweet to validate inference. 1941 | first_tweet = None 1942 | 1943 | for i in range(hparams.batch_size): 1944 | # Tweet 1945 | tweet = np.random.randint(low=0, high=hparams.vocab_size, 1946 | size=hparams.encoder_length) 1947 | train_encoder_inputs[:, i] = tweet 1948 | train_encoder_inputs_lengths[i] = len(tweet) 1949 | # Reply 1950 | # Note that low = 2, as 0 and 1 are reserved. 1951 | reply = np.random.randint(low=2, high=hparams.vocab_size, 1952 | size=hparams.decoder_length - 1) 1953 | 1954 | training_target_label = np.concatenate( 1955 | (reply, np.array([hparams.eos_id]))) 1956 | training_target_labels[i] = training_target_label 1957 | 1958 | training_decoder_input = np.concatenate(([hparams.sos_id], reply)) 1959 | training_decoder_inputs[:, i] = training_decoder_input 1960 | 1961 | if i == 0: 1962 | first_tweet = tweet 1963 | return first_tweet, train_encoder_inputs, train_encoder_inputs_lengths, \ 1964 | training_target_labels, training_decoder_inputs 1965 | 1966 | 1967 | def test_training(hparams, model): 1968 | if hparams.use_attention: 1969 | pp("==== training model[attention] ====") 1970 | else: 1971 | pp("==== training model ====") 1972 | first_tweet, train_encoder_inputs, train_encoder_inputs_lengths, \ 1973 | training_target_labels, training_decoder_inputs = make_test_training_data( 1974 | hparams) 1975 | for i in range(hparams.num_train_steps): 1976 | _ = model.train(train_encoder_inputs, 1977 | train_encoder_inputs_lengths, 1978 | training_target_labels, 1979 | training_decoder_inputs, 1980 | np.ones(hparams.batch_size, 1981 | dtype=int) * hparams.decoder_length) 1982 | if i % 5 == 0 and hparams.debug_verbose: 1983 | print('.', end='') 1984 | 1985 | if i % 15 == 0: 1986 | model.save() 1987 | 1988 | inference_encoder_inputs = np.empty((hparams.encoder_length, 1), 1989 | dtype=np.int) 1990 | inference_encoder_inputs_lengths = np.empty(1, dtype=np.int) 1991 | for i in range(1): 1992 | inference_encoder_inputs[:, i] = first_tweet 1993 | inference_encoder_inputs_lengths[i] = len(first_tweet) 1994 | 1995 | # testing 1996 | log_prob54 = model.log_prob(inference_encoder_inputs, 1997 | inference_encoder_inputs_lengths, 1998 | np.array([5, 4])) 1999 | log_prob65 = model.log_prob(inference_encoder_inputs, 2000 | inference_encoder_inputs_lengths, 2001 | np.array([6, 5])) 2002 | pp("log_prob for 54", log_prob54) 2003 | pp("log_prob for 65", log_prob65) 2004 | 2005 | reward = model.reward_ease_of_answering(hparams.encoder_length, 2006 | inference_encoder_inputs, 2007 | inference_encoder_inputs_lengths, 2008 | np.array([[5], [6]])) 2009 | pp("reward=", reward) 2010 | 2011 | if hparams.debug_verbose: 2012 | pp(inference_encoder_inputs) 2013 | replies, _ = model.infer(inference_encoder_inputs, 2014 | inference_encoder_inputs_lengths) 2015 | pp("Inferred replies", replies[0]) 2016 | pp("Expected replies", training_target_labels[0]) 2017 | 2018 | 2019 | def test_distributed_pattern(hparams): 2020 | for d in [hparams.model_path]: 2021 | shutil.rmtree(d, ignore_errors=True) 2022 | os.makedirs(d, exist_ok=True) 2023 | 2024 | pp('==== test_distributed_pattern[{} {}] ===='.format( 2025 | 'attention' if hparams.use_attention else '', 2026 | 'beam' if hparams.beam_width > 0 else '')) 2027 | 2028 | first_tweet, train_encoder_inputs, train_encoder_inputs_lengths, \ 2029 | training_target_labels, training_decoder_inputs = make_test_training_data( 2030 | hparams) 2031 | 2032 | model = Trainer().create_model(hparams) 2033 | 2034 | for i in range(hparams.num_train_steps): 2035 | _ = model.train(train_encoder_inputs, 2036 | train_encoder_inputs_lengths, 2037 | training_target_labels, 2038 | training_decoder_inputs, 2039 | np.ones(hparams.batch_size, 2040 | dtype=int) * hparams.decoder_length) 2041 | 2042 | model.save() 2043 | 2044 | inference_encoder_inputs = np.empty( 2045 | (hparams.encoder_length, hparams.batch_size), 2046 | dtype=np.int) 2047 | inference_encoder_inputs_lengths = np.empty(hparams.batch_size, 2048 | dtype=np.int) 2049 | 2050 | for i in range(hparams.batch_size): 2051 | inference_encoder_inputs[:, i] = first_tweet 2052 | inference_encoder_inputs_lengths[i] = len(first_tweet) 2053 | 2054 | model.restore() 2055 | replies, _ = model.infer(inference_encoder_inputs, 2056 | inference_encoder_inputs_lengths) 2057 | pp("Inferred replies", replies[0]) 2058 | 2059 | beam_replies, logits, _, _ = model.infer_beam_search( 2060 | inference_encoder_inputs, 2061 | inference_encoder_inputs_lengths) 2062 | 2063 | pp("logits", logits[0]) 2064 | pp("Inferred replies candidate0", beam_replies[0][:, 0]) 2065 | pp("Inferred replies candidate1", beam_replies[0][:, 1]) 2066 | 2067 | inference_encoder_inputs = np.empty( 2068 | (hparams.encoder_length, hparams.batch_size), 2069 | dtype=np.int) 2070 | inference_encoder_inputs_lengths = np.empty(hparams.batch_size, 2071 | dtype=np.int) 2072 | 2073 | for i in range(hparams.batch_size): 2074 | inference_encoder_inputs[:, i] = first_tweet 2075 | inference_encoder_inputs_lengths[i] = len(first_tweet) 2076 | 2077 | replies = model.sample(inference_encoder_inputs, 2078 | inference_encoder_inputs_lengths) 2079 | pp("sample replies", replies[0]) 2080 | pp("Expected replies", training_target_labels[0]) 2081 | 2082 | 2083 | def test_distributed_one(enable_attention): 2084 | hparams = copy.deepcopy(base_hparams).override_from_dict({ 2085 | 'model_path': ModelDirectory.test_distributed.value, 2086 | 'use_attention': enable_attention, 2087 | 'beam_width': 2, 2088 | 'num_train_steps': 100, 2089 | 'learning_rate': 0.5 2090 | }) 2091 | test_distributed_pattern(hparams) 2092 | 2093 | 2094 | def clean_model_path(model_path): 2095 | shutil.rmtree(model_path) 2096 | os.makedirs(model_path) 2097 | 2098 | 2099 | def print_header(text): 2100 | pp("============== {} ==============".format(text)) 2101 | 2102 | 2103 | def test_tweets_small_swapped(hparams): 2104 | replies = ["@higepon おはようございます!", "おつかれさまー。気をつけて。", "こちらこそよろしくお願いします。"] 2105 | trainer = Trainer() 2106 | trainer.train_seq2seq_swapped(hparams, "tweets_small.txt", replies) 2107 | 2108 | 2109 | def test_tweets_large(hparams): 2110 | tweets = ["さて福岡行ってきます!", "誰か飲みに行こう", "熱でてるけど、でもなんか食べなきゃーと思ってアイス買おうとしたの", 2111 | "今日のドラマ面白そう!", "お腹すいたー", "おやすみ~", "おはようございます。寒いですね。", 2112 | "さて帰ろう。明日は早い。", "今回もよろしくです。", "ばいとおわ!"] 2113 | trainer = Trainer() 2114 | trainer.train_seq2seq(hparams, "tweets_conversation.txt", tweets, 2115 | should_clean_saved_model=False) 2116 | return trainer.model 2117 | 2118 | 2119 | def test_tweets_large_swapped(hparams): 2120 | tweets = ["今日のドラマ面白そう!", "お腹すいたー", "おやすみ~", "おはようございます。寒いですね。", 2121 | "さて帰ろう。明日は早い。", "今回もよろしくです。", "ばいとおわ!"] 2122 | trainer = Trainer() 2123 | trainer.train_seq2seq_swapped(hparams, "tweets_large.txt", tweets, 2124 | should_clean_saved_model=False) 2125 | return trainer.model 2126 | 2127 | 2128 | class StreamListener(tweepy.StreamListener): 2129 | def __init__(self, api, helper): 2130 | self.api = api 2131 | self.helper = helper 2132 | 2133 | def on_status(self, status): 2134 | # done handle @reply only 2135 | # done print reply 2136 | # add model parameter 2137 | # direct reply 2138 | # unk reply 2139 | # shuffle beam search 2140 | print("{0}: {1}".format(status.text, status.author.screen_name)) 2141 | 2142 | screen_name = status.author.screen_name 2143 | # ignore my tweets 2144 | if screen_name == self.api.me().screen_name: 2145 | print("Ignored my tweet") 2146 | return True 2147 | elif status.text.startswith("@{0}".format(self.api.me().screen_name)): 2148 | 2149 | replies = self.helper.inferences(status.text) 2150 | reply = random.choice(replies) 2151 | reply = "@" + status.author.screen_name + " " + reply 2152 | print(reply) 2153 | self.api.update_status(status=reply, 2154 | in_reply_to_status_id=status.id) 2155 | 2156 | return True 2157 | 2158 | @staticmethod 2159 | def on_error(status_code, **kwargs): 2160 | print(status_code) 2161 | return True 2162 | 2163 | 2164 | def listener(hparams): 2165 | Shell.download_model_data_if_necessary(hparams.model_path) 2166 | 2167 | infer_model = Trainer().create_model(hparams) 2168 | 2169 | source_path = "conversations_large.txt" 2170 | Shell.download_file_if_necessary(source_path) 2171 | generator = TrainDataGenerator(source_path=source_path, hparams=hparams) 2172 | _, vocab, rev_vocab = generator.generate() 2173 | infer_model.restore() 2174 | helper = InferenceHelper(infer_model, vocab, rev_vocab) 2175 | 2176 | config_yml = 'config.yml' 2177 | Shell.download_file_if_necessary(config_yml) 2178 | file = open(config_yml, 'rt') 2179 | cfg = yaml.load(file)['twitter'] 2180 | 2181 | consumer_key = cfg['consumer_key'] 2182 | consumer_secret = cfg['consumer_secret'] 2183 | access_token = cfg['access_token'] 2184 | access_token_secret = cfg['access_token_secret'] 2185 | 2186 | auth = tweepy.OAuthHandler(consumer_key, consumer_secret) 2187 | auth.set_access_token(access_token, access_token_secret) 2188 | api = tweepy.API(auth) 2189 | 2190 | while True: 2191 | # try: 2192 | stream = tweepy.Stream(auth=api.auth, 2193 | listener=StreamListener(api, helper)) 2194 | print("listener starting...") 2195 | stream.userstream() 2196 | 2197 | 2198 | conversations_large_hparams = copy.deepcopy(base_hparams).override_from_dict( 2199 | { 2200 | # In typical seq2seq chatbot 2201 | # num_layers=3, learning_rate=0.5, batch_size=64, vocab=20000-100000, 2202 | # learning_rate decay is 0.99, which is taken care as default 2203 | # parameter in AdamOptimizer. 2204 | 'batch_size': 128, 2205 | # of tweets should be dividable by batch_size default 64 2206 | 'encoder_length': 28, 2207 | 'decoder_length': 28, 2208 | 'num_units': 1024, 2209 | 'num_layers': 3, 2210 | 'vocab_size': 60000, 2211 | # conversations.txt actually has about 70K uniq words. 2212 | 'embedding_size': 1024, 2213 | 'beam_width': 2, # for faster iteration, this should be 10 2214 | 'num_train_steps': 0, 2215 | 'model_path': ModelDirectory.conversations_large.value, 2216 | 'learning_rate': 0.5, 2217 | # For vocab_size 50000, num_layers 3, num_units 1024, tweet_large, 2218 | # starting learning_rate 0.05 works well, change it t0 0.01 at 2219 | # perplexity 800, changed it to 0.005 at 200. 2220 | 'learning_rate_decay': 0.99, 2221 | 'use_attention': True, 2222 | 2223 | }) 2224 | 2225 | # batch_size=128, learning_rage=0.001 work very well for RL. Loss decreases 2226 | # as expected. enthropy didn't flat out. 2227 | 2228 | conversations_large_rl_hparams = copy.deepcopy( 2229 | conversations_large_hparams).override_from_dict( 2230 | { 2231 | 'model_path': ModelDirectory.conversations_large_rl.value, 2232 | 'num_train_steps': 2000, 2233 | 'learning_rate': 0.001, 2234 | 'beam_width': 3, 2235 | }) 2236 | 2237 | conversations_large_backward_hparams = copy.deepcopy( 2238 | conversations_large_hparams).override_from_dict( 2239 | { 2240 | 'model_path': ModelDirectory.conversations_large_backward.value, 2241 | 'num_train_steps': 0, 2242 | }) 2243 | 2244 | 2245 | def test_train_rl(): 2246 | resume_rl = True 2247 | 2248 | conversations_txt = "conversations_large.txt" 2249 | Shell.download_file_if_necessary(conversations_txt) 2250 | ConversationTrainDataGenerator().generate(conversations_txt) 2251 | 2252 | trainer = Trainer() 2253 | valid_tweets = ["さて福岡行ってきます!", "誰か飲みに行こう", 2254 | "熱でてるけど、でもなんか食べなきゃーと思ってアイス買おうとしたの", 2255 | "今日のドラマ面白そう!", "お腹すいたー", "おやすみ~", "おはようございます。寒いですね。", 2256 | "さて帰ろう。明日は早い。", "今回もよろしくです。", "ばいとおわ!"] 2257 | trainer.train_seq2seq(conversations_large_hparams, 2258 | "conversations_large_seq2seq.txt", 2259 | valid_tweets, should_clean_saved_model=False) 2260 | trainer.train_seq2seq_swapped(conversations_large_backward_hparams, 2261 | "conversations_large_seq2seq.txt", 2262 | ["この難にでも応用可能なひどいやつ", 2263 | "おはようございます。明日はよろしくおねがいします。"], 2264 | vocab_path="conversations_large_seq2seq_vocab.txt", 2265 | should_clean_saved_model=False) 2266 | 2267 | if not resume_rl: 2268 | Shell.copy_saved_model(conversations_large_hparams, 2269 | conversations_large_rl_hparams) 2270 | Trainer().train_rl(conversations_large_rl_hparams, 2271 | conversations_large_hparams, 2272 | conversations_large_backward_hparams, 2273 | "conversations_large_seq2seq.txt", 2274 | 2275 | "conversations_large_rl.txt", 2276 | valid_tweets) 2277 | --------------------------------------------------------------------------------