├── lib ├── __init__.py ├── chat.py ├── predict.py ├── config.py ├── train.py ├── seq2seq_model_utils.py ├── data_utils.py └── seq2seq_model.py ├── requirements.txt ├── .DS_Store ├── doc ├── messenger.png ├── beam_search_10.png └── beam_search_10_antilm.png ├── works └── lyrics_ptt │ └── data │ └── train │ └── chat.txt.gz ├── .gitignore ├── go_server ├── main.py ├── templates ├── index.html └── privacy.html ├── go_example ├── example_chat.md ├── app.py ├── README2.md ├── README.md └── ref ├── data_utils.py ├── translate.py ├── seq2seq_model.py └── seq2seq.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==0.10.0rc0 2 | jieba 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsan-ma/tf_chatbot_seq2seq_antilm/HEAD/.DS_Store -------------------------------------------------------------------------------- /doc/messenger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsan-ma/tf_chatbot_seq2seq_antilm/HEAD/doc/messenger.png -------------------------------------------------------------------------------- /doc/beam_search_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsan-ma/tf_chatbot_seq2seq_antilm/HEAD/doc/beam_search_10.png -------------------------------------------------------------------------------- /doc/beam_search_10_antilm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsan-ma/tf_chatbot_seq2seq_antilm/HEAD/doc/beam_search_10_antilm.png -------------------------------------------------------------------------------- /works/lyrics_ptt/data/train/chat.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marsan-ma/tf_chatbot_seq2seq_antilm/HEAD/works/lyrics_ptt/data/train/chat.txt.gz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | sftp-config.json 2 | *.pyc 3 | works/* 4 | !works/lyrics_ptt 5 | go_train 6 | *.yaml 7 | ssl/* 8 | lib/old 9 | lib/foward_only 10 | lib/foward_only_tensor 11 | -------------------------------------------------------------------------------- /go_server: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # python3 app.py --mode test --model_name 17live_comments --vocab_size 200000 --antilm 0.7 4 | 5 | python3 app.py --mode test --model_name twitter_en --vocab_size 100000 --antilm 0.7 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys, argparse 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import tensorflow as tf 6 | 7 | from lib.config import params_setup 8 | from lib.train import train 9 | from lib.predict import predict 10 | from lib.chat import chat 11 | # from lib.mert import mert 12 | 13 | 14 | def main(_): 15 | args = params_setup() 16 | print("[args]: ", args) 17 | if args.mode == 'train': 18 | train(args) 19 | elif args.mode == 'test': 20 | predict(args) 21 | elif args.mode == 'chat': 22 | chat(args) 23 | # elif args.mode == 'mert': 24 | # mert(args) 25 | 26 | 27 | if __name__ == "__main__": 28 | tf.app.run() -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Chatko 7 | 8 | 17 | 18 |
19 |
Chatko
20 |
21 |
This is a chatbot developed by Marsan Ma.
22 | 23 | -------------------------------------------------------------------------------- /go_example: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NOW=$(date +"%Y%m%d_%H%M") 4 | PY3='stdbuf -o0 nohup python3 -u' 5 | mkdir -p logs 6 | 7 | if [ ! -d ./works/example/nn_models ]; then 8 | wget https://github.com/Marsan-Ma/tf_chatbot_pretrained_model/raw/master/example.tar.gz.part-aa 9 | wget https://github.com/Marsan-Ma/tf_chatbot_pretrained_model/raw/master/example.tar.gz.part-ab 10 | cat example.tar.gz.part-* > example.tar.gz 11 | tar -zxvf example.tar.gz 12 | fi 13 | 14 | #============================= 15 | # Example Twitter Corpus 16 | #============================= 17 | # 1. train model 18 | # $PY3 main.py --mode train --model_name example --vocab_size 100000 --size 128 > "./logs/seq2seq_twitter_en_$NOW.log" & 19 | 20 | # 2. chat interactively in command line 21 | python3 main.py --mode chat --model_name example --vocab_size 100000 --size 128 # --beam_size 5 --antilm 0.7 22 | 23 | # 3. do batch output test 24 | # python3 main.py --mode test --model_name example --vocab_size 100000 --size 128 --beam_size 5 --antilm 0.7 25 | 26 | -------------------------------------------------------------------------------- /lib/chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import tensorflow as tf 5 | 6 | from lib import data_utils 7 | from lib.seq2seq_model_utils import create_model, get_predicted_sentence 8 | 9 | 10 | def chat(args): 11 | with tf.Session() as sess: 12 | # Create model and load parameters. 13 | args.batch_size = 1 # We decode one sentence at a time. 14 | model = create_model(sess, args) 15 | 16 | # Load vocabularies. 17 | vocab_path = os.path.join(args.data_dir, "vocab%d.in" % args.vocab_size) 18 | vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) 19 | 20 | # Decode from standard input. 21 | sys.stdout.write("> ") 22 | sys.stdout.flush() 23 | sentence = sys.stdin.readline() 24 | 25 | while sentence: 26 | predicted_sentence = get_predicted_sentence(args, sentence, vocab, rev_vocab, model, sess) 27 | # print(predicted_sentence) 28 | if isinstance(predicted_sentence, list): 29 | for sent in predicted_sentence: 30 | print(" (%s) -> %s" % (sent['prob'], sent['dec_inp'])) 31 | else: 32 | print(sentence, ' -> ', predicted_sentence) 33 | 34 | sys.stdout.write("> ") 35 | sys.stdout.flush() 36 | sentence = sys.stdin.readline() 37 | 38 | -------------------------------------------------------------------------------- /lib/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | from datetime import datetime 5 | 6 | from lib import data_utils 7 | from lib.seq2seq_model_utils import create_model, get_predicted_sentence 8 | 9 | 10 | def predict(args, debug=False): 11 | def _get_test_dataset(): 12 | with open(args.test_dataset_path) as test_fh: 13 | test_sentences = [s.strip() for s in test_fh.readlines()] 14 | return test_sentences 15 | 16 | results_filename = '_'.join(['results', str(args.num_layers), str(args.size), str(args.vocab_size)]) 17 | results_path = os.path.join(args.results_dir, results_filename+'.txt') 18 | 19 | with tf.Session() as sess, open(results_path, 'w') as results_fh: 20 | # Create model and load parameters. 21 | args.batch_size = 1 22 | model = create_model(sess, args) 23 | 24 | # Load vocabularies. 25 | vocab_path = os.path.join(args.data_dir, "vocab%d.in" % args.vocab_size) 26 | vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) 27 | 28 | test_dataset = _get_test_dataset() 29 | 30 | for sentence in test_dataset: 31 | # Get token-ids for the input sentence. 32 | predicted_sentence = get_predicted_sentence(args, sentence, vocab, rev_vocab, model, sess, debug=debug) 33 | if isinstance(predicted_sentence, list): 34 | print("%s : (%s)" % (sentence, datetime.now())) 35 | results_fh.write("%s : (%s)\n" % (sentence, datetime.now())) 36 | for sent in predicted_sentence: 37 | print(" (%s) -> %s" % (sent['prob'], sent['dec_inp'])) 38 | results_fh.write(" (%f) -> %s\n" % (sent['prob'], sent['dec_inp'])) 39 | else: 40 | print(sentence, ' -> ', predicted_sentence) 41 | results_fh.write("%s -> %s\n" % (sentence, predicted_sentence)) 42 | # break 43 | 44 | results_fh.close() 45 | print("results written in %s" % results_path) 46 | -------------------------------------------------------------------------------- /templates/privacy.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 隱私權政策 7 | 8 | 17 | 18 |
19 |
隱私權政策
20 |
21 |
適用範圍
22 | 以下的隱私權保護政策,適用於您在本活動網站活動時,所涉及的個人資料之收集、運用與保護。 23 |

24 |
資料蒐集
25 | 根據本網站所提供的不同服務,可能向網友蒐集下列個人基本資料: 26 |
    27 |
  • 會員資料:當您開始參予本網站之活動時,我們會請您同意我們存取您於facebook上註冊之個人資料:包括姓名、電子信箱等個人資訊。 28 |
  • 一般瀏覽: 本活動網站會保留網友在上網瀏覽或查詢時,伺服器自行產生的相關記錄(LOG),包括連線設備IP位址、使用時間、使用的瀏覽器、瀏覽及點選資料記錄等,以歸納使用者瀏覽器在本活動網站內部所瀏覽的網頁及瀏覽時間,俾據以提昇本活動網站的服務品質。 29 |
30 |

31 |
網友資料的分享、公開及運用方式
32 | 本活動網站不會任意出售、交換、或出租任何您的個人資料給其他團體或個人。只有在以下狀況,本活動網站會在「隱私權保護政策」原則之下,運用您的個人資料。 33 |
    34 |
  • 統計與分析:本活動網站根據使用者帳號資料、投票結果或伺服器日誌文件,進行統計分析與整理,做為本活動票選內容的結果,不會對各別使用者進行分析,亦不會提供特定對象個別資料之分析報告。 35 |
  • 司法單位因公眾安全,要求本活動網站公開特定個人資料時,本活動網站將視司法單位合法正式的程序,以及對本資訊網所有使用者安全考量下做可能必要之配合。除有關法律 裁判、政府審查、司法調查、犯罪預防或就非法活動採取行動、就懷疑詐騙或因事件涉及威脅到任何人的人身安全、以及違法之行為上有需要外,所有資料均屬保密。 36 |
37 |

38 |
隱私權保護政策修訂
39 | 本活動網站會不定時修訂本項政策,以符合最新之隱私權保護規範。當我們在使用個人資料的規定做較大幅度修改時,我們會在網頁上張貼告示,通知您相關修訂事項。 40 |

41 |
42 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def params_setup(cmdline=None): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--mode', type=str, required=True, help='work mode: train/test/chat') 6 | 7 | # path ctrl 8 | parser.add_argument('--model_name', type=str, default='movie_subtitles_en', help='model name, affects data, model, result save path') 9 | parser.add_argument('--scope_name', type=str, help='separate namespace, for multi-models working together') 10 | parser.add_argument('--work_root', type=str, default='works', help='root dir for data, model, result save path') 11 | 12 | # training params 13 | parser.add_argument('--learning_rate', type=float, default=0.5, help='Learning rate.') 14 | parser.add_argument('--learning_rate_decay_factor', type=float, default=0.99, help='Learning rate decays by this much.') 15 | parser.add_argument('--max_gradient_norm', type=float, default=5.0, help='Clip gradients to this norm.') 16 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size to use during training.') 17 | 18 | parser.add_argument('--vocab_size', type=int, default=100000, help='Dialog vocabulary size.') 19 | parser.add_argument('--size', type=int, default=256, help='Size of each model layer.') 20 | parser.add_argument('--num_layers', type=int, default=4, help='Number of layers in the model.') 21 | 22 | parser.add_argument('--max_train_data_size', type=int, default=0, help='Limit on the size of training data (0: no limit)') 23 | parser.add_argument('--steps_per_checkpoint', type=int, default=500, help='How many training steps to do per checkpoint') 24 | 25 | # predicting params 26 | parser.add_argument('--beam_size', type=int, default=1, help='beam search size') 27 | parser.add_argument('--antilm', type=float, default=0, help='anti-language model weight') 28 | parser.add_argument('--n_bonus', type=int, default=0, help='bonus with sentence length') 29 | 30 | # environment params 31 | parser.add_argument('--gpu_usage', type=float, default=1.0, help='tensorflow gpu memory fraction used') 32 | parser.add_argument('--rev_model', type=int, default=0, help='reverse Q-A pair, for bi-direction model') 33 | parser.add_argument('--reinforce_learn', type=int, default=0, help='1 to enable reinforcement learning mode') 34 | parser.add_argument('--en_tfboard', type=int, default=0, help='Enable writing out tensorboard meta data') 35 | 36 | 37 | if cmdline: 38 | args = parser.parse_args(cmdline) 39 | else: 40 | args = parser.parse_args() 41 | 42 | if not args.scope_name: args.scope_name = args.model_name 43 | if args.rev_model: args.model_name += '_bidi' # bi-direction model 44 | 45 | # We use a number of buckets and pad to the closest one for efficiency. 46 | # See seq2seq_model.Seq2SeqModel for details of how they work. 47 | args.buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 48 | 49 | # post-process 50 | args.workspace = '%s/%s' % (args.work_root, args.model_name) 51 | args.test_dataset_path = '%s/data/test/test_set.txt' % (args.workspace) 52 | args.mert_dataset_path = '%s/data/test/mert_set.txt' % (args.workspace) 53 | args.data_dir = '%s/data' % args.workspace 54 | args.model_dir = '%s/nn_models' % args.workspace 55 | args.results_dir = '%s/results' % args.workspace 56 | args.tf_board_dir = '%s/tf_board' % args.workspace 57 | return args 58 | 59 | -------------------------------------------------------------------------------- /example_chat.md: -------------------------------------------------------------------------------- 1 | ## 2017.03.14 2 | 3 | This is a chat history for the demo model of [tf_chatbot_seq2seq_antilm](https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm). 4 | 5 | You could use `./go_example` to get exactly the same model and same result, you could also try `beam_search` and `antiLM` by uncomment the option in `go_example`. 6 | 7 | 8 | > Tell me about Trump. 9 | (1) -> sorry funny for that though . 10 | > How about Hillary 11 | (1) -> better check up . 12 | > check up what? 13 | (1) -> where ? are you at ? 14 | > why you ask me where am I? 15 | (1) -> welcome to the gay party ! hugs to you anytime ! ! ! 16 | > no, I am not gay... 17 | (1) -> you are the sassy part of the country 18 | > thank you! 19 | (1) -> you're welcome ! 20 | > how about order some pizza for the lunch? 21 | (1) -> second pink pink league 22 | > what do you mean pink league? 23 | (1) -> do you think you're talking to me ? ? ? 24 | > are you argry? 25 | (1) -> yeah yes , thanks 😉 26 | > sorry for being rude. 27 | (1) -> what worry college use a twitter ? 28 | > which college do you mean? 29 | (1) -> i agree than i am , never understand about my point 30 | > yes I never understand your point, too... 31 | (1) -> who were that 32 | > who? 33 | (1) -> peach you on 34 | > who is peach? 35 | (1) -> playing jason 36 | > who is jason? 37 | (1) -> explosion for the hotel recorded . 38 | > ok, bye. 39 | (1) -> magic good show later scared 40 | > you idiot. 41 | (1) -> really . 42 | > so you think you know nothing? 43 | (1) -> i know yes 44 | > that's a good point... 45 | (1) -> i want these car keep brought up and getting regular 46 | > well, it's nice to chat with you. 47 | (1) -> i have no idea 48 | > why? 49 | (1) -> i feel from the level of anyone . 50 | > I am an engineer. 51 | (1) -> too babe ! good fries 52 | > yes, engineers love fries, and cokes. 53 | (1) -> i have the full of dreams in la . 54 | > I've been in LA last month 55 | (1) -> bake just watch - and watch the whole phone out 56 | > I can't understand you. 57 | (1) -> i have to research with you 58 | > That's great idea... 59 | (1) -> islam is not the truth 60 | > don't comment on islam... 61 | (1) -> haha ? sec of the jets are related . 62 | > well at least you are not always answering i don't know. 63 | (1) -> i wish i should be scared too 64 | > you never scare someone. 65 | (1) -> i thought that was 66 | > that was ... what? 67 | (1) -> and again you didn't enjoy em 68 | > Well, I enjoy talking with you. 69 | (1) -> don't even bring shot for the show ? 70 | > no I won't shot you. 71 | (1) -> im too expensive with that 72 | > yes, you cost a lot, lol. 73 | (1) -> just then i saw it . 74 | > an AWS g2 instance is not cheap at all 75 | (1) -> honestly rose can’t use law mode without a kind of eyes 76 | > that's a anology hard to understand... 77 | (1) -> why are you there ? ? ? 78 | > because i am talking with you. 79 | (1) -> i actually want to read that . 80 | > read about what? more training corpus? 81 | (1) -> by the android of a nuclear first . 82 | > why android got to do with nuclear? 83 | (1) -> unlike gary gold , you expect ? you are the muslim . 84 | > no i am not ... 85 | (1) -> i can't believe twitter 86 | > but i use twitter corpus to train your model ... 87 | (1) -> rocking hair ass at my house 88 | > you high? 89 | (1) -> nope , i plan to see it . 90 | > you plan to see rocking hair ass ... whatever. 91 | (1) -> i still have a chance but i was coming rn 92 | > yeah you have many chances, i believe you are promising. 93 | (1) -> i heard of this as well . 94 | > yeah, that's cool. 95 | (1) -> you will be a team this year . 96 | > that's cool, i've tired of work alone. 97 | (1) -> who really 98 | > ok, seems you are good enough for publish this time. 99 | (1) -> yeah , thanks ! all more him with letting it know something better . 100 | > whatever you say ... being a language model you are good. 101 | (1) -> congrats , - so many need to have a better place : ) 102 | > I'll find some way to make your response more consistent. 103 | (1) -> i want to buy the experience , and i wear a phone button . 104 | > yeah you did need much experience by some other learning ways! 105 | (1) -> with new york ? 106 | > with new network! 107 | (1) -> and i have a great weekend , bruce : 108 | > no I am not bruce ... 109 | (1) -> i can't believe i was sorry for him 110 | > you mean bruce lee? 111 | (1) -> was the part of the new york . 112 | > no york is absense today. 113 | (1) -> i will have been watching a season lately on ! 114 | > see you! 115 | (1) -> thanks buddy ! you rock 😘 116 | > thanks :) 117 | (1) -> high night 118 | 119 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os, json, yaml, requests, jieba 3 | from datetime import datetime 4 | from flask import Flask, request, render_template 5 | from OpenSSL import SSL 6 | from random import random, choice 7 | 8 | app = Flask(__name__) 9 | 10 | #--------------------------- 11 | # Load Model 12 | #--------------------------- 13 | import tensorflow as tf 14 | from lib import data_utils 15 | from lib.config import params_setup 16 | from lib.seq2seq_model_utils import create_model, get_predicted_sentence 17 | 18 | 19 | class ChatBot(object): 20 | 21 | def __init__(self, args, debug=False): 22 | start_time = datetime.now() 23 | 24 | # flow ctrl 25 | self.args = args 26 | self.debug = debug 27 | self.fbm_processed = [] 28 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_usage) 29 | self.sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options)) 30 | 31 | # Create model and load parameters. 32 | self.args.batch_size = 1 # We decode one sentence at a time. 33 | self.model = create_model(self.sess, self.args) 34 | 35 | # Load vocabularies. 36 | self.vocab_path = os.path.join(self.args.data_dir, "vocab%d.in" % self.args.vocab_size) 37 | self.vocab, self.rev_vocab = data_utils.initialize_vocabulary(self.vocab_path) 38 | print("[ChatBot] model initialize, cost %i secs" % (datetime.now() - start_time).seconds) 39 | 40 | # load yaml setup 41 | self.FBM_API = "https://graph.facebook.com/v2.6/me/messages" 42 | with open("config.yaml", 'rt') as stream: 43 | try: 44 | cfg = yaml.load(stream) 45 | self.FACEBOOK_TOKEN = cfg.get('FACEBOOK_TOKEN') 46 | self.VERIFY_TOKEN = cfg.get('VERIFY_TOKEN') 47 | except yaml.YAMLError as exc: 48 | print(exc) 49 | 50 | 51 | def process_fbm(self, payload): 52 | for sender, msg in self.fbm_events(payload): 53 | self.fbm_api({"recipient": {"id": sender}, "sender_action": 'typing_on'}) 54 | resp = self.gen_response(msg) 55 | self.fbm_api({"recipient": {"id": sender}, "message": {"text": resp}}) 56 | if self.debug: print("%s: %s => resp: %s" % (sender, msg, resp)) 57 | 58 | 59 | def gen_response(self, sent, max_cand=100): 60 | sent = " ".join([w.lower() for w in jieba.cut(sent) if w not in [' ']]) 61 | # if self.debug: return sent 62 | raw = get_predicted_sentence(self.args, sent, self.vocab, self.rev_vocab, self.model, self.sess, debug=False) 63 | # find bests candidates 64 | cands = sorted(raw, key=lambda v: v['prob'], reverse=True)[:max_cand] 65 | 66 | if max_cand == -1: # return all cands for debug 67 | cands = [(r['prob'], ' '.join([w for w in r['dec_inp'].split() if w[0] != '_'])) for r in cands] 68 | return cands 69 | else: 70 | cands = [[w for w in r['dec_inp'].split() if w[0] != '_'] for r in cands] 71 | return ' '.join(choice(cands)) or 'No comment' 72 | 73 | 74 | def gen_response_debug(self, sent, args=None): 75 | sent = " ".join([w.lower() for w in jieba.cut(sent) if w not in [' ']]) 76 | raw = get_predicted_sentence(args, sent, self.vocab, self.rev_vocab, self.model, self.sess, debug=False, return_raw=True) 77 | return raw 78 | 79 | 80 | #------------------------------ 81 | # FB Messenger API 82 | #------------------------------ 83 | def fbm_events(self, payload): 84 | data = json.loads(payload.decode('utf8')) 85 | if self.debug: print("[fbm_payload]", data) 86 | for event in data["entry"][0]["messaging"]: 87 | if "message" in event and "text" in event["message"]: 88 | q = (event["sender"]["id"], event["message"]["seq"]) 89 | if q in self.fbm_processed: 90 | continue 91 | else: 92 | self.fbm_processed.append(q) 93 | yield event["sender"]["id"], event["message"]["text"] 94 | 95 | 96 | def fbm_api(self, data): 97 | r = requests.post(self.FBM_API, 98 | params={"access_token": self.FACEBOOK_TOKEN}, 99 | data=json.dumps(data), 100 | headers={'Content-type': 'application/json'}) 101 | if r.status_code != requests.codes.ok: 102 | print("fb error:", r.text) 103 | if self.debug: print("fbm_send", r.status_code, r.text) 104 | 105 | 106 | #--------------------------- 107 | # Server 108 | #--------------------------- 109 | @app.route('/chat', methods=['GET']) 110 | def verify(): 111 | if request.args.get('hub.verify_token', '') == chatbot.VERIFY_TOKEN: 112 | return request.args.get('hub.challenge', '') 113 | else: 114 | return 'Error, wrong validation token' 115 | 116 | @app.route('/chat', methods=['POST']) 117 | def chat(): 118 | payload = request.get_data() 119 | chatbot.process_fbm(payload) 120 | return "ok" 121 | 122 | 123 | @app.route('/', methods=['GET']) 124 | def home(): 125 | return render_template('index.html') 126 | 127 | 128 | @app.route('/privacy', methods=['GET']) 129 | def privacy(): 130 | return render_template('privacy.html') 131 | 132 | 133 | #--------------------------- 134 | # Start Server 135 | #--------------------------- 136 | if __name__ == '__main__': 137 | # check ssl files 138 | if not os.path.exists('ssl/server.crt'): 139 | print("SSL certificate not found! (should placed in ./ssl/server.crt)") 140 | elif not os.path.exists('ssl/server.key'): 141 | print("SSL key not found! (should placed in ./ssl/server.key)") 142 | else: 143 | # initialize model 144 | args = params_setup() 145 | chatbot = ChatBot(args, debug=False) 146 | # start server 147 | context = ('ssl/server.crt', 'ssl/server.key') 148 | app.run(host='0.0.0.0', port=443, debug=False, ssl_context=context) 149 | 150 | -------------------------------------------------------------------------------- /lib/train.py: -------------------------------------------------------------------------------- 1 | import sys, os, math, time, argparse, shutil, gzip 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | 7 | from six.moves import xrange # pylint: disable=redefined-builtin 8 | from datetime import datetime 9 | from lib import seq2seq_model_utils, data_utils 10 | 11 | 12 | def setup_workpath(workspace): 13 | for p in ['data', 'nn_models', 'results']: 14 | wp = "%s/%s" % (workspace, p) 15 | if not os.path.exists(wp): os.mkdir(wp) 16 | 17 | data_dir = "%s/data" % (workspace) 18 | # training data 19 | if not os.path.exists("%s/chat.in" % data_dir): 20 | n = 0 21 | f_zip = gzip.open("%s/train/chat.txt.gz" % data_dir, 'rt') 22 | f_train = open("%s/chat.in" % data_dir, 'w') 23 | f_dev = open("%s/chat_test.in" % data_dir, 'w') 24 | for line in f_zip: 25 | f_train.write(line) 26 | if n < 10000: 27 | f_dev.write(line) 28 | n += 1 29 | 30 | 31 | def train(args): 32 | print("[%s] Preparing dialog data in %s" % (args.model_name, args.data_dir)) 33 | setup_workpath(workspace=args.workspace) 34 | train_data, dev_data, _ = data_utils.prepare_dialog_data(args.data_dir, args.vocab_size) 35 | 36 | if args.reinforce_learn: 37 | args.batch_size = 1 # We decode one sentence at a time. 38 | 39 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_usage) 40 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 41 | 42 | # Create model. 43 | print("Creating %d layers of %d units." % (args.num_layers, args.size)) 44 | model = seq2seq_model_utils.create_model(sess, args, forward_only=False) 45 | 46 | # Read data into buckets and compute their sizes. 47 | print("Reading development and training data (limit: %d)." % args.max_train_data_size) 48 | dev_set = data_utils.read_data(dev_data, args.buckets, reversed=args.rev_model) 49 | train_set = data_utils.read_data(train_data, args.buckets, args.max_train_data_size, reversed=args.rev_model) 50 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(args.buckets))] 51 | train_total_size = float(sum(train_bucket_sizes)) 52 | 53 | # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use 54 | # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to 55 | # the size if i-th training bucket, as used later. 56 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 57 | for i in xrange(len(train_bucket_sizes))] 58 | 59 | # This is the training loop. 60 | step_time, loss = 0.0, 0.0 61 | current_step = 0 62 | previous_losses = [] 63 | 64 | # Load vocabularies. 65 | vocab_path = os.path.join(args.data_dir, "vocab%d.in" % args.vocab_size) 66 | vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) 67 | 68 | while True: 69 | # Choose a bucket according to data distribution. We pick a random number 70 | # in [0, 1] and use the corresponding interval in train_buckets_scale. 71 | random_number_01 = np.random.random_sample() 72 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) 73 | if train_buckets_scale[i] > random_number_01]) 74 | 75 | # Get a batch and make a step. 76 | start_time = time.time() 77 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 78 | train_set, bucket_id) 79 | 80 | # print("[shape]", np.shape(encoder_inputs), np.shape(decoder_inputs), np.shape(target_weights)) 81 | if args.reinforce_learn: 82 | _, step_loss, _ = model.step_rf(args, sess, encoder_inputs, decoder_inputs, 83 | target_weights, bucket_id, rev_vocab=rev_vocab) 84 | else: 85 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 86 | target_weights, bucket_id, forward_only=False, force_dec_input=True) 87 | 88 | step_time += (time.time() - start_time) / args.steps_per_checkpoint 89 | loss += step_loss / args.steps_per_checkpoint 90 | current_step += 1 91 | 92 | # Once in a while, we save checkpoint, print statistics, and run evals. 93 | if (current_step % args.steps_per_checkpoint == 0) and (not args.reinforce_learn): 94 | # Print statistics for the previous epoch. 95 | perplexity = math.exp(loss) if loss < 300 else float('inf') 96 | print ("global step %d learning rate %.4f step-time %.2f perplexity %.2f @ %s" % 97 | (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity, datetime.now())) 98 | 99 | # Decrease learning rate if no improvement was seen over last 3 times. 100 | if len(previous_losses) > 2 and loss > max(previous_losses[-3:]): 101 | sess.run(model.learning_rate_decay_op) 102 | 103 | previous_losses.append(loss) 104 | 105 | # # Save checkpoint and zero timer and loss. 106 | checkpoint_path = os.path.join(args.model_dir, "model.ckpt") 107 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 108 | step_time, loss = 0.0, 0.0 109 | 110 | # Run evals on development set and print their perplexity. 111 | for bucket_id in xrange(len(args.buckets)): 112 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(dev_set, bucket_id) 113 | _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 114 | target_weights, bucket_id, forward_only=True, force_dec_input=False) 115 | 116 | eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') 117 | print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) 118 | 119 | sys.stdout.flush() 120 | -------------------------------------------------------------------------------- /README2.md: -------------------------------------------------------------------------------- 1 | # A more detailed explaination about "the tensorflow chatbot" 2 | 3 | Here I'll try to explain some algorithm and implementation details about [this work][a1] in layman's terms. 4 | 5 | [a1]: https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm 6 | 7 | 8 | ## Sequence to sequence model 9 | 10 | ### What is a language model? 11 | 12 | Let's say a language model is ... 13 | a) Trained by a lot of corpus. 14 | b) It could predict the **probability of next word** given foregoing words. 15 | => It's just conditional probability, **P(next_word | foregoing_words)** 16 | c) Since we could predict next word: 17 | => then predict even next, according to words just been generated 18 | => continuously, we could produce sentences, even paragraph. 19 | 20 | We could easily achieve this by simple [LSTM model][b1]. 21 | 22 | 23 | ### The seq2seq model architecture 24 | 25 | Again we quote this seq2seq architecture from [Google's blogpost] 26 | [![seq2seq][b2]][b3] 27 | 28 | It's composed of two language model: encoder and decoder. Both of them could be LSTM model we just mentioned. 29 | 30 | The encoder part accept input tokens and transform the whole input sentence into an embedding **"thought vector"**, which express the meaning of input sentence in our language model domain. 31 | 32 | Then the decoder is just a language model, like we just said, a language model could generate new sentence according to foregoing corpus. Here we use this **"thought vector"** as kick-off and receive the corresponding mapping, and decode it into the response. 33 | 34 | 35 | ### Reversed encoder input and Attention mechanism 36 | 37 | Now you might wonder: 38 | a) Considering this architecture, wil the "thought vector" be dominated by later stages of encoder? 39 | b) Is that enough to represent the meaning of whole input sentence into just a vector? 40 | 41 | 42 | For (a) actually, one of the implement detail we didn't mention before: the input sentence will be reversed before input to the encoder. Thus we shorten the distance between head of input sentence and head of response sentence. Empirically, it achieves better results. (This trick is not shown in the architecture figure above, for easy to understanding) 43 | 44 | For (b), another methods to disclose more information to decoder is the [attention mechanism][b4]. The idea is simple: allowing each stage in decoder to peep any encoder stages, if they found useful in training phase. So decoder could understand the input sentence more and automagically peep suitable positions while generating response. 45 | 46 | 47 | 48 | [b1]: http://colah.github.io/posts/2015-08-Understanding-LSTMs 49 | [b2]: http://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png 50 | [b3]: http://googleresearch.blogspot.ru/2015/11/computer-respond-to-this-email.html 51 | [b4]: http://arxiv.org/abs/1412.7449 52 | 53 | 54 | 55 | ## Techniques about language model 56 | 57 | ### Dictionary space compressing and projection 58 | 59 | A naive implementation of language model is: suppose we are training english language model, which a dictionary size of 80,000 is roughly enough. As we one-hot coding each word in our dictionary, our LSTM cell should have 80,000 outputs and we will do the softmax to choose for words with best probability... 60 | 61 | ... even if you have lots of computing resource, you don't need to waste like that. Especially if you are dealing with some other languages with more words like Chinese, which 200,000 words is barely enough. 62 | 63 | Practically, we could reduce this 80,000 one-hot coding dictionary into embedding spaces, we could use like 64, 128 or 256 dimention to embed our 80,000 words dictionary, and train our model with only by this lower dimention. Then finally when we are generating the response, we project the embedding back into one-hot coding space for dictionary lookup. 64 | 65 | 66 | ### Beam search 67 | 68 | The original implementation of tensorflow decode response sentence greedily. Empirically this trapped result in local optimum, and result in dump response which do have maximum probability in first couple of words. 69 | 70 | So we do the beam search, keep best N candidates and move-forward, thus we could avoid local optimum and find more longer, interesting responses more closer to global optimum result. 71 | 72 | In [this paper][b4], Google Brain team found that beam search didn't benefit a lot in machine translation, I guess that's why they didn't implement beam search. But in my experience, chatbot do benefit a lot from beam search. 73 | 74 | 75 | ## Anti-Language Model 76 | 77 | ### Generic response problem 78 | 79 | As the seq2seq model is trained by [MLE][c1] (maximum likelyhood estimation), the model do follow this object function by finding the "most possible" response well. But in human dialogue, a response with high probability like "thank you", "I don't know", "I love you" is not informative at all. 80 | 81 | As currently we haven't find a good enough object function to replace MLE, there are some works to suppress this "generic response problem". 82 | 83 | 84 | ### Supressing generic response 85 | 86 | The work of [Li. et al][c2] from Stanford and Microsoft Research try to suppress generic response by lower the probability of generic response from candidates while doing the beam search. 87 | 88 | The idea is somewhat like Tf-Idf: if this response is suitable for all kinds of foregoing sentence, which means it's not specific answer for current sentence, then we discard it. 89 | 90 | According my own experiment result, this helps a lot! Although the cost is that we will choose something grammatically not so correct, but most of time the effect is acceptable. It does generate more interesting, informative response. 91 | 92 | 93 | [c1]: https://en.wikipedia.org/wiki/Maximum_likelihood_estimation 94 | [c2]: https://arxiv.org/abs/1606.01541 95 | 96 | 97 | 98 | ## Deep Reinforcement Learning 99 | 100 | ### Reinforcement learning 101 | 102 | Reinforcement learning is a promising domain now (in 2016). It's promising because it solve the delayed reward problem, and that's a huge merit for chatbot training. Since we could judge a continuous dialogue includeing several sentences, rather than one single sentence at a time. We could design more sophiscated metrics to reward the model and make it learn more abstract ideas. 103 | 104 | 105 | ### Implement tricks in tensorflow 106 | 107 | The magic of tensorflow is that it construct a graph, which all the computing in graph could be dispatched automagically to CPU, GPU, or even distributed system (more CPU/GPU). 108 | 109 | So far tensorflow have no native supporting operations for the delayed rewarding, so we have to do some work-around. We will calculate the gradients in graph, and accumulate and do post-processing to them out-of-graph, finally inject them back to do the `apply_gradient()`. You could find a minimum example in [this ipython notebook][d1]. 110 | 111 | 112 | [d1]: https://github.com/awjuliani/DeepRL-Agents/blob/master/Policy-Network.ipynb 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /lib/seq2seq_model_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from random import random 8 | from datetime import datetime 9 | from tensorflow.python.platform import gfile 10 | 11 | from lib import data_utils 12 | from lib import seq2seq_model 13 | 14 | 15 | import heapq 16 | 17 | def create_model(session, args, forward_only=True): 18 | """Create translation model and initialize or load parameters in session.""" 19 | model = seq2seq_model.Seq2SeqModel( 20 | source_vocab_size=args.vocab_size, 21 | target_vocab_size=args.vocab_size, 22 | buckets=args.buckets, 23 | size=args.size, 24 | num_layers=args.num_layers, 25 | max_gradient_norm=args.max_gradient_norm, 26 | batch_size=args.batch_size, 27 | learning_rate=args.learning_rate, 28 | learning_rate_decay_factor=args.learning_rate_decay_factor, 29 | forward_only=forward_only, 30 | ) 31 | 32 | # for tensorboard 33 | if args.en_tfboard: 34 | summary_writer = tf.train.SummaryWriter(args.tf_board_dir, session.graph) 35 | 36 | ckpt = tf.train.get_checkpoint_state(args.model_dir) 37 | # if ckpt and gfile.Exists(ckpt.model_checkpoint_path): 38 | if ckpt and ckpt.model_checkpoint_path: 39 | print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now())) 40 | model.saver.restore(session, ckpt.model_checkpoint_path) 41 | print("Model reloaded @ %s" % (datetime.now())) 42 | else: 43 | print("Created model with fresh parameters.") 44 | session.run(tf.global_variables_initializer()) 45 | return model 46 | 47 | 48 | def dict_lookup(rev_vocab, out): 49 | word = rev_vocab[out] if (out < len(rev_vocab)) else data_utils._UNK 50 | if isinstance(word, bytes): 51 | word = word.decode() 52 | return word 53 | 54 | 55 | 56 | def softmax(x): 57 | prob = np.exp(x) / np.sum(np.exp(x), axis=0) 58 | return prob 59 | 60 | 61 | 62 | def cal_bleu(cands, ref, stopwords=['的', '嗎']): 63 | cands = [s['dec_inp'].split() for s in cands] 64 | cands = [[w for w in sent if w[0] != '_'] for sent in cands] 65 | refs = [w for w in ref.split() if w not in stopwords] 66 | bleus = [] 67 | for cand in cands: 68 | if len(cand) < 4: cand += [''] * (4 - len(cand)) 69 | bleu = sentence_bleu(refs, cand) 70 | bleus.append(bleu) 71 | print(refs, cand, bleu) 72 | return np.average(bleus) 73 | 74 | 75 | 76 | def get_predicted_sentence(args, input_sentence, vocab, rev_vocab, model, sess, debug=False, return_raw=False): 77 | def model_step(enc_inp, dec_inp, dptr, target_weights, bucket_id): 78 | _, _, logits = model.step(sess, enc_inp, dec_inp, target_weights, bucket_id, forward_only=True) 79 | prob = softmax(logits[dptr][0]) 80 | # print("model_step @ %s" % (datetime.now())) 81 | return prob 82 | 83 | def greedy_dec(output_logits, rev_vocab): 84 | selected_token_ids = [int(np.argmax(logit, axis=1)) for logit in output_logits] 85 | if data_utils.EOS_ID in selected_token_ids: 86 | eos = selected_token_ids.index(data_utils.EOS_ID) 87 | selected_token_ids = selected_token_ids[:eos] 88 | output_sentence = ' '.join([dict_lookup(rev_vocab, t) for t in selected_token_ids]) 89 | return output_sentence 90 | 91 | input_token_ids = data_utils.sentence_to_token_ids(input_sentence, vocab) 92 | 93 | # Which bucket does it belong to? 94 | bucket_id = min([b for b in range(len(args.buckets)) if args.buckets[b][0] > len(input_token_ids)]) 95 | outputs = [] 96 | feed_data = {bucket_id: [(input_token_ids, outputs)]} 97 | 98 | # Get a 1-element batch to feed the sentence to the model. 99 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(feed_data, bucket_id) 100 | if debug: print("\n[get_batch]\n", encoder_inputs, decoder_inputs, target_weights) 101 | 102 | ### Original greedy decoding 103 | if args.beam_size == 1: 104 | _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=True) 105 | return [{"dec_inp": greedy_dec(output_logits, rev_vocab), 'prob': 1}] 106 | 107 | # Get output logits for the sentence. 108 | beams, new_beams, results = [(1, 0, {'eos': 0, 'dec_inp': decoder_inputs, 'prob': 1, 'prob_ts': 1, 'prob_t': 1})], [], [] # initialize beams as (log_prob, empty_string, eos) 109 | dummy_encoder_inputs = [np.array([data_utils.PAD_ID]) for _ in range(len(encoder_inputs))] 110 | 111 | for dptr in range(len(decoder_inputs)-1): 112 | if dptr > 0: 113 | target_weights[dptr] = [1.] 114 | beams, new_beams = new_beams[:args.beam_size], [] 115 | if debug: print("=====[beams]=====", beams) 116 | heapq.heapify(beams) # since we will remove something 117 | for prob, _, cand in beams: 118 | if cand['eos']: 119 | results += [(prob, 0, cand)] 120 | continue 121 | 122 | # normal seq2seq 123 | if debug: print(cand['prob'], " ".join([dict_lookup(rev_vocab, w) for w in cand['dec_inp']])) 124 | 125 | all_prob_ts = model_step(encoder_inputs, cand['dec_inp'], dptr, target_weights, bucket_id) 126 | if args.antilm: 127 | # anti-lm 128 | all_prob_t = model_step(dummy_encoder_inputs, cand['dec_inp'], dptr, target_weights, bucket_id) 129 | # adjusted probability 130 | all_prob = all_prob_ts - args.antilm * all_prob_t #+ args.n_bonus * dptr + random() * 1e-50 131 | else: 132 | all_prob_t = [0]*len(all_prob_ts) 133 | all_prob = all_prob_ts 134 | 135 | # suppress copy-cat (respond the same as input) 136 | if dptr < len(input_token_ids): 137 | all_prob[input_token_ids[dptr]] = all_prob[input_token_ids[dptr]] * 0.01 138 | 139 | # for debug use 140 | if return_raw: return all_prob, all_prob_ts, all_prob_t 141 | 142 | # beam search 143 | for c in np.argsort(all_prob)[::-1][:args.beam_size]: 144 | new_cand = { 145 | 'eos' : (c == data_utils.EOS_ID), 146 | 'dec_inp' : [(np.array([c]) if i == (dptr+1) else k) for i, k in enumerate(cand['dec_inp'])], 147 | 'prob_ts' : cand['prob_ts'] * all_prob_ts[c], 148 | 'prob_t' : cand['prob_t'] * all_prob_t[c], 149 | 'prob' : cand['prob'] * all_prob[c], 150 | } 151 | new_cand = (new_cand['prob'], random(), new_cand) # stuff a random to prevent comparing new_cand 152 | 153 | try: 154 | if (len(new_beams) < args.beam_size): 155 | heapq.heappush(new_beams, new_cand) 156 | elif (new_cand[0] > new_beams[0][0]): 157 | heapq.heapreplace(new_beams, new_cand) 158 | except Exception as e: 159 | print("[Error]", e) 160 | print("-----[new_beams]-----\n", new_beams) 161 | print("-----[new_cand]-----\n", new_cand) 162 | 163 | results += new_beams # flush last cands 164 | 165 | # post-process results 166 | res_cands = [] 167 | for prob, _, cand in sorted(results, reverse=True): 168 | cand['dec_inp'] = " ".join([dict_lookup(rev_vocab, w) for w in cand['dec_inp']]) 169 | res_cands.append(cand) 170 | return res_cands[:args.beam_size] 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow chatbot 2 | ### (with seq2seq + attention + dict-compress + beam search + anti-LM + facebook messenger server) 3 | 4 | 5 | > ####[Update 2017-03-14] 6 | > 1. Upgrade to tensorflow v1.0.0, no backward compatible since tensorflow have changed so much. 7 | > 2. A pre-trained model with twitter corpus is added, just `./go_example` to chat! (or preview my [chat example](https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm/blob/master/example_chat.md)) 8 | > 3. You could start from tracing this `go_example` script to know how things work! 9 | 10 | 11 | ## Briefing 12 | This is a [seq2seq model][a1] modified from [tensorflow example][a2]. 13 | 14 | 1. The original tensorflow seq2seq has [attention mechanism][a3] implemented out-of-box. 15 | 2. And speedup training by [dictionary space compressing][a4], then decompressed by projection the embedding while decoding. 16 | 3. This work add option to do [beam search][a5] in decoding procedure, which usually find better, more interesting response. 17 | 4. Added [anti-language model][a6] to suppress the generic response problem of intrinsic seq2seq model. 18 | 5. Imeplemented [this deep reinforcement learning architecture][a7] as an option to enhence semantic coherence and perplexity of response. 19 | 6. A light weight [Flask][a8] server `app.py` is included to be the Facebook Messenger App backend. 20 | 21 | 22 | [a1]: http://arxiv.org/abs/1406.1078 23 | [a2]: https://www.tensorflow.org/versions/r0.10/tutorials/seq2seq/index.html 24 | [a3]: http://arxiv.org/abs/1412.7449 25 | [a4]: https://arxiv.org/pdf/1412.2007v2.pdf 26 | [a5]: https://en.wikipedia.org/wiki/Beam_search 27 | [a6]: https://arxiv.org/abs/1510.03055 28 | [a7]: https://arxiv.org/abs/1606.01541 29 | [a8]: http://flask.pocoo.org/ 30 | [a9]: https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm/blob/master/README2.md 31 | 32 | 33 | ## In Layman's terms 34 | 35 | I explained some detail about the features and some implementation tricks [here][a9]. 36 | 37 | 38 | ## Just tell me how it works 39 | 40 | #### Clone the repository 41 | 42 | git clone github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm.git 43 | 44 | #### Prepare for Corpus 45 | You may find corpus such as twitter chat, open movie subtitle, or ptt forums from [my chat corpus repository][b1]. You need to put it under path like: 46 | 47 | tf_chatbot_seq2seq_antilm/works//data/train/chat.txt 48 | 49 | And hand craft some testing sentences (each sentence per line) in: 50 | 51 | tf_chatbot_seq2seq_antilm/works//data/test/test_set.txt 52 | 53 | #### Train the model 54 | 55 | python3 main.py --mode train --model_name 56 | 57 | #### Run some test example and see the bot response 58 | 59 | after you trained your model until perplexity under 50 or so, you could do: 60 | 61 | python3 main.py --mode test --model_name 62 | 63 | 64 | **[Note!!!] if you put any parameter overwrite in this main.py commmand, be sure to apply both to train and test, or just modify in lib/config.py for failsafe.** 65 | 66 | 67 | 68 | ## Start your Facebook Messenger backend server 69 | 70 | python3 app.py --model_name 71 | 72 | You may see this [minimum fb_messenger example][b2] for more details like setting up SSL, webhook, and work-arounds for known bug. 73 | 74 | Here's an interesting comparison: The left conversation enabled beam search with beam = 10, the response is barely better than always "i don't know". The right conversation also used beam search and additionally, enabled anti-language model. This supposed to suppress generic response, and the response do seems better. 75 | 76 | ![messenger.png][h1] 77 | 78 | [h1]: https://raw.githubusercontent.com/Marsan-Ma/tf_chatbot_seq2seq_antilm/master/doc/messenger.png 79 | 80 | 81 | 82 | 83 | ## Deep reinforcement learning 84 | 85 | > [Update 2017-03-09] Reinforcement learning does not work now, wait for fix. 86 | 87 | If you want some chance to further improve your model, here I implemented a reinforcement learning architecture inspired by [Li et al., 2016][b3]. Just enable the reinforce_learn option in `config.py`, you might want to add your own rule in `step_rf()` function in `lib/seq2seq_mode.py`. 88 | 89 | Note that you should **train in normal mode to get a decent model first!**, since the reinforcement learning will explore the brave new world with this pre-trained model. It will end up taking forever to improve itself if you start with a bad model. 90 | 91 | [b1]: https://github.com/Marsan-Ma/chat_corpus 92 | [b2]: https://github.com/Marsan-Ma/fb_messenger 93 | [b3]: https://arxiv.org/abs/1606.01541 94 | 95 | ## Introduction 96 | 97 | Seq2seq is a great model released by [Cho et al., 2014][c1]. At first it's used to do machine translation, and soon people find that anything about **mapping something to another thing** could be also achieved by seq2seq model. Chatbot is one of these miracles, where we consider consecutive dialog as some kind of "mapping" relationship. 98 | 99 | Here is the classic intro picture show the seq2seq model architecture, quote from this [blogpost about gmail auto-reply feature][c2]. 100 | 101 | [![seq2seq][c3]][c3] 102 | 103 | 104 | The problem is, so far we haven't find a better objective function for chatbot. We are still using [MLE (maximum likelyhood estimation)][c4], which is doing good for machine translation, but always generate generic response like "me too", "I think so", "I love you" while doing chat. 105 | 106 | These responses are not informative, but they do have large probability --- since they tend to appear many times in training corpus. We don't won't our chatbot always replying these noncense, so we need to find some way to make our bot more "interesting", technically speaking, to increase the "perplexity" of reponse. 107 | 108 | Here we reproduce the work of [Li. et al., 2016][c5] try to solve this problem. The main idea is using the same seq2seq model as a language model, to get the candidate words with high probability in each decoding timestamp as a anti-model, then we penalize these words always being high probability for any input. By this anti-model, we could get more special, non-generic, informative response. 109 | 110 | The original work of [Li. et al][c5] use [MERT (Och, 2003)][c6] with [BLEU][c7] as metrics to find the best probability weighting (the **λ** and **γ** in 111 | **Score(T) = p(T|S) − λU(T) + γNt**) of the corresponding anti-language model. But I find that BLEU score in chat corpus tend to always being zero, thus can't get meaningful result here. If anyone has any idea about this, drop me a message, thanks! 112 | 113 | 114 | [c1]: http://arxiv.org/abs/1406.1078 115 | [c2]: http://googleresearch.blogspot.ru/2015/11/computer-respond-to-this-email.html 116 | [c3]: http://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png 117 | [c4]: https://en.wikipedia.org/wiki/Maximum_likelihood_estimation 118 | [c5]: http://arxiv.org/pdf/1510.03055v3.pdf 119 | [c6]: http://delivery.acm.org/10.1145/1080000/1075117/p160-och.pdf 120 | [c7]: https://en.wikipedia.org/wiki/BLEU 121 | 122 | 123 | ## Parameters 124 | 125 | There are some options to for model training and predicting in lib/config.py. Basically they are self-explained and could work with default value for most of cases. Here we only list something you need to config: 126 | 127 | **About environment** 128 | 129 | name | type | Description 130 | ---- | ---- | ----------- 131 | mode | string | work mode: train/test/chat 132 | model_name | string | model name, affects your working path (storing the data, nn_model, result folders) 133 | scope_name | string | In tensorflow if you need to load two graph at the same time, you need to save/load them in different namespace. (If you need only one seq2seq model, leave it as default) 134 | vocab_size | integer | depends on your corpus language: for english, 60000 is good enough. For chinese you need at least 100000 or 200000. 135 | gpu_usage | float | tensorflow gpu memory fraction used, default is 1 and tensorflow will occupy 100% of your GPU. If you have multi jobs sharing your GPU resource, make it 0.5 or 0.3, for 2 or 3 jobs. 136 | reinforce_learn | int | set 1 to enable reinforcement learning mode 137 | 138 | 139 | **About decoding** 140 | 141 | name | type | default | Description 142 | ---- | ---- | ------- | ------- 143 | beam_size | int | 10 | beam search size, setting 1 equals to greedy search 144 | antilm | float | 0 (disabled) | punish weight of [anti-language model][d1] 145 | n_bonus | float | 0 (disabled) | reward weight of sentence length 146 | 147 | 148 | The anti-LM functin is disabled by default, you may start from setting antilm=0.5~0.7 and n_bonus=0.05 to see if you like the difference in results. 149 | 150 | [d1]: http://arxiv.org/pdf/1510.03055v3.pdf 151 | 152 | 153 | ## Requirements 154 | 155 | 1. For training, GPU is recommended since seq2seq is a large model, you need certain computing power to do the training and predicting efficiently, especially when you set a large beam-search size. 156 | 157 | 2. DRAM requirement is not strict as CPU/GPU, since we are doing stochastic gradient decent. 158 | 159 | 3. If you are new to deep-learning, setting-up things like GPU, python environment is annoying to you, here are dockers of my machine learning environment: 160 | [(non-gpu version docker)][e1] / [(gpu version docker)][e2] 161 | 162 | [e1]: https://github.com/Marsan-Ma/docker_mldm 163 | [e2]: https://github.com/Marsan-Ma/docker_mldm_gpu 164 | 165 | 166 | ## References 167 | 168 | Seq2seq is a model with many preliminaries, I've been spend quite some time surveying and here are some best materials which benefit me a lot: 169 | 170 | 1. The best blogpost explaining RNN, LSTM, GRU and seq2seq model: [Understanding LSTM Networks][f1] by Christopher Olah. 171 | 172 | 2. This work [sherjilozair/char-rnn-tensorflow][f2] helps me learn a lot about language model and implementation graph in tensorflow. 173 | 174 | 3. If you are interested in more magic about RNN, here is a MUST-READ blogpost: [The Unreasonable Effectiveness of Recurrent Neural Networks][f3] by Andrej Karpathy. 175 | 176 | 4. The vanilla version seq2seq+attention: [nicolas-ivanov/tf_seq2seq_chatbot][f4]. This will help you figure out the main flow of vanilla seq2seq model, and I build this repository based on this work. 177 | 178 | [f1]: http://colah.github.io/posts/2015-08-Understanding-LSTMs/ 179 | [f2]: https://github.com/sherjilozair/char-rnn-tensorflow 180 | [f3]: http://karpathy.github.io/2015/05/21/rnn-effectiveness/ 181 | [f4]: https://github.com/nicolas-ivanov/tf_seq2seq_chatbot 182 | 183 | 184 | ## TODOs 185 | 1. Currently I build beam-search out of graph, which means --- it's very slow. There are discussions about build it in-graph [here][g1] and [there][g2]. But unfortunately if you want add something more than beam-search, like this anti-LM work, you need much more than just beam search to be in-graph. 186 | 187 | 2. I haven't figure out how the MERT with BLEU can optimize weight of anti-LM model, since currently the BLEU is often being zero. 188 | 189 | [g1]: https://github.com/tensorflow/tensorflow/issues/654#issuecomment-196168030 190 | [g2]: https://github.com/tensorflow/tensorflow/pull/3756 191 | -------------------------------------------------------------------------------- /ref/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for downloading data from WMT, tokenizing, vocabularies.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import gzip 22 | import os 23 | import re 24 | import tarfile 25 | 26 | from six.moves import urllib 27 | 28 | from tensorflow.python.platform import gfile 29 | import tensorflow as tf 30 | 31 | # Special vocabulary symbols - we always put them at the start. 32 | _PAD = b"_PAD" 33 | _GO = b"_GO" 34 | _EOS = b"_EOS" 35 | _UNK = b"_UNK" 36 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 37 | 38 | PAD_ID = 0 39 | GO_ID = 1 40 | EOS_ID = 2 41 | UNK_ID = 3 42 | 43 | # Regular expressions used to tokenize. 44 | _WORD_SPLIT = re.compile(b"([.,!?\"':;)(])") 45 | _DIGIT_RE = re.compile(br"\d") 46 | 47 | # URLs for WMT data. 48 | _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" 49 | _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz" 50 | 51 | 52 | def maybe_download(directory, filename, url): 53 | """Download filename from url unless it's already in directory.""" 54 | if not os.path.exists(directory): 55 | print("Creating directory %s" % directory) 56 | os.mkdir(directory) 57 | filepath = os.path.join(directory, filename) 58 | if not os.path.exists(filepath): 59 | print("Downloading %s to %s" % (url, filepath)) 60 | filepath, _ = urllib.request.urlretrieve(url, filepath) 61 | statinfo = os.stat(filepath) 62 | print("Successfully downloaded", filename, statinfo.st_size, "bytes") 63 | return filepath 64 | 65 | 66 | def gunzip_file(gz_path, new_path): 67 | """Unzips from gz_path into new_path.""" 68 | print("Unpacking %s to %s" % (gz_path, new_path)) 69 | with gzip.open(gz_path, "rb") as gz_file: 70 | with open(new_path, "wb") as new_file: 71 | for line in gz_file: 72 | new_file.write(line) 73 | 74 | 75 | def get_wmt_enfr_train_set(directory): 76 | """Download the WMT en-fr training corpus to directory unless it's there.""" 77 | train_path = os.path.join(directory, "giga-fren.release2.fixed") 78 | if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")): 79 | corpus_file = maybe_download(directory, "training-giga-fren.tar", 80 | _WMT_ENFR_TRAIN_URL) 81 | print("Extracting tar file %s" % corpus_file) 82 | with tarfile.open(corpus_file, "r") as corpus_tar: 83 | corpus_tar.extractall(directory) 84 | gunzip_file(train_path + ".fr.gz", train_path + ".fr") 85 | gunzip_file(train_path + ".en.gz", train_path + ".en") 86 | return train_path 87 | 88 | 89 | def get_wmt_enfr_dev_set(directory): 90 | """Download the WMT en-fr training corpus to directory unless it's there.""" 91 | dev_name = "newstest2013" 92 | dev_path = os.path.join(directory, dev_name) 93 | if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")): 94 | dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL) 95 | print("Extracting tgz file %s" % dev_file) 96 | with tarfile.open(dev_file, "r:gz") as dev_tar: 97 | fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr") 98 | en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en") 99 | fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix. 100 | en_dev_file.name = dev_name + ".en" 101 | dev_tar.extract(fr_dev_file, directory) 102 | dev_tar.extract(en_dev_file, directory) 103 | return dev_path 104 | 105 | 106 | def basic_tokenizer(sentence): 107 | """Very basic tokenizer: split the sentence into a list of tokens.""" 108 | words = [] 109 | for space_separated_fragment in sentence.strip().split(): 110 | words.extend(_WORD_SPLIT.split(space_separated_fragment)) 111 | return [w for w in words if w] 112 | 113 | 114 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, 115 | tokenizer=None, normalize_digits=True): 116 | """Create vocabulary file (if it does not exist yet) from data file. 117 | 118 | Data file is assumed to contain one sentence per line. Each sentence is 119 | tokenized and digits are normalized (if normalize_digits is set). 120 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 121 | We write it to vocabulary_path in a one-token-per-line format, so that later 122 | token in the first line gets id=0, second line gets id=1, and so on. 123 | 124 | Args: 125 | vocabulary_path: path where the vocabulary will be created. 126 | data_path: data file that will be used to create vocabulary. 127 | max_vocabulary_size: limit on the size of the created vocabulary. 128 | tokenizer: a function to use to tokenize each data sentence; 129 | if None, basic_tokenizer will be used. 130 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 131 | """ 132 | if not gfile.Exists(vocabulary_path): 133 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 134 | vocab = {} 135 | with gfile.GFile(data_path, mode="rb") as f: 136 | counter = 0 137 | for line in f: 138 | counter += 1 139 | if counter % 100000 == 0: 140 | print(" processing line %d" % counter) 141 | line = tf.compat.as_bytes(line) 142 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) 143 | for w in tokens: 144 | word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w 145 | if word in vocab: 146 | vocab[word] += 1 147 | else: 148 | vocab[word] = 1 149 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 150 | if len(vocab_list) > max_vocabulary_size: 151 | vocab_list = vocab_list[:max_vocabulary_size] 152 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 153 | for w in vocab_list: 154 | vocab_file.write(w + b"\n") 155 | 156 | 157 | def initialize_vocabulary(vocabulary_path): 158 | """Initialize vocabulary from file. 159 | 160 | We assume the vocabulary is stored one-item-per-line, so a file: 161 | dog 162 | cat 163 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 164 | also return the reversed-vocabulary ["dog", "cat"]. 165 | 166 | Args: 167 | vocabulary_path: path to the file containing the vocabulary. 168 | 169 | Returns: 170 | a pair: the vocabulary (a dictionary mapping string to integers), and 171 | the reversed vocabulary (a list, which reverses the vocabulary mapping). 172 | 173 | Raises: 174 | ValueError: if the provided vocabulary_path does not exist. 175 | """ 176 | if gfile.Exists(vocabulary_path): 177 | rev_vocab = [] 178 | with gfile.GFile(vocabulary_path, mode="rb") as f: 179 | rev_vocab.extend(f.readlines()) 180 | rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab] 181 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 182 | return vocab, rev_vocab 183 | else: 184 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 185 | 186 | 187 | def sentence_to_token_ids(sentence, vocabulary, 188 | tokenizer=None, normalize_digits=True): 189 | """Convert a string to list of integers representing token-ids. 190 | 191 | For example, a sentence "I have a dog" may become tokenized into 192 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 193 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 194 | 195 | Args: 196 | sentence: the sentence in bytes format to convert to token-ids. 197 | vocabulary: a dictionary mapping tokens to integers. 198 | tokenizer: a function to use to tokenize each sentence; 199 | if None, basic_tokenizer will be used. 200 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 201 | 202 | Returns: 203 | a list of integers, the token-ids for the sentence. 204 | """ 205 | 206 | if tokenizer: 207 | words = tokenizer(sentence) 208 | else: 209 | words = basic_tokenizer(sentence) 210 | if not normalize_digits: 211 | return [vocabulary.get(w, UNK_ID) for w in words] 212 | # Normalize digits by 0 before looking words up in the vocabulary. 213 | return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words] 214 | 215 | 216 | def data_to_token_ids(data_path, target_path, vocabulary_path, 217 | tokenizer=None, normalize_digits=True): 218 | """Tokenize data file and turn into token-ids using given vocabulary file. 219 | 220 | This function loads data line-by-line from data_path, calls the above 221 | sentence_to_token_ids, and saves the result to target_path. See comment 222 | for sentence_to_token_ids on the details of token-ids format. 223 | 224 | Args: 225 | data_path: path to the data file in one-sentence-per-line format. 226 | target_path: path where the file with token-ids will be created. 227 | vocabulary_path: path to the vocabulary file. 228 | tokenizer: a function to use to tokenize each sentence; 229 | if None, basic_tokenizer will be used. 230 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 231 | """ 232 | if not gfile.Exists(target_path): 233 | print("Tokenizing data in %s" % data_path) 234 | vocab, _ = initialize_vocabulary(vocabulary_path) 235 | with gfile.GFile(data_path, mode="rb") as data_file: 236 | with gfile.GFile(target_path, mode="w") as tokens_file: 237 | counter = 0 238 | for line in data_file: 239 | counter += 1 240 | if counter % 100000 == 0: 241 | print(" tokenizing line %d" % counter) 242 | token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab, 243 | tokenizer, normalize_digits) 244 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 245 | 246 | 247 | def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None): 248 | """Get WMT data into data_dir, create vocabularies and tokenize data. 249 | 250 | Args: 251 | data_dir: directory in which the data sets will be stored. 252 | en_vocabulary_size: size of the English vocabulary to create and use. 253 | fr_vocabulary_size: size of the French vocabulary to create and use. 254 | tokenizer: a function to use to tokenize each data sentence; 255 | if None, basic_tokenizer will be used. 256 | 257 | Returns: 258 | A tuple of 6 elements: 259 | (1) path to the token-ids for English training data-set, 260 | (2) path to the token-ids for French training data-set, 261 | (3) path to the token-ids for English development data-set, 262 | (4) path to the token-ids for French development data-set, 263 | (5) path to the English vocabulary file, 264 | (6) path to the French vocabulary file. 265 | """ 266 | # Get wmt data to the specified directory. 267 | train_path = get_wmt_enfr_train_set(data_dir) 268 | dev_path = get_wmt_enfr_dev_set(data_dir) 269 | 270 | from_train_path = train_path + ".en" 271 | to_train_path = train_path + ".fr" 272 | from_dev_path = dev_path + ".en" 273 | to_dev_path = dev_path + ".fr" 274 | return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size, 275 | fr_vocabulary_size, tokenizer) 276 | 277 | 278 | def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size, 279 | to_vocabulary_size, tokenizer=None): 280 | """Preapre all necessary files that are required for the training. 281 | 282 | Args: 283 | data_dir: directory in which the data sets will be stored. 284 | from_train_path: path to the file that includes "from" training samples. 285 | to_train_path: path to the file that includes "to" training samples. 286 | from_dev_path: path to the file that includes "from" dev samples. 287 | to_dev_path: path to the file that includes "to" dev samples. 288 | from_vocabulary_size: size of the "from language" vocabulary to create and use. 289 | to_vocabulary_size: size of the "to language" vocabulary to create and use. 290 | tokenizer: a function to use to tokenize each data sentence; 291 | if None, basic_tokenizer will be used. 292 | 293 | Returns: 294 | A tuple of 6 elements: 295 | (1) path to the token-ids for "from language" training data-set, 296 | (2) path to the token-ids for "to language" training data-set, 297 | (3) path to the token-ids for "from language" development data-set, 298 | (4) path to the token-ids for "to language" development data-set, 299 | (5) path to the "from language" vocabulary file, 300 | (6) path to the "to language" vocabulary file. 301 | """ 302 | # Create vocabularies of the appropriate sizes. 303 | to_vocab_path = os.path.join(data_dir, "vocab%d.to" % to_vocabulary_size) 304 | from_vocab_path = os.path.join(data_dir, "vocab%d.from" % from_vocabulary_size) 305 | create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer) 306 | create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer) 307 | 308 | # Create token ids for the training data. 309 | to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size) 310 | from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size) 311 | data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer) 312 | data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer) 313 | 314 | # Create token ids for the development data. 315 | to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size) 316 | from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size) 317 | data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer) 318 | data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer) 319 | 320 | return (from_train_ids_path, to_train_ids_path, 321 | from_dev_ids_path, to_dev_ids_path, 322 | from_vocab_path, to_vocab_path) 323 | -------------------------------------------------------------------------------- /ref/translate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Binary for training translation models and decoding from them. 17 | 18 | Running this program without --decode will download the WMT corpus into 19 | the directory specified as --data_dir and tokenize it in a very basic way, 20 | and then start training a model saving checkpoints to --train_dir. 21 | 22 | Running with --decode starts an interactive loop so you can see how 23 | the current checkpoint translates English sentences into French. 24 | 25 | See the following papers for more information on neural translation models. 26 | * http://arxiv.org/abs/1409.3215 27 | * http://arxiv.org/abs/1409.0473 28 | * http://arxiv.org/abs/1412.2007 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import math 35 | import os 36 | import random 37 | import sys 38 | import time 39 | import logging 40 | 41 | import numpy as np 42 | from six.moves import xrange # pylint: disable=redefined-builtin 43 | import tensorflow as tf 44 | 45 | import data_utils 46 | import seq2seq_model 47 | 48 | 49 | tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.") 50 | tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99, 51 | "Learning rate decays by this much.") 52 | tf.app.flags.DEFINE_float("max_gradient_norm", 5.0, 53 | "Clip gradients to this norm.") 54 | tf.app.flags.DEFINE_integer("batch_size", 64, 55 | "Batch size to use during training.") 56 | tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.") 57 | tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.") 58 | tf.app.flags.DEFINE_integer("from_vocab_size", 40000, "English vocabulary size.") 59 | tf.app.flags.DEFINE_integer("to_vocab_size", 40000, "French vocabulary size.") 60 | tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory") 61 | tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.") 62 | tf.app.flags.DEFINE_string("from_train_data", None, "Training data.") 63 | tf.app.flags.DEFINE_string("to_train_data", None, "Training data.") 64 | tf.app.flags.DEFINE_string("from_dev_data", None, "Training data.") 65 | tf.app.flags.DEFINE_string("to_dev_data", None, "Training data.") 66 | tf.app.flags.DEFINE_integer("max_train_data_size", 0, 67 | "Limit on the size of training data (0: no limit).") 68 | tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200, 69 | "How many training steps to do per checkpoint.") 70 | tf.app.flags.DEFINE_boolean("decode", False, 71 | "Set to True for interactive decoding.") 72 | tf.app.flags.DEFINE_boolean("self_test", False, 73 | "Run a self-test if this is set to True.") 74 | tf.app.flags.DEFINE_boolean("use_fp16", False, 75 | "Train using fp16 instead of fp32.") 76 | 77 | FLAGS = tf.app.flags.FLAGS 78 | 79 | # We use a number of buckets and pad to the closest one for efficiency. 80 | # See seq2seq_model.Seq2SeqModel for details of how they work. 81 | _buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 82 | 83 | 84 | def read_data(source_path, target_path, max_size=None): 85 | """Read data from source and target files and put into buckets. 86 | 87 | Args: 88 | source_path: path to the files with token-ids for the source language. 89 | target_path: path to the file with token-ids for the target language; 90 | it must be aligned with the source file: n-th line contains the desired 91 | output for n-th line from the source_path. 92 | max_size: maximum number of lines to read, all other will be ignored; 93 | if 0 or None, data files will be read completely (no limit). 94 | 95 | Returns: 96 | data_set: a list of length len(_buckets); data_set[n] contains a list of 97 | (source, target) pairs read from the provided data files that fit 98 | into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and 99 | len(target) < _buckets[n][1]; source and target are lists of token-ids. 100 | """ 101 | data_set = [[] for _ in _buckets] 102 | with tf.gfile.GFile(source_path, mode="r") as source_file: 103 | with tf.gfile.GFile(target_path, mode="r") as target_file: 104 | source, target = source_file.readline(), target_file.readline() 105 | counter = 0 106 | while source and target and (not max_size or counter < max_size): 107 | counter += 1 108 | if counter % 100000 == 0: 109 | print(" reading data line %d" % counter) 110 | sys.stdout.flush() 111 | source_ids = [int(x) for x in source.split()] 112 | target_ids = [int(x) for x in target.split()] 113 | target_ids.append(data_utils.EOS_ID) 114 | for bucket_id, (source_size, target_size) in enumerate(_buckets): 115 | if len(source_ids) < source_size and len(target_ids) < target_size: 116 | data_set[bucket_id].append([source_ids, target_ids]) 117 | break 118 | source, target = source_file.readline(), target_file.readline() 119 | return data_set 120 | 121 | 122 | def create_model(session, forward_only): 123 | """Create translation model and initialize or load parameters in session.""" 124 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 125 | model = seq2seq_model.Seq2SeqModel( 126 | FLAGS.from_vocab_size, 127 | FLAGS.to_vocab_size, 128 | _buckets, 129 | FLAGS.size, 130 | FLAGS.num_layers, 131 | FLAGS.max_gradient_norm, 132 | FLAGS.batch_size, 133 | FLAGS.learning_rate, 134 | FLAGS.learning_rate_decay_factor, 135 | forward_only=forward_only, 136 | dtype=dtype) 137 | ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) 138 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 139 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 140 | model.saver.restore(session, ckpt.model_checkpoint_path) 141 | else: 142 | print("Created model with fresh parameters.") 143 | session.run(tf.global_variables_initializer()) 144 | return model 145 | 146 | 147 | def train(): 148 | """Train a en->fr translation model using WMT data.""" 149 | from_train = None 150 | to_train = None 151 | from_dev = None 152 | to_dev = None 153 | if FLAGS.from_train_data and FLAGS.to_train_data: 154 | from_train_data = FLAGS.from_train_data 155 | to_train_data = FLAGS.to_train_data 156 | from_dev_data = from_train_data 157 | to_dev_data = to_train_data 158 | if FLAGS.from_dev_data and FLAGS.to_dev_data: 159 | from_dev_data = FLAGS.from_dev_data 160 | to_dev_data = FLAGS.to_dev_data 161 | from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data( 162 | FLAGS.data_dir, 163 | from_train_data, 164 | to_train_data, 165 | from_dev_data, 166 | to_dev_data, 167 | FLAGS.from_vocab_size, 168 | FLAGS.to_vocab_size) 169 | else: 170 | # Prepare WMT data. 171 | print("Preparing WMT data in %s" % FLAGS.data_dir) 172 | from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data( 173 | FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size) 174 | 175 | with tf.Session() as sess: 176 | # Create model. 177 | print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) 178 | model = create_model(sess, False) 179 | 180 | # Read data into buckets and compute their sizes. 181 | print ("Reading development and training data (limit: %d)." 182 | % FLAGS.max_train_data_size) 183 | dev_set = read_data(from_dev, to_dev) 184 | train_set = read_data(from_train, to_train, FLAGS.max_train_data_size) 185 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))] 186 | train_total_size = float(sum(train_bucket_sizes)) 187 | 188 | # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use 189 | # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to 190 | # the size if i-th training bucket, as used later. 191 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 192 | for i in xrange(len(train_bucket_sizes))] 193 | 194 | # This is the training loop. 195 | step_time, loss = 0.0, 0.0 196 | current_step = 0 197 | previous_losses = [] 198 | while True: 199 | # Choose a bucket according to data distribution. We pick a random number 200 | # in [0, 1] and use the corresponding interval in train_buckets_scale. 201 | random_number_01 = np.random.random_sample() 202 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) 203 | if train_buckets_scale[i] > random_number_01]) 204 | 205 | # Get a batch and make a step. 206 | start_time = time.time() 207 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 208 | train_set, bucket_id) 209 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 210 | target_weights, bucket_id, False) 211 | step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint 212 | loss += step_loss / FLAGS.steps_per_checkpoint 213 | current_step += 1 214 | 215 | # Once in a while, we save checkpoint, print statistics, and run evals. 216 | if current_step % FLAGS.steps_per_checkpoint == 0: 217 | # Print statistics for the previous epoch. 218 | perplexity = math.exp(float(loss)) if loss < 300 else float("inf") 219 | print ("global step %d learning rate %.4f step-time %.2f perplexity " 220 | "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), 221 | step_time, perplexity)) 222 | # Decrease learning rate if no improvement was seen over last 3 times. 223 | if len(previous_losses) > 2 and loss > max(previous_losses[-3:]): 224 | sess.run(model.learning_rate_decay_op) 225 | previous_losses.append(loss) 226 | # Save checkpoint and zero timer and loss. 227 | checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt") 228 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 229 | step_time, loss = 0.0, 0.0 230 | # Run evals on development set and print their perplexity. 231 | for bucket_id in xrange(len(_buckets)): 232 | if len(dev_set[bucket_id]) == 0: 233 | print(" eval: empty bucket %d" % (bucket_id)) 234 | continue 235 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 236 | dev_set, bucket_id) 237 | _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 238 | target_weights, bucket_id, True) 239 | eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float( 240 | "inf") 241 | print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) 242 | sys.stdout.flush() 243 | 244 | 245 | def decode(): 246 | with tf.Session() as sess: 247 | # Create model and load parameters. 248 | model = create_model(sess, True) 249 | model.batch_size = 1 # We decode one sentence at a time. 250 | 251 | # Load vocabularies. 252 | en_vocab_path = os.path.join(FLAGS.data_dir, 253 | "vocab%d.from" % FLAGS.from_vocab_size) 254 | fr_vocab_path = os.path.join(FLAGS.data_dir, 255 | "vocab%d.to" % FLAGS.to_vocab_size) 256 | en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path) 257 | _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path) 258 | 259 | # Decode from standard input. 260 | sys.stdout.write("> ") 261 | sys.stdout.flush() 262 | sentence = sys.stdin.readline() 263 | while sentence: 264 | # Get token-ids for the input sentence. 265 | token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab) 266 | # Which bucket does it belong to? 267 | bucket_id = len(_buckets) - 1 268 | for i, bucket in enumerate(_buckets): 269 | if bucket[0] >= len(token_ids): 270 | bucket_id = i 271 | break 272 | else: 273 | logging.warning("Sentence truncated: %s", sentence) 274 | 275 | # Get a 1-element batch to feed the sentence to the model. 276 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 277 | {bucket_id: [(token_ids, [])]}, bucket_id) 278 | # Get output logits for the sentence. 279 | _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, 280 | target_weights, bucket_id, True) 281 | # This is a greedy decoder - outputs are just argmaxes of output_logits. 282 | outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] 283 | # If there is an EOS symbol in outputs, cut them at that point. 284 | if data_utils.EOS_ID in outputs: 285 | outputs = outputs[:outputs.index(data_utils.EOS_ID)] 286 | # Print out French sentence corresponding to outputs. 287 | print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs])) 288 | print("> ", end="") 289 | sys.stdout.flush() 290 | sentence = sys.stdin.readline() 291 | 292 | 293 | def self_test(): 294 | """Test the translation model.""" 295 | with tf.Session() as sess: 296 | print("Self-test for neural translation model.") 297 | # Create model with vocabularies of 10, 2 small buckets, 2 layers of 32. 298 | model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2, 299 | 5.0, 32, 0.3, 0.99, num_samples=8) 300 | sess.run(tf.global_variables_initializer()) 301 | 302 | # Fake data set for both the (3, 3) and (6, 6) bucket. 303 | data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])], 304 | [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])]) 305 | for _ in xrange(5): # Train the fake model for 5 steps. 306 | bucket_id = random.choice([0, 1]) 307 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 308 | data_set, bucket_id) 309 | model.step(sess, encoder_inputs, decoder_inputs, target_weights, 310 | bucket_id, False) 311 | 312 | 313 | def main(_): 314 | if FLAGS.self_test: 315 | self_test() 316 | elif FLAGS.decode: 317 | decode() 318 | else: 319 | train() 320 | 321 | if __name__ == "__main__": 322 | tf.app.run() 323 | -------------------------------------------------------------------------------- /ref/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sequence-to-sequence model with an attention mechanism.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import random 23 | 24 | import numpy as np 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | import data_utils 29 | 30 | 31 | class Seq2SeqModel(object): 32 | """Sequence-to-sequence model with attention and for multiple buckets. 33 | 34 | This class implements a multi-layer recurrent neural network as encoder, 35 | and an attention-based decoder. This is the same as the model described in 36 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details, 37 | or into the seq2seq library for complete model implementation. 38 | This class also allows to use GRU cells in addition to LSTM cells, and 39 | sampled softmax to handle large output vocabulary size. A single-layer 40 | version of this model, but with bi-directional encoder, was presented in 41 | http://arxiv.org/abs/1409.0473 42 | and sampled softmax is described in Section 3 of the following paper. 43 | http://arxiv.org/abs/1412.2007 44 | """ 45 | 46 | def __init__(self, 47 | source_vocab_size, 48 | target_vocab_size, 49 | buckets, 50 | size, 51 | num_layers, 52 | max_gradient_norm, 53 | batch_size, 54 | learning_rate, 55 | learning_rate_decay_factor, 56 | use_lstm=False, 57 | num_samples=512, 58 | forward_only=False, 59 | dtype=tf.float32): 60 | """Create the model. 61 | 62 | Args: 63 | source_vocab_size: size of the source vocabulary. 64 | target_vocab_size: size of the target vocabulary. 65 | buckets: a list of pairs (I, O), where I specifies maximum input length 66 | that will be processed in that bucket, and O specifies maximum output 67 | length. Training instances that have inputs longer than I or outputs 68 | longer than O will be pushed to the next bucket and padded accordingly. 69 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. 70 | size: number of units in each layer of the model. 71 | num_layers: number of layers in the model. 72 | max_gradient_norm: gradients will be clipped to maximally this norm. 73 | batch_size: the size of the batches used during training; 74 | the model construction is independent of batch_size, so it can be 75 | changed after initialization if this is convenient, e.g., for decoding. 76 | learning_rate: learning rate to start with. 77 | learning_rate_decay_factor: decay learning rate by this much when needed. 78 | use_lstm: if true, we use LSTM cells instead of GRU cells. 79 | num_samples: number of samples for sampled softmax. 80 | forward_only: if set, we do not construct the backward pass in the model. 81 | dtype: the data type to use to store internal variables. 82 | """ 83 | self.source_vocab_size = source_vocab_size 84 | self.target_vocab_size = target_vocab_size 85 | self.buckets = buckets 86 | self.batch_size = batch_size 87 | self.learning_rate = tf.Variable( 88 | float(learning_rate), trainable=False, dtype=dtype) 89 | self.learning_rate_decay_op = self.learning_rate.assign( 90 | self.learning_rate * learning_rate_decay_factor) 91 | self.global_step = tf.Variable(0, trainable=False) 92 | 93 | # If we use sampled softmax, we need an output projection. 94 | output_projection = None 95 | softmax_loss_function = None 96 | # Sampled softmax only makes sense if we sample less than vocabulary size. 97 | if num_samples > 0 and num_samples < self.target_vocab_size: 98 | w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype) 99 | w = tf.transpose(w_t) 100 | b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype) 101 | output_projection = (w, b) 102 | 103 | def sampled_loss(labels, inputs): 104 | labels = tf.reshape(labels, [-1, 1]) 105 | # We need to compute the sampled_softmax_loss using 32bit floats to 106 | # avoid numerical instabilities. 107 | local_w_t = tf.cast(w_t, tf.float32) 108 | local_b = tf.cast(b, tf.float32) 109 | local_inputs = tf.cast(inputs, tf.float32) 110 | return tf.cast( 111 | tf.nn.sampled_softmax_loss( 112 | weights=local_w_t, 113 | biases=local_b, 114 | labels=labels, 115 | inputs=local_inputs, 116 | num_sampled=num_samples, 117 | num_classes=self.target_vocab_size), 118 | dtype) 119 | softmax_loss_function = sampled_loss 120 | 121 | # Create the internal multi-layer cell for our RNN. 122 | def single_cell(): 123 | return tf.contrib.rnn.GRUCell(size) 124 | if use_lstm: 125 | def single_cell(): 126 | return tf.contrib.rnn.BasicLSTMCell(size) 127 | cell = single_cell() 128 | if num_layers > 1: 129 | cell = tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)]) 130 | 131 | # The seq2seq function: we use embedding for the input and attention. 132 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 133 | return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq( 134 | encoder_inputs, 135 | decoder_inputs, 136 | cell, 137 | num_encoder_symbols=source_vocab_size, 138 | num_decoder_symbols=target_vocab_size, 139 | embedding_size=size, 140 | output_projection=output_projection, 141 | feed_previous=do_decode, 142 | dtype=dtype) 143 | 144 | # Feeds for inputs. 145 | self.encoder_inputs = [] 146 | self.decoder_inputs = [] 147 | self.target_weights = [] 148 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. 149 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 150 | name="encoder{0}".format(i))) 151 | for i in xrange(buckets[-1][1] + 1): 152 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 153 | name="decoder{0}".format(i))) 154 | self.target_weights.append(tf.placeholder(dtype, shape=[None], 155 | name="weight{0}".format(i))) 156 | 157 | # Our targets are decoder inputs shifted by one. 158 | targets = [self.decoder_inputs[i + 1] 159 | for i in xrange(len(self.decoder_inputs) - 1)] 160 | 161 | # Training outputs and losses. 162 | if forward_only: 163 | self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets( 164 | self.encoder_inputs, self.decoder_inputs, targets, 165 | self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True), 166 | softmax_loss_function=softmax_loss_function) 167 | # If we use output projection, we need to project outputs for decoding. 168 | if output_projection is not None: 169 | for b in xrange(len(buckets)): 170 | self.outputs[b] = [ 171 | tf.matmul(output, output_projection[0]) + output_projection[1] 172 | for output in self.outputs[b] 173 | ] 174 | else: 175 | self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets( 176 | self.encoder_inputs, self.decoder_inputs, targets, 177 | self.target_weights, buckets, 178 | lambda x, y: seq2seq_f(x, y, False), 179 | softmax_loss_function=softmax_loss_function) 180 | 181 | # Gradients and SGD update operation for training the model. 182 | params = tf.trainable_variables() 183 | if not forward_only: 184 | self.gradient_norms = [] 185 | self.updates = [] 186 | opt = tf.train.GradientDescentOptimizer(self.learning_rate) 187 | for b in xrange(len(buckets)): 188 | gradients = tf.gradients(self.losses[b], params) 189 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 190 | max_gradient_norm) 191 | self.gradient_norms.append(norm) 192 | self.updates.append(opt.apply_gradients( 193 | zip(clipped_gradients, params), global_step=self.global_step)) 194 | 195 | self.saver = tf.train.Saver(tf.global_variables()) 196 | 197 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 198 | bucket_id, forward_only): 199 | """Run a step of the model feeding the given inputs. 200 | 201 | Args: 202 | session: tensorflow session to use. 203 | encoder_inputs: list of numpy int vectors to feed as encoder inputs. 204 | decoder_inputs: list of numpy int vectors to feed as decoder inputs. 205 | target_weights: list of numpy float vectors to feed as target weights. 206 | bucket_id: which bucket of the model to use. 207 | forward_only: whether to do the backward step or only forward. 208 | 209 | Returns: 210 | A triple consisting of gradient norm (or None if we did not do backward), 211 | average perplexity, and the outputs. 212 | 213 | Raises: 214 | ValueError: if length of encoder_inputs, decoder_inputs, or 215 | target_weights disagrees with bucket size for the specified bucket_id. 216 | """ 217 | # Check if the sizes match. 218 | encoder_size, decoder_size = self.buckets[bucket_id] 219 | if len(encoder_inputs) != encoder_size: 220 | raise ValueError("Encoder length must be equal to the one in bucket," 221 | " %d != %d." % (len(encoder_inputs), encoder_size)) 222 | if len(decoder_inputs) != decoder_size: 223 | raise ValueError("Decoder length must be equal to the one in bucket," 224 | " %d != %d." % (len(decoder_inputs), decoder_size)) 225 | if len(target_weights) != decoder_size: 226 | raise ValueError("Weights length must be equal to the one in bucket," 227 | " %d != %d." % (len(target_weights), decoder_size)) 228 | 229 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 230 | input_feed = {} 231 | for l in xrange(encoder_size): 232 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 233 | for l in xrange(decoder_size): 234 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 235 | input_feed[self.target_weights[l].name] = target_weights[l] 236 | 237 | # Since our targets are decoder inputs shifted by one, we need one more. 238 | last_target = self.decoder_inputs[decoder_size].name 239 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 240 | 241 | # Output feed: depends on whether we do a backward step or not. 242 | if not forward_only: 243 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 244 | self.gradient_norms[bucket_id], # Gradient norm. 245 | self.losses[bucket_id]] # Loss for this batch. 246 | else: 247 | output_feed = [self.losses[bucket_id]] # Loss for this batch. 248 | for l in xrange(decoder_size): # Output logits. 249 | output_feed.append(self.outputs[bucket_id][l]) 250 | 251 | outputs = session.run(output_feed, input_feed) 252 | if not forward_only: 253 | return outputs[1], outputs[2], None # Gradient norm, loss, no outputs. 254 | else: 255 | return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs. 256 | 257 | def get_batch(self, data, bucket_id): 258 | """Get a random batch of data from the specified bucket, prepare for step. 259 | 260 | To feed data in step(..) it must be a list of batch-major vectors, while 261 | data here contains single length-major cases. So the main logic of this 262 | function is to re-index data cases to be in the proper format for feeding. 263 | 264 | Args: 265 | data: a tuple of size len(self.buckets) in which each element contains 266 | lists of pairs of input and output data that we use to create a batch. 267 | bucket_id: integer, which bucket to get the batch for. 268 | 269 | Returns: 270 | The triple (encoder_inputs, decoder_inputs, target_weights) for 271 | the constructed batch that has the proper format to call step(...) later. 272 | """ 273 | encoder_size, decoder_size = self.buckets[bucket_id] 274 | encoder_inputs, decoder_inputs = [], [] 275 | 276 | # Get a random batch of encoder and decoder inputs from data, 277 | # pad them if needed, reverse encoder inputs and add GO to decoder. 278 | for _ in xrange(self.batch_size): 279 | encoder_input, decoder_input = random.choice(data[bucket_id]) 280 | 281 | # Encoder inputs are padded and then reversed. 282 | encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input)) 283 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) 284 | 285 | # Decoder inputs get an extra "GO" symbol, and are padded then. 286 | decoder_pad_size = decoder_size - len(decoder_input) - 1 287 | decoder_inputs.append([data_utils.GO_ID] + decoder_input + 288 | [data_utils.PAD_ID] * decoder_pad_size) 289 | 290 | # Now we create batch-major vectors from the data selected above. 291 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 292 | 293 | # Batch encoder inputs are just re-indexed encoder_inputs. 294 | for length_idx in xrange(encoder_size): 295 | batch_encoder_inputs.append( 296 | np.array([encoder_inputs[batch_idx][length_idx] 297 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 298 | 299 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 300 | for length_idx in xrange(decoder_size): 301 | batch_decoder_inputs.append( 302 | np.array([decoder_inputs[batch_idx][length_idx] 303 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 304 | 305 | # Create target_weights to be 0 for targets that are padding. 306 | batch_weight = np.ones(self.batch_size, dtype=np.float32) 307 | for batch_idx in xrange(self.batch_size): 308 | # We set weight to 0 if the corresponding target is a PAD symbol. 309 | # The corresponding target is decoder_input shifted by 1 forward. 310 | if length_idx < decoder_size - 1: 311 | target = decoder_inputs[batch_idx][length_idx + 1] 312 | if length_idx == decoder_size - 1 or target == data_utils.PAD_ID: 313 | batch_weight[batch_idx] = 0.0 314 | batch_weights.append(batch_weight) 315 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights 316 | -------------------------------------------------------------------------------- /lib/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for downloading data from WMT, tokenizing, vocabularies.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import sys, os, re, gzip, tarfile 22 | 23 | from six.moves import urllib 24 | 25 | from tensorflow.python.platform import gfile 26 | import tensorflow as tf 27 | 28 | # Special vocabulary symbols - we always put them at the start. 29 | _PAD = b"_PAD" 30 | _GO = b"_GO" 31 | _EOS = b"_EOS" 32 | _UNK = b"_UNK" 33 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 34 | 35 | PAD_ID = 0 36 | GO_ID = 1 37 | EOS_ID = 2 38 | UNK_ID = 3 39 | 40 | # Regular expressions used to tokenize. 41 | #_WORD_SPLIT = re.compile(b"([.,!?\"':;,。!)(])") 42 | _WORD_SPLIT = re.compile(b"([.,!?\"':;)(])") 43 | # _DIGIT_RE = re.compile(br"\d{3,}") 44 | _DIGIT_RE = re.compile(br"\d") 45 | 46 | def get_dialog_train_set_path(path): 47 | return os.path.join(path, 'chat') 48 | 49 | 50 | # URLs for WMT data. 51 | _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" 52 | _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz" 53 | 54 | 55 | def maybe_download(directory, filename, url): 56 | """Download filename from url unless it's already in directory.""" 57 | if not os.path.exists(directory): 58 | print("Creating directory %s" % directory) 59 | os.mkdir(directory) 60 | filepath = os.path.join(directory, filename) 61 | if not os.path.exists(filepath): 62 | print("Downloading %s to %s" % (url, filepath)) 63 | filepath, _ = urllib.request.urlretrieve(url, filepath) 64 | statinfo = os.stat(filepath) 65 | print("Successfully downloaded", filename, statinfo.st_size, "bytes") 66 | return filepath 67 | 68 | 69 | def gunzip_file(gz_path, new_path): 70 | """Unzips from gz_path into new_path.""" 71 | print("Unpacking %s to %s" % (gz_path, new_path)) 72 | with gzip.open(gz_path, "rb") as gz_file: 73 | with open(new_path, "wb") as new_file: 74 | for line in gz_file: 75 | new_file.write(line) 76 | 77 | 78 | def get_wmt_enfr_train_set(directory): 79 | """Download the WMT en-fr training corpus to directory unless it's there.""" 80 | train_path = os.path.join(directory, "giga-fren.release2.fixed") 81 | if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")): 82 | corpus_file = maybe_download(directory, "training-giga-fren.tar", 83 | _WMT_ENFR_TRAIN_URL) 84 | print("Extracting tar file %s" % corpus_file) 85 | with tarfile.open(corpus_file, "r") as corpus_tar: 86 | corpus_tar.extractall(directory) 87 | gunzip_file(train_path + ".fr.gz", train_path + ".fr") 88 | gunzip_file(train_path + ".en.gz", train_path + ".en") 89 | return train_path 90 | 91 | 92 | def get_wmt_enfr_dev_set(directory): 93 | """Download the WMT en-fr training corpus to directory unless it's there.""" 94 | dev_name = "newstest2013" 95 | dev_path = os.path.join(directory, dev_name) 96 | if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")): 97 | dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL) 98 | print("Extracting tgz file %s" % dev_file) 99 | with tarfile.open(dev_file, "r:gz") as dev_tar: 100 | fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr") 101 | en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en") 102 | fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix. 103 | en_dev_file.name = dev_name + ".en" 104 | dev_tar.extract(fr_dev_file, directory) 105 | dev_tar.extract(en_dev_file, directory) 106 | return dev_path 107 | 108 | 109 | def get_dialog_dev_set_path(path): 110 | return os.path.join(path, 'chat_test') 111 | 112 | 113 | def basic_tokenizer(sentence, en_jieba=False): 114 | """Very basic tokenizer: split the sentence into a list of tokens.""" 115 | if en_jieba: 116 | tokens = list([w.lower() for w in jieba.cut(sentence) if w not in [' ']]) 117 | return tokens 118 | else: 119 | words = [] 120 | for space_separated_fragment in sentence.strip().split(): 121 | if isinstance(space_separated_fragment, str): 122 | space_separated_fragment = space_separated_fragment.encode() 123 | words.extend(_WORD_SPLIT.split(space_separated_fragment)) 124 | return [w.lower() for w in words if w] 125 | 126 | 127 | 128 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, 129 | tokenizer=None, normalize_digits=True): 130 | """Create vocabulary file (if it does not exist yet) from data file. 131 | 132 | Data file is assumed to contain one sentence per line. Each sentence is 133 | tokenized and digits are normalized (if normalize_digits is set). 134 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 135 | We write it to vocabulary_path in a one-token-per-line format, so that later 136 | token in the first line gets id=0, second line gets id=1, and so on. 137 | 138 | Args: 139 | vocabulary_path: path where the vocabulary will be created. 140 | data_path: data file that will be used to create vocabulary. 141 | max_vocabulary_size: limit on the size of the created vocabulary. 142 | tokenizer: a function to use to tokenize each data sentence; 143 | if None, basic_tokenizer will be used. 144 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 145 | """ 146 | if not gfile.Exists(vocabulary_path): 147 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 148 | vocab = {} 149 | with gfile.GFile(data_path, mode="rb") as f: 150 | counter = 0 151 | for line in f: 152 | # try: 153 | # line = line.decode('utf8', 'ignore') 154 | # except Exception as e: 155 | # print(e, line) 156 | # continue 157 | counter += 1 158 | if counter % 100000 == 0: 159 | print(" processing line %d" % counter) 160 | line = tf.compat.as_bytes(line) 161 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) 162 | for w in tokens: 163 | word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w 164 | if word in vocab: 165 | vocab[word] += 1 166 | else: 167 | vocab[word] = 1 168 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 169 | if len(vocab_list) > max_vocabulary_size: 170 | vocab_list = vocab_list[:max_vocabulary_size] 171 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 172 | for w in vocab_list: 173 | vocab_file.write(w + b"\n") 174 | 175 | 176 | def initialize_vocabulary(vocabulary_path): 177 | """Initialize vocabulary from file. 178 | 179 | We assume the vocabulary is stored one-item-per-line, so a file: 180 | dog 181 | cat 182 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 183 | also return the reversed-vocabulary ["dog", "cat"]. 184 | 185 | Args: 186 | vocabulary_path: path to the file containing the vocabulary. 187 | 188 | Returns: 189 | a pair: the vocabulary (a dictionary mapping string to integers), and 190 | the reversed vocabulary (a list, which reverses the vocabulary mapping). 191 | 192 | Raises: 193 | ValueError: if the provided vocabulary_path does not exist. 194 | """ 195 | if gfile.Exists(vocabulary_path): 196 | rev_vocab = [] 197 | with gfile.GFile(vocabulary_path, mode="rb") as f: 198 | rev_vocab.extend(f.readlines()) 199 | rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab] 200 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 201 | return vocab, rev_vocab 202 | else: 203 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 204 | 205 | 206 | def sentence_to_token_ids(sentence, vocabulary, 207 | tokenizer=None, normalize_digits=True): 208 | """Convert a string to list of integers representing token-ids. 209 | 210 | For example, a sentence "I have a dog" may become tokenized into 211 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 212 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 213 | 214 | Args: 215 | sentence: the sentence in bytes format to convert to token-ids. 216 | vocabulary: a dictionary mapping tokens to integers. 217 | tokenizer: a function to use to tokenize each sentence; 218 | if None, basic_tokenizer will be used. 219 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 220 | 221 | Returns: 222 | a list of integers, the token-ids for the sentence. 223 | """ 224 | 225 | if tokenizer: 226 | words = tokenizer(sentence) 227 | else: 228 | words = basic_tokenizer(sentence) 229 | if not normalize_digits: 230 | return [vocabulary.get(w, UNK_ID) for w in words] 231 | # Normalize digits by 0 before looking words up in the vocabulary. 232 | return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words] 233 | 234 | 235 | def data_to_token_ids(data_path, target_path, vocabulary_path, 236 | tokenizer=None, normalize_digits=True): 237 | """Tokenize data file and turn into token-ids using given vocabulary file. 238 | 239 | This function loads data line-by-line from data_path, calls the above 240 | sentence_to_token_ids, and saves the result to target_path. See comment 241 | for sentence_to_token_ids on the details of token-ids format. 242 | 243 | Args: 244 | data_path: path to the data file in one-sentence-per-line format. 245 | target_path: path where the file with token-ids will be created. 246 | vocabulary_path: path to the vocabulary file. 247 | tokenizer: a function to use to tokenize each sentence; 248 | if None, basic_tokenizer will be used. 249 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 250 | """ 251 | if not gfile.Exists(target_path): 252 | print("Tokenizing data in %s" % data_path) 253 | vocab, _ = initialize_vocabulary(vocabulary_path) 254 | with gfile.GFile(data_path, mode="rb") as data_file: 255 | with gfile.GFile(target_path, mode="w") as tokens_file: 256 | counter = 0 257 | for line in data_file: 258 | # try: 259 | # line = line.decode('utf8', 'ignore') 260 | # except Exception as e: 261 | # print(e, line) 262 | # continue 263 | counter += 1 264 | if counter % 100000 == 0: 265 | print(" tokenizing line %d" % counter) 266 | token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab, 267 | tokenizer, normalize_digits) 268 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 269 | 270 | 271 | def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None): 272 | """Get WMT data into data_dir, create vocabularies and tokenize data. 273 | 274 | Args: 275 | data_dir: directory in which the data sets will be stored. 276 | en_vocabulary_size: size of the English vocabulary to create and use. 277 | fr_vocabulary_size: size of the French vocabulary to create and use. 278 | tokenizer: a function to use to tokenize each data sentence; 279 | if None, basic_tokenizer will be used. 280 | 281 | Returns: 282 | A tuple of 6 elements: 283 | (1) path to the token-ids for English training data-set, 284 | (2) path to the token-ids for French training data-set, 285 | (3) path to the token-ids for English development data-set, 286 | (4) path to the token-ids for French development data-set, 287 | (5) path to the English vocabulary file, 288 | (6) path to the French vocabulary file. 289 | """ 290 | # Get wmt data to the specified directory. 291 | train_path = get_wmt_enfr_train_set(data_dir) 292 | dev_path = get_wmt_enfr_dev_set(data_dir) 293 | 294 | from_train_path = train_path + ".en" 295 | to_train_path = train_path + ".fr" 296 | from_dev_path = dev_path + ".en" 297 | to_dev_path = dev_path + ".fr" 298 | return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size, 299 | fr_vocabulary_size, tokenizer) 300 | 301 | 302 | def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size, 303 | to_vocabulary_size, tokenizer=None): 304 | """Preapre all necessary files that are required for the training. 305 | 306 | Args: 307 | data_dir: directory in which the data sets will be stored. 308 | from_train_path: path to the file that includes "from" training samples. 309 | to_train_path: path to the file that includes "to" training samples. 310 | from_dev_path: path to the file that includes "from" dev samples. 311 | to_dev_path: path to the file that includes "to" dev samples. 312 | from_vocabulary_size: size of the "from language" vocabulary to create and use. 313 | to_vocabulary_size: size of the "to language" vocabulary to create and use. 314 | tokenizer: a function to use to tokenize each data sentence; 315 | if None, basic_tokenizer will be used. 316 | 317 | Returns: 318 | A tuple of 6 elements: 319 | (1) path to the token-ids for "from language" training data-set, 320 | (2) path to the token-ids for "to language" training data-set, 321 | (3) path to the token-ids for "from language" development data-set, 322 | (4) path to the token-ids for "to language" development data-set, 323 | (5) path to the "from language" vocabulary file, 324 | (6) path to the "to language" vocabulary file. 325 | """ 326 | # Create vocabularies of the appropriate sizes. 327 | to_vocab_path = os.path.join(data_dir, "vocab%d.to" % to_vocabulary_size) 328 | from_vocab_path = os.path.join(data_dir, "vocab%d.from" % from_vocabulary_size) 329 | create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer) 330 | create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer) 331 | 332 | # Create token ids for the training data. 333 | to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size) 334 | from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size) 335 | data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer) 336 | data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer) 337 | 338 | # Create token ids for the development data. 339 | to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size) 340 | from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size) 341 | data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer) 342 | data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer) 343 | 344 | return (from_train_ids_path, to_train_ids_path, 345 | from_dev_ids_path, to_dev_ids_path, 346 | from_vocab_path, to_vocab_path) 347 | 348 | 349 | 350 | def prepare_dialog_data(data_dir, vocabulary_size): 351 | """Get dialog data into data_dir, create vocabularies and tokenize data. 352 | 353 | Args: 354 | data_dir: directory in which the data sets will be stored. 355 | vocabulary_size: size of the English vocabulary to create and use. 356 | 357 | Returns: 358 | A tuple of 3 elements: 359 | (1) path to the token-ids for chat training data-set, 360 | (2) path to the token-ids for chat development data-set, 361 | (3) path to the chat vocabulary file 362 | """ 363 | # Get dialog data to the specified directory. 364 | train_path = get_dialog_train_set_path(data_dir) 365 | dev_path = get_dialog_dev_set_path(data_dir) 366 | 367 | # Create vocabularies of the appropriate sizes. 368 | vocab_path = os.path.join(data_dir, "vocab%d.in" % vocabulary_size) 369 | create_vocabulary(vocab_path, train_path + ".in", vocabulary_size) 370 | 371 | # Create token ids for the training data. 372 | train_ids_path = train_path + (".ids%d.in" % vocabulary_size) 373 | data_to_token_ids(train_path + ".in", train_ids_path, vocab_path) 374 | 375 | # Create token ids for the development data. 376 | dev_ids_path = dev_path + (".ids%d.in" % vocabulary_size) 377 | data_to_token_ids(dev_path + ".in", dev_ids_path, vocab_path) 378 | 379 | return (train_ids_path, dev_ids_path, vocab_path) 380 | 381 | 382 | def read_data(tokenized_dialog_path, buckets, max_size=None, reversed=False): 383 | """Read data from source file and put into buckets. 384 | 385 | Args: 386 | source_path: path to the files with token-ids. 387 | max_size: maximum number of lines to read, all other will be ignored; 388 | if 0 or None, data files will be read completely (no limit). 389 | 390 | Returns: 391 | data_set: a list of length len(_buckets); data_set[n] contains a list of 392 | (source, target) pairs read from the provided data files that fit 393 | into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and 394 | len(target) < _buckets[n][1]; source and target are lists of token-ids. 395 | """ 396 | data_set = [[] for _ in buckets] 397 | 398 | with gfile.GFile(tokenized_dialog_path, mode="r") as fh: 399 | source, target = fh.readline(), fh.readline() 400 | if reversed: 401 | source, target = target, source # reverse Q-A pair, for bi-direction model 402 | counter = 0 403 | while source and target and (not max_size or counter < max_size): 404 | counter += 1 405 | if counter % 100000 == 0: 406 | print(" reading data line %d" % counter) 407 | sys.stdout.flush() 408 | 409 | source_ids = [int(x) for x in source.split()] 410 | target_ids = [int(x) for x in target.split()] 411 | target_ids.append(EOS_ID) 412 | 413 | for bucket_id, (source_size, target_size) in enumerate(buckets): 414 | if len(source_ids) < source_size and len(target_ids) < target_size: 415 | data_set[bucket_id].append([source_ids, target_ids]) 416 | break 417 | source, target = fh.readline(), fh.readline() 418 | return data_set 419 | -------------------------------------------------------------------------------- /lib/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sequence-to-sequence model with an attention mechanism.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import random 23 | from math import log 24 | import numpy as np 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | from tensorflow.python.ops import control_flow_ops 29 | from lib import data_utils as data_utils 30 | from lib import seq2seq as tf_seq2seq 31 | 32 | class Seq2SeqModel(object): 33 | """Sequence-to-sequence model with attention and for multiple buckets. 34 | 35 | This class implements a multi-layer recurrent neural network as encoder, 36 | and an attention-based decoder. This is the same as the model described in 37 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details, 38 | or into the seq2seq library for complete model implementation. 39 | This class also allows to use GRU cells in addition to LSTM cells, and 40 | sampled softmax to handle large output vocabulary size. A single-layer 41 | version of this model, but with bi-directional encoder, was presented in 42 | http://arxiv.org/abs/1409.0473 43 | and sampled softmax is described in Section 3 of the following paper. 44 | http://arxiv.org/abs/1412.2007 45 | """ 46 | 47 | def __init__(self, 48 | source_vocab_size, 49 | target_vocab_size, 50 | buckets, 51 | size, 52 | num_layers, 53 | max_gradient_norm, 54 | batch_size, 55 | learning_rate, 56 | learning_rate_decay_factor, 57 | use_lstm=False, 58 | num_samples=512, 59 | forward_only=False, 60 | scope_name='seq2seq', 61 | dtype=tf.float32): 62 | """Create the model. 63 | 64 | Args: 65 | source_vocab_size: size of the source vocabulary. 66 | target_vocab_size: size of the target vocabulary. 67 | buckets: a list of pairs (I, O), where I specifies maximum input length 68 | that will be processed in that bucket, and O specifies maximum output 69 | length. Training instances that have inputs longer than I or outputs 70 | longer than O will be pushed to the next bucket and padded accordingly. 71 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. 72 | size: number of units in each layer of the model. 73 | num_layers: number of layers in the model. 74 | max_gradient_norm: gradients will be clipped to maximally this norm. 75 | batch_size: the size of the batches used during training; 76 | the model construction is independent of batch_size, so it can be 77 | changed after initialization if this is convenient, e.g., for decoding. 78 | learning_rate: learning rate to start with. 79 | learning_rate_decay_factor: decay learning rate by this much when needed. 80 | use_lstm: if true, we use LSTM cells instead of GRU cells. 81 | num_samples: number of samples for sampled softmax. 82 | forward_only: if set, we do not construct the backward pass in the model. 83 | dtype: the data type to use to store internal variables. 84 | """ 85 | self.scope_name = scope_name 86 | with tf.variable_scope(self.scope_name): 87 | self.source_vocab_size = source_vocab_size 88 | self.target_vocab_size = target_vocab_size 89 | self.buckets = buckets 90 | self.batch_size = batch_size 91 | self.learning_rate = tf.Variable( 92 | float(learning_rate), trainable=False, dtype=dtype) 93 | self.learning_rate_decay_op = self.learning_rate.assign( 94 | self.learning_rate * learning_rate_decay_factor) 95 | self.global_step = tf.Variable(0, trainable=False) 96 | self.dummy_dialogs = [] # [TODO] load dummy sentences 97 | 98 | # If we use sampled softmax, we need an output projection. 99 | output_projection = None 100 | softmax_loss_function = None 101 | # Sampled softmax only makes sense if we sample less than vocabulary size. 102 | if num_samples > 0 and num_samples < self.target_vocab_size: 103 | w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype) 104 | w = tf.transpose(w_t) 105 | b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype) 106 | output_projection = (w, b) 107 | 108 | def sampled_loss(labels, inputs): 109 | labels = tf.reshape(labels, [-1, 1]) 110 | # We need to compute the sampled_softmax_loss using 32bit floats to 111 | # avoid numerical instabilities. 112 | local_w_t = tf.cast(w_t, tf.float32) 113 | local_b = tf.cast(b, tf.float32) 114 | local_inputs = tf.cast(inputs, tf.float32) 115 | return tf.cast( 116 | tf.nn.sampled_softmax_loss( 117 | weights=local_w_t, 118 | biases=local_b, 119 | labels=labels, 120 | inputs=local_inputs, 121 | num_sampled=num_samples, 122 | num_classes=self.target_vocab_size), 123 | dtype) 124 | softmax_loss_function = sampled_loss 125 | 126 | # Create the internal multi-layer cell for our RNN. 127 | def single_cell(): 128 | return tf.contrib.rnn.GRUCell(size) 129 | if use_lstm: 130 | def single_cell(): 131 | return tf.contrib.rnn.BasicLSTMCell(size) 132 | cell = single_cell() 133 | if num_layers > 1: 134 | cell = tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)]) 135 | 136 | # The seq2seq function: we use embedding for the input and attention. 137 | def seq2seq_f(encoder_inputs, decoder_inputs, feed_previous): 138 | return tf_seq2seq.embedding_attention_seq2seq( 139 | encoder_inputs, 140 | decoder_inputs, 141 | cell, 142 | num_encoder_symbols=source_vocab_size, 143 | num_decoder_symbols=target_vocab_size, 144 | embedding_size=size, 145 | output_projection=output_projection, 146 | feed_previous=feed_previous, #do_decode, 147 | dtype=dtype) 148 | 149 | # Feeds for inputs. 150 | self.encoder_inputs = [] 151 | self.decoder_inputs = [] 152 | self.target_weights = [] 153 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. 154 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 155 | name="encoder{0}".format(i))) 156 | for i in xrange(buckets[-1][1] + 1): 157 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 158 | name="decoder{0}".format(i))) 159 | self.target_weights.append(tf.placeholder(dtype, shape=[None], 160 | name="weight{0}".format(i))) 161 | 162 | # Our targets are decoder inputs shifted by one. 163 | targets = [self.decoder_inputs[i + 1] 164 | for i in xrange(len(self.decoder_inputs) - 1)] 165 | 166 | # for reinforcement learning 167 | # self.force_dec_input = tf.placeholder(tf.bool, name="force_dec_input") 168 | # self.en_output_proj = tf.placeholder(tf.bool, name="en_output_proj") 169 | 170 | # Training outputs and losses. 171 | if forward_only: 172 | self.outputs, self.losses, self.encoder_state = tf_seq2seq.model_with_buckets( 173 | self.encoder_inputs, self.decoder_inputs, targets, 174 | self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True), 175 | softmax_loss_function=softmax_loss_function) 176 | # If we use output projection, we need to project outputs for decoding. 177 | if output_projection is not None: 178 | for b in xrange(len(buckets)): 179 | self.outputs[b] = [ 180 | tf.matmul(output, output_projection[0]) + output_projection[1] 181 | for output in self.outputs[b] 182 | ] 183 | else: 184 | self.outputs, self.losses, self.encoder_state = tf_seq2seq.model_with_buckets( 185 | self.encoder_inputs, self.decoder_inputs, targets, 186 | self.target_weights, buckets, 187 | lambda x, y: seq2seq_f(x, y, False), 188 | softmax_loss_function=softmax_loss_function) 189 | 190 | # # Training outputs and losses. 191 | # self.outputs, self.losses, self.encoder_state = tf_seq2seq.model_with_buckets( 192 | # self.encoder_inputs, self.decoder_inputs, targets, 193 | # self.target_weights, buckets, 194 | # lambda x, y: seq2seq_f(x, y, tf.where(self.force_dec_input, False, True)), 195 | # softmax_loss_function=softmax_loss_function 196 | # ) 197 | # # If we use output projection, we need to project outputs for decoding. 198 | # # if output_projection is not None: 199 | # for b in xrange(len(buckets)): 200 | # self.outputs[b] = [ 201 | # control_flow_ops.cond( 202 | # self.en_output_proj, 203 | # lambda: tf.matmul(output, output_projection[0]) + output_projection[1], 204 | # lambda: output 205 | # ) 206 | # for output in self.outputs[b] 207 | # ] 208 | 209 | # Gradients and SGD update operation for training the model. 210 | params = tf.trainable_variables() 211 | # if not forward_only: 212 | self.gradient_norms = [] 213 | self.updates = [] 214 | self.advantage = [tf.placeholder(tf.float32, name="advantage_%i" % i) for i in xrange(len(buckets))] 215 | opt = tf.train.GradientDescentOptimizer(self.learning_rate) 216 | for b in xrange(len(buckets)): 217 | # self.losses[b] = tf.subtract(self.losses[b], self.advantage[b]) 218 | gradients = tf.gradients(self.losses[b], params) 219 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 220 | max_gradient_norm) 221 | self.gradient_norms.append(norm) 222 | self.updates.append(opt.apply_gradients( 223 | zip(clipped_gradients, params), global_step=self.global_step)) 224 | 225 | all_variables = tf.global_variables() 226 | all_variables = [k for k in tf.global_variables() if k.name.startswith(self.scope_name)] 227 | self.saver = tf.train.Saver(all_variables) 228 | 229 | 230 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 231 | bucket_id, forward_only, force_dec_input=False, advantage=None): 232 | 233 | """Run a step of the model feeding the given inputs. 234 | 235 | Args: 236 | session: tensorflow session to use. 237 | encoder_inputs: list of numpy int vectors to feed as encoder inputs. 238 | decoder_inputs: list of numpy int vectors to feed as decoder inputs. 239 | target_weights: list of numpy float vectors to feed as target weights. 240 | bucket_id: which bucket of the model to use. 241 | forward_only: whether to do the backward step or only forward. 242 | 243 | Returns: 244 | A triple consisting of gradient norm (or None if we did not do backward), 245 | average perplexity, and the outputs. 246 | 247 | Raises: 248 | ValueError: if length of encoder_inputs, decoder_inputs, or 249 | target_weights disagrees with bucket size for the specified bucket_id. 250 | """ 251 | # Check if the sizes match. 252 | encoder_size, decoder_size = self.buckets[bucket_id] 253 | if len(encoder_inputs) != encoder_size: 254 | raise ValueError("Encoder length must be equal to the one in bucket," 255 | " %d != %d." % (len(encoder_inputs), encoder_size)) 256 | if len(decoder_inputs) != decoder_size: 257 | raise ValueError("Decoder length must be equal to the one in bucket," 258 | " %d != %d." % (len(decoder_inputs), decoder_size)) 259 | if len(target_weights) != decoder_size: 260 | raise ValueError("Weights length must be equal to the one in bucket," 261 | " %d != %d." % (len(target_weights), decoder_size)) 262 | 263 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 264 | input_feed = { 265 | # self.force_dec_input.name: force_dec_input, 266 | # self.en_output_proj.name: forward_only, 267 | } 268 | for l in xrange(len(self.buckets)): 269 | input_feed[self.advantage[l].name] = advantage[l] if advantage else 0 270 | for l in xrange(encoder_size): 271 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 272 | for l in xrange(decoder_size): 273 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 274 | input_feed[self.target_weights[l].name] = target_weights[l] 275 | 276 | # Since our targets are decoder inputs shifted by one, we need one more. 277 | last_target = self.decoder_inputs[decoder_size].name 278 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 279 | 280 | 281 | # Output feed: depends on whether we do a backward step or not. 282 | if not forward_only: 283 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 284 | self.gradient_norms[bucket_id], # Gradient norm. 285 | self.losses[bucket_id]] # Loss for this batch. 286 | else: 287 | output_feed = [self.encoder_state[bucket_id], 288 | self.losses[bucket_id]] # Loss for this batch. 289 | for l in xrange(decoder_size): # Output logits. 290 | output_feed.append(self.outputs[bucket_id][l]) 291 | 292 | outputs = session.run(output_feed, input_feed) 293 | if not forward_only: 294 | return outputs[1], outputs[2], None # Gradient norm, loss, no outputs. 295 | else: 296 | return outputs[0], outputs[1], outputs[2:] # encoder_state, loss, outputs. 297 | 298 | 299 | # # Output feed: depends on whether we do a backward step or not. 300 | # if training: # normal training 301 | # output_feed = [self.updates[bucket_id], # Update Op that does SGD. 302 | # self.gradient_norms[bucket_id], # Gradient norm. 303 | # self.losses[bucket_id]] # Loss for this batch. 304 | # else: # testing or reinforcement learning 305 | # output_feed = [self.encoder_state[bucket_id], self.losses[bucket_id]] # Loss for this batch. 306 | # for l in xrange(decoder_size): # Output logits. 307 | # output_feed.append(self.outputs[bucket_id][l]) 308 | 309 | # outputs = session.run(output_feed, input_feed) 310 | # if training: 311 | # return outputs[1], outputs[2], None # Gradient norm, loss, no outputs. 312 | # else: 313 | # return outputs[0], outputs[1], outputs[2:] # encoder_state, loss, outputs. 314 | 315 | 316 | 317 | def step_rf(self, args, session, encoder_inputs, decoder_inputs, target_weights, 318 | bucket_id, rev_vocab=None, debug=True): 319 | 320 | # initialize 321 | init_inputs = [encoder_inputs, decoder_inputs, target_weights, bucket_id] 322 | sent_max_length = args.buckets[-1][0] 323 | resp_tokens, resp_txt = self.logits2tokens(encoder_inputs, rev_vocab, sent_max_length, reverse=True) 324 | if debug: print("[INPUT]:", resp_txt) 325 | 326 | # Initialize 327 | ep_rewards, ep_step_loss, enc_states = [], [], [] 328 | ep_encoder_inputs, ep_target_weights, ep_bucket_id = [], [], [] 329 | 330 | # [Episode] per episode = n steps, until break 331 | while True: 332 | #----[Step]---------------------------------------- 333 | encoder_state, step_loss, output_logits = self.step(session, encoder_inputs, decoder_inputs, target_weights, 334 | bucket_id, training=False, force_dec_input=False) 335 | 336 | # memorize inputs for reproducing curriculum with adjusted losses 337 | ep_encoder_inputs.append(encoder_inputs) 338 | ep_target_weights.append(target_weights) 339 | ep_bucket_id.append(bucket_id) 340 | ep_step_loss.append(step_loss) 341 | enc_states_vec = np.reshape(np.squeeze(encoder_state, axis=1), (-1)) 342 | enc_states.append(enc_states_vec) 343 | 344 | # process response 345 | resp_tokens, resp_txt = self.logits2tokens(output_logits, rev_vocab, sent_max_length) 346 | if debug: print("[RESP]: (%.4f) %s" % (step_loss, resp_txt)) 347 | 348 | # prepare for next dialogue 349 | bucket_id = min([b for b in range(len(args.buckets)) if args.buckets[b][0] > len(resp_tokens)]) 350 | feed_data = {bucket_id: [(resp_tokens, [])]} 351 | encoder_inputs, decoder_inputs, target_weights = self.get_batch(feed_data, bucket_id) 352 | 353 | #----[Reward]---------------------------------------- 354 | # r1: Ease of answering 355 | r1 = [self.logProb(session, args.buckets, resp_tokens, d) for d in self.dummy_dialogs] 356 | r1 = -np.mean(r1) if r1 else 0 357 | 358 | # r2: Information Flow 359 | if len(enc_states) < 2: 360 | r2 = 0 361 | else: 362 | vec_a, vec_b = enc_states[-2], enc_states[-1] 363 | r2 = sum(vec_a*vec_b) / sum(abs(vec_a)*abs(vec_b)) 364 | r2 = -log(r2) 365 | 366 | # r3: Semantic Coherence 367 | r3 = -self.logProb(session, args.buckets, resp_tokens, ep_encoder_inputs[-1]) 368 | 369 | # Episode total reward 370 | R = 0.25*r1 + 0.25*r2 + 0.5*r3 371 | rewards.append(R) 372 | #---------------------------------------------------- 373 | if (resp_txt in self.dummy_dialogs) or (len(resp_tokens) <= 3) or (encoder_inputs in ep_encoder_inputs): 374 | break # check if dialog ended 375 | 376 | # gradient decent according to batch rewards 377 | rto = (max(ep_step_loss) - min(ep_step_loss)) / (max(ep_rewards) - min(ep_rewards)) 378 | advantage = [mp.mean(ep_rewards)*rto] * len(args.buckets) 379 | _, step_loss, _ = self.step(session, init_inputs[0], init_inputs[1], init_inputs[2], init_inputs[3], 380 | training=True, force_dec_input=False, advantage=advantage) 381 | 382 | return None, step_loss, None 383 | 384 | 385 | 386 | # log(P(b|a)), the conditional likelyhood 387 | def logProb(self, session, buckets, tokens_a, tokens_b): 388 | def softmax(x): 389 | return np.exp(x) / np.sum(np.exp(x), axis=0) 390 | 391 | # prepare for next dialogue 392 | bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(tokens_a)]) 393 | feed_data = {bucket_id: [(tokens_a, tokens_b)]} 394 | encoder_inputs, decoder_inputs, target_weights = self.get_batch(feed_data, bucket_id) 395 | 396 | # step 397 | _, _, output_logits = self.step(session, encoder_inputs, decoder_inputs, target_weights, 398 | bucket_id, training=False, force_dec_input=True) 399 | 400 | # p = log(P(b|a)) / N 401 | p = 1 402 | for t, logit in zip(tokens_b, output_logits): 403 | p *= softmax(logit[0])[t] 404 | p = log(p) / len(tokens_b) 405 | return p 406 | 407 | 408 | def logits2tokens(self, logits, rev_vocab, sent_max_length=None, reverse=False): 409 | if reverse: 410 | tokens = [t[0] for t in reversed(logits)] 411 | else: 412 | tokens = [int(np.argmax(t, axis=1)) for t in logits] 413 | if data_utils.EOS_ID in tokens: 414 | eos = tokens.index(data_utils.EOS_ID) 415 | tokens = tokens[:eos] 416 | txt = [rev_vocab[t] for t in tokens] 417 | if sent_max_length: 418 | tokens, txt = tokens[:sent_max_length], txt[:sent_max_length] 419 | return tokens, txt 420 | 421 | 422 | def discount_rewards(self, r, gamma=0.99): 423 | """ take 1D float array of rewards and compute discounted reward """ 424 | discounted_r = np.zeros_like(r) 425 | running_add = 0 426 | for t in reversed(xrange(0, r.size)): 427 | running_add = running_add * gamma + r[t] 428 | discounted_r[t] = running_add 429 | return discounted_r 430 | 431 | 432 | def get_batch(self, data, bucket_id): 433 | """Get a random batch of data from the specified bucket, prepare for step. 434 | 435 | To feed data in step(..) it must be a list of batch-major vectors, while 436 | data here contains single length-major cases. So the main logic of this 437 | function is to re-index data cases to be in the proper format for feeding. 438 | 439 | Args: 440 | data: a tuple of size len(self.buckets) in which each element contains 441 | lists of pairs of input and output data that we use to create a batch. 442 | bucket_id: integer, which bucket to get the batch for. 443 | 444 | Returns: 445 | The triple (encoder_inputs, decoder_inputs, target_weights) for 446 | the constructed batch that has the proper format to call step(...) later. 447 | """ 448 | encoder_size, decoder_size = self.buckets[bucket_id] 449 | encoder_inputs, decoder_inputs = [], [] 450 | 451 | # Get a random batch of encoder and decoder inputs from data, 452 | # pad them if needed, reverse encoder inputs and add GO to decoder. 453 | for _ in xrange(self.batch_size): 454 | encoder_input, decoder_input = random.choice(data[bucket_id]) 455 | 456 | # Encoder inputs are padded and then reversed. 457 | encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input)) 458 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) 459 | 460 | # Decoder inputs get an extra "GO" symbol, and are padded then. 461 | decoder_pad_size = decoder_size - len(decoder_input) - 1 462 | decoder_inputs.append([data_utils.GO_ID] + decoder_input + 463 | [data_utils.PAD_ID] * decoder_pad_size) 464 | 465 | # Now we create batch-major vectors from the data selected above. 466 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 467 | 468 | # Batch encoder inputs are just re-indexed encoder_inputs. 469 | for length_idx in xrange(encoder_size): 470 | batch_encoder_inputs.append( 471 | np.array([encoder_inputs[batch_idx][length_idx] 472 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 473 | 474 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 475 | for length_idx in xrange(decoder_size): 476 | batch_decoder_inputs.append( 477 | np.array([decoder_inputs[batch_idx][length_idx] 478 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 479 | 480 | # Create target_weights to be 0 for targets that are padding. 481 | batch_weight = np.ones(self.batch_size, dtype=np.float32) 482 | for batch_idx in xrange(self.batch_size): 483 | # We set weight to 0 if the corresponding target is a PAD symbol. 484 | # The corresponding target is decoder_input shifted by 1 forward. 485 | if length_idx < decoder_size - 1: 486 | target = decoder_inputs[batch_idx][length_idx + 1] 487 | if length_idx == decoder_size - 1 or target == data_utils.PAD_ID: 488 | batch_weight[batch_idx] = 0.0 489 | batch_weights.append(batch_weight) 490 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights 491 | -------------------------------------------------------------------------------- /ref/seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Library for creating sequence-to-sequence models in TensorFlow. 16 | 17 | Sequence-to-sequence recurrent neural networks can learn complex functions 18 | that map input sequences to output sequences. These models yield very good 19 | results on a number of tasks, such as speech recognition, parsing, machine 20 | translation, or even constructing automated replies to emails. 21 | 22 | Before using this module, it is recommended to read the TensorFlow tutorial 23 | on sequence-to-sequence models. It explains the basic concepts of this module 24 | and shows an end-to-end example of how to build a translation model. 25 | https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 26 | 27 | Here is an overview of functions available in this module. They all use 28 | a very similar interface, so after reading the above tutorial and using 29 | one of them, others should be easy to substitute. 30 | 31 | * Full sequence-to-sequence models. 32 | - basic_rnn_seq2seq: The most basic RNN-RNN model. 33 | - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights. 34 | - embedding_rnn_seq2seq: The basic model with input embedding. 35 | - embedding_tied_rnn_seq2seq: The tied model with input embedding. 36 | - embedding_attention_seq2seq: Advanced model with input embedding and 37 | the neural attention mechanism; recommended for complex tasks. 38 | 39 | * Multi-task sequence-to-sequence models. 40 | - one2many_rnn_seq2seq: The embedding model with multiple decoders. 41 | 42 | * Decoders (when you write your own encoder, you can use these to decode; 43 | e.g., if you want to write a model that generates captions for images). 44 | - rnn_decoder: The basic decoder based on a pure RNN. 45 | - attention_decoder: A decoder that uses the attention mechanism. 46 | 47 | * Losses. 48 | - sequence_loss: Loss for a sequence model returning average log-perplexity. 49 | - sequence_loss_by_example: As above, but not averaging over all examples. 50 | 51 | * model_with_buckets: A convenience function to create models with bucketing 52 | (see the tutorial above for an explanation of why and how to use it). 53 | """ 54 | 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import copy 60 | 61 | # We disable pylint because we need python3 compatibility. 62 | from six.moves import xrange # pylint: disable=redefined-builtin 63 | from six.moves import zip # pylint: disable=redefined-builtin 64 | 65 | from tensorflow.contrib.rnn.python.ops import core_rnn 66 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell 67 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 68 | from tensorflow.python.framework import dtypes 69 | from tensorflow.python.framework import ops 70 | from tensorflow.python.ops import array_ops 71 | from tensorflow.python.ops import control_flow_ops 72 | from tensorflow.python.ops import embedding_ops 73 | from tensorflow.python.ops import math_ops 74 | from tensorflow.python.ops import nn_ops 75 | from tensorflow.python.ops import variable_scope 76 | from tensorflow.python.util import nest 77 | 78 | # TODO(ebrevdo): Remove once _linear is fully deprecated. 79 | linear = core_rnn_cell_impl._linear # pylint: disable=protected-access 80 | 81 | 82 | def _extract_argmax_and_embed(embedding, 83 | output_projection=None, 84 | update_embedding=True): 85 | """Get a loop_function that extracts the previous symbol and embeds it. 86 | 87 | Args: 88 | embedding: embedding tensor for symbols. 89 | output_projection: None or a pair (W, B). If provided, each fed previous 90 | output will first be multiplied by W and added B. 91 | update_embedding: Boolean; if False, the gradients will not propagate 92 | through the embeddings. 93 | 94 | Returns: 95 | A loop function. 96 | """ 97 | 98 | def loop_function(prev, _): 99 | if output_projection is not None: 100 | prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1]) 101 | prev_symbol = math_ops.argmax(prev, 1) 102 | # Note that gradients will not propagate through the second parameter of 103 | # embedding_lookup. 104 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 105 | if not update_embedding: 106 | emb_prev = array_ops.stop_gradient(emb_prev) 107 | return emb_prev 108 | 109 | return loop_function 110 | 111 | 112 | def rnn_decoder(decoder_inputs, 113 | initial_state, 114 | cell, 115 | loop_function=None, 116 | scope=None): 117 | """RNN decoder for the sequence-to-sequence model. 118 | 119 | Args: 120 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 121 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 122 | cell: core_rnn_cell.RNNCell defining the cell function and size. 123 | loop_function: If not None, this function will be applied to the i-th output 124 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 125 | except for the first element ("GO" symbol). This can be used for decoding, 126 | but also for training to emulate http://arxiv.org/abs/1506.03099. 127 | Signature -- loop_function(prev, i) = next 128 | * prev is a 2D Tensor of shape [batch_size x output_size], 129 | * i is an integer, the step number (when advanced control is needed), 130 | * next is a 2D Tensor of shape [batch_size x input_size]. 131 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 132 | 133 | Returns: 134 | A tuple of the form (outputs, state), where: 135 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 136 | shape [batch_size x output_size] containing generated outputs. 137 | state: The state of each cell at the final time-step. 138 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 139 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 140 | states can be the same. They are different for LSTM cells though.) 141 | """ 142 | with variable_scope.variable_scope(scope or "rnn_decoder"): 143 | state = initial_state 144 | outputs = [] 145 | prev = None 146 | for i, inp in enumerate(decoder_inputs): 147 | if loop_function is not None and prev is not None: 148 | with variable_scope.variable_scope("loop_function", reuse=True): 149 | inp = loop_function(prev, i) 150 | if i > 0: 151 | variable_scope.get_variable_scope().reuse_variables() 152 | output, state = cell(inp, state) 153 | outputs.append(output) 154 | if loop_function is not None: 155 | prev = output 156 | return outputs, state 157 | 158 | 159 | def basic_rnn_seq2seq(encoder_inputs, 160 | decoder_inputs, 161 | cell, 162 | dtype=dtypes.float32, 163 | scope=None): 164 | """Basic RNN sequence-to-sequence model. 165 | 166 | This model first runs an RNN to encode encoder_inputs into a state vector, 167 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 168 | Encoder and decoder use the same RNN cell type, but don't share parameters. 169 | 170 | Args: 171 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 172 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 173 | cell: core_rnn_cell.RNNCell defining the cell function and size. 174 | dtype: The dtype of the initial state of the RNN cell (default: tf.float32). 175 | scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". 176 | 177 | Returns: 178 | A tuple of the form (outputs, state), where: 179 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 180 | shape [batch_size x output_size] containing the generated outputs. 181 | state: The state of each decoder cell in the final time-step. 182 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 183 | """ 184 | with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): 185 | enc_cell = copy.deepcopy(cell) 186 | _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) 187 | return rnn_decoder(decoder_inputs, enc_state, cell) 188 | 189 | 190 | def tied_rnn_seq2seq(encoder_inputs, 191 | decoder_inputs, 192 | cell, 193 | loop_function=None, 194 | dtype=dtypes.float32, 195 | scope=None): 196 | """RNN sequence-to-sequence model with tied encoder and decoder parameters. 197 | 198 | This model first runs an RNN to encode encoder_inputs into a state vector, and 199 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 200 | Encoder and decoder use the same RNN cell and share parameters. 201 | 202 | Args: 203 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 204 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 205 | cell: core_rnn_cell.RNNCell defining the cell function and size. 206 | loop_function: If not None, this function will be applied to i-th output 207 | in order to generate i+1-th input, and decoder_inputs will be ignored, 208 | except for the first element ("GO" symbol), see rnn_decoder for details. 209 | dtype: The dtype of the initial state of the rnn cell (default: tf.float32). 210 | scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". 211 | 212 | Returns: 213 | A tuple of the form (outputs, state), where: 214 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 215 | shape [batch_size x output_size] containing the generated outputs. 216 | state: The state of each decoder cell in each time-step. This is a list 217 | with length len(decoder_inputs) -- one item for each time-step. 218 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 219 | """ 220 | with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): 221 | scope = scope or "tied_rnn_seq2seq" 222 | _, enc_state = core_rnn.static_rnn( 223 | cell, encoder_inputs, dtype=dtype, scope=scope) 224 | variable_scope.get_variable_scope().reuse_variables() 225 | return rnn_decoder( 226 | decoder_inputs, 227 | enc_state, 228 | cell, 229 | loop_function=loop_function, 230 | scope=scope) 231 | 232 | 233 | def embedding_rnn_decoder(decoder_inputs, 234 | initial_state, 235 | cell, 236 | num_symbols, 237 | embedding_size, 238 | output_projection=None, 239 | feed_previous=False, 240 | update_embedding_for_previous=True, 241 | scope=None): 242 | """RNN decoder with embedding and a pure-decoding option. 243 | 244 | Args: 245 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 246 | initial_state: 2D Tensor [batch_size x cell.state_size]. 247 | cell: core_rnn_cell.RNNCell defining the cell function. 248 | num_symbols: Integer, how many symbols come into the embedding. 249 | embedding_size: Integer, the length of the embedding vector for each symbol. 250 | output_projection: None or a pair (W, B) of output projection weights and 251 | biases; W has shape [output_size x num_symbols] and B has 252 | shape [num_symbols]; if provided and feed_previous=True, each fed 253 | previous output will first be multiplied by W and added B. 254 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 255 | used (the "GO" symbol), and all other decoder inputs will be generated by: 256 | next = embedding_lookup(embedding, argmax(previous_output)), 257 | In effect, this implements a greedy decoder. It can also be used 258 | during training to emulate http://arxiv.org/abs/1506.03099. 259 | If False, decoder_inputs are used as given (the standard decoder case). 260 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 261 | only the embedding for the first symbol of decoder_inputs (the "GO" 262 | symbol) will be updated by back propagation. Embeddings for the symbols 263 | generated from the decoder itself remain unchanged. This parameter has 264 | no effect if feed_previous=False. 265 | scope: VariableScope for the created subgraph; defaults to 266 | "embedding_rnn_decoder". 267 | 268 | Returns: 269 | A tuple of the form (outputs, state), where: 270 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The 271 | output is of shape [batch_size x cell.output_size] when 272 | output_projection is not None (and represents the dense representation 273 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 274 | when output_projection is None. 275 | state: The state of each decoder cell in each time-step. This is a list 276 | with length len(decoder_inputs) -- one item for each time-step. 277 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 278 | 279 | Raises: 280 | ValueError: When output_projection has the wrong shape. 281 | """ 282 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope: 283 | if output_projection is not None: 284 | dtype = scope.dtype 285 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 286 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 287 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 288 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 289 | 290 | embedding = variable_scope.get_variable("embedding", 291 | [num_symbols, embedding_size]) 292 | loop_function = _extract_argmax_and_embed( 293 | embedding, output_projection, 294 | update_embedding_for_previous) if feed_previous else None 295 | emb_inp = (embedding_ops.embedding_lookup(embedding, i) 296 | for i in decoder_inputs) 297 | return rnn_decoder( 298 | emb_inp, initial_state, cell, loop_function=loop_function) 299 | 300 | 301 | def embedding_rnn_seq2seq(encoder_inputs, 302 | decoder_inputs, 303 | cell, 304 | num_encoder_symbols, 305 | num_decoder_symbols, 306 | embedding_size, 307 | output_projection=None, 308 | feed_previous=False, 309 | dtype=None, 310 | scope=None): 311 | """Embedding RNN sequence-to-sequence model. 312 | 313 | This model first embeds encoder_inputs by a newly created embedding (of shape 314 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 315 | embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs 316 | by another newly created embedding (of shape [num_decoder_symbols x 317 | input_size]). Then it runs RNN decoder, initialized with the last 318 | encoder state, on embedded decoder_inputs. 319 | 320 | Args: 321 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 322 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 323 | cell: core_rnn_cell.RNNCell defining the cell function and size. 324 | num_encoder_symbols: Integer; number of symbols on the encoder side. 325 | num_decoder_symbols: Integer; number of symbols on the decoder side. 326 | embedding_size: Integer, the length of the embedding vector for each symbol. 327 | output_projection: None or a pair (W, B) of output projection weights and 328 | biases; W has shape [output_size x num_decoder_symbols] and B has 329 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 330 | fed previous output will first be multiplied by W and added B. 331 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 332 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 333 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 334 | If False, decoder_inputs are used as given (the standard decoder case). 335 | dtype: The dtype of the initial state for both the encoder and encoder 336 | rnn cells (default: tf.float32). 337 | scope: VariableScope for the created subgraph; defaults to 338 | "embedding_rnn_seq2seq" 339 | 340 | Returns: 341 | A tuple of the form (outputs, state), where: 342 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The 343 | output is of shape [batch_size x cell.output_size] when 344 | output_projection is not None (and represents the dense representation 345 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 346 | when output_projection is None. 347 | state: The state of each decoder cell in each time-step. This is a list 348 | with length len(decoder_inputs) -- one item for each time-step. 349 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 350 | """ 351 | with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope: 352 | if dtype is not None: 353 | scope.set_dtype(dtype) 354 | else: 355 | dtype = scope.dtype 356 | 357 | # Encoder. 358 | encoder_cell = copy.deepcopy(cell) 359 | encoder_cell = core_rnn_cell.EmbeddingWrapper( 360 | encoder_cell, 361 | embedding_classes=num_encoder_symbols, 362 | embedding_size=embedding_size) 363 | _, encoder_state = core_rnn.static_rnn( 364 | encoder_cell, encoder_inputs, dtype=dtype) 365 | 366 | # Decoder. 367 | if output_projection is None: 368 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 369 | 370 | if isinstance(feed_previous, bool): 371 | return embedding_rnn_decoder( 372 | decoder_inputs, 373 | encoder_state, 374 | cell, 375 | num_decoder_symbols, 376 | embedding_size, 377 | output_projection=output_projection, 378 | feed_previous=feed_previous) 379 | 380 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 381 | def decoder(feed_previous_bool): 382 | reuse = None if feed_previous_bool else True 383 | with variable_scope.variable_scope( 384 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 385 | outputs, state = embedding_rnn_decoder( 386 | decoder_inputs, 387 | encoder_state, 388 | cell, 389 | num_decoder_symbols, 390 | embedding_size, 391 | output_projection=output_projection, 392 | feed_previous=feed_previous_bool, 393 | update_embedding_for_previous=False) 394 | state_list = [state] 395 | if nest.is_sequence(state): 396 | state_list = nest.flatten(state) 397 | return outputs + state_list 398 | 399 | outputs_and_state = control_flow_ops.cond(feed_previous, 400 | lambda: decoder(True), 401 | lambda: decoder(False)) 402 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 403 | state_list = outputs_and_state[outputs_len:] 404 | state = state_list[0] 405 | if nest.is_sequence(encoder_state): 406 | state = nest.pack_sequence_as( 407 | structure=encoder_state, flat_sequence=state_list) 408 | return outputs_and_state[:outputs_len], state 409 | 410 | 411 | def embedding_tied_rnn_seq2seq(encoder_inputs, 412 | decoder_inputs, 413 | cell, 414 | num_symbols, 415 | embedding_size, 416 | num_decoder_symbols=None, 417 | output_projection=None, 418 | feed_previous=False, 419 | dtype=None, 420 | scope=None): 421 | """Embedding RNN sequence-to-sequence model with tied (shared) parameters. 422 | 423 | This model first embeds encoder_inputs by a newly created embedding (of shape 424 | [num_symbols x input_size]). Then it runs an RNN to encode embedded 425 | encoder_inputs into a state vector. Next, it embeds decoder_inputs using 426 | the same embedding. Then it runs RNN decoder, initialized with the last 427 | encoder state, on embedded decoder_inputs. The decoder output is over symbols 428 | from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it 429 | is over 0 to num_symbols - 1. 430 | 431 | Args: 432 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 433 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 434 | cell: core_rnn_cell.RNNCell defining the cell function and size. 435 | num_symbols: Integer; number of symbols for both encoder and decoder. 436 | embedding_size: Integer, the length of the embedding vector for each symbol. 437 | num_decoder_symbols: Integer; number of output symbols for decoder. If 438 | provided, the decoder output is over symbols 0 to num_decoder_symbols - 1. 439 | Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that 440 | this assumes that the vocabulary is set up such that the first 441 | num_decoder_symbols of num_symbols are part of decoding. 442 | output_projection: None or a pair (W, B) of output projection weights and 443 | biases; W has shape [output_size x num_symbols] and B has 444 | shape [num_symbols]; if provided and feed_previous=True, each 445 | fed previous output will first be multiplied by W and added B. 446 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 447 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 448 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 449 | If False, decoder_inputs are used as given (the standard decoder case). 450 | dtype: The dtype to use for the initial RNN states (default: tf.float32). 451 | scope: VariableScope for the created subgraph; defaults to 452 | "embedding_tied_rnn_seq2seq". 453 | 454 | Returns: 455 | A tuple of the form (outputs, state), where: 456 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 457 | shape [batch_size x output_symbols] containing the generated 458 | outputs where output_symbols = num_decoder_symbols if 459 | num_decoder_symbols is not None otherwise output_symbols = num_symbols. 460 | state: The state of each decoder cell at the final time-step. 461 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 462 | 463 | Raises: 464 | ValueError: When output_projection has the wrong shape. 465 | """ 466 | with variable_scope.variable_scope( 467 | scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope: 468 | dtype = scope.dtype 469 | 470 | if output_projection is not None: 471 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 472 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 473 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 474 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 475 | 476 | embedding = variable_scope.get_variable( 477 | "embedding", [num_symbols, embedding_size], dtype=dtype) 478 | 479 | emb_encoder_inputs = [ 480 | embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs 481 | ] 482 | emb_decoder_inputs = [ 483 | embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs 484 | ] 485 | 486 | output_symbols = num_symbols 487 | if num_decoder_symbols is not None: 488 | output_symbols = num_decoder_symbols 489 | if output_projection is None: 490 | cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols) 491 | 492 | if isinstance(feed_previous, bool): 493 | loop_function = _extract_argmax_and_embed(embedding, output_projection, 494 | True) if feed_previous else None 495 | return tied_rnn_seq2seq( 496 | emb_encoder_inputs, 497 | emb_decoder_inputs, 498 | cell, 499 | loop_function=loop_function, 500 | dtype=dtype) 501 | 502 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 503 | def decoder(feed_previous_bool): 504 | loop_function = _extract_argmax_and_embed( 505 | embedding, output_projection, False) if feed_previous_bool else None 506 | reuse = None if feed_previous_bool else True 507 | with variable_scope.variable_scope( 508 | variable_scope.get_variable_scope(), reuse=reuse): 509 | outputs, state = tied_rnn_seq2seq( 510 | emb_encoder_inputs, 511 | emb_decoder_inputs, 512 | cell, 513 | loop_function=loop_function, 514 | dtype=dtype) 515 | state_list = [state] 516 | if nest.is_sequence(state): 517 | state_list = nest.flatten(state) 518 | return outputs + state_list 519 | 520 | outputs_and_state = control_flow_ops.cond(feed_previous, 521 | lambda: decoder(True), 522 | lambda: decoder(False)) 523 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 524 | state_list = outputs_and_state[outputs_len:] 525 | state = state_list[0] 526 | # Calculate zero-state to know it's structure. 527 | static_batch_size = encoder_inputs[0].get_shape()[0] 528 | for inp in encoder_inputs[1:]: 529 | static_batch_size.merge_with(inp.get_shape()[0]) 530 | batch_size = static_batch_size.value 531 | if batch_size is None: 532 | batch_size = array_ops.shape(encoder_inputs[0])[0] 533 | zero_state = cell.zero_state(batch_size, dtype) 534 | if nest.is_sequence(zero_state): 535 | state = nest.pack_sequence_as( 536 | structure=zero_state, flat_sequence=state_list) 537 | return outputs_and_state[:outputs_len], state 538 | 539 | 540 | def attention_decoder(decoder_inputs, 541 | initial_state, 542 | attention_states, 543 | cell, 544 | output_size=None, 545 | num_heads=1, 546 | loop_function=None, 547 | dtype=None, 548 | scope=None, 549 | initial_state_attention=False): 550 | """RNN decoder with attention for the sequence-to-sequence model. 551 | 552 | In this context "attention" means that, during decoding, the RNN can look up 553 | information in the additional tensor attention_states, and it does this by 554 | focusing on a few entries from the tensor. This model has proven to yield 555 | especially good results in a number of sequence-to-sequence tasks. This 556 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for 557 | details). It is recommended for complex sequence-to-sequence tasks. 558 | 559 | Args: 560 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 561 | initial_state: 2D Tensor [batch_size x cell.state_size]. 562 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 563 | cell: core_rnn_cell.RNNCell defining the cell function and size. 564 | output_size: Size of the output vectors; if None, we use cell.output_size. 565 | num_heads: Number of attention heads that read from attention_states. 566 | loop_function: If not None, this function will be applied to i-th output 567 | in order to generate i+1-th input, and decoder_inputs will be ignored, 568 | except for the first element ("GO" symbol). This can be used for decoding, 569 | but also for training to emulate http://arxiv.org/abs/1506.03099. 570 | Signature -- loop_function(prev, i) = next 571 | * prev is a 2D Tensor of shape [batch_size x output_size], 572 | * i is an integer, the step number (when advanced control is needed), 573 | * next is a 2D Tensor of shape [batch_size x input_size]. 574 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 575 | scope: VariableScope for the created subgraph; default: "attention_decoder". 576 | initial_state_attention: If False (default), initial attentions are zero. 577 | If True, initialize the attentions from the initial state and attention 578 | states -- useful when we wish to resume decoding from a previously 579 | stored decoder state and attention states. 580 | 581 | Returns: 582 | A tuple of the form (outputs, state), where: 583 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 584 | shape [batch_size x output_size]. These represent the generated outputs. 585 | Output i is computed from input i (which is either the i-th element 586 | of decoder_inputs or loop_function(output {i-1}, i)) as follows. 587 | First, we run the cell on a combination of the input and previous 588 | attention masks: 589 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 590 | Then, we calculate new attention masks: 591 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 592 | and then we calculate the output: 593 | output = linear(cell_output, new_attn). 594 | state: The state of each decoder cell the final time-step. 595 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 596 | 597 | Raises: 598 | ValueError: when num_heads is not positive, there are no inputs, shapes 599 | of attention_states are not set, or input size cannot be inferred 600 | from the input. 601 | """ 602 | if not decoder_inputs: 603 | raise ValueError("Must provide at least 1 input to attention decoder.") 604 | if num_heads < 1: 605 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 606 | if attention_states.get_shape()[2].value is None: 607 | raise ValueError("Shape[2] of attention_states must be known: %s" % 608 | attention_states.get_shape()) 609 | if output_size is None: 610 | output_size = cell.output_size 611 | 612 | with variable_scope.variable_scope( 613 | scope or "attention_decoder", dtype=dtype) as scope: 614 | dtype = scope.dtype 615 | 616 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 617 | attn_length = attention_states.get_shape()[1].value 618 | if attn_length is None: 619 | attn_length = array_ops.shape(attention_states)[1] 620 | attn_size = attention_states.get_shape()[2].value 621 | 622 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 623 | hidden = array_ops.reshape(attention_states, 624 | [-1, attn_length, 1, attn_size]) 625 | hidden_features = [] 626 | v = [] 627 | attention_vec_size = attn_size # Size of query vectors for attention. 628 | for a in xrange(num_heads): 629 | k = variable_scope.get_variable("AttnW_%d" % a, 630 | [1, 1, attn_size, attention_vec_size]) 631 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 632 | v.append( 633 | variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) 634 | 635 | state = initial_state 636 | 637 | def attention(query): 638 | """Put attention masks on hidden using hidden_features and query.""" 639 | ds = [] # Results of attention reads will be stored here. 640 | if nest.is_sequence(query): # If the query is a tuple, flatten it. 641 | query_list = nest.flatten(query) 642 | for q in query_list: # Check that ndims == 2 if specified. 643 | ndims = q.get_shape().ndims 644 | if ndims: 645 | assert ndims == 2 646 | query = array_ops.concat(query_list, 1) 647 | for a in xrange(num_heads): 648 | with variable_scope.variable_scope("Attention_%d" % a): 649 | y = linear(query, attention_vec_size, True) 650 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 651 | # Attention mask is a softmax of v^T * tanh(...). 652 | s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), 653 | [2, 3]) 654 | a = nn_ops.softmax(s) 655 | # Now calculate the attention-weighted vector d. 656 | d = math_ops.reduce_sum( 657 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) 658 | ds.append(array_ops.reshape(d, [-1, attn_size])) 659 | return ds 660 | 661 | outputs = [] 662 | prev = None 663 | batch_attn_size = array_ops.stack([batch_size, attn_size]) 664 | attns = [ 665 | array_ops.zeros( 666 | batch_attn_size, dtype=dtype) for _ in xrange(num_heads) 667 | ] 668 | for a in attns: # Ensure the second shape of attention vectors is set. 669 | a.set_shape([None, attn_size]) 670 | if initial_state_attention: 671 | attns = attention(initial_state) 672 | for i, inp in enumerate(decoder_inputs): 673 | if i > 0: 674 | variable_scope.get_variable_scope().reuse_variables() 675 | # If loop_function is set, we use it instead of decoder_inputs. 676 | if loop_function is not None and prev is not None: 677 | with variable_scope.variable_scope("loop_function", reuse=True): 678 | inp = loop_function(prev, i) 679 | # Merge input and previous attentions into one vector of the right size. 680 | input_size = inp.get_shape().with_rank(2)[1] 681 | if input_size.value is None: 682 | raise ValueError("Could not infer input size from input: %s" % inp.name) 683 | x = linear([inp] + attns, input_size, True) 684 | # Run the RNN. 685 | cell_output, state = cell(x, state) 686 | # Run the attention mechanism. 687 | if i == 0 and initial_state_attention: 688 | with variable_scope.variable_scope( 689 | variable_scope.get_variable_scope(), reuse=True): 690 | attns = attention(state) 691 | else: 692 | attns = attention(state) 693 | 694 | with variable_scope.variable_scope("AttnOutputProjection"): 695 | output = linear([cell_output] + attns, output_size, True) 696 | if loop_function is not None: 697 | prev = output 698 | outputs.append(output) 699 | 700 | return outputs, state 701 | 702 | 703 | def embedding_attention_decoder(decoder_inputs, 704 | initial_state, 705 | attention_states, 706 | cell, 707 | num_symbols, 708 | embedding_size, 709 | num_heads=1, 710 | output_size=None, 711 | output_projection=None, 712 | feed_previous=False, 713 | update_embedding_for_previous=True, 714 | dtype=None, 715 | scope=None, 716 | initial_state_attention=False): 717 | """RNN decoder with embedding and attention and a pure-decoding option. 718 | 719 | Args: 720 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 721 | initial_state: 2D Tensor [batch_size x cell.state_size]. 722 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 723 | cell: core_rnn_cell.RNNCell defining the cell function. 724 | num_symbols: Integer, how many symbols come into the embedding. 725 | embedding_size: Integer, the length of the embedding vector for each symbol. 726 | num_heads: Number of attention heads that read from attention_states. 727 | output_size: Size of the output vectors; if None, use output_size. 728 | output_projection: None or a pair (W, B) of output projection weights and 729 | biases; W has shape [output_size x num_symbols] and B has shape 730 | [num_symbols]; if provided and feed_previous=True, each fed previous 731 | output will first be multiplied by W and added B. 732 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 733 | used (the "GO" symbol), and all other decoder inputs will be generated by: 734 | next = embedding_lookup(embedding, argmax(previous_output)), 735 | In effect, this implements a greedy decoder. It can also be used 736 | during training to emulate http://arxiv.org/abs/1506.03099. 737 | If False, decoder_inputs are used as given (the standard decoder case). 738 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 739 | only the embedding for the first symbol of decoder_inputs (the "GO" 740 | symbol) will be updated by back propagation. Embeddings for the symbols 741 | generated from the decoder itself remain unchanged. This parameter has 742 | no effect if feed_previous=False. 743 | dtype: The dtype to use for the RNN initial states (default: tf.float32). 744 | scope: VariableScope for the created subgraph; defaults to 745 | "embedding_attention_decoder". 746 | initial_state_attention: If False (default), initial attentions are zero. 747 | If True, initialize the attentions from the initial state and attention 748 | states -- useful when we wish to resume decoding from a previously 749 | stored decoder state and attention states. 750 | 751 | Returns: 752 | A tuple of the form (outputs, state), where: 753 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 754 | shape [batch_size x output_size] containing the generated outputs. 755 | state: The state of each decoder cell at the final time-step. 756 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 757 | 758 | Raises: 759 | ValueError: When output_projection has the wrong shape. 760 | """ 761 | if output_size is None: 762 | output_size = cell.output_size 763 | if output_projection is not None: 764 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 765 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 766 | 767 | with variable_scope.variable_scope( 768 | scope or "embedding_attention_decoder", dtype=dtype) as scope: 769 | 770 | embedding = variable_scope.get_variable("embedding", 771 | [num_symbols, embedding_size]) 772 | loop_function = _extract_argmax_and_embed( 773 | embedding, output_projection, 774 | update_embedding_for_previous) if feed_previous else None 775 | emb_inp = [ 776 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs 777 | ] 778 | return attention_decoder( 779 | emb_inp, 780 | initial_state, 781 | attention_states, 782 | cell, 783 | output_size=output_size, 784 | num_heads=num_heads, 785 | loop_function=loop_function, 786 | initial_state_attention=initial_state_attention) 787 | 788 | 789 | def embedding_attention_seq2seq(encoder_inputs, 790 | decoder_inputs, 791 | cell, 792 | num_encoder_symbols, 793 | num_decoder_symbols, 794 | embedding_size, 795 | num_heads=1, 796 | output_projection=None, 797 | feed_previous=False, 798 | dtype=None, 799 | scope=None, 800 | initial_state_attention=False): 801 | """Embedding sequence-to-sequence model with attention. 802 | 803 | This model first embeds encoder_inputs by a newly created embedding (of shape 804 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 805 | embedded encoder_inputs into a state vector. It keeps the outputs of this 806 | RNN at every step to use for attention later. Next, it embeds decoder_inputs 807 | by another newly created embedding (of shape [num_decoder_symbols x 808 | input_size]). Then it runs attention decoder, initialized with the last 809 | encoder state, on embedded decoder_inputs and attending to encoder outputs. 810 | 811 | Warning: when output_projection is None, the size of the attention vectors 812 | and variables will be made proportional to num_decoder_symbols, can be large. 813 | 814 | Args: 815 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 816 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 817 | cell: core_rnn_cell.RNNCell defining the cell function and size. 818 | num_encoder_symbols: Integer; number of symbols on the encoder side. 819 | num_decoder_symbols: Integer; number of symbols on the decoder side. 820 | embedding_size: Integer, the length of the embedding vector for each symbol. 821 | num_heads: Number of attention heads that read from attention_states. 822 | output_projection: None or a pair (W, B) of output projection weights and 823 | biases; W has shape [output_size x num_decoder_symbols] and B has 824 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 825 | fed previous output will first be multiplied by W and added B. 826 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 827 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 828 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 829 | If False, decoder_inputs are used as given (the standard decoder case). 830 | dtype: The dtype of the initial RNN state (default: tf.float32). 831 | scope: VariableScope for the created subgraph; defaults to 832 | "embedding_attention_seq2seq". 833 | initial_state_attention: If False (default), initial attentions are zero. 834 | If True, initialize the attentions from the initial state and attention 835 | states. 836 | 837 | Returns: 838 | A tuple of the form (outputs, state), where: 839 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 840 | shape [batch_size x num_decoder_symbols] containing the generated 841 | outputs. 842 | state: The state of each decoder cell at the final time-step. 843 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 844 | """ 845 | with variable_scope.variable_scope( 846 | scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 847 | dtype = scope.dtype 848 | # Encoder. 849 | encoder_cell = copy.deepcopy(cell) 850 | encoder_cell = core_rnn_cell.EmbeddingWrapper( 851 | encoder_cell, 852 | embedding_classes=num_encoder_symbols, 853 | embedding_size=embedding_size) 854 | encoder_outputs, encoder_state = core_rnn.static_rnn( 855 | encoder_cell, encoder_inputs, dtype=dtype) 856 | 857 | # First calculate a concatenation of encoder outputs to put attention on. 858 | top_states = [ 859 | array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs 860 | ] 861 | attention_states = array_ops.concat(top_states, 1) 862 | 863 | # Decoder. 864 | output_size = None 865 | if output_projection is None: 866 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 867 | output_size = num_decoder_symbols 868 | 869 | if isinstance(feed_previous, bool): 870 | return embedding_attention_decoder( 871 | decoder_inputs, 872 | encoder_state, 873 | attention_states, 874 | cell, 875 | num_decoder_symbols, 876 | embedding_size, 877 | num_heads=num_heads, 878 | output_size=output_size, 879 | output_projection=output_projection, 880 | feed_previous=feed_previous, 881 | initial_state_attention=initial_state_attention) 882 | 883 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 884 | def decoder(feed_previous_bool): 885 | reuse = None if feed_previous_bool else True 886 | with variable_scope.variable_scope( 887 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 888 | outputs, state = embedding_attention_decoder( 889 | decoder_inputs, 890 | encoder_state, 891 | attention_states, 892 | cell, 893 | num_decoder_symbols, 894 | embedding_size, 895 | num_heads=num_heads, 896 | output_size=output_size, 897 | output_projection=output_projection, 898 | feed_previous=feed_previous_bool, 899 | update_embedding_for_previous=False, 900 | initial_state_attention=initial_state_attention) 901 | state_list = [state] 902 | if nest.is_sequence(state): 903 | state_list = nest.flatten(state) 904 | return outputs + state_list 905 | 906 | outputs_and_state = control_flow_ops.cond(feed_previous, 907 | lambda: decoder(True), 908 | lambda: decoder(False)) 909 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 910 | state_list = outputs_and_state[outputs_len:] 911 | state = state_list[0] 912 | if nest.is_sequence(encoder_state): 913 | state = nest.pack_sequence_as( 914 | structure=encoder_state, flat_sequence=state_list) 915 | return outputs_and_state[:outputs_len], state 916 | 917 | 918 | def one2many_rnn_seq2seq(encoder_inputs, 919 | decoder_inputs_dict, 920 | enc_cell, 921 | dec_cells_dict, 922 | num_encoder_symbols, 923 | num_decoder_symbols_dict, 924 | embedding_size, 925 | feed_previous=False, 926 | dtype=None, 927 | scope=None): 928 | """One-to-many RNN sequence-to-sequence model (multi-task). 929 | 930 | This is a multi-task sequence-to-sequence model with one encoder and multiple 931 | decoders. Reference to multi-task sequence-to-sequence learning can be found 932 | here: http://arxiv.org/abs/1511.06114 933 | 934 | Args: 935 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 936 | decoder_inputs_dict: A dictionary mapping decoder name (string) to 937 | the corresponding decoder_inputs; each decoder_inputs is a list of 1D 938 | Tensors of shape [batch_size]; num_decoders is defined as 939 | len(decoder_inputs_dict). 940 | enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size. 941 | dec_cells_dict: A dictionary mapping encoder name (string) to an 942 | instance of core_rnn_cell.RNNCell. 943 | num_encoder_symbols: Integer; number of symbols on the encoder side. 944 | num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an 945 | integer specifying number of symbols for the corresponding decoder; 946 | len(num_decoder_symbols_dict) must be equal to num_decoders. 947 | embedding_size: Integer, the length of the embedding vector for each symbol. 948 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of 949 | decoder_inputs will be used (the "GO" symbol), and all other decoder 950 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 951 | If False, decoder_inputs are used as given (the standard decoder case). 952 | dtype: The dtype of the initial state for both the encoder and encoder 953 | rnn cells (default: tf.float32). 954 | scope: VariableScope for the created subgraph; defaults to 955 | "one2many_rnn_seq2seq" 956 | 957 | Returns: 958 | A tuple of the form (outputs_dict, state_dict), where: 959 | outputs_dict: A mapping from decoder name (string) to a list of the same 960 | length as decoder_inputs_dict[name]; each element in the list is a 2D 961 | Tensors with shape [batch_size x num_decoder_symbol_list[name]] 962 | containing the generated outputs. 963 | state_dict: A mapping from decoder name (string) to the final state of the 964 | corresponding decoder RNN; it is a 2D Tensor of shape 965 | [batch_size x cell.state_size]. 966 | 967 | Raises: 968 | TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell. 969 | ValueError: if len(dec_cells) != len(decoder_inputs_dict). 970 | """ 971 | outputs_dict = {} 972 | state_dict = {} 973 | 974 | if not isinstance(enc_cell, core_rnn_cell.RNNCell): 975 | raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell)) 976 | if set(dec_cells_dict) != set(decoder_inputs_dict): 977 | raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict") 978 | for dec_cell in dec_cells_dict.values(): 979 | if not isinstance(dec_cell, core_rnn_cell.RNNCell): 980 | raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell)) 981 | 982 | with variable_scope.variable_scope( 983 | scope or "one2many_rnn_seq2seq", dtype=dtype) as scope: 984 | dtype = scope.dtype 985 | 986 | # Encoder. 987 | enc_cell = core_rnn_cell.EmbeddingWrapper( 988 | enc_cell, 989 | embedding_classes=num_encoder_symbols, 990 | embedding_size=embedding_size) 991 | _, encoder_state = core_rnn.static_rnn( 992 | enc_cell, encoder_inputs, dtype=dtype) 993 | 994 | # Decoder. 995 | for name, decoder_inputs in decoder_inputs_dict.items(): 996 | num_decoder_symbols = num_decoder_symbols_dict[name] 997 | dec_cell = dec_cells_dict[name] 998 | 999 | with variable_scope.variable_scope("one2many_decoder_" + str( 1000 | name)) as scope: 1001 | dec_cell = core_rnn_cell.OutputProjectionWrapper( 1002 | dec_cell, num_decoder_symbols) 1003 | if isinstance(feed_previous, bool): 1004 | outputs, state = embedding_rnn_decoder( 1005 | decoder_inputs, 1006 | encoder_state, 1007 | dec_cell, 1008 | num_decoder_symbols, 1009 | embedding_size, 1010 | feed_previous=feed_previous) 1011 | else: 1012 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 1013 | def filled_embedding_rnn_decoder(feed_previous): 1014 | """The current decoder with a fixed feed_previous parameter.""" 1015 | # pylint: disable=cell-var-from-loop 1016 | reuse = None if feed_previous else True 1017 | vs = variable_scope.get_variable_scope() 1018 | with variable_scope.variable_scope(vs, reuse=reuse): 1019 | outputs, state = embedding_rnn_decoder( 1020 | decoder_inputs, 1021 | encoder_state, 1022 | dec_cell, 1023 | num_decoder_symbols, 1024 | embedding_size, 1025 | feed_previous=feed_previous) 1026 | # pylint: enable=cell-var-from-loop 1027 | state_list = [state] 1028 | if nest.is_sequence(state): 1029 | state_list = nest.flatten(state) 1030 | return outputs + state_list 1031 | 1032 | outputs_and_state = control_flow_ops.cond( 1033 | feed_previous, lambda: filled_embedding_rnn_decoder(True), 1034 | lambda: filled_embedding_rnn_decoder(False)) 1035 | # Outputs length is the same as for decoder inputs. 1036 | outputs_len = len(decoder_inputs) 1037 | outputs = outputs_and_state[:outputs_len] 1038 | state_list = outputs_and_state[outputs_len:] 1039 | state = state_list[0] 1040 | if nest.is_sequence(encoder_state): 1041 | state = nest.pack_sequence_as( 1042 | structure=encoder_state, flat_sequence=state_list) 1043 | outputs_dict[name] = outputs 1044 | state_dict[name] = state 1045 | 1046 | return outputs_dict, state_dict 1047 | 1048 | 1049 | def sequence_loss_by_example(logits, 1050 | targets, 1051 | weights, 1052 | average_across_timesteps=True, 1053 | softmax_loss_function=None, 1054 | name=None): 1055 | """Weighted cross-entropy loss for a sequence of logits (per example). 1056 | 1057 | Args: 1058 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1059 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1060 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 1061 | average_across_timesteps: If set, divide the returned cost by the total 1062 | label weight. 1063 | softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch 1064 | to be used instead of the standard softmax (the default if this is None). 1065 | name: Optional name for this operation, default: "sequence_loss_by_example". 1066 | 1067 | Returns: 1068 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 1069 | 1070 | Raises: 1071 | ValueError: If len(logits) is different from len(targets) or len(weights). 1072 | """ 1073 | if len(targets) != len(logits) or len(weights) != len(logits): 1074 | raise ValueError("Lengths of logits, weights, and targets must be the same " 1075 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 1076 | with ops.name_scope(name, "sequence_loss_by_example", 1077 | logits + targets + weights): 1078 | log_perp_list = [] 1079 | for logit, target, weight in zip(logits, targets, weights): 1080 | if softmax_loss_function is None: 1081 | # TODO(irving,ebrevdo): This reshape is needed because 1082 | # sequence_loss_by_example is called with scalars sometimes, which 1083 | # violates our general scalar strictness policy. 1084 | target = array_ops.reshape(target, [-1]) 1085 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 1086 | labels=target, logits=logit) 1087 | else: 1088 | crossent = softmax_loss_function(target, logit) 1089 | log_perp_list.append(crossent * weight) 1090 | log_perps = math_ops.add_n(log_perp_list) 1091 | if average_across_timesteps: 1092 | total_size = math_ops.add_n(weights) 1093 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 1094 | log_perps /= total_size 1095 | return log_perps 1096 | 1097 | 1098 | def sequence_loss(logits, 1099 | targets, 1100 | weights, 1101 | average_across_timesteps=True, 1102 | average_across_batch=True, 1103 | softmax_loss_function=None, 1104 | name=None): 1105 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 1106 | 1107 | Args: 1108 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1109 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1110 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 1111 | average_across_timesteps: If set, divide the returned cost by the total 1112 | label weight. 1113 | average_across_batch: If set, divide the returned cost by the batch size. 1114 | softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch 1115 | to be used instead of the standard softmax (the default if this is None). 1116 | name: Optional name for this operation, defaults to "sequence_loss". 1117 | 1118 | Returns: 1119 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 1120 | 1121 | Raises: 1122 | ValueError: If len(logits) is different from len(targets) or len(weights). 1123 | """ 1124 | with ops.name_scope(name, "sequence_loss", logits + targets + weights): 1125 | cost = math_ops.reduce_sum( 1126 | sequence_loss_by_example( 1127 | logits, 1128 | targets, 1129 | weights, 1130 | average_across_timesteps=average_across_timesteps, 1131 | softmax_loss_function=softmax_loss_function)) 1132 | if average_across_batch: 1133 | batch_size = array_ops.shape(targets[0])[0] 1134 | return cost / math_ops.cast(batch_size, cost.dtype) 1135 | else: 1136 | return cost 1137 | 1138 | 1139 | def model_with_buckets(encoder_inputs, 1140 | decoder_inputs, 1141 | targets, 1142 | weights, 1143 | buckets, 1144 | seq2seq, 1145 | softmax_loss_function=None, 1146 | per_example_loss=False, 1147 | name=None): 1148 | """Create a sequence-to-sequence model with support for bucketing. 1149 | 1150 | The seq2seq argument is a function that defines a sequence-to-sequence model, 1151 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq( 1152 | x, y, core_rnn_cell.GRUCell(24)) 1153 | 1154 | Args: 1155 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 1156 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 1157 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 1158 | weights: List of 1D batch-sized float-Tensors to weight the targets. 1159 | buckets: A list of pairs of (input size, output size) for each bucket. 1160 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 1161 | agree with encoder_inputs and decoder_inputs, and returns a pair 1162 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 1163 | softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch 1164 | to be used instead of the standard softmax (the default if this is None). 1165 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 1166 | tensor of losses for each sequence in the batch. If unset, it will be 1167 | a scalar with the averaged loss from all examples. 1168 | name: Optional name for this operation, defaults to "model_with_buckets". 1169 | 1170 | Returns: 1171 | A tuple of the form (outputs, losses), where: 1172 | outputs: The outputs for each bucket. Its j'th element consists of a list 1173 | of 2D Tensors. The shape of output tensors can be either 1174 | [batch_size x output_size] or [batch_size x num_decoder_symbols] 1175 | depending on the seq2seq model used. 1176 | losses: List of scalar Tensors, representing losses for each bucket, or, 1177 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 1178 | 1179 | Raises: 1180 | ValueError: If length of encoder_inputs, targets, or weights is smaller 1181 | than the largest (last) bucket. 1182 | """ 1183 | if len(encoder_inputs) < buckets[-1][0]: 1184 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 1185 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 1186 | if len(targets) < buckets[-1][1]: 1187 | raise ValueError("Length of targets (%d) must be at least that of last" 1188 | "bucket (%d)." % (len(targets), buckets[-1][1])) 1189 | if len(weights) < buckets[-1][1]: 1190 | raise ValueError("Length of weights (%d) must be at least that of last" 1191 | "bucket (%d)." % (len(weights), buckets[-1][1])) 1192 | 1193 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 1194 | losses = [] 1195 | outputs = [] 1196 | with ops.name_scope(name, "model_with_buckets", all_inputs): 1197 | for j, bucket in enumerate(buckets): 1198 | with variable_scope.variable_scope( 1199 | variable_scope.get_variable_scope(), reuse=True if j > 0 else None): 1200 | bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], 1201 | decoder_inputs[:bucket[1]]) 1202 | outputs.append(bucket_outputs) 1203 | if per_example_loss: 1204 | losses.append( 1205 | sequence_loss_by_example( 1206 | outputs[-1], 1207 | targets[:bucket[1]], 1208 | weights[:bucket[1]], 1209 | softmax_loss_function=softmax_loss_function)) 1210 | else: 1211 | losses.append( 1212 | sequence_loss( 1213 | outputs[-1], 1214 | targets[:bucket[1]], 1215 | weights[:bucket[1]], 1216 | softmax_loss_function=softmax_loss_function)) 1217 | 1218 | return outputs, losses 1219 | --------------------------------------------------------------------------------