22 |
23 |
--------------------------------------------------------------------------------
/go_example:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NOW=$(date +"%Y%m%d_%H%M")
4 | PY3='stdbuf -o0 nohup python3 -u'
5 | mkdir -p logs
6 |
7 | if [ ! -d ./works/example/nn_models ]; then
8 | wget https://github.com/Marsan-Ma/tf_chatbot_pretrained_model/raw/master/example.tar.gz.part-aa
9 | wget https://github.com/Marsan-Ma/tf_chatbot_pretrained_model/raw/master/example.tar.gz.part-ab
10 | cat example.tar.gz.part-* > example.tar.gz
11 | tar -zxvf example.tar.gz
12 | fi
13 |
14 | #=============================
15 | # Example Twitter Corpus
16 | #=============================
17 | # 1. train model
18 | # $PY3 main.py --mode train --model_name example --vocab_size 100000 --size 128 > "./logs/seq2seq_twitter_en_$NOW.log" &
19 |
20 | # 2. chat interactively in command line
21 | python3 main.py --mode chat --model_name example --vocab_size 100000 --size 128 # --beam_size 5 --antilm 0.7
22 |
23 | # 3. do batch output test
24 | # python3 main.py --mode test --model_name example --vocab_size 100000 --size 128 --beam_size 5 --antilm 0.7
25 |
26 |
--------------------------------------------------------------------------------
/lib/chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import tensorflow as tf
5 |
6 | from lib import data_utils
7 | from lib.seq2seq_model_utils import create_model, get_predicted_sentence
8 |
9 |
10 | def chat(args):
11 | with tf.Session() as sess:
12 | # Create model and load parameters.
13 | args.batch_size = 1 # We decode one sentence at a time.
14 | model = create_model(sess, args)
15 |
16 | # Load vocabularies.
17 | vocab_path = os.path.join(args.data_dir, "vocab%d.in" % args.vocab_size)
18 | vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
19 |
20 | # Decode from standard input.
21 | sys.stdout.write("> ")
22 | sys.stdout.flush()
23 | sentence = sys.stdin.readline()
24 |
25 | while sentence:
26 | predicted_sentence = get_predicted_sentence(args, sentence, vocab, rev_vocab, model, sess)
27 | # print(predicted_sentence)
28 | if isinstance(predicted_sentence, list):
29 | for sent in predicted_sentence:
30 | print(" (%s) -> %s" % (sent['prob'], sent['dec_inp']))
31 | else:
32 | print(sentence, ' -> ', predicted_sentence)
33 |
34 | sys.stdout.write("> ")
35 | sys.stdout.flush()
36 | sentence = sys.stdin.readline()
37 |
38 |
--------------------------------------------------------------------------------
/lib/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | from datetime import datetime
5 |
6 | from lib import data_utils
7 | from lib.seq2seq_model_utils import create_model, get_predicted_sentence
8 |
9 |
10 | def predict(args, debug=False):
11 | def _get_test_dataset():
12 | with open(args.test_dataset_path) as test_fh:
13 | test_sentences = [s.strip() for s in test_fh.readlines()]
14 | return test_sentences
15 |
16 | results_filename = '_'.join(['results', str(args.num_layers), str(args.size), str(args.vocab_size)])
17 | results_path = os.path.join(args.results_dir, results_filename+'.txt')
18 |
19 | with tf.Session() as sess, open(results_path, 'w') as results_fh:
20 | # Create model and load parameters.
21 | args.batch_size = 1
22 | model = create_model(sess, args)
23 |
24 | # Load vocabularies.
25 | vocab_path = os.path.join(args.data_dir, "vocab%d.in" % args.vocab_size)
26 | vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
27 |
28 | test_dataset = _get_test_dataset()
29 |
30 | for sentence in test_dataset:
31 | # Get token-ids for the input sentence.
32 | predicted_sentence = get_predicted_sentence(args, sentence, vocab, rev_vocab, model, sess, debug=debug)
33 | if isinstance(predicted_sentence, list):
34 | print("%s : (%s)" % (sentence, datetime.now()))
35 | results_fh.write("%s : (%s)\n" % (sentence, datetime.now()))
36 | for sent in predicted_sentence:
37 | print(" (%s) -> %s" % (sent['prob'], sent['dec_inp']))
38 | results_fh.write(" (%f) -> %s\n" % (sent['prob'], sent['dec_inp']))
39 | else:
40 | print(sentence, ' -> ', predicted_sentence)
41 | results_fh.write("%s -> %s\n" % (sentence, predicted_sentence))
42 | # break
43 |
44 | results_fh.close()
45 | print("results written in %s" % results_path)
46 |
--------------------------------------------------------------------------------
/templates/privacy.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
19 |
隱私權政策
20 |
21 |
適用範圍
22 | 以下的隱私權保護政策,適用於您在本活動網站活動時,所涉及的個人資料之收集、運用與保護。
23 |
24 |
資料蒐集
25 | 根據本網站所提供的不同服務,可能向網友蒐集下列個人基本資料:
26 |
27 | - 會員資料:當您開始參予本網站之活動時,我們會請您同意我們存取您於facebook上註冊之個人資料:包括姓名、電子信箱等個人資訊。
28 |
- 一般瀏覽: 本活動網站會保留網友在上網瀏覽或查詢時,伺服器自行產生的相關記錄(LOG),包括連線設備IP位址、使用時間、使用的瀏覽器、瀏覽及點選資料記錄等,以歸納使用者瀏覽器在本活動網站內部所瀏覽的網頁及瀏覽時間,俾據以提昇本活動網站的服務品質。
29 |
30 |
31 |
網友資料的分享、公開及運用方式
32 | 本活動網站不會任意出售、交換、或出租任何您的個人資料給其他團體或個人。只有在以下狀況,本活動網站會在「隱私權保護政策」原則之下,運用您的個人資料。
33 |
34 | - 統計與分析:本活動網站根據使用者帳號資料、投票結果或伺服器日誌文件,進行統計分析與整理,做為本活動票選內容的結果,不會對各別使用者進行分析,亦不會提供特定對象個別資料之分析報告。
35 |
- 司法單位因公眾安全,要求本活動網站公開特定個人資料時,本活動網站將視司法單位合法正式的程序,以及對本資訊網所有使用者安全考量下做可能必要之配合。除有關法律 裁判、政府審查、司法調查、犯罪預防或就非法活動採取行動、就懷疑詐騙或因事件涉及威脅到任何人的人身安全、以及違法之行為上有需要外,所有資料均屬保密。
36 |
37 |
38 |
隱私權保護政策修訂
39 | 本活動網站會不定時修訂本項政策,以符合最新之隱私權保護規範。當我們在使用個人資料的規定做較大幅度修改時,我們會在網頁上張貼告示,通知您相關修訂事項。
40 |
41 |
42 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/lib/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def params_setup(cmdline=None):
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--mode', type=str, required=True, help='work mode: train/test/chat')
6 |
7 | # path ctrl
8 | parser.add_argument('--model_name', type=str, default='movie_subtitles_en', help='model name, affects data, model, result save path')
9 | parser.add_argument('--scope_name', type=str, help='separate namespace, for multi-models working together')
10 | parser.add_argument('--work_root', type=str, default='works', help='root dir for data, model, result save path')
11 |
12 | # training params
13 | parser.add_argument('--learning_rate', type=float, default=0.5, help='Learning rate.')
14 | parser.add_argument('--learning_rate_decay_factor', type=float, default=0.99, help='Learning rate decays by this much.')
15 | parser.add_argument('--max_gradient_norm', type=float, default=5.0, help='Clip gradients to this norm.')
16 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size to use during training.')
17 |
18 | parser.add_argument('--vocab_size', type=int, default=100000, help='Dialog vocabulary size.')
19 | parser.add_argument('--size', type=int, default=256, help='Size of each model layer.')
20 | parser.add_argument('--num_layers', type=int, default=4, help='Number of layers in the model.')
21 |
22 | parser.add_argument('--max_train_data_size', type=int, default=0, help='Limit on the size of training data (0: no limit)')
23 | parser.add_argument('--steps_per_checkpoint', type=int, default=500, help='How many training steps to do per checkpoint')
24 |
25 | # predicting params
26 | parser.add_argument('--beam_size', type=int, default=1, help='beam search size')
27 | parser.add_argument('--antilm', type=float, default=0, help='anti-language model weight')
28 | parser.add_argument('--n_bonus', type=int, default=0, help='bonus with sentence length')
29 |
30 | # environment params
31 | parser.add_argument('--gpu_usage', type=float, default=1.0, help='tensorflow gpu memory fraction used')
32 | parser.add_argument('--rev_model', type=int, default=0, help='reverse Q-A pair, for bi-direction model')
33 | parser.add_argument('--reinforce_learn', type=int, default=0, help='1 to enable reinforcement learning mode')
34 | parser.add_argument('--en_tfboard', type=int, default=0, help='Enable writing out tensorboard meta data')
35 |
36 |
37 | if cmdline:
38 | args = parser.parse_args(cmdline)
39 | else:
40 | args = parser.parse_args()
41 |
42 | if not args.scope_name: args.scope_name = args.model_name
43 | if args.rev_model: args.model_name += '_bidi' # bi-direction model
44 |
45 | # We use a number of buckets and pad to the closest one for efficiency.
46 | # See seq2seq_model.Seq2SeqModel for details of how they work.
47 | args.buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
48 |
49 | # post-process
50 | args.workspace = '%s/%s' % (args.work_root, args.model_name)
51 | args.test_dataset_path = '%s/data/test/test_set.txt' % (args.workspace)
52 | args.mert_dataset_path = '%s/data/test/mert_set.txt' % (args.workspace)
53 | args.data_dir = '%s/data' % args.workspace
54 | args.model_dir = '%s/nn_models' % args.workspace
55 | args.results_dir = '%s/results' % args.workspace
56 | args.tf_board_dir = '%s/tf_board' % args.workspace
57 | return args
58 |
59 |
--------------------------------------------------------------------------------
/example_chat.md:
--------------------------------------------------------------------------------
1 | ## 2017.03.14
2 |
3 | This is a chat history for the demo model of [tf_chatbot_seq2seq_antilm](https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm).
4 |
5 | You could use `./go_example` to get exactly the same model and same result, you could also try `beam_search` and `antiLM` by uncomment the option in `go_example`.
6 |
7 |
8 | > Tell me about Trump.
9 | (1) -> sorry funny for that though .
10 | > How about Hillary
11 | (1) -> better check up .
12 | > check up what?
13 | (1) -> where ? are you at ?
14 | > why you ask me where am I?
15 | (1) -> welcome to the gay party ! hugs to you anytime ! ! !
16 | > no, I am not gay...
17 | (1) -> you are the sassy part of the country
18 | > thank you!
19 | (1) -> you're welcome !
20 | > how about order some pizza for the lunch?
21 | (1) -> second pink pink league
22 | > what do you mean pink league?
23 | (1) -> do you think you're talking to me ? ? ?
24 | > are you argry?
25 | (1) -> yeah yes , thanks 😉
26 | > sorry for being rude.
27 | (1) -> what worry college use a twitter ?
28 | > which college do you mean?
29 | (1) -> i agree than i am , never understand about my point
30 | > yes I never understand your point, too...
31 | (1) -> who were that
32 | > who?
33 | (1) -> peach you on
34 | > who is peach?
35 | (1) -> playing jason
36 | > who is jason?
37 | (1) -> explosion for the hotel recorded .
38 | > ok, bye.
39 | (1) -> magic good show later scared
40 | > you idiot.
41 | (1) -> really .
42 | > so you think you know nothing?
43 | (1) -> i know yes
44 | > that's a good point...
45 | (1) -> i want these car keep brought up and getting regular
46 | > well, it's nice to chat with you.
47 | (1) -> i have no idea
48 | > why?
49 | (1) -> i feel from the level of anyone .
50 | > I am an engineer.
51 | (1) -> too babe ! good fries
52 | > yes, engineers love fries, and cokes.
53 | (1) -> i have the full of dreams in la .
54 | > I've been in LA last month
55 | (1) -> bake just watch - and watch the whole phone out
56 | > I can't understand you.
57 | (1) -> i have to research with you
58 | > That's great idea...
59 | (1) -> islam is not the truth
60 | > don't comment on islam...
61 | (1) -> haha ? sec of the jets are related .
62 | > well at least you are not always answering i don't know.
63 | (1) -> i wish i should be scared too
64 | > you never scare someone.
65 | (1) -> i thought that was
66 | > that was ... what?
67 | (1) -> and again you didn't enjoy em
68 | > Well, I enjoy talking with you.
69 | (1) -> don't even bring shot for the show ?
70 | > no I won't shot you.
71 | (1) -> im too expensive with that
72 | > yes, you cost a lot, lol.
73 | (1) -> just then i saw it .
74 | > an AWS g2 instance is not cheap at all
75 | (1) -> honestly rose can’t use law mode without a kind of eyes
76 | > that's a anology hard to understand...
77 | (1) -> why are you there ? ? ?
78 | > because i am talking with you.
79 | (1) -> i actually want to read that .
80 | > read about what? more training corpus?
81 | (1) -> by the android of a nuclear first .
82 | > why android got to do with nuclear?
83 | (1) -> unlike gary gold , you expect ? you are the muslim .
84 | > no i am not ...
85 | (1) -> i can't believe twitter
86 | > but i use twitter corpus to train your model ...
87 | (1) -> rocking hair ass at my house
88 | > you high?
89 | (1) -> nope , i plan to see it .
90 | > you plan to see rocking hair ass ... whatever.
91 | (1) -> i still have a chance but i was coming rn
92 | > yeah you have many chances, i believe you are promising.
93 | (1) -> i heard of this as well .
94 | > yeah, that's cool.
95 | (1) -> you will be a team this year .
96 | > that's cool, i've tired of work alone.
97 | (1) -> who really
98 | > ok, seems you are good enough for publish this time.
99 | (1) -> yeah , thanks ! all more him with letting it know something better .
100 | > whatever you say ... being a language model you are good.
101 | (1) -> congrats , - so many need to have a better place : )
102 | > I'll find some way to make your response more consistent.
103 | (1) -> i want to buy the experience , and i wear a phone button .
104 | > yeah you did need much experience by some other learning ways!
105 | (1) -> with new york ?
106 | > with new network!
107 | (1) -> and i have a great weekend , bruce :
108 | > no I am not bruce ...
109 | (1) -> i can't believe i was sorry for him
110 | > you mean bruce lee?
111 | (1) -> was the part of the new york .
112 | > no york is absense today.
113 | (1) -> i will have been watching a season lately on !
114 | > see you!
115 | (1) -> thanks buddy ! you rock 😘
116 | > thanks :)
117 | (1) -> high night
118 |
119 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os, json, yaml, requests, jieba
3 | from datetime import datetime
4 | from flask import Flask, request, render_template
5 | from OpenSSL import SSL
6 | from random import random, choice
7 |
8 | app = Flask(__name__)
9 |
10 | #---------------------------
11 | # Load Model
12 | #---------------------------
13 | import tensorflow as tf
14 | from lib import data_utils
15 | from lib.config import params_setup
16 | from lib.seq2seq_model_utils import create_model, get_predicted_sentence
17 |
18 |
19 | class ChatBot(object):
20 |
21 | def __init__(self, args, debug=False):
22 | start_time = datetime.now()
23 |
24 | # flow ctrl
25 | self.args = args
26 | self.debug = debug
27 | self.fbm_processed = []
28 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_usage)
29 | self.sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
30 |
31 | # Create model and load parameters.
32 | self.args.batch_size = 1 # We decode one sentence at a time.
33 | self.model = create_model(self.sess, self.args)
34 |
35 | # Load vocabularies.
36 | self.vocab_path = os.path.join(self.args.data_dir, "vocab%d.in" % self.args.vocab_size)
37 | self.vocab, self.rev_vocab = data_utils.initialize_vocabulary(self.vocab_path)
38 | print("[ChatBot] model initialize, cost %i secs" % (datetime.now() - start_time).seconds)
39 |
40 | # load yaml setup
41 | self.FBM_API = "https://graph.facebook.com/v2.6/me/messages"
42 | with open("config.yaml", 'rt') as stream:
43 | try:
44 | cfg = yaml.load(stream)
45 | self.FACEBOOK_TOKEN = cfg.get('FACEBOOK_TOKEN')
46 | self.VERIFY_TOKEN = cfg.get('VERIFY_TOKEN')
47 | except yaml.YAMLError as exc:
48 | print(exc)
49 |
50 |
51 | def process_fbm(self, payload):
52 | for sender, msg in self.fbm_events(payload):
53 | self.fbm_api({"recipient": {"id": sender}, "sender_action": 'typing_on'})
54 | resp = self.gen_response(msg)
55 | self.fbm_api({"recipient": {"id": sender}, "message": {"text": resp}})
56 | if self.debug: print("%s: %s => resp: %s" % (sender, msg, resp))
57 |
58 |
59 | def gen_response(self, sent, max_cand=100):
60 | sent = " ".join([w.lower() for w in jieba.cut(sent) if w not in [' ']])
61 | # if self.debug: return sent
62 | raw = get_predicted_sentence(self.args, sent, self.vocab, self.rev_vocab, self.model, self.sess, debug=False)
63 | # find bests candidates
64 | cands = sorted(raw, key=lambda v: v['prob'], reverse=True)[:max_cand]
65 |
66 | if max_cand == -1: # return all cands for debug
67 | cands = [(r['prob'], ' '.join([w for w in r['dec_inp'].split() if w[0] != '_'])) for r in cands]
68 | return cands
69 | else:
70 | cands = [[w for w in r['dec_inp'].split() if w[0] != '_'] for r in cands]
71 | return ' '.join(choice(cands)) or 'No comment'
72 |
73 |
74 | def gen_response_debug(self, sent, args=None):
75 | sent = " ".join([w.lower() for w in jieba.cut(sent) if w not in [' ']])
76 | raw = get_predicted_sentence(args, sent, self.vocab, self.rev_vocab, self.model, self.sess, debug=False, return_raw=True)
77 | return raw
78 |
79 |
80 | #------------------------------
81 | # FB Messenger API
82 | #------------------------------
83 | def fbm_events(self, payload):
84 | data = json.loads(payload.decode('utf8'))
85 | if self.debug: print("[fbm_payload]", data)
86 | for event in data["entry"][0]["messaging"]:
87 | if "message" in event and "text" in event["message"]:
88 | q = (event["sender"]["id"], event["message"]["seq"])
89 | if q in self.fbm_processed:
90 | continue
91 | else:
92 | self.fbm_processed.append(q)
93 | yield event["sender"]["id"], event["message"]["text"]
94 |
95 |
96 | def fbm_api(self, data):
97 | r = requests.post(self.FBM_API,
98 | params={"access_token": self.FACEBOOK_TOKEN},
99 | data=json.dumps(data),
100 | headers={'Content-type': 'application/json'})
101 | if r.status_code != requests.codes.ok:
102 | print("fb error:", r.text)
103 | if self.debug: print("fbm_send", r.status_code, r.text)
104 |
105 |
106 | #---------------------------
107 | # Server
108 | #---------------------------
109 | @app.route('/chat', methods=['GET'])
110 | def verify():
111 | if request.args.get('hub.verify_token', '') == chatbot.VERIFY_TOKEN:
112 | return request.args.get('hub.challenge', '')
113 | else:
114 | return 'Error, wrong validation token'
115 |
116 | @app.route('/chat', methods=['POST'])
117 | def chat():
118 | payload = request.get_data()
119 | chatbot.process_fbm(payload)
120 | return "ok"
121 |
122 |
123 | @app.route('/', methods=['GET'])
124 | def home():
125 | return render_template('index.html')
126 |
127 |
128 | @app.route('/privacy', methods=['GET'])
129 | def privacy():
130 | return render_template('privacy.html')
131 |
132 |
133 | #---------------------------
134 | # Start Server
135 | #---------------------------
136 | if __name__ == '__main__':
137 | # check ssl files
138 | if not os.path.exists('ssl/server.crt'):
139 | print("SSL certificate not found! (should placed in ./ssl/server.crt)")
140 | elif not os.path.exists('ssl/server.key'):
141 | print("SSL key not found! (should placed in ./ssl/server.key)")
142 | else:
143 | # initialize model
144 | args = params_setup()
145 | chatbot = ChatBot(args, debug=False)
146 | # start server
147 | context = ('ssl/server.crt', 'ssl/server.key')
148 | app.run(host='0.0.0.0', port=443, debug=False, ssl_context=context)
149 |
150 |
--------------------------------------------------------------------------------
/lib/train.py:
--------------------------------------------------------------------------------
1 | import sys, os, math, time, argparse, shutil, gzip
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6 |
7 | from six.moves import xrange # pylint: disable=redefined-builtin
8 | from datetime import datetime
9 | from lib import seq2seq_model_utils, data_utils
10 |
11 |
12 | def setup_workpath(workspace):
13 | for p in ['data', 'nn_models', 'results']:
14 | wp = "%s/%s" % (workspace, p)
15 | if not os.path.exists(wp): os.mkdir(wp)
16 |
17 | data_dir = "%s/data" % (workspace)
18 | # training data
19 | if not os.path.exists("%s/chat.in" % data_dir):
20 | n = 0
21 | f_zip = gzip.open("%s/train/chat.txt.gz" % data_dir, 'rt')
22 | f_train = open("%s/chat.in" % data_dir, 'w')
23 | f_dev = open("%s/chat_test.in" % data_dir, 'w')
24 | for line in f_zip:
25 | f_train.write(line)
26 | if n < 10000:
27 | f_dev.write(line)
28 | n += 1
29 |
30 |
31 | def train(args):
32 | print("[%s] Preparing dialog data in %s" % (args.model_name, args.data_dir))
33 | setup_workpath(workspace=args.workspace)
34 | train_data, dev_data, _ = data_utils.prepare_dialog_data(args.data_dir, args.vocab_size)
35 |
36 | if args.reinforce_learn:
37 | args.batch_size = 1 # We decode one sentence at a time.
38 |
39 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_usage)
40 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
41 |
42 | # Create model.
43 | print("Creating %d layers of %d units." % (args.num_layers, args.size))
44 | model = seq2seq_model_utils.create_model(sess, args, forward_only=False)
45 |
46 | # Read data into buckets and compute their sizes.
47 | print("Reading development and training data (limit: %d)." % args.max_train_data_size)
48 | dev_set = data_utils.read_data(dev_data, args.buckets, reversed=args.rev_model)
49 | train_set = data_utils.read_data(train_data, args.buckets, args.max_train_data_size, reversed=args.rev_model)
50 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(args.buckets))]
51 | train_total_size = float(sum(train_bucket_sizes))
52 |
53 | # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
54 | # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
55 | # the size if i-th training bucket, as used later.
56 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
57 | for i in xrange(len(train_bucket_sizes))]
58 |
59 | # This is the training loop.
60 | step_time, loss = 0.0, 0.0
61 | current_step = 0
62 | previous_losses = []
63 |
64 | # Load vocabularies.
65 | vocab_path = os.path.join(args.data_dir, "vocab%d.in" % args.vocab_size)
66 | vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
67 |
68 | while True:
69 | # Choose a bucket according to data distribution. We pick a random number
70 | # in [0, 1] and use the corresponding interval in train_buckets_scale.
71 | random_number_01 = np.random.random_sample()
72 | bucket_id = min([i for i in xrange(len(train_buckets_scale))
73 | if train_buckets_scale[i] > random_number_01])
74 |
75 | # Get a batch and make a step.
76 | start_time = time.time()
77 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(
78 | train_set, bucket_id)
79 |
80 | # print("[shape]", np.shape(encoder_inputs), np.shape(decoder_inputs), np.shape(target_weights))
81 | if args.reinforce_learn:
82 | _, step_loss, _ = model.step_rf(args, sess, encoder_inputs, decoder_inputs,
83 | target_weights, bucket_id, rev_vocab=rev_vocab)
84 | else:
85 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
86 | target_weights, bucket_id, forward_only=False, force_dec_input=True)
87 |
88 | step_time += (time.time() - start_time) / args.steps_per_checkpoint
89 | loss += step_loss / args.steps_per_checkpoint
90 | current_step += 1
91 |
92 | # Once in a while, we save checkpoint, print statistics, and run evals.
93 | if (current_step % args.steps_per_checkpoint == 0) and (not args.reinforce_learn):
94 | # Print statistics for the previous epoch.
95 | perplexity = math.exp(loss) if loss < 300 else float('inf')
96 | print ("global step %d learning rate %.4f step-time %.2f perplexity %.2f @ %s" %
97 | (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity, datetime.now()))
98 |
99 | # Decrease learning rate if no improvement was seen over last 3 times.
100 | if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
101 | sess.run(model.learning_rate_decay_op)
102 |
103 | previous_losses.append(loss)
104 |
105 | # # Save checkpoint and zero timer and loss.
106 | checkpoint_path = os.path.join(args.model_dir, "model.ckpt")
107 | model.saver.save(sess, checkpoint_path, global_step=model.global_step)
108 | step_time, loss = 0.0, 0.0
109 |
110 | # Run evals on development set and print their perplexity.
111 | for bucket_id in xrange(len(args.buckets)):
112 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(dev_set, bucket_id)
113 | _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
114 | target_weights, bucket_id, forward_only=True, force_dec_input=False)
115 |
116 | eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
117 | print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
118 |
119 | sys.stdout.flush()
120 |
--------------------------------------------------------------------------------
/README2.md:
--------------------------------------------------------------------------------
1 | # A more detailed explaination about "the tensorflow chatbot"
2 |
3 | Here I'll try to explain some algorithm and implementation details about [this work][a1] in layman's terms.
4 |
5 | [a1]: https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm
6 |
7 |
8 | ## Sequence to sequence model
9 |
10 | ### What is a language model?
11 |
12 | Let's say a language model is ...
13 | a) Trained by a lot of corpus.
14 | b) It could predict the **probability of next word** given foregoing words.
15 | => It's just conditional probability, **P(next_word | foregoing_words)**
16 | c) Since we could predict next word:
17 | => then predict even next, according to words just been generated
18 | => continuously, we could produce sentences, even paragraph.
19 |
20 | We could easily achieve this by simple [LSTM model][b1].
21 |
22 |
23 | ### The seq2seq model architecture
24 |
25 | Again we quote this seq2seq architecture from [Google's blogpost]
26 | [![seq2seq][b2]][b3]
27 |
28 | It's composed of two language model: encoder and decoder. Both of them could be LSTM model we just mentioned.
29 |
30 | The encoder part accept input tokens and transform the whole input sentence into an embedding **"thought vector"**, which express the meaning of input sentence in our language model domain.
31 |
32 | Then the decoder is just a language model, like we just said, a language model could generate new sentence according to foregoing corpus. Here we use this **"thought vector"** as kick-off and receive the corresponding mapping, and decode it into the response.
33 |
34 |
35 | ### Reversed encoder input and Attention mechanism
36 |
37 | Now you might wonder:
38 | a) Considering this architecture, wil the "thought vector" be dominated by later stages of encoder?
39 | b) Is that enough to represent the meaning of whole input sentence into just a vector?
40 |
41 |
42 | For (a) actually, one of the implement detail we didn't mention before: the input sentence will be reversed before input to the encoder. Thus we shorten the distance between head of input sentence and head of response sentence. Empirically, it achieves better results. (This trick is not shown in the architecture figure above, for easy to understanding)
43 |
44 | For (b), another methods to disclose more information to decoder is the [attention mechanism][b4]. The idea is simple: allowing each stage in decoder to peep any encoder stages, if they found useful in training phase. So decoder could understand the input sentence more and automagically peep suitable positions while generating response.
45 |
46 |
47 |
48 | [b1]: http://colah.github.io/posts/2015-08-Understanding-LSTMs
49 | [b2]: http://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png
50 | [b3]: http://googleresearch.blogspot.ru/2015/11/computer-respond-to-this-email.html
51 | [b4]: http://arxiv.org/abs/1412.7449
52 |
53 |
54 |
55 | ## Techniques about language model
56 |
57 | ### Dictionary space compressing and projection
58 |
59 | A naive implementation of language model is: suppose we are training english language model, which a dictionary size of 80,000 is roughly enough. As we one-hot coding each word in our dictionary, our LSTM cell should have 80,000 outputs and we will do the softmax to choose for words with best probability...
60 |
61 | ... even if you have lots of computing resource, you don't need to waste like that. Especially if you are dealing with some other languages with more words like Chinese, which 200,000 words is barely enough.
62 |
63 | Practically, we could reduce this 80,000 one-hot coding dictionary into embedding spaces, we could use like 64, 128 or 256 dimention to embed our 80,000 words dictionary, and train our model with only by this lower dimention. Then finally when we are generating the response, we project the embedding back into one-hot coding space for dictionary lookup.
64 |
65 |
66 | ### Beam search
67 |
68 | The original implementation of tensorflow decode response sentence greedily. Empirically this trapped result in local optimum, and result in dump response which do have maximum probability in first couple of words.
69 |
70 | So we do the beam search, keep best N candidates and move-forward, thus we could avoid local optimum and find more longer, interesting responses more closer to global optimum result.
71 |
72 | In [this paper][b4], Google Brain team found that beam search didn't benefit a lot in machine translation, I guess that's why they didn't implement beam search. But in my experience, chatbot do benefit a lot from beam search.
73 |
74 |
75 | ## Anti-Language Model
76 |
77 | ### Generic response problem
78 |
79 | As the seq2seq model is trained by [MLE][c1] (maximum likelyhood estimation), the model do follow this object function by finding the "most possible" response well. But in human dialogue, a response with high probability like "thank you", "I don't know", "I love you" is not informative at all.
80 |
81 | As currently we haven't find a good enough object function to replace MLE, there are some works to suppress this "generic response problem".
82 |
83 |
84 | ### Supressing generic response
85 |
86 | The work of [Li. et al][c2] from Stanford and Microsoft Research try to suppress generic response by lower the probability of generic response from candidates while doing the beam search.
87 |
88 | The idea is somewhat like Tf-Idf: if this response is suitable for all kinds of foregoing sentence, which means it's not specific answer for current sentence, then we discard it.
89 |
90 | According my own experiment result, this helps a lot! Although the cost is that we will choose something grammatically not so correct, but most of time the effect is acceptable. It does generate more interesting, informative response.
91 |
92 |
93 | [c1]: https://en.wikipedia.org/wiki/Maximum_likelihood_estimation
94 | [c2]: https://arxiv.org/abs/1606.01541
95 |
96 |
97 |
98 | ## Deep Reinforcement Learning
99 |
100 | ### Reinforcement learning
101 |
102 | Reinforcement learning is a promising domain now (in 2016). It's promising because it solve the delayed reward problem, and that's a huge merit for chatbot training. Since we could judge a continuous dialogue includeing several sentences, rather than one single sentence at a time. We could design more sophiscated metrics to reward the model and make it learn more abstract ideas.
103 |
104 |
105 | ### Implement tricks in tensorflow
106 |
107 | The magic of tensorflow is that it construct a graph, which all the computing in graph could be dispatched automagically to CPU, GPU, or even distributed system (more CPU/GPU).
108 |
109 | So far tensorflow have no native supporting operations for the delayed rewarding, so we have to do some work-around. We will calculate the gradients in graph, and accumulate and do post-processing to them out-of-graph, finally inject them back to do the `apply_gradient()`. You could find a minimum example in [this ipython notebook][d1].
110 |
111 |
112 | [d1]: https://github.com/awjuliani/DeepRL-Agents/blob/master/Policy-Network.ipynb
113 |
114 |
115 |
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/lib/seq2seq_model_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from random import random
8 | from datetime import datetime
9 | from tensorflow.python.platform import gfile
10 |
11 | from lib import data_utils
12 | from lib import seq2seq_model
13 |
14 |
15 | import heapq
16 |
17 | def create_model(session, args, forward_only=True):
18 | """Create translation model and initialize or load parameters in session."""
19 | model = seq2seq_model.Seq2SeqModel(
20 | source_vocab_size=args.vocab_size,
21 | target_vocab_size=args.vocab_size,
22 | buckets=args.buckets,
23 | size=args.size,
24 | num_layers=args.num_layers,
25 | max_gradient_norm=args.max_gradient_norm,
26 | batch_size=args.batch_size,
27 | learning_rate=args.learning_rate,
28 | learning_rate_decay_factor=args.learning_rate_decay_factor,
29 | forward_only=forward_only,
30 | )
31 |
32 | # for tensorboard
33 | if args.en_tfboard:
34 | summary_writer = tf.train.SummaryWriter(args.tf_board_dir, session.graph)
35 |
36 | ckpt = tf.train.get_checkpoint_state(args.model_dir)
37 | # if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
38 | if ckpt and ckpt.model_checkpoint_path:
39 | print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
40 | model.saver.restore(session, ckpt.model_checkpoint_path)
41 | print("Model reloaded @ %s" % (datetime.now()))
42 | else:
43 | print("Created model with fresh parameters.")
44 | session.run(tf.global_variables_initializer())
45 | return model
46 |
47 |
48 | def dict_lookup(rev_vocab, out):
49 | word = rev_vocab[out] if (out < len(rev_vocab)) else data_utils._UNK
50 | if isinstance(word, bytes):
51 | word = word.decode()
52 | return word
53 |
54 |
55 |
56 | def softmax(x):
57 | prob = np.exp(x) / np.sum(np.exp(x), axis=0)
58 | return prob
59 |
60 |
61 |
62 | def cal_bleu(cands, ref, stopwords=['的', '嗎']):
63 | cands = [s['dec_inp'].split() for s in cands]
64 | cands = [[w for w in sent if w[0] != '_'] for sent in cands]
65 | refs = [w for w in ref.split() if w not in stopwords]
66 | bleus = []
67 | for cand in cands:
68 | if len(cand) < 4: cand += [''] * (4 - len(cand))
69 | bleu = sentence_bleu(refs, cand)
70 | bleus.append(bleu)
71 | print(refs, cand, bleu)
72 | return np.average(bleus)
73 |
74 |
75 |
76 | def get_predicted_sentence(args, input_sentence, vocab, rev_vocab, model, sess, debug=False, return_raw=False):
77 | def model_step(enc_inp, dec_inp, dptr, target_weights, bucket_id):
78 | _, _, logits = model.step(sess, enc_inp, dec_inp, target_weights, bucket_id, forward_only=True)
79 | prob = softmax(logits[dptr][0])
80 | # print("model_step @ %s" % (datetime.now()))
81 | return prob
82 |
83 | def greedy_dec(output_logits, rev_vocab):
84 | selected_token_ids = [int(np.argmax(logit, axis=1)) for logit in output_logits]
85 | if data_utils.EOS_ID in selected_token_ids:
86 | eos = selected_token_ids.index(data_utils.EOS_ID)
87 | selected_token_ids = selected_token_ids[:eos]
88 | output_sentence = ' '.join([dict_lookup(rev_vocab, t) for t in selected_token_ids])
89 | return output_sentence
90 |
91 | input_token_ids = data_utils.sentence_to_token_ids(input_sentence, vocab)
92 |
93 | # Which bucket does it belong to?
94 | bucket_id = min([b for b in range(len(args.buckets)) if args.buckets[b][0] > len(input_token_ids)])
95 | outputs = []
96 | feed_data = {bucket_id: [(input_token_ids, outputs)]}
97 |
98 | # Get a 1-element batch to feed the sentence to the model.
99 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(feed_data, bucket_id)
100 | if debug: print("\n[get_batch]\n", encoder_inputs, decoder_inputs, target_weights)
101 |
102 | ### Original greedy decoding
103 | if args.beam_size == 1:
104 | _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=True)
105 | return [{"dec_inp": greedy_dec(output_logits, rev_vocab), 'prob': 1}]
106 |
107 | # Get output logits for the sentence.
108 | beams, new_beams, results = [(1, 0, {'eos': 0, 'dec_inp': decoder_inputs, 'prob': 1, 'prob_ts': 1, 'prob_t': 1})], [], [] # initialize beams as (log_prob, empty_string, eos)
109 | dummy_encoder_inputs = [np.array([data_utils.PAD_ID]) for _ in range(len(encoder_inputs))]
110 |
111 | for dptr in range(len(decoder_inputs)-1):
112 | if dptr > 0:
113 | target_weights[dptr] = [1.]
114 | beams, new_beams = new_beams[:args.beam_size], []
115 | if debug: print("=====[beams]=====", beams)
116 | heapq.heapify(beams) # since we will remove something
117 | for prob, _, cand in beams:
118 | if cand['eos']:
119 | results += [(prob, 0, cand)]
120 | continue
121 |
122 | # normal seq2seq
123 | if debug: print(cand['prob'], " ".join([dict_lookup(rev_vocab, w) for w in cand['dec_inp']]))
124 |
125 | all_prob_ts = model_step(encoder_inputs, cand['dec_inp'], dptr, target_weights, bucket_id)
126 | if args.antilm:
127 | # anti-lm
128 | all_prob_t = model_step(dummy_encoder_inputs, cand['dec_inp'], dptr, target_weights, bucket_id)
129 | # adjusted probability
130 | all_prob = all_prob_ts - args.antilm * all_prob_t #+ args.n_bonus * dptr + random() * 1e-50
131 | else:
132 | all_prob_t = [0]*len(all_prob_ts)
133 | all_prob = all_prob_ts
134 |
135 | # suppress copy-cat (respond the same as input)
136 | if dptr < len(input_token_ids):
137 | all_prob[input_token_ids[dptr]] = all_prob[input_token_ids[dptr]] * 0.01
138 |
139 | # for debug use
140 | if return_raw: return all_prob, all_prob_ts, all_prob_t
141 |
142 | # beam search
143 | for c in np.argsort(all_prob)[::-1][:args.beam_size]:
144 | new_cand = {
145 | 'eos' : (c == data_utils.EOS_ID),
146 | 'dec_inp' : [(np.array([c]) if i == (dptr+1) else k) for i, k in enumerate(cand['dec_inp'])],
147 | 'prob_ts' : cand['prob_ts'] * all_prob_ts[c],
148 | 'prob_t' : cand['prob_t'] * all_prob_t[c],
149 | 'prob' : cand['prob'] * all_prob[c],
150 | }
151 | new_cand = (new_cand['prob'], random(), new_cand) # stuff a random to prevent comparing new_cand
152 |
153 | try:
154 | if (len(new_beams) < args.beam_size):
155 | heapq.heappush(new_beams, new_cand)
156 | elif (new_cand[0] > new_beams[0][0]):
157 | heapq.heapreplace(new_beams, new_cand)
158 | except Exception as e:
159 | print("[Error]", e)
160 | print("-----[new_beams]-----\n", new_beams)
161 | print("-----[new_cand]-----\n", new_cand)
162 |
163 | results += new_beams # flush last cands
164 |
165 | # post-process results
166 | res_cands = []
167 | for prob, _, cand in sorted(results, reverse=True):
168 | cand['dec_inp'] = " ".join([dict_lookup(rev_vocab, w) for w in cand['dec_inp']])
169 | res_cands.append(cand)
170 | return res_cands[:args.beam_size]
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow chatbot
2 | ### (with seq2seq + attention + dict-compress + beam search + anti-LM + facebook messenger server)
3 |
4 |
5 | > ####[Update 2017-03-14]
6 | > 1. Upgrade to tensorflow v1.0.0, no backward compatible since tensorflow have changed so much.
7 | > 2. A pre-trained model with twitter corpus is added, just `./go_example` to chat! (or preview my [chat example](https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm/blob/master/example_chat.md))
8 | > 3. You could start from tracing this `go_example` script to know how things work!
9 |
10 |
11 | ## Briefing
12 | This is a [seq2seq model][a1] modified from [tensorflow example][a2].
13 |
14 | 1. The original tensorflow seq2seq has [attention mechanism][a3] implemented out-of-box.
15 | 2. And speedup training by [dictionary space compressing][a4], then decompressed by projection the embedding while decoding.
16 | 3. This work add option to do [beam search][a5] in decoding procedure, which usually find better, more interesting response.
17 | 4. Added [anti-language model][a6] to suppress the generic response problem of intrinsic seq2seq model.
18 | 5. Imeplemented [this deep reinforcement learning architecture][a7] as an option to enhence semantic coherence and perplexity of response.
19 | 6. A light weight [Flask][a8] server `app.py` is included to be the Facebook Messenger App backend.
20 |
21 |
22 | [a1]: http://arxiv.org/abs/1406.1078
23 | [a2]: https://www.tensorflow.org/versions/r0.10/tutorials/seq2seq/index.html
24 | [a3]: http://arxiv.org/abs/1412.7449
25 | [a4]: https://arxiv.org/pdf/1412.2007v2.pdf
26 | [a5]: https://en.wikipedia.org/wiki/Beam_search
27 | [a6]: https://arxiv.org/abs/1510.03055
28 | [a7]: https://arxiv.org/abs/1606.01541
29 | [a8]: http://flask.pocoo.org/
30 | [a9]: https://github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm/blob/master/README2.md
31 |
32 |
33 | ## In Layman's terms
34 |
35 | I explained some detail about the features and some implementation tricks [here][a9].
36 |
37 |
38 | ## Just tell me how it works
39 |
40 | #### Clone the repository
41 |
42 | git clone github.com/Marsan-Ma/tf_chatbot_seq2seq_antilm.git
43 |
44 | #### Prepare for Corpus
45 | You may find corpus such as twitter chat, open movie subtitle, or ptt forums from [my chat corpus repository][b1]. You need to put it under path like:
46 |
47 | tf_chatbot_seq2seq_antilm/works/
/data/train/chat.txt
48 |
49 | And hand craft some testing sentences (each sentence per line) in:
50 |
51 | tf_chatbot_seq2seq_antilm/works//data/test/test_set.txt
52 |
53 | #### Train the model
54 |
55 | python3 main.py --mode train --model_name
56 |
57 | #### Run some test example and see the bot response
58 |
59 | after you trained your model until perplexity under 50 or so, you could do:
60 |
61 | python3 main.py --mode test --model_name
62 |
63 |
64 | **[Note!!!] if you put any parameter overwrite in this main.py commmand, be sure to apply both to train and test, or just modify in lib/config.py for failsafe.**
65 |
66 |
67 |
68 | ## Start your Facebook Messenger backend server
69 |
70 | python3 app.py --model_name
71 |
72 | You may see this [minimum fb_messenger example][b2] for more details like setting up SSL, webhook, and work-arounds for known bug.
73 |
74 | Here's an interesting comparison: The left conversation enabled beam search with beam = 10, the response is barely better than always "i don't know". The right conversation also used beam search and additionally, enabled anti-language model. This supposed to suppress generic response, and the response do seems better.
75 |
76 | ![messenger.png][h1]
77 |
78 | [h1]: https://raw.githubusercontent.com/Marsan-Ma/tf_chatbot_seq2seq_antilm/master/doc/messenger.png
79 |
80 |
81 |
82 |
83 | ## Deep reinforcement learning
84 |
85 | > [Update 2017-03-09] Reinforcement learning does not work now, wait for fix.
86 |
87 | If you want some chance to further improve your model, here I implemented a reinforcement learning architecture inspired by [Li et al., 2016][b3]. Just enable the reinforce_learn option in `config.py`, you might want to add your own rule in `step_rf()` function in `lib/seq2seq_mode.py`.
88 |
89 | Note that you should **train in normal mode to get a decent model first!**, since the reinforcement learning will explore the brave new world with this pre-trained model. It will end up taking forever to improve itself if you start with a bad model.
90 |
91 | [b1]: https://github.com/Marsan-Ma/chat_corpus
92 | [b2]: https://github.com/Marsan-Ma/fb_messenger
93 | [b3]: https://arxiv.org/abs/1606.01541
94 |
95 | ## Introduction
96 |
97 | Seq2seq is a great model released by [Cho et al., 2014][c1]. At first it's used to do machine translation, and soon people find that anything about **mapping something to another thing** could be also achieved by seq2seq model. Chatbot is one of these miracles, where we consider consecutive dialog as some kind of "mapping" relationship.
98 |
99 | Here is the classic intro picture show the seq2seq model architecture, quote from this [blogpost about gmail auto-reply feature][c2].
100 |
101 | [![seq2seq][c3]][c3]
102 |
103 |
104 | The problem is, so far we haven't find a better objective function for chatbot. We are still using [MLE (maximum likelyhood estimation)][c4], which is doing good for machine translation, but always generate generic response like "me too", "I think so", "I love you" while doing chat.
105 |
106 | These responses are not informative, but they do have large probability --- since they tend to appear many times in training corpus. We don't won't our chatbot always replying these noncense, so we need to find some way to make our bot more "interesting", technically speaking, to increase the "perplexity" of reponse.
107 |
108 | Here we reproduce the work of [Li. et al., 2016][c5] try to solve this problem. The main idea is using the same seq2seq model as a language model, to get the candidate words with high probability in each decoding timestamp as a anti-model, then we penalize these words always being high probability for any input. By this anti-model, we could get more special, non-generic, informative response.
109 |
110 | The original work of [Li. et al][c5] use [MERT (Och, 2003)][c6] with [BLEU][c7] as metrics to find the best probability weighting (the **λ** and **γ** in
111 | **Score(T) = p(T|S) − λU(T) + γNt**) of the corresponding anti-language model. But I find that BLEU score in chat corpus tend to always being zero, thus can't get meaningful result here. If anyone has any idea about this, drop me a message, thanks!
112 |
113 |
114 | [c1]: http://arxiv.org/abs/1406.1078
115 | [c2]: http://googleresearch.blogspot.ru/2015/11/computer-respond-to-this-email.html
116 | [c3]: http://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png
117 | [c4]: https://en.wikipedia.org/wiki/Maximum_likelihood_estimation
118 | [c5]: http://arxiv.org/pdf/1510.03055v3.pdf
119 | [c6]: http://delivery.acm.org/10.1145/1080000/1075117/p160-och.pdf
120 | [c7]: https://en.wikipedia.org/wiki/BLEU
121 |
122 |
123 | ## Parameters
124 |
125 | There are some options to for model training and predicting in lib/config.py. Basically they are self-explained and could work with default value for most of cases. Here we only list something you need to config:
126 |
127 | **About environment**
128 |
129 | name | type | Description
130 | ---- | ---- | -----------
131 | mode | string | work mode: train/test/chat
132 | model_name | string | model name, affects your working path (storing the data, nn_model, result folders)
133 | scope_name | string | In tensorflow if you need to load two graph at the same time, you need to save/load them in different namespace. (If you need only one seq2seq model, leave it as default)
134 | vocab_size | integer | depends on your corpus language: for english, 60000 is good enough. For chinese you need at least 100000 or 200000.
135 | gpu_usage | float | tensorflow gpu memory fraction used, default is 1 and tensorflow will occupy 100% of your GPU. If you have multi jobs sharing your GPU resource, make it 0.5 or 0.3, for 2 or 3 jobs.
136 | reinforce_learn | int | set 1 to enable reinforcement learning mode
137 |
138 |
139 | **About decoding**
140 |
141 | name | type | default | Description
142 | ---- | ---- | ------- | -------
143 | beam_size | int | 10 | beam search size, setting 1 equals to greedy search
144 | antilm | float | 0 (disabled) | punish weight of [anti-language model][d1]
145 | n_bonus | float | 0 (disabled) | reward weight of sentence length
146 |
147 |
148 | The anti-LM functin is disabled by default, you may start from setting antilm=0.5~0.7 and n_bonus=0.05 to see if you like the difference in results.
149 |
150 | [d1]: http://arxiv.org/pdf/1510.03055v3.pdf
151 |
152 |
153 | ## Requirements
154 |
155 | 1. For training, GPU is recommended since seq2seq is a large model, you need certain computing power to do the training and predicting efficiently, especially when you set a large beam-search size.
156 |
157 | 2. DRAM requirement is not strict as CPU/GPU, since we are doing stochastic gradient decent.
158 |
159 | 3. If you are new to deep-learning, setting-up things like GPU, python environment is annoying to you, here are dockers of my machine learning environment:
160 | [(non-gpu version docker)][e1] / [(gpu version docker)][e2]
161 |
162 | [e1]: https://github.com/Marsan-Ma/docker_mldm
163 | [e2]: https://github.com/Marsan-Ma/docker_mldm_gpu
164 |
165 |
166 | ## References
167 |
168 | Seq2seq is a model with many preliminaries, I've been spend quite some time surveying and here are some best materials which benefit me a lot:
169 |
170 | 1. The best blogpost explaining RNN, LSTM, GRU and seq2seq model: [Understanding LSTM Networks][f1] by Christopher Olah.
171 |
172 | 2. This work [sherjilozair/char-rnn-tensorflow][f2] helps me learn a lot about language model and implementation graph in tensorflow.
173 |
174 | 3. If you are interested in more magic about RNN, here is a MUST-READ blogpost: [The Unreasonable Effectiveness of Recurrent Neural Networks][f3] by Andrej Karpathy.
175 |
176 | 4. The vanilla version seq2seq+attention: [nicolas-ivanov/tf_seq2seq_chatbot][f4]. This will help you figure out the main flow of vanilla seq2seq model, and I build this repository based on this work.
177 |
178 | [f1]: http://colah.github.io/posts/2015-08-Understanding-LSTMs/
179 | [f2]: https://github.com/sherjilozair/char-rnn-tensorflow
180 | [f3]: http://karpathy.github.io/2015/05/21/rnn-effectiveness/
181 | [f4]: https://github.com/nicolas-ivanov/tf_seq2seq_chatbot
182 |
183 |
184 | ## TODOs
185 | 1. Currently I build beam-search out of graph, which means --- it's very slow. There are discussions about build it in-graph [here][g1] and [there][g2]. But unfortunately if you want add something more than beam-search, like this anti-LM work, you need much more than just beam search to be in-graph.
186 |
187 | 2. I haven't figure out how the MERT with BLEU can optimize weight of anti-LM model, since currently the BLEU is often being zero.
188 |
189 | [g1]: https://github.com/tensorflow/tensorflow/issues/654#issuecomment-196168030
190 | [g2]: https://github.com/tensorflow/tensorflow/pull/3756
191 |
--------------------------------------------------------------------------------
/ref/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for downloading data from WMT, tokenizing, vocabularies."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import gzip
22 | import os
23 | import re
24 | import tarfile
25 |
26 | from six.moves import urllib
27 |
28 | from tensorflow.python.platform import gfile
29 | import tensorflow as tf
30 |
31 | # Special vocabulary symbols - we always put them at the start.
32 | _PAD = b"_PAD"
33 | _GO = b"_GO"
34 | _EOS = b"_EOS"
35 | _UNK = b"_UNK"
36 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK]
37 |
38 | PAD_ID = 0
39 | GO_ID = 1
40 | EOS_ID = 2
41 | UNK_ID = 3
42 |
43 | # Regular expressions used to tokenize.
44 | _WORD_SPLIT = re.compile(b"([.,!?\"':;)(])")
45 | _DIGIT_RE = re.compile(br"\d")
46 |
47 | # URLs for WMT data.
48 | _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
49 | _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
50 |
51 |
52 | def maybe_download(directory, filename, url):
53 | """Download filename from url unless it's already in directory."""
54 | if not os.path.exists(directory):
55 | print("Creating directory %s" % directory)
56 | os.mkdir(directory)
57 | filepath = os.path.join(directory, filename)
58 | if not os.path.exists(filepath):
59 | print("Downloading %s to %s" % (url, filepath))
60 | filepath, _ = urllib.request.urlretrieve(url, filepath)
61 | statinfo = os.stat(filepath)
62 | print("Successfully downloaded", filename, statinfo.st_size, "bytes")
63 | return filepath
64 |
65 |
66 | def gunzip_file(gz_path, new_path):
67 | """Unzips from gz_path into new_path."""
68 | print("Unpacking %s to %s" % (gz_path, new_path))
69 | with gzip.open(gz_path, "rb") as gz_file:
70 | with open(new_path, "wb") as new_file:
71 | for line in gz_file:
72 | new_file.write(line)
73 |
74 |
75 | def get_wmt_enfr_train_set(directory):
76 | """Download the WMT en-fr training corpus to directory unless it's there."""
77 | train_path = os.path.join(directory, "giga-fren.release2.fixed")
78 | if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
79 | corpus_file = maybe_download(directory, "training-giga-fren.tar",
80 | _WMT_ENFR_TRAIN_URL)
81 | print("Extracting tar file %s" % corpus_file)
82 | with tarfile.open(corpus_file, "r") as corpus_tar:
83 | corpus_tar.extractall(directory)
84 | gunzip_file(train_path + ".fr.gz", train_path + ".fr")
85 | gunzip_file(train_path + ".en.gz", train_path + ".en")
86 | return train_path
87 |
88 |
89 | def get_wmt_enfr_dev_set(directory):
90 | """Download the WMT en-fr training corpus to directory unless it's there."""
91 | dev_name = "newstest2013"
92 | dev_path = os.path.join(directory, dev_name)
93 | if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
94 | dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
95 | print("Extracting tgz file %s" % dev_file)
96 | with tarfile.open(dev_file, "r:gz") as dev_tar:
97 | fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
98 | en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
99 | fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix.
100 | en_dev_file.name = dev_name + ".en"
101 | dev_tar.extract(fr_dev_file, directory)
102 | dev_tar.extract(en_dev_file, directory)
103 | return dev_path
104 |
105 |
106 | def basic_tokenizer(sentence):
107 | """Very basic tokenizer: split the sentence into a list of tokens."""
108 | words = []
109 | for space_separated_fragment in sentence.strip().split():
110 | words.extend(_WORD_SPLIT.split(space_separated_fragment))
111 | return [w for w in words if w]
112 |
113 |
114 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
115 | tokenizer=None, normalize_digits=True):
116 | """Create vocabulary file (if it does not exist yet) from data file.
117 |
118 | Data file is assumed to contain one sentence per line. Each sentence is
119 | tokenized and digits are normalized (if normalize_digits is set).
120 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
121 | We write it to vocabulary_path in a one-token-per-line format, so that later
122 | token in the first line gets id=0, second line gets id=1, and so on.
123 |
124 | Args:
125 | vocabulary_path: path where the vocabulary will be created.
126 | data_path: data file that will be used to create vocabulary.
127 | max_vocabulary_size: limit on the size of the created vocabulary.
128 | tokenizer: a function to use to tokenize each data sentence;
129 | if None, basic_tokenizer will be used.
130 | normalize_digits: Boolean; if true, all digits are replaced by 0s.
131 | """
132 | if not gfile.Exists(vocabulary_path):
133 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
134 | vocab = {}
135 | with gfile.GFile(data_path, mode="rb") as f:
136 | counter = 0
137 | for line in f:
138 | counter += 1
139 | if counter % 100000 == 0:
140 | print(" processing line %d" % counter)
141 | line = tf.compat.as_bytes(line)
142 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
143 | for w in tokens:
144 | word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w
145 | if word in vocab:
146 | vocab[word] += 1
147 | else:
148 | vocab[word] = 1
149 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
150 | if len(vocab_list) > max_vocabulary_size:
151 | vocab_list = vocab_list[:max_vocabulary_size]
152 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
153 | for w in vocab_list:
154 | vocab_file.write(w + b"\n")
155 |
156 |
157 | def initialize_vocabulary(vocabulary_path):
158 | """Initialize vocabulary from file.
159 |
160 | We assume the vocabulary is stored one-item-per-line, so a file:
161 | dog
162 | cat
163 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
164 | also return the reversed-vocabulary ["dog", "cat"].
165 |
166 | Args:
167 | vocabulary_path: path to the file containing the vocabulary.
168 |
169 | Returns:
170 | a pair: the vocabulary (a dictionary mapping string to integers), and
171 | the reversed vocabulary (a list, which reverses the vocabulary mapping).
172 |
173 | Raises:
174 | ValueError: if the provided vocabulary_path does not exist.
175 | """
176 | if gfile.Exists(vocabulary_path):
177 | rev_vocab = []
178 | with gfile.GFile(vocabulary_path, mode="rb") as f:
179 | rev_vocab.extend(f.readlines())
180 | rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab]
181 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
182 | return vocab, rev_vocab
183 | else:
184 | raise ValueError("Vocabulary file %s not found.", vocabulary_path)
185 |
186 |
187 | def sentence_to_token_ids(sentence, vocabulary,
188 | tokenizer=None, normalize_digits=True):
189 | """Convert a string to list of integers representing token-ids.
190 |
191 | For example, a sentence "I have a dog" may become tokenized into
192 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
193 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
194 |
195 | Args:
196 | sentence: the sentence in bytes format to convert to token-ids.
197 | vocabulary: a dictionary mapping tokens to integers.
198 | tokenizer: a function to use to tokenize each sentence;
199 | if None, basic_tokenizer will be used.
200 | normalize_digits: Boolean; if true, all digits are replaced by 0s.
201 |
202 | Returns:
203 | a list of integers, the token-ids for the sentence.
204 | """
205 |
206 | if tokenizer:
207 | words = tokenizer(sentence)
208 | else:
209 | words = basic_tokenizer(sentence)
210 | if not normalize_digits:
211 | return [vocabulary.get(w, UNK_ID) for w in words]
212 | # Normalize digits by 0 before looking words up in the vocabulary.
213 | return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words]
214 |
215 |
216 | def data_to_token_ids(data_path, target_path, vocabulary_path,
217 | tokenizer=None, normalize_digits=True):
218 | """Tokenize data file and turn into token-ids using given vocabulary file.
219 |
220 | This function loads data line-by-line from data_path, calls the above
221 | sentence_to_token_ids, and saves the result to target_path. See comment
222 | for sentence_to_token_ids on the details of token-ids format.
223 |
224 | Args:
225 | data_path: path to the data file in one-sentence-per-line format.
226 | target_path: path where the file with token-ids will be created.
227 | vocabulary_path: path to the vocabulary file.
228 | tokenizer: a function to use to tokenize each sentence;
229 | if None, basic_tokenizer will be used.
230 | normalize_digits: Boolean; if true, all digits are replaced by 0s.
231 | """
232 | if not gfile.Exists(target_path):
233 | print("Tokenizing data in %s" % data_path)
234 | vocab, _ = initialize_vocabulary(vocabulary_path)
235 | with gfile.GFile(data_path, mode="rb") as data_file:
236 | with gfile.GFile(target_path, mode="w") as tokens_file:
237 | counter = 0
238 | for line in data_file:
239 | counter += 1
240 | if counter % 100000 == 0:
241 | print(" tokenizing line %d" % counter)
242 | token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab,
243 | tokenizer, normalize_digits)
244 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
245 |
246 |
247 | def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None):
248 | """Get WMT data into data_dir, create vocabularies and tokenize data.
249 |
250 | Args:
251 | data_dir: directory in which the data sets will be stored.
252 | en_vocabulary_size: size of the English vocabulary to create and use.
253 | fr_vocabulary_size: size of the French vocabulary to create and use.
254 | tokenizer: a function to use to tokenize each data sentence;
255 | if None, basic_tokenizer will be used.
256 |
257 | Returns:
258 | A tuple of 6 elements:
259 | (1) path to the token-ids for English training data-set,
260 | (2) path to the token-ids for French training data-set,
261 | (3) path to the token-ids for English development data-set,
262 | (4) path to the token-ids for French development data-set,
263 | (5) path to the English vocabulary file,
264 | (6) path to the French vocabulary file.
265 | """
266 | # Get wmt data to the specified directory.
267 | train_path = get_wmt_enfr_train_set(data_dir)
268 | dev_path = get_wmt_enfr_dev_set(data_dir)
269 |
270 | from_train_path = train_path + ".en"
271 | to_train_path = train_path + ".fr"
272 | from_dev_path = dev_path + ".en"
273 | to_dev_path = dev_path + ".fr"
274 | return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size,
275 | fr_vocabulary_size, tokenizer)
276 |
277 |
278 | def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size,
279 | to_vocabulary_size, tokenizer=None):
280 | """Preapre all necessary files that are required for the training.
281 |
282 | Args:
283 | data_dir: directory in which the data sets will be stored.
284 | from_train_path: path to the file that includes "from" training samples.
285 | to_train_path: path to the file that includes "to" training samples.
286 | from_dev_path: path to the file that includes "from" dev samples.
287 | to_dev_path: path to the file that includes "to" dev samples.
288 | from_vocabulary_size: size of the "from language" vocabulary to create and use.
289 | to_vocabulary_size: size of the "to language" vocabulary to create and use.
290 | tokenizer: a function to use to tokenize each data sentence;
291 | if None, basic_tokenizer will be used.
292 |
293 | Returns:
294 | A tuple of 6 elements:
295 | (1) path to the token-ids for "from language" training data-set,
296 | (2) path to the token-ids for "to language" training data-set,
297 | (3) path to the token-ids for "from language" development data-set,
298 | (4) path to the token-ids for "to language" development data-set,
299 | (5) path to the "from language" vocabulary file,
300 | (6) path to the "to language" vocabulary file.
301 | """
302 | # Create vocabularies of the appropriate sizes.
303 | to_vocab_path = os.path.join(data_dir, "vocab%d.to" % to_vocabulary_size)
304 | from_vocab_path = os.path.join(data_dir, "vocab%d.from" % from_vocabulary_size)
305 | create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer)
306 | create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer)
307 |
308 | # Create token ids for the training data.
309 | to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size)
310 | from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size)
311 | data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer)
312 | data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer)
313 |
314 | # Create token ids for the development data.
315 | to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size)
316 | from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size)
317 | data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer)
318 | data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer)
319 |
320 | return (from_train_ids_path, to_train_ids_path,
321 | from_dev_ids_path, to_dev_ids_path,
322 | from_vocab_path, to_vocab_path)
323 |
--------------------------------------------------------------------------------
/ref/translate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Binary for training translation models and decoding from them.
17 |
18 | Running this program without --decode will download the WMT corpus into
19 | the directory specified as --data_dir and tokenize it in a very basic way,
20 | and then start training a model saving checkpoints to --train_dir.
21 |
22 | Running with --decode starts an interactive loop so you can see how
23 | the current checkpoint translates English sentences into French.
24 |
25 | See the following papers for more information on neural translation models.
26 | * http://arxiv.org/abs/1409.3215
27 | * http://arxiv.org/abs/1409.0473
28 | * http://arxiv.org/abs/1412.2007
29 | """
30 | from __future__ import absolute_import
31 | from __future__ import division
32 | from __future__ import print_function
33 |
34 | import math
35 | import os
36 | import random
37 | import sys
38 | import time
39 | import logging
40 |
41 | import numpy as np
42 | from six.moves import xrange # pylint: disable=redefined-builtin
43 | import tensorflow as tf
44 |
45 | import data_utils
46 | import seq2seq_model
47 |
48 |
49 | tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
50 | tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99,
51 | "Learning rate decays by this much.")
52 | tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
53 | "Clip gradients to this norm.")
54 | tf.app.flags.DEFINE_integer("batch_size", 64,
55 | "Batch size to use during training.")
56 | tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
57 | tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
58 | tf.app.flags.DEFINE_integer("from_vocab_size", 40000, "English vocabulary size.")
59 | tf.app.flags.DEFINE_integer("to_vocab_size", 40000, "French vocabulary size.")
60 | tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
61 | tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
62 | tf.app.flags.DEFINE_string("from_train_data", None, "Training data.")
63 | tf.app.flags.DEFINE_string("to_train_data", None, "Training data.")
64 | tf.app.flags.DEFINE_string("from_dev_data", None, "Training data.")
65 | tf.app.flags.DEFINE_string("to_dev_data", None, "Training data.")
66 | tf.app.flags.DEFINE_integer("max_train_data_size", 0,
67 | "Limit on the size of training data (0: no limit).")
68 | tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
69 | "How many training steps to do per checkpoint.")
70 | tf.app.flags.DEFINE_boolean("decode", False,
71 | "Set to True for interactive decoding.")
72 | tf.app.flags.DEFINE_boolean("self_test", False,
73 | "Run a self-test if this is set to True.")
74 | tf.app.flags.DEFINE_boolean("use_fp16", False,
75 | "Train using fp16 instead of fp32.")
76 |
77 | FLAGS = tf.app.flags.FLAGS
78 |
79 | # We use a number of buckets and pad to the closest one for efficiency.
80 | # See seq2seq_model.Seq2SeqModel for details of how they work.
81 | _buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
82 |
83 |
84 | def read_data(source_path, target_path, max_size=None):
85 | """Read data from source and target files and put into buckets.
86 |
87 | Args:
88 | source_path: path to the files with token-ids for the source language.
89 | target_path: path to the file with token-ids for the target language;
90 | it must be aligned with the source file: n-th line contains the desired
91 | output for n-th line from the source_path.
92 | max_size: maximum number of lines to read, all other will be ignored;
93 | if 0 or None, data files will be read completely (no limit).
94 |
95 | Returns:
96 | data_set: a list of length len(_buckets); data_set[n] contains a list of
97 | (source, target) pairs read from the provided data files that fit
98 | into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
99 | len(target) < _buckets[n][1]; source and target are lists of token-ids.
100 | """
101 | data_set = [[] for _ in _buckets]
102 | with tf.gfile.GFile(source_path, mode="r") as source_file:
103 | with tf.gfile.GFile(target_path, mode="r") as target_file:
104 | source, target = source_file.readline(), target_file.readline()
105 | counter = 0
106 | while source and target and (not max_size or counter < max_size):
107 | counter += 1
108 | if counter % 100000 == 0:
109 | print(" reading data line %d" % counter)
110 | sys.stdout.flush()
111 | source_ids = [int(x) for x in source.split()]
112 | target_ids = [int(x) for x in target.split()]
113 | target_ids.append(data_utils.EOS_ID)
114 | for bucket_id, (source_size, target_size) in enumerate(_buckets):
115 | if len(source_ids) < source_size and len(target_ids) < target_size:
116 | data_set[bucket_id].append([source_ids, target_ids])
117 | break
118 | source, target = source_file.readline(), target_file.readline()
119 | return data_set
120 |
121 |
122 | def create_model(session, forward_only):
123 | """Create translation model and initialize or load parameters in session."""
124 | dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
125 | model = seq2seq_model.Seq2SeqModel(
126 | FLAGS.from_vocab_size,
127 | FLAGS.to_vocab_size,
128 | _buckets,
129 | FLAGS.size,
130 | FLAGS.num_layers,
131 | FLAGS.max_gradient_norm,
132 | FLAGS.batch_size,
133 | FLAGS.learning_rate,
134 | FLAGS.learning_rate_decay_factor,
135 | forward_only=forward_only,
136 | dtype=dtype)
137 | ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
138 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
139 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
140 | model.saver.restore(session, ckpt.model_checkpoint_path)
141 | else:
142 | print("Created model with fresh parameters.")
143 | session.run(tf.global_variables_initializer())
144 | return model
145 |
146 |
147 | def train():
148 | """Train a en->fr translation model using WMT data."""
149 | from_train = None
150 | to_train = None
151 | from_dev = None
152 | to_dev = None
153 | if FLAGS.from_train_data and FLAGS.to_train_data:
154 | from_train_data = FLAGS.from_train_data
155 | to_train_data = FLAGS.to_train_data
156 | from_dev_data = from_train_data
157 | to_dev_data = to_train_data
158 | if FLAGS.from_dev_data and FLAGS.to_dev_data:
159 | from_dev_data = FLAGS.from_dev_data
160 | to_dev_data = FLAGS.to_dev_data
161 | from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data(
162 | FLAGS.data_dir,
163 | from_train_data,
164 | to_train_data,
165 | from_dev_data,
166 | to_dev_data,
167 | FLAGS.from_vocab_size,
168 | FLAGS.to_vocab_size)
169 | else:
170 | # Prepare WMT data.
171 | print("Preparing WMT data in %s" % FLAGS.data_dir)
172 | from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data(
173 | FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size)
174 |
175 | with tf.Session() as sess:
176 | # Create model.
177 | print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
178 | model = create_model(sess, False)
179 |
180 | # Read data into buckets and compute their sizes.
181 | print ("Reading development and training data (limit: %d)."
182 | % FLAGS.max_train_data_size)
183 | dev_set = read_data(from_dev, to_dev)
184 | train_set = read_data(from_train, to_train, FLAGS.max_train_data_size)
185 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
186 | train_total_size = float(sum(train_bucket_sizes))
187 |
188 | # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
189 | # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
190 | # the size if i-th training bucket, as used later.
191 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
192 | for i in xrange(len(train_bucket_sizes))]
193 |
194 | # This is the training loop.
195 | step_time, loss = 0.0, 0.0
196 | current_step = 0
197 | previous_losses = []
198 | while True:
199 | # Choose a bucket according to data distribution. We pick a random number
200 | # in [0, 1] and use the corresponding interval in train_buckets_scale.
201 | random_number_01 = np.random.random_sample()
202 | bucket_id = min([i for i in xrange(len(train_buckets_scale))
203 | if train_buckets_scale[i] > random_number_01])
204 |
205 | # Get a batch and make a step.
206 | start_time = time.time()
207 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(
208 | train_set, bucket_id)
209 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
210 | target_weights, bucket_id, False)
211 | step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
212 | loss += step_loss / FLAGS.steps_per_checkpoint
213 | current_step += 1
214 |
215 | # Once in a while, we save checkpoint, print statistics, and run evals.
216 | if current_step % FLAGS.steps_per_checkpoint == 0:
217 | # Print statistics for the previous epoch.
218 | perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
219 | print ("global step %d learning rate %.4f step-time %.2f perplexity "
220 | "%.2f" % (model.global_step.eval(), model.learning_rate.eval(),
221 | step_time, perplexity))
222 | # Decrease learning rate if no improvement was seen over last 3 times.
223 | if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
224 | sess.run(model.learning_rate_decay_op)
225 | previous_losses.append(loss)
226 | # Save checkpoint and zero timer and loss.
227 | checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt")
228 | model.saver.save(sess, checkpoint_path, global_step=model.global_step)
229 | step_time, loss = 0.0, 0.0
230 | # Run evals on development set and print their perplexity.
231 | for bucket_id in xrange(len(_buckets)):
232 | if len(dev_set[bucket_id]) == 0:
233 | print(" eval: empty bucket %d" % (bucket_id))
234 | continue
235 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(
236 | dev_set, bucket_id)
237 | _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
238 | target_weights, bucket_id, True)
239 | eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float(
240 | "inf")
241 | print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
242 | sys.stdout.flush()
243 |
244 |
245 | def decode():
246 | with tf.Session() as sess:
247 | # Create model and load parameters.
248 | model = create_model(sess, True)
249 | model.batch_size = 1 # We decode one sentence at a time.
250 |
251 | # Load vocabularies.
252 | en_vocab_path = os.path.join(FLAGS.data_dir,
253 | "vocab%d.from" % FLAGS.from_vocab_size)
254 | fr_vocab_path = os.path.join(FLAGS.data_dir,
255 | "vocab%d.to" % FLAGS.to_vocab_size)
256 | en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
257 | _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
258 |
259 | # Decode from standard input.
260 | sys.stdout.write("> ")
261 | sys.stdout.flush()
262 | sentence = sys.stdin.readline()
263 | while sentence:
264 | # Get token-ids for the input sentence.
265 | token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
266 | # Which bucket does it belong to?
267 | bucket_id = len(_buckets) - 1
268 | for i, bucket in enumerate(_buckets):
269 | if bucket[0] >= len(token_ids):
270 | bucket_id = i
271 | break
272 | else:
273 | logging.warning("Sentence truncated: %s", sentence)
274 |
275 | # Get a 1-element batch to feed the sentence to the model.
276 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(
277 | {bucket_id: [(token_ids, [])]}, bucket_id)
278 | # Get output logits for the sentence.
279 | _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
280 | target_weights, bucket_id, True)
281 | # This is a greedy decoder - outputs are just argmaxes of output_logits.
282 | outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
283 | # If there is an EOS symbol in outputs, cut them at that point.
284 | if data_utils.EOS_ID in outputs:
285 | outputs = outputs[:outputs.index(data_utils.EOS_ID)]
286 | # Print out French sentence corresponding to outputs.
287 | print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
288 | print("> ", end="")
289 | sys.stdout.flush()
290 | sentence = sys.stdin.readline()
291 |
292 |
293 | def self_test():
294 | """Test the translation model."""
295 | with tf.Session() as sess:
296 | print("Self-test for neural translation model.")
297 | # Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
298 | model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
299 | 5.0, 32, 0.3, 0.99, num_samples=8)
300 | sess.run(tf.global_variables_initializer())
301 |
302 | # Fake data set for both the (3, 3) and (6, 6) bucket.
303 | data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])],
304 | [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])])
305 | for _ in xrange(5): # Train the fake model for 5 steps.
306 | bucket_id = random.choice([0, 1])
307 | encoder_inputs, decoder_inputs, target_weights = model.get_batch(
308 | data_set, bucket_id)
309 | model.step(sess, encoder_inputs, decoder_inputs, target_weights,
310 | bucket_id, False)
311 |
312 |
313 | def main(_):
314 | if FLAGS.self_test:
315 | self_test()
316 | elif FLAGS.decode:
317 | decode()
318 | else:
319 | train()
320 |
321 | if __name__ == "__main__":
322 | tf.app.run()
323 |
--------------------------------------------------------------------------------
/ref/seq2seq_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Sequence-to-sequence model with an attention mechanism."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import random
23 |
24 | import numpy as np
25 | from six.moves import xrange # pylint: disable=redefined-builtin
26 | import tensorflow as tf
27 |
28 | import data_utils
29 |
30 |
31 | class Seq2SeqModel(object):
32 | """Sequence-to-sequence model with attention and for multiple buckets.
33 |
34 | This class implements a multi-layer recurrent neural network as encoder,
35 | and an attention-based decoder. This is the same as the model described in
36 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details,
37 | or into the seq2seq library for complete model implementation.
38 | This class also allows to use GRU cells in addition to LSTM cells, and
39 | sampled softmax to handle large output vocabulary size. A single-layer
40 | version of this model, but with bi-directional encoder, was presented in
41 | http://arxiv.org/abs/1409.0473
42 | and sampled softmax is described in Section 3 of the following paper.
43 | http://arxiv.org/abs/1412.2007
44 | """
45 |
46 | def __init__(self,
47 | source_vocab_size,
48 | target_vocab_size,
49 | buckets,
50 | size,
51 | num_layers,
52 | max_gradient_norm,
53 | batch_size,
54 | learning_rate,
55 | learning_rate_decay_factor,
56 | use_lstm=False,
57 | num_samples=512,
58 | forward_only=False,
59 | dtype=tf.float32):
60 | """Create the model.
61 |
62 | Args:
63 | source_vocab_size: size of the source vocabulary.
64 | target_vocab_size: size of the target vocabulary.
65 | buckets: a list of pairs (I, O), where I specifies maximum input length
66 | that will be processed in that bucket, and O specifies maximum output
67 | length. Training instances that have inputs longer than I or outputs
68 | longer than O will be pushed to the next bucket and padded accordingly.
69 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
70 | size: number of units in each layer of the model.
71 | num_layers: number of layers in the model.
72 | max_gradient_norm: gradients will be clipped to maximally this norm.
73 | batch_size: the size of the batches used during training;
74 | the model construction is independent of batch_size, so it can be
75 | changed after initialization if this is convenient, e.g., for decoding.
76 | learning_rate: learning rate to start with.
77 | learning_rate_decay_factor: decay learning rate by this much when needed.
78 | use_lstm: if true, we use LSTM cells instead of GRU cells.
79 | num_samples: number of samples for sampled softmax.
80 | forward_only: if set, we do not construct the backward pass in the model.
81 | dtype: the data type to use to store internal variables.
82 | """
83 | self.source_vocab_size = source_vocab_size
84 | self.target_vocab_size = target_vocab_size
85 | self.buckets = buckets
86 | self.batch_size = batch_size
87 | self.learning_rate = tf.Variable(
88 | float(learning_rate), trainable=False, dtype=dtype)
89 | self.learning_rate_decay_op = self.learning_rate.assign(
90 | self.learning_rate * learning_rate_decay_factor)
91 | self.global_step = tf.Variable(0, trainable=False)
92 |
93 | # If we use sampled softmax, we need an output projection.
94 | output_projection = None
95 | softmax_loss_function = None
96 | # Sampled softmax only makes sense if we sample less than vocabulary size.
97 | if num_samples > 0 and num_samples < self.target_vocab_size:
98 | w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype)
99 | w = tf.transpose(w_t)
100 | b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
101 | output_projection = (w, b)
102 |
103 | def sampled_loss(labels, inputs):
104 | labels = tf.reshape(labels, [-1, 1])
105 | # We need to compute the sampled_softmax_loss using 32bit floats to
106 | # avoid numerical instabilities.
107 | local_w_t = tf.cast(w_t, tf.float32)
108 | local_b = tf.cast(b, tf.float32)
109 | local_inputs = tf.cast(inputs, tf.float32)
110 | return tf.cast(
111 | tf.nn.sampled_softmax_loss(
112 | weights=local_w_t,
113 | biases=local_b,
114 | labels=labels,
115 | inputs=local_inputs,
116 | num_sampled=num_samples,
117 | num_classes=self.target_vocab_size),
118 | dtype)
119 | softmax_loss_function = sampled_loss
120 |
121 | # Create the internal multi-layer cell for our RNN.
122 | def single_cell():
123 | return tf.contrib.rnn.GRUCell(size)
124 | if use_lstm:
125 | def single_cell():
126 | return tf.contrib.rnn.BasicLSTMCell(size)
127 | cell = single_cell()
128 | if num_layers > 1:
129 | cell = tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)])
130 |
131 | # The seq2seq function: we use embedding for the input and attention.
132 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
133 | return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
134 | encoder_inputs,
135 | decoder_inputs,
136 | cell,
137 | num_encoder_symbols=source_vocab_size,
138 | num_decoder_symbols=target_vocab_size,
139 | embedding_size=size,
140 | output_projection=output_projection,
141 | feed_previous=do_decode,
142 | dtype=dtype)
143 |
144 | # Feeds for inputs.
145 | self.encoder_inputs = []
146 | self.decoder_inputs = []
147 | self.target_weights = []
148 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one.
149 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
150 | name="encoder{0}".format(i)))
151 | for i in xrange(buckets[-1][1] + 1):
152 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
153 | name="decoder{0}".format(i)))
154 | self.target_weights.append(tf.placeholder(dtype, shape=[None],
155 | name="weight{0}".format(i)))
156 |
157 | # Our targets are decoder inputs shifted by one.
158 | targets = [self.decoder_inputs[i + 1]
159 | for i in xrange(len(self.decoder_inputs) - 1)]
160 |
161 | # Training outputs and losses.
162 | if forward_only:
163 | self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
164 | self.encoder_inputs, self.decoder_inputs, targets,
165 | self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
166 | softmax_loss_function=softmax_loss_function)
167 | # If we use output projection, we need to project outputs for decoding.
168 | if output_projection is not None:
169 | for b in xrange(len(buckets)):
170 | self.outputs[b] = [
171 | tf.matmul(output, output_projection[0]) + output_projection[1]
172 | for output in self.outputs[b]
173 | ]
174 | else:
175 | self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
176 | self.encoder_inputs, self.decoder_inputs, targets,
177 | self.target_weights, buckets,
178 | lambda x, y: seq2seq_f(x, y, False),
179 | softmax_loss_function=softmax_loss_function)
180 |
181 | # Gradients and SGD update operation for training the model.
182 | params = tf.trainable_variables()
183 | if not forward_only:
184 | self.gradient_norms = []
185 | self.updates = []
186 | opt = tf.train.GradientDescentOptimizer(self.learning_rate)
187 | for b in xrange(len(buckets)):
188 | gradients = tf.gradients(self.losses[b], params)
189 | clipped_gradients, norm = tf.clip_by_global_norm(gradients,
190 | max_gradient_norm)
191 | self.gradient_norms.append(norm)
192 | self.updates.append(opt.apply_gradients(
193 | zip(clipped_gradients, params), global_step=self.global_step))
194 |
195 | self.saver = tf.train.Saver(tf.global_variables())
196 |
197 | def step(self, session, encoder_inputs, decoder_inputs, target_weights,
198 | bucket_id, forward_only):
199 | """Run a step of the model feeding the given inputs.
200 |
201 | Args:
202 | session: tensorflow session to use.
203 | encoder_inputs: list of numpy int vectors to feed as encoder inputs.
204 | decoder_inputs: list of numpy int vectors to feed as decoder inputs.
205 | target_weights: list of numpy float vectors to feed as target weights.
206 | bucket_id: which bucket of the model to use.
207 | forward_only: whether to do the backward step or only forward.
208 |
209 | Returns:
210 | A triple consisting of gradient norm (or None if we did not do backward),
211 | average perplexity, and the outputs.
212 |
213 | Raises:
214 | ValueError: if length of encoder_inputs, decoder_inputs, or
215 | target_weights disagrees with bucket size for the specified bucket_id.
216 | """
217 | # Check if the sizes match.
218 | encoder_size, decoder_size = self.buckets[bucket_id]
219 | if len(encoder_inputs) != encoder_size:
220 | raise ValueError("Encoder length must be equal to the one in bucket,"
221 | " %d != %d." % (len(encoder_inputs), encoder_size))
222 | if len(decoder_inputs) != decoder_size:
223 | raise ValueError("Decoder length must be equal to the one in bucket,"
224 | " %d != %d." % (len(decoder_inputs), decoder_size))
225 | if len(target_weights) != decoder_size:
226 | raise ValueError("Weights length must be equal to the one in bucket,"
227 | " %d != %d." % (len(target_weights), decoder_size))
228 |
229 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
230 | input_feed = {}
231 | for l in xrange(encoder_size):
232 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
233 | for l in xrange(decoder_size):
234 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
235 | input_feed[self.target_weights[l].name] = target_weights[l]
236 |
237 | # Since our targets are decoder inputs shifted by one, we need one more.
238 | last_target = self.decoder_inputs[decoder_size].name
239 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
240 |
241 | # Output feed: depends on whether we do a backward step or not.
242 | if not forward_only:
243 | output_feed = [self.updates[bucket_id], # Update Op that does SGD.
244 | self.gradient_norms[bucket_id], # Gradient norm.
245 | self.losses[bucket_id]] # Loss for this batch.
246 | else:
247 | output_feed = [self.losses[bucket_id]] # Loss for this batch.
248 | for l in xrange(decoder_size): # Output logits.
249 | output_feed.append(self.outputs[bucket_id][l])
250 |
251 | outputs = session.run(output_feed, input_feed)
252 | if not forward_only:
253 | return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
254 | else:
255 | return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.
256 |
257 | def get_batch(self, data, bucket_id):
258 | """Get a random batch of data from the specified bucket, prepare for step.
259 |
260 | To feed data in step(..) it must be a list of batch-major vectors, while
261 | data here contains single length-major cases. So the main logic of this
262 | function is to re-index data cases to be in the proper format for feeding.
263 |
264 | Args:
265 | data: a tuple of size len(self.buckets) in which each element contains
266 | lists of pairs of input and output data that we use to create a batch.
267 | bucket_id: integer, which bucket to get the batch for.
268 |
269 | Returns:
270 | The triple (encoder_inputs, decoder_inputs, target_weights) for
271 | the constructed batch that has the proper format to call step(...) later.
272 | """
273 | encoder_size, decoder_size = self.buckets[bucket_id]
274 | encoder_inputs, decoder_inputs = [], []
275 |
276 | # Get a random batch of encoder and decoder inputs from data,
277 | # pad them if needed, reverse encoder inputs and add GO to decoder.
278 | for _ in xrange(self.batch_size):
279 | encoder_input, decoder_input = random.choice(data[bucket_id])
280 |
281 | # Encoder inputs are padded and then reversed.
282 | encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
283 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
284 |
285 | # Decoder inputs get an extra "GO" symbol, and are padded then.
286 | decoder_pad_size = decoder_size - len(decoder_input) - 1
287 | decoder_inputs.append([data_utils.GO_ID] + decoder_input +
288 | [data_utils.PAD_ID] * decoder_pad_size)
289 |
290 | # Now we create batch-major vectors from the data selected above.
291 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
292 |
293 | # Batch encoder inputs are just re-indexed encoder_inputs.
294 | for length_idx in xrange(encoder_size):
295 | batch_encoder_inputs.append(
296 | np.array([encoder_inputs[batch_idx][length_idx]
297 | for batch_idx in xrange(self.batch_size)], dtype=np.int32))
298 |
299 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights.
300 | for length_idx in xrange(decoder_size):
301 | batch_decoder_inputs.append(
302 | np.array([decoder_inputs[batch_idx][length_idx]
303 | for batch_idx in xrange(self.batch_size)], dtype=np.int32))
304 |
305 | # Create target_weights to be 0 for targets that are padding.
306 | batch_weight = np.ones(self.batch_size, dtype=np.float32)
307 | for batch_idx in xrange(self.batch_size):
308 | # We set weight to 0 if the corresponding target is a PAD symbol.
309 | # The corresponding target is decoder_input shifted by 1 forward.
310 | if length_idx < decoder_size - 1:
311 | target = decoder_inputs[batch_idx][length_idx + 1]
312 | if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
313 | batch_weight[batch_idx] = 0.0
314 | batch_weights.append(batch_weight)
315 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights
316 |
--------------------------------------------------------------------------------
/lib/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for downloading data from WMT, tokenizing, vocabularies."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import sys, os, re, gzip, tarfile
22 |
23 | from six.moves import urllib
24 |
25 | from tensorflow.python.platform import gfile
26 | import tensorflow as tf
27 |
28 | # Special vocabulary symbols - we always put them at the start.
29 | _PAD = b"_PAD"
30 | _GO = b"_GO"
31 | _EOS = b"_EOS"
32 | _UNK = b"_UNK"
33 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK]
34 |
35 | PAD_ID = 0
36 | GO_ID = 1
37 | EOS_ID = 2
38 | UNK_ID = 3
39 |
40 | # Regular expressions used to tokenize.
41 | #_WORD_SPLIT = re.compile(b"([.,!?\"':;,。!)(])")
42 | _WORD_SPLIT = re.compile(b"([.,!?\"':;)(])")
43 | # _DIGIT_RE = re.compile(br"\d{3,}")
44 | _DIGIT_RE = re.compile(br"\d")
45 |
46 | def get_dialog_train_set_path(path):
47 | return os.path.join(path, 'chat')
48 |
49 |
50 | # URLs for WMT data.
51 | _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
52 | _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
53 |
54 |
55 | def maybe_download(directory, filename, url):
56 | """Download filename from url unless it's already in directory."""
57 | if not os.path.exists(directory):
58 | print("Creating directory %s" % directory)
59 | os.mkdir(directory)
60 | filepath = os.path.join(directory, filename)
61 | if not os.path.exists(filepath):
62 | print("Downloading %s to %s" % (url, filepath))
63 | filepath, _ = urllib.request.urlretrieve(url, filepath)
64 | statinfo = os.stat(filepath)
65 | print("Successfully downloaded", filename, statinfo.st_size, "bytes")
66 | return filepath
67 |
68 |
69 | def gunzip_file(gz_path, new_path):
70 | """Unzips from gz_path into new_path."""
71 | print("Unpacking %s to %s" % (gz_path, new_path))
72 | with gzip.open(gz_path, "rb") as gz_file:
73 | with open(new_path, "wb") as new_file:
74 | for line in gz_file:
75 | new_file.write(line)
76 |
77 |
78 | def get_wmt_enfr_train_set(directory):
79 | """Download the WMT en-fr training corpus to directory unless it's there."""
80 | train_path = os.path.join(directory, "giga-fren.release2.fixed")
81 | if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
82 | corpus_file = maybe_download(directory, "training-giga-fren.tar",
83 | _WMT_ENFR_TRAIN_URL)
84 | print("Extracting tar file %s" % corpus_file)
85 | with tarfile.open(corpus_file, "r") as corpus_tar:
86 | corpus_tar.extractall(directory)
87 | gunzip_file(train_path + ".fr.gz", train_path + ".fr")
88 | gunzip_file(train_path + ".en.gz", train_path + ".en")
89 | return train_path
90 |
91 |
92 | def get_wmt_enfr_dev_set(directory):
93 | """Download the WMT en-fr training corpus to directory unless it's there."""
94 | dev_name = "newstest2013"
95 | dev_path = os.path.join(directory, dev_name)
96 | if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
97 | dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
98 | print("Extracting tgz file %s" % dev_file)
99 | with tarfile.open(dev_file, "r:gz") as dev_tar:
100 | fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
101 | en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
102 | fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix.
103 | en_dev_file.name = dev_name + ".en"
104 | dev_tar.extract(fr_dev_file, directory)
105 | dev_tar.extract(en_dev_file, directory)
106 | return dev_path
107 |
108 |
109 | def get_dialog_dev_set_path(path):
110 | return os.path.join(path, 'chat_test')
111 |
112 |
113 | def basic_tokenizer(sentence, en_jieba=False):
114 | """Very basic tokenizer: split the sentence into a list of tokens."""
115 | if en_jieba:
116 | tokens = list([w.lower() for w in jieba.cut(sentence) if w not in [' ']])
117 | return tokens
118 | else:
119 | words = []
120 | for space_separated_fragment in sentence.strip().split():
121 | if isinstance(space_separated_fragment, str):
122 | space_separated_fragment = space_separated_fragment.encode()
123 | words.extend(_WORD_SPLIT.split(space_separated_fragment))
124 | return [w.lower() for w in words if w]
125 |
126 |
127 |
128 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
129 | tokenizer=None, normalize_digits=True):
130 | """Create vocabulary file (if it does not exist yet) from data file.
131 |
132 | Data file is assumed to contain one sentence per line. Each sentence is
133 | tokenized and digits are normalized (if normalize_digits is set).
134 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
135 | We write it to vocabulary_path in a one-token-per-line format, so that later
136 | token in the first line gets id=0, second line gets id=1, and so on.
137 |
138 | Args:
139 | vocabulary_path: path where the vocabulary will be created.
140 | data_path: data file that will be used to create vocabulary.
141 | max_vocabulary_size: limit on the size of the created vocabulary.
142 | tokenizer: a function to use to tokenize each data sentence;
143 | if None, basic_tokenizer will be used.
144 | normalize_digits: Boolean; if true, all digits are replaced by 0s.
145 | """
146 | if not gfile.Exists(vocabulary_path):
147 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
148 | vocab = {}
149 | with gfile.GFile(data_path, mode="rb") as f:
150 | counter = 0
151 | for line in f:
152 | # try:
153 | # line = line.decode('utf8', 'ignore')
154 | # except Exception as e:
155 | # print(e, line)
156 | # continue
157 | counter += 1
158 | if counter % 100000 == 0:
159 | print(" processing line %d" % counter)
160 | line = tf.compat.as_bytes(line)
161 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
162 | for w in tokens:
163 | word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w
164 | if word in vocab:
165 | vocab[word] += 1
166 | else:
167 | vocab[word] = 1
168 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
169 | if len(vocab_list) > max_vocabulary_size:
170 | vocab_list = vocab_list[:max_vocabulary_size]
171 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
172 | for w in vocab_list:
173 | vocab_file.write(w + b"\n")
174 |
175 |
176 | def initialize_vocabulary(vocabulary_path):
177 | """Initialize vocabulary from file.
178 |
179 | We assume the vocabulary is stored one-item-per-line, so a file:
180 | dog
181 | cat
182 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
183 | also return the reversed-vocabulary ["dog", "cat"].
184 |
185 | Args:
186 | vocabulary_path: path to the file containing the vocabulary.
187 |
188 | Returns:
189 | a pair: the vocabulary (a dictionary mapping string to integers), and
190 | the reversed vocabulary (a list, which reverses the vocabulary mapping).
191 |
192 | Raises:
193 | ValueError: if the provided vocabulary_path does not exist.
194 | """
195 | if gfile.Exists(vocabulary_path):
196 | rev_vocab = []
197 | with gfile.GFile(vocabulary_path, mode="rb") as f:
198 | rev_vocab.extend(f.readlines())
199 | rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab]
200 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
201 | return vocab, rev_vocab
202 | else:
203 | raise ValueError("Vocabulary file %s not found.", vocabulary_path)
204 |
205 |
206 | def sentence_to_token_ids(sentence, vocabulary,
207 | tokenizer=None, normalize_digits=True):
208 | """Convert a string to list of integers representing token-ids.
209 |
210 | For example, a sentence "I have a dog" may become tokenized into
211 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
212 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
213 |
214 | Args:
215 | sentence: the sentence in bytes format to convert to token-ids.
216 | vocabulary: a dictionary mapping tokens to integers.
217 | tokenizer: a function to use to tokenize each sentence;
218 | if None, basic_tokenizer will be used.
219 | normalize_digits: Boolean; if true, all digits are replaced by 0s.
220 |
221 | Returns:
222 | a list of integers, the token-ids for the sentence.
223 | """
224 |
225 | if tokenizer:
226 | words = tokenizer(sentence)
227 | else:
228 | words = basic_tokenizer(sentence)
229 | if not normalize_digits:
230 | return [vocabulary.get(w, UNK_ID) for w in words]
231 | # Normalize digits by 0 before looking words up in the vocabulary.
232 | return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words]
233 |
234 |
235 | def data_to_token_ids(data_path, target_path, vocabulary_path,
236 | tokenizer=None, normalize_digits=True):
237 | """Tokenize data file and turn into token-ids using given vocabulary file.
238 |
239 | This function loads data line-by-line from data_path, calls the above
240 | sentence_to_token_ids, and saves the result to target_path. See comment
241 | for sentence_to_token_ids on the details of token-ids format.
242 |
243 | Args:
244 | data_path: path to the data file in one-sentence-per-line format.
245 | target_path: path where the file with token-ids will be created.
246 | vocabulary_path: path to the vocabulary file.
247 | tokenizer: a function to use to tokenize each sentence;
248 | if None, basic_tokenizer will be used.
249 | normalize_digits: Boolean; if true, all digits are replaced by 0s.
250 | """
251 | if not gfile.Exists(target_path):
252 | print("Tokenizing data in %s" % data_path)
253 | vocab, _ = initialize_vocabulary(vocabulary_path)
254 | with gfile.GFile(data_path, mode="rb") as data_file:
255 | with gfile.GFile(target_path, mode="w") as tokens_file:
256 | counter = 0
257 | for line in data_file:
258 | # try:
259 | # line = line.decode('utf8', 'ignore')
260 | # except Exception as e:
261 | # print(e, line)
262 | # continue
263 | counter += 1
264 | if counter % 100000 == 0:
265 | print(" tokenizing line %d" % counter)
266 | token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab,
267 | tokenizer, normalize_digits)
268 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
269 |
270 |
271 | def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None):
272 | """Get WMT data into data_dir, create vocabularies and tokenize data.
273 |
274 | Args:
275 | data_dir: directory in which the data sets will be stored.
276 | en_vocabulary_size: size of the English vocabulary to create and use.
277 | fr_vocabulary_size: size of the French vocabulary to create and use.
278 | tokenizer: a function to use to tokenize each data sentence;
279 | if None, basic_tokenizer will be used.
280 |
281 | Returns:
282 | A tuple of 6 elements:
283 | (1) path to the token-ids for English training data-set,
284 | (2) path to the token-ids for French training data-set,
285 | (3) path to the token-ids for English development data-set,
286 | (4) path to the token-ids for French development data-set,
287 | (5) path to the English vocabulary file,
288 | (6) path to the French vocabulary file.
289 | """
290 | # Get wmt data to the specified directory.
291 | train_path = get_wmt_enfr_train_set(data_dir)
292 | dev_path = get_wmt_enfr_dev_set(data_dir)
293 |
294 | from_train_path = train_path + ".en"
295 | to_train_path = train_path + ".fr"
296 | from_dev_path = dev_path + ".en"
297 | to_dev_path = dev_path + ".fr"
298 | return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size,
299 | fr_vocabulary_size, tokenizer)
300 |
301 |
302 | def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size,
303 | to_vocabulary_size, tokenizer=None):
304 | """Preapre all necessary files that are required for the training.
305 |
306 | Args:
307 | data_dir: directory in which the data sets will be stored.
308 | from_train_path: path to the file that includes "from" training samples.
309 | to_train_path: path to the file that includes "to" training samples.
310 | from_dev_path: path to the file that includes "from" dev samples.
311 | to_dev_path: path to the file that includes "to" dev samples.
312 | from_vocabulary_size: size of the "from language" vocabulary to create and use.
313 | to_vocabulary_size: size of the "to language" vocabulary to create and use.
314 | tokenizer: a function to use to tokenize each data sentence;
315 | if None, basic_tokenizer will be used.
316 |
317 | Returns:
318 | A tuple of 6 elements:
319 | (1) path to the token-ids for "from language" training data-set,
320 | (2) path to the token-ids for "to language" training data-set,
321 | (3) path to the token-ids for "from language" development data-set,
322 | (4) path to the token-ids for "to language" development data-set,
323 | (5) path to the "from language" vocabulary file,
324 | (6) path to the "to language" vocabulary file.
325 | """
326 | # Create vocabularies of the appropriate sizes.
327 | to_vocab_path = os.path.join(data_dir, "vocab%d.to" % to_vocabulary_size)
328 | from_vocab_path = os.path.join(data_dir, "vocab%d.from" % from_vocabulary_size)
329 | create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer)
330 | create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer)
331 |
332 | # Create token ids for the training data.
333 | to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size)
334 | from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size)
335 | data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer)
336 | data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer)
337 |
338 | # Create token ids for the development data.
339 | to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size)
340 | from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size)
341 | data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer)
342 | data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer)
343 |
344 | return (from_train_ids_path, to_train_ids_path,
345 | from_dev_ids_path, to_dev_ids_path,
346 | from_vocab_path, to_vocab_path)
347 |
348 |
349 |
350 | def prepare_dialog_data(data_dir, vocabulary_size):
351 | """Get dialog data into data_dir, create vocabularies and tokenize data.
352 |
353 | Args:
354 | data_dir: directory in which the data sets will be stored.
355 | vocabulary_size: size of the English vocabulary to create and use.
356 |
357 | Returns:
358 | A tuple of 3 elements:
359 | (1) path to the token-ids for chat training data-set,
360 | (2) path to the token-ids for chat development data-set,
361 | (3) path to the chat vocabulary file
362 | """
363 | # Get dialog data to the specified directory.
364 | train_path = get_dialog_train_set_path(data_dir)
365 | dev_path = get_dialog_dev_set_path(data_dir)
366 |
367 | # Create vocabularies of the appropriate sizes.
368 | vocab_path = os.path.join(data_dir, "vocab%d.in" % vocabulary_size)
369 | create_vocabulary(vocab_path, train_path + ".in", vocabulary_size)
370 |
371 | # Create token ids for the training data.
372 | train_ids_path = train_path + (".ids%d.in" % vocabulary_size)
373 | data_to_token_ids(train_path + ".in", train_ids_path, vocab_path)
374 |
375 | # Create token ids for the development data.
376 | dev_ids_path = dev_path + (".ids%d.in" % vocabulary_size)
377 | data_to_token_ids(dev_path + ".in", dev_ids_path, vocab_path)
378 |
379 | return (train_ids_path, dev_ids_path, vocab_path)
380 |
381 |
382 | def read_data(tokenized_dialog_path, buckets, max_size=None, reversed=False):
383 | """Read data from source file and put into buckets.
384 |
385 | Args:
386 | source_path: path to the files with token-ids.
387 | max_size: maximum number of lines to read, all other will be ignored;
388 | if 0 or None, data files will be read completely (no limit).
389 |
390 | Returns:
391 | data_set: a list of length len(_buckets); data_set[n] contains a list of
392 | (source, target) pairs read from the provided data files that fit
393 | into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
394 | len(target) < _buckets[n][1]; source and target are lists of token-ids.
395 | """
396 | data_set = [[] for _ in buckets]
397 |
398 | with gfile.GFile(tokenized_dialog_path, mode="r") as fh:
399 | source, target = fh.readline(), fh.readline()
400 | if reversed:
401 | source, target = target, source # reverse Q-A pair, for bi-direction model
402 | counter = 0
403 | while source and target and (not max_size or counter < max_size):
404 | counter += 1
405 | if counter % 100000 == 0:
406 | print(" reading data line %d" % counter)
407 | sys.stdout.flush()
408 |
409 | source_ids = [int(x) for x in source.split()]
410 | target_ids = [int(x) for x in target.split()]
411 | target_ids.append(EOS_ID)
412 |
413 | for bucket_id, (source_size, target_size) in enumerate(buckets):
414 | if len(source_ids) < source_size and len(target_ids) < target_size:
415 | data_set[bucket_id].append([source_ids, target_ids])
416 | break
417 | source, target = fh.readline(), fh.readline()
418 | return data_set
419 |
--------------------------------------------------------------------------------
/lib/seq2seq_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Sequence-to-sequence model with an attention mechanism."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import random
23 | from math import log
24 | import numpy as np
25 | from six.moves import xrange # pylint: disable=redefined-builtin
26 | import tensorflow as tf
27 |
28 | from tensorflow.python.ops import control_flow_ops
29 | from lib import data_utils as data_utils
30 | from lib import seq2seq as tf_seq2seq
31 |
32 | class Seq2SeqModel(object):
33 | """Sequence-to-sequence model with attention and for multiple buckets.
34 |
35 | This class implements a multi-layer recurrent neural network as encoder,
36 | and an attention-based decoder. This is the same as the model described in
37 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details,
38 | or into the seq2seq library for complete model implementation.
39 | This class also allows to use GRU cells in addition to LSTM cells, and
40 | sampled softmax to handle large output vocabulary size. A single-layer
41 | version of this model, but with bi-directional encoder, was presented in
42 | http://arxiv.org/abs/1409.0473
43 | and sampled softmax is described in Section 3 of the following paper.
44 | http://arxiv.org/abs/1412.2007
45 | """
46 |
47 | def __init__(self,
48 | source_vocab_size,
49 | target_vocab_size,
50 | buckets,
51 | size,
52 | num_layers,
53 | max_gradient_norm,
54 | batch_size,
55 | learning_rate,
56 | learning_rate_decay_factor,
57 | use_lstm=False,
58 | num_samples=512,
59 | forward_only=False,
60 | scope_name='seq2seq',
61 | dtype=tf.float32):
62 | """Create the model.
63 |
64 | Args:
65 | source_vocab_size: size of the source vocabulary.
66 | target_vocab_size: size of the target vocabulary.
67 | buckets: a list of pairs (I, O), where I specifies maximum input length
68 | that will be processed in that bucket, and O specifies maximum output
69 | length. Training instances that have inputs longer than I or outputs
70 | longer than O will be pushed to the next bucket and padded accordingly.
71 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
72 | size: number of units in each layer of the model.
73 | num_layers: number of layers in the model.
74 | max_gradient_norm: gradients will be clipped to maximally this norm.
75 | batch_size: the size of the batches used during training;
76 | the model construction is independent of batch_size, so it can be
77 | changed after initialization if this is convenient, e.g., for decoding.
78 | learning_rate: learning rate to start with.
79 | learning_rate_decay_factor: decay learning rate by this much when needed.
80 | use_lstm: if true, we use LSTM cells instead of GRU cells.
81 | num_samples: number of samples for sampled softmax.
82 | forward_only: if set, we do not construct the backward pass in the model.
83 | dtype: the data type to use to store internal variables.
84 | """
85 | self.scope_name = scope_name
86 | with tf.variable_scope(self.scope_name):
87 | self.source_vocab_size = source_vocab_size
88 | self.target_vocab_size = target_vocab_size
89 | self.buckets = buckets
90 | self.batch_size = batch_size
91 | self.learning_rate = tf.Variable(
92 | float(learning_rate), trainable=False, dtype=dtype)
93 | self.learning_rate_decay_op = self.learning_rate.assign(
94 | self.learning_rate * learning_rate_decay_factor)
95 | self.global_step = tf.Variable(0, trainable=False)
96 | self.dummy_dialogs = [] # [TODO] load dummy sentences
97 |
98 | # If we use sampled softmax, we need an output projection.
99 | output_projection = None
100 | softmax_loss_function = None
101 | # Sampled softmax only makes sense if we sample less than vocabulary size.
102 | if num_samples > 0 and num_samples < self.target_vocab_size:
103 | w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype)
104 | w = tf.transpose(w_t)
105 | b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
106 | output_projection = (w, b)
107 |
108 | def sampled_loss(labels, inputs):
109 | labels = tf.reshape(labels, [-1, 1])
110 | # We need to compute the sampled_softmax_loss using 32bit floats to
111 | # avoid numerical instabilities.
112 | local_w_t = tf.cast(w_t, tf.float32)
113 | local_b = tf.cast(b, tf.float32)
114 | local_inputs = tf.cast(inputs, tf.float32)
115 | return tf.cast(
116 | tf.nn.sampled_softmax_loss(
117 | weights=local_w_t,
118 | biases=local_b,
119 | labels=labels,
120 | inputs=local_inputs,
121 | num_sampled=num_samples,
122 | num_classes=self.target_vocab_size),
123 | dtype)
124 | softmax_loss_function = sampled_loss
125 |
126 | # Create the internal multi-layer cell for our RNN.
127 | def single_cell():
128 | return tf.contrib.rnn.GRUCell(size)
129 | if use_lstm:
130 | def single_cell():
131 | return tf.contrib.rnn.BasicLSTMCell(size)
132 | cell = single_cell()
133 | if num_layers > 1:
134 | cell = tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)])
135 |
136 | # The seq2seq function: we use embedding for the input and attention.
137 | def seq2seq_f(encoder_inputs, decoder_inputs, feed_previous):
138 | return tf_seq2seq.embedding_attention_seq2seq(
139 | encoder_inputs,
140 | decoder_inputs,
141 | cell,
142 | num_encoder_symbols=source_vocab_size,
143 | num_decoder_symbols=target_vocab_size,
144 | embedding_size=size,
145 | output_projection=output_projection,
146 | feed_previous=feed_previous, #do_decode,
147 | dtype=dtype)
148 |
149 | # Feeds for inputs.
150 | self.encoder_inputs = []
151 | self.decoder_inputs = []
152 | self.target_weights = []
153 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one.
154 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
155 | name="encoder{0}".format(i)))
156 | for i in xrange(buckets[-1][1] + 1):
157 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
158 | name="decoder{0}".format(i)))
159 | self.target_weights.append(tf.placeholder(dtype, shape=[None],
160 | name="weight{0}".format(i)))
161 |
162 | # Our targets are decoder inputs shifted by one.
163 | targets = [self.decoder_inputs[i + 1]
164 | for i in xrange(len(self.decoder_inputs) - 1)]
165 |
166 | # for reinforcement learning
167 | # self.force_dec_input = tf.placeholder(tf.bool, name="force_dec_input")
168 | # self.en_output_proj = tf.placeholder(tf.bool, name="en_output_proj")
169 |
170 | # Training outputs and losses.
171 | if forward_only:
172 | self.outputs, self.losses, self.encoder_state = tf_seq2seq.model_with_buckets(
173 | self.encoder_inputs, self.decoder_inputs, targets,
174 | self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
175 | softmax_loss_function=softmax_loss_function)
176 | # If we use output projection, we need to project outputs for decoding.
177 | if output_projection is not None:
178 | for b in xrange(len(buckets)):
179 | self.outputs[b] = [
180 | tf.matmul(output, output_projection[0]) + output_projection[1]
181 | for output in self.outputs[b]
182 | ]
183 | else:
184 | self.outputs, self.losses, self.encoder_state = tf_seq2seq.model_with_buckets(
185 | self.encoder_inputs, self.decoder_inputs, targets,
186 | self.target_weights, buckets,
187 | lambda x, y: seq2seq_f(x, y, False),
188 | softmax_loss_function=softmax_loss_function)
189 |
190 | # # Training outputs and losses.
191 | # self.outputs, self.losses, self.encoder_state = tf_seq2seq.model_with_buckets(
192 | # self.encoder_inputs, self.decoder_inputs, targets,
193 | # self.target_weights, buckets,
194 | # lambda x, y: seq2seq_f(x, y, tf.where(self.force_dec_input, False, True)),
195 | # softmax_loss_function=softmax_loss_function
196 | # )
197 | # # If we use output projection, we need to project outputs for decoding.
198 | # # if output_projection is not None:
199 | # for b in xrange(len(buckets)):
200 | # self.outputs[b] = [
201 | # control_flow_ops.cond(
202 | # self.en_output_proj,
203 | # lambda: tf.matmul(output, output_projection[0]) + output_projection[1],
204 | # lambda: output
205 | # )
206 | # for output in self.outputs[b]
207 | # ]
208 |
209 | # Gradients and SGD update operation for training the model.
210 | params = tf.trainable_variables()
211 | # if not forward_only:
212 | self.gradient_norms = []
213 | self.updates = []
214 | self.advantage = [tf.placeholder(tf.float32, name="advantage_%i" % i) for i in xrange(len(buckets))]
215 | opt = tf.train.GradientDescentOptimizer(self.learning_rate)
216 | for b in xrange(len(buckets)):
217 | # self.losses[b] = tf.subtract(self.losses[b], self.advantage[b])
218 | gradients = tf.gradients(self.losses[b], params)
219 | clipped_gradients, norm = tf.clip_by_global_norm(gradients,
220 | max_gradient_norm)
221 | self.gradient_norms.append(norm)
222 | self.updates.append(opt.apply_gradients(
223 | zip(clipped_gradients, params), global_step=self.global_step))
224 |
225 | all_variables = tf.global_variables()
226 | all_variables = [k for k in tf.global_variables() if k.name.startswith(self.scope_name)]
227 | self.saver = tf.train.Saver(all_variables)
228 |
229 |
230 | def step(self, session, encoder_inputs, decoder_inputs, target_weights,
231 | bucket_id, forward_only, force_dec_input=False, advantage=None):
232 |
233 | """Run a step of the model feeding the given inputs.
234 |
235 | Args:
236 | session: tensorflow session to use.
237 | encoder_inputs: list of numpy int vectors to feed as encoder inputs.
238 | decoder_inputs: list of numpy int vectors to feed as decoder inputs.
239 | target_weights: list of numpy float vectors to feed as target weights.
240 | bucket_id: which bucket of the model to use.
241 | forward_only: whether to do the backward step or only forward.
242 |
243 | Returns:
244 | A triple consisting of gradient norm (or None if we did not do backward),
245 | average perplexity, and the outputs.
246 |
247 | Raises:
248 | ValueError: if length of encoder_inputs, decoder_inputs, or
249 | target_weights disagrees with bucket size for the specified bucket_id.
250 | """
251 | # Check if the sizes match.
252 | encoder_size, decoder_size = self.buckets[bucket_id]
253 | if len(encoder_inputs) != encoder_size:
254 | raise ValueError("Encoder length must be equal to the one in bucket,"
255 | " %d != %d." % (len(encoder_inputs), encoder_size))
256 | if len(decoder_inputs) != decoder_size:
257 | raise ValueError("Decoder length must be equal to the one in bucket,"
258 | " %d != %d." % (len(decoder_inputs), decoder_size))
259 | if len(target_weights) != decoder_size:
260 | raise ValueError("Weights length must be equal to the one in bucket,"
261 | " %d != %d." % (len(target_weights), decoder_size))
262 |
263 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
264 | input_feed = {
265 | # self.force_dec_input.name: force_dec_input,
266 | # self.en_output_proj.name: forward_only,
267 | }
268 | for l in xrange(len(self.buckets)):
269 | input_feed[self.advantage[l].name] = advantage[l] if advantage else 0
270 | for l in xrange(encoder_size):
271 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
272 | for l in xrange(decoder_size):
273 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
274 | input_feed[self.target_weights[l].name] = target_weights[l]
275 |
276 | # Since our targets are decoder inputs shifted by one, we need one more.
277 | last_target = self.decoder_inputs[decoder_size].name
278 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
279 |
280 |
281 | # Output feed: depends on whether we do a backward step or not.
282 | if not forward_only:
283 | output_feed = [self.updates[bucket_id], # Update Op that does SGD.
284 | self.gradient_norms[bucket_id], # Gradient norm.
285 | self.losses[bucket_id]] # Loss for this batch.
286 | else:
287 | output_feed = [self.encoder_state[bucket_id],
288 | self.losses[bucket_id]] # Loss for this batch.
289 | for l in xrange(decoder_size): # Output logits.
290 | output_feed.append(self.outputs[bucket_id][l])
291 |
292 | outputs = session.run(output_feed, input_feed)
293 | if not forward_only:
294 | return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
295 | else:
296 | return outputs[0], outputs[1], outputs[2:] # encoder_state, loss, outputs.
297 |
298 |
299 | # # Output feed: depends on whether we do a backward step or not.
300 | # if training: # normal training
301 | # output_feed = [self.updates[bucket_id], # Update Op that does SGD.
302 | # self.gradient_norms[bucket_id], # Gradient norm.
303 | # self.losses[bucket_id]] # Loss for this batch.
304 | # else: # testing or reinforcement learning
305 | # output_feed = [self.encoder_state[bucket_id], self.losses[bucket_id]] # Loss for this batch.
306 | # for l in xrange(decoder_size): # Output logits.
307 | # output_feed.append(self.outputs[bucket_id][l])
308 |
309 | # outputs = session.run(output_feed, input_feed)
310 | # if training:
311 | # return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
312 | # else:
313 | # return outputs[0], outputs[1], outputs[2:] # encoder_state, loss, outputs.
314 |
315 |
316 |
317 | def step_rf(self, args, session, encoder_inputs, decoder_inputs, target_weights,
318 | bucket_id, rev_vocab=None, debug=True):
319 |
320 | # initialize
321 | init_inputs = [encoder_inputs, decoder_inputs, target_weights, bucket_id]
322 | sent_max_length = args.buckets[-1][0]
323 | resp_tokens, resp_txt = self.logits2tokens(encoder_inputs, rev_vocab, sent_max_length, reverse=True)
324 | if debug: print("[INPUT]:", resp_txt)
325 |
326 | # Initialize
327 | ep_rewards, ep_step_loss, enc_states = [], [], []
328 | ep_encoder_inputs, ep_target_weights, ep_bucket_id = [], [], []
329 |
330 | # [Episode] per episode = n steps, until break
331 | while True:
332 | #----[Step]----------------------------------------
333 | encoder_state, step_loss, output_logits = self.step(session, encoder_inputs, decoder_inputs, target_weights,
334 | bucket_id, training=False, force_dec_input=False)
335 |
336 | # memorize inputs for reproducing curriculum with adjusted losses
337 | ep_encoder_inputs.append(encoder_inputs)
338 | ep_target_weights.append(target_weights)
339 | ep_bucket_id.append(bucket_id)
340 | ep_step_loss.append(step_loss)
341 | enc_states_vec = np.reshape(np.squeeze(encoder_state, axis=1), (-1))
342 | enc_states.append(enc_states_vec)
343 |
344 | # process response
345 | resp_tokens, resp_txt = self.logits2tokens(output_logits, rev_vocab, sent_max_length)
346 | if debug: print("[RESP]: (%.4f) %s" % (step_loss, resp_txt))
347 |
348 | # prepare for next dialogue
349 | bucket_id = min([b for b in range(len(args.buckets)) if args.buckets[b][0] > len(resp_tokens)])
350 | feed_data = {bucket_id: [(resp_tokens, [])]}
351 | encoder_inputs, decoder_inputs, target_weights = self.get_batch(feed_data, bucket_id)
352 |
353 | #----[Reward]----------------------------------------
354 | # r1: Ease of answering
355 | r1 = [self.logProb(session, args.buckets, resp_tokens, d) for d in self.dummy_dialogs]
356 | r1 = -np.mean(r1) if r1 else 0
357 |
358 | # r2: Information Flow
359 | if len(enc_states) < 2:
360 | r2 = 0
361 | else:
362 | vec_a, vec_b = enc_states[-2], enc_states[-1]
363 | r2 = sum(vec_a*vec_b) / sum(abs(vec_a)*abs(vec_b))
364 | r2 = -log(r2)
365 |
366 | # r3: Semantic Coherence
367 | r3 = -self.logProb(session, args.buckets, resp_tokens, ep_encoder_inputs[-1])
368 |
369 | # Episode total reward
370 | R = 0.25*r1 + 0.25*r2 + 0.5*r3
371 | rewards.append(R)
372 | #----------------------------------------------------
373 | if (resp_txt in self.dummy_dialogs) or (len(resp_tokens) <= 3) or (encoder_inputs in ep_encoder_inputs):
374 | break # check if dialog ended
375 |
376 | # gradient decent according to batch rewards
377 | rto = (max(ep_step_loss) - min(ep_step_loss)) / (max(ep_rewards) - min(ep_rewards))
378 | advantage = [mp.mean(ep_rewards)*rto] * len(args.buckets)
379 | _, step_loss, _ = self.step(session, init_inputs[0], init_inputs[1], init_inputs[2], init_inputs[3],
380 | training=True, force_dec_input=False, advantage=advantage)
381 |
382 | return None, step_loss, None
383 |
384 |
385 |
386 | # log(P(b|a)), the conditional likelyhood
387 | def logProb(self, session, buckets, tokens_a, tokens_b):
388 | def softmax(x):
389 | return np.exp(x) / np.sum(np.exp(x), axis=0)
390 |
391 | # prepare for next dialogue
392 | bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(tokens_a)])
393 | feed_data = {bucket_id: [(tokens_a, tokens_b)]}
394 | encoder_inputs, decoder_inputs, target_weights = self.get_batch(feed_data, bucket_id)
395 |
396 | # step
397 | _, _, output_logits = self.step(session, encoder_inputs, decoder_inputs, target_weights,
398 | bucket_id, training=False, force_dec_input=True)
399 |
400 | # p = log(P(b|a)) / N
401 | p = 1
402 | for t, logit in zip(tokens_b, output_logits):
403 | p *= softmax(logit[0])[t]
404 | p = log(p) / len(tokens_b)
405 | return p
406 |
407 |
408 | def logits2tokens(self, logits, rev_vocab, sent_max_length=None, reverse=False):
409 | if reverse:
410 | tokens = [t[0] for t in reversed(logits)]
411 | else:
412 | tokens = [int(np.argmax(t, axis=1)) for t in logits]
413 | if data_utils.EOS_ID in tokens:
414 | eos = tokens.index(data_utils.EOS_ID)
415 | tokens = tokens[:eos]
416 | txt = [rev_vocab[t] for t in tokens]
417 | if sent_max_length:
418 | tokens, txt = tokens[:sent_max_length], txt[:sent_max_length]
419 | return tokens, txt
420 |
421 |
422 | def discount_rewards(self, r, gamma=0.99):
423 | """ take 1D float array of rewards and compute discounted reward """
424 | discounted_r = np.zeros_like(r)
425 | running_add = 0
426 | for t in reversed(xrange(0, r.size)):
427 | running_add = running_add * gamma + r[t]
428 | discounted_r[t] = running_add
429 | return discounted_r
430 |
431 |
432 | def get_batch(self, data, bucket_id):
433 | """Get a random batch of data from the specified bucket, prepare for step.
434 |
435 | To feed data in step(..) it must be a list of batch-major vectors, while
436 | data here contains single length-major cases. So the main logic of this
437 | function is to re-index data cases to be in the proper format for feeding.
438 |
439 | Args:
440 | data: a tuple of size len(self.buckets) in which each element contains
441 | lists of pairs of input and output data that we use to create a batch.
442 | bucket_id: integer, which bucket to get the batch for.
443 |
444 | Returns:
445 | The triple (encoder_inputs, decoder_inputs, target_weights) for
446 | the constructed batch that has the proper format to call step(...) later.
447 | """
448 | encoder_size, decoder_size = self.buckets[bucket_id]
449 | encoder_inputs, decoder_inputs = [], []
450 |
451 | # Get a random batch of encoder and decoder inputs from data,
452 | # pad them if needed, reverse encoder inputs and add GO to decoder.
453 | for _ in xrange(self.batch_size):
454 | encoder_input, decoder_input = random.choice(data[bucket_id])
455 |
456 | # Encoder inputs are padded and then reversed.
457 | encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
458 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
459 |
460 | # Decoder inputs get an extra "GO" symbol, and are padded then.
461 | decoder_pad_size = decoder_size - len(decoder_input) - 1
462 | decoder_inputs.append([data_utils.GO_ID] + decoder_input +
463 | [data_utils.PAD_ID] * decoder_pad_size)
464 |
465 | # Now we create batch-major vectors from the data selected above.
466 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
467 |
468 | # Batch encoder inputs are just re-indexed encoder_inputs.
469 | for length_idx in xrange(encoder_size):
470 | batch_encoder_inputs.append(
471 | np.array([encoder_inputs[batch_idx][length_idx]
472 | for batch_idx in xrange(self.batch_size)], dtype=np.int32))
473 |
474 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights.
475 | for length_idx in xrange(decoder_size):
476 | batch_decoder_inputs.append(
477 | np.array([decoder_inputs[batch_idx][length_idx]
478 | for batch_idx in xrange(self.batch_size)], dtype=np.int32))
479 |
480 | # Create target_weights to be 0 for targets that are padding.
481 | batch_weight = np.ones(self.batch_size, dtype=np.float32)
482 | for batch_idx in xrange(self.batch_size):
483 | # We set weight to 0 if the corresponding target is a PAD symbol.
484 | # The corresponding target is decoder_input shifted by 1 forward.
485 | if length_idx < decoder_size - 1:
486 | target = decoder_inputs[batch_idx][length_idx + 1]
487 | if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
488 | batch_weight[batch_idx] = 0.0
489 | batch_weights.append(batch_weight)
490 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights
491 |
--------------------------------------------------------------------------------
/ref/seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Library for creating sequence-to-sequence models in TensorFlow.
16 |
17 | Sequence-to-sequence recurrent neural networks can learn complex functions
18 | that map input sequences to output sequences. These models yield very good
19 | results on a number of tasks, such as speech recognition, parsing, machine
20 | translation, or even constructing automated replies to emails.
21 |
22 | Before using this module, it is recommended to read the TensorFlow tutorial
23 | on sequence-to-sequence models. It explains the basic concepts of this module
24 | and shows an end-to-end example of how to build a translation model.
25 | https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html
26 |
27 | Here is an overview of functions available in this module. They all use
28 | a very similar interface, so after reading the above tutorial and using
29 | one of them, others should be easy to substitute.
30 |
31 | * Full sequence-to-sequence models.
32 | - basic_rnn_seq2seq: The most basic RNN-RNN model.
33 | - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights.
34 | - embedding_rnn_seq2seq: The basic model with input embedding.
35 | - embedding_tied_rnn_seq2seq: The tied model with input embedding.
36 | - embedding_attention_seq2seq: Advanced model with input embedding and
37 | the neural attention mechanism; recommended for complex tasks.
38 |
39 | * Multi-task sequence-to-sequence models.
40 | - one2many_rnn_seq2seq: The embedding model with multiple decoders.
41 |
42 | * Decoders (when you write your own encoder, you can use these to decode;
43 | e.g., if you want to write a model that generates captions for images).
44 | - rnn_decoder: The basic decoder based on a pure RNN.
45 | - attention_decoder: A decoder that uses the attention mechanism.
46 |
47 | * Losses.
48 | - sequence_loss: Loss for a sequence model returning average log-perplexity.
49 | - sequence_loss_by_example: As above, but not averaging over all examples.
50 |
51 | * model_with_buckets: A convenience function to create models with bucketing
52 | (see the tutorial above for an explanation of why and how to use it).
53 | """
54 |
55 | from __future__ import absolute_import
56 | from __future__ import division
57 | from __future__ import print_function
58 |
59 | import copy
60 |
61 | # We disable pylint because we need python3 compatibility.
62 | from six.moves import xrange # pylint: disable=redefined-builtin
63 | from six.moves import zip # pylint: disable=redefined-builtin
64 |
65 | from tensorflow.contrib.rnn.python.ops import core_rnn
66 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell
67 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
68 | from tensorflow.python.framework import dtypes
69 | from tensorflow.python.framework import ops
70 | from tensorflow.python.ops import array_ops
71 | from tensorflow.python.ops import control_flow_ops
72 | from tensorflow.python.ops import embedding_ops
73 | from tensorflow.python.ops import math_ops
74 | from tensorflow.python.ops import nn_ops
75 | from tensorflow.python.ops import variable_scope
76 | from tensorflow.python.util import nest
77 |
78 | # TODO(ebrevdo): Remove once _linear is fully deprecated.
79 | linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
80 |
81 |
82 | def _extract_argmax_and_embed(embedding,
83 | output_projection=None,
84 | update_embedding=True):
85 | """Get a loop_function that extracts the previous symbol and embeds it.
86 |
87 | Args:
88 | embedding: embedding tensor for symbols.
89 | output_projection: None or a pair (W, B). If provided, each fed previous
90 | output will first be multiplied by W and added B.
91 | update_embedding: Boolean; if False, the gradients will not propagate
92 | through the embeddings.
93 |
94 | Returns:
95 | A loop function.
96 | """
97 |
98 | def loop_function(prev, _):
99 | if output_projection is not None:
100 | prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
101 | prev_symbol = math_ops.argmax(prev, 1)
102 | # Note that gradients will not propagate through the second parameter of
103 | # embedding_lookup.
104 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
105 | if not update_embedding:
106 | emb_prev = array_ops.stop_gradient(emb_prev)
107 | return emb_prev
108 |
109 | return loop_function
110 |
111 |
112 | def rnn_decoder(decoder_inputs,
113 | initial_state,
114 | cell,
115 | loop_function=None,
116 | scope=None):
117 | """RNN decoder for the sequence-to-sequence model.
118 |
119 | Args:
120 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
121 | initial_state: 2D Tensor with shape [batch_size x cell.state_size].
122 | cell: core_rnn_cell.RNNCell defining the cell function and size.
123 | loop_function: If not None, this function will be applied to the i-th output
124 | in order to generate the i+1-st input, and decoder_inputs will be ignored,
125 | except for the first element ("GO" symbol). This can be used for decoding,
126 | but also for training to emulate http://arxiv.org/abs/1506.03099.
127 | Signature -- loop_function(prev, i) = next
128 | * prev is a 2D Tensor of shape [batch_size x output_size],
129 | * i is an integer, the step number (when advanced control is needed),
130 | * next is a 2D Tensor of shape [batch_size x input_size].
131 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder".
132 |
133 | Returns:
134 | A tuple of the form (outputs, state), where:
135 | outputs: A list of the same length as decoder_inputs of 2D Tensors with
136 | shape [batch_size x output_size] containing generated outputs.
137 | state: The state of each cell at the final time-step.
138 | It is a 2D Tensor of shape [batch_size x cell.state_size].
139 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and
140 | states can be the same. They are different for LSTM cells though.)
141 | """
142 | with variable_scope.variable_scope(scope or "rnn_decoder"):
143 | state = initial_state
144 | outputs = []
145 | prev = None
146 | for i, inp in enumerate(decoder_inputs):
147 | if loop_function is not None and prev is not None:
148 | with variable_scope.variable_scope("loop_function", reuse=True):
149 | inp = loop_function(prev, i)
150 | if i > 0:
151 | variable_scope.get_variable_scope().reuse_variables()
152 | output, state = cell(inp, state)
153 | outputs.append(output)
154 | if loop_function is not None:
155 | prev = output
156 | return outputs, state
157 |
158 |
159 | def basic_rnn_seq2seq(encoder_inputs,
160 | decoder_inputs,
161 | cell,
162 | dtype=dtypes.float32,
163 | scope=None):
164 | """Basic RNN sequence-to-sequence model.
165 |
166 | This model first runs an RNN to encode encoder_inputs into a state vector,
167 | then runs decoder, initialized with the last encoder state, on decoder_inputs.
168 | Encoder and decoder use the same RNN cell type, but don't share parameters.
169 |
170 | Args:
171 | encoder_inputs: A list of 2D Tensors [batch_size x input_size].
172 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
173 | cell: core_rnn_cell.RNNCell defining the cell function and size.
174 | dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
175 | scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
176 |
177 | Returns:
178 | A tuple of the form (outputs, state), where:
179 | outputs: A list of the same length as decoder_inputs of 2D Tensors with
180 | shape [batch_size x output_size] containing the generated outputs.
181 | state: The state of each decoder cell in the final time-step.
182 | It is a 2D Tensor of shape [batch_size x cell.state_size].
183 | """
184 | with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
185 | enc_cell = copy.deepcopy(cell)
186 | _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
187 | return rnn_decoder(decoder_inputs, enc_state, cell)
188 |
189 |
190 | def tied_rnn_seq2seq(encoder_inputs,
191 | decoder_inputs,
192 | cell,
193 | loop_function=None,
194 | dtype=dtypes.float32,
195 | scope=None):
196 | """RNN sequence-to-sequence model with tied encoder and decoder parameters.
197 |
198 | This model first runs an RNN to encode encoder_inputs into a state vector, and
199 | then runs decoder, initialized with the last encoder state, on decoder_inputs.
200 | Encoder and decoder use the same RNN cell and share parameters.
201 |
202 | Args:
203 | encoder_inputs: A list of 2D Tensors [batch_size x input_size].
204 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
205 | cell: core_rnn_cell.RNNCell defining the cell function and size.
206 | loop_function: If not None, this function will be applied to i-th output
207 | in order to generate i+1-th input, and decoder_inputs will be ignored,
208 | except for the first element ("GO" symbol), see rnn_decoder for details.
209 | dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
210 | scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".
211 |
212 | Returns:
213 | A tuple of the form (outputs, state), where:
214 | outputs: A list of the same length as decoder_inputs of 2D Tensors with
215 | shape [batch_size x output_size] containing the generated outputs.
216 | state: The state of each decoder cell in each time-step. This is a list
217 | with length len(decoder_inputs) -- one item for each time-step.
218 | It is a 2D Tensor of shape [batch_size x cell.state_size].
219 | """
220 | with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
221 | scope = scope or "tied_rnn_seq2seq"
222 | _, enc_state = core_rnn.static_rnn(
223 | cell, encoder_inputs, dtype=dtype, scope=scope)
224 | variable_scope.get_variable_scope().reuse_variables()
225 | return rnn_decoder(
226 | decoder_inputs,
227 | enc_state,
228 | cell,
229 | loop_function=loop_function,
230 | scope=scope)
231 |
232 |
233 | def embedding_rnn_decoder(decoder_inputs,
234 | initial_state,
235 | cell,
236 | num_symbols,
237 | embedding_size,
238 | output_projection=None,
239 | feed_previous=False,
240 | update_embedding_for_previous=True,
241 | scope=None):
242 | """RNN decoder with embedding and a pure-decoding option.
243 |
244 | Args:
245 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
246 | initial_state: 2D Tensor [batch_size x cell.state_size].
247 | cell: core_rnn_cell.RNNCell defining the cell function.
248 | num_symbols: Integer, how many symbols come into the embedding.
249 | embedding_size: Integer, the length of the embedding vector for each symbol.
250 | output_projection: None or a pair (W, B) of output projection weights and
251 | biases; W has shape [output_size x num_symbols] and B has
252 | shape [num_symbols]; if provided and feed_previous=True, each fed
253 | previous output will first be multiplied by W and added B.
254 | feed_previous: Boolean; if True, only the first of decoder_inputs will be
255 | used (the "GO" symbol), and all other decoder inputs will be generated by:
256 | next = embedding_lookup(embedding, argmax(previous_output)),
257 | In effect, this implements a greedy decoder. It can also be used
258 | during training to emulate http://arxiv.org/abs/1506.03099.
259 | If False, decoder_inputs are used as given (the standard decoder case).
260 | update_embedding_for_previous: Boolean; if False and feed_previous=True,
261 | only the embedding for the first symbol of decoder_inputs (the "GO"
262 | symbol) will be updated by back propagation. Embeddings for the symbols
263 | generated from the decoder itself remain unchanged. This parameter has
264 | no effect if feed_previous=False.
265 | scope: VariableScope for the created subgraph; defaults to
266 | "embedding_rnn_decoder".
267 |
268 | Returns:
269 | A tuple of the form (outputs, state), where:
270 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The
271 | output is of shape [batch_size x cell.output_size] when
272 | output_projection is not None (and represents the dense representation
273 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
274 | when output_projection is None.
275 | state: The state of each decoder cell in each time-step. This is a list
276 | with length len(decoder_inputs) -- one item for each time-step.
277 | It is a 2D Tensor of shape [batch_size x cell.state_size].
278 |
279 | Raises:
280 | ValueError: When output_projection has the wrong shape.
281 | """
282 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope:
283 | if output_projection is not None:
284 | dtype = scope.dtype
285 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
286 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
287 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
288 | proj_biases.get_shape().assert_is_compatible_with([num_symbols])
289 |
290 | embedding = variable_scope.get_variable("embedding",
291 | [num_symbols, embedding_size])
292 | loop_function = _extract_argmax_and_embed(
293 | embedding, output_projection,
294 | update_embedding_for_previous) if feed_previous else None
295 | emb_inp = (embedding_ops.embedding_lookup(embedding, i)
296 | for i in decoder_inputs)
297 | return rnn_decoder(
298 | emb_inp, initial_state, cell, loop_function=loop_function)
299 |
300 |
301 | def embedding_rnn_seq2seq(encoder_inputs,
302 | decoder_inputs,
303 | cell,
304 | num_encoder_symbols,
305 | num_decoder_symbols,
306 | embedding_size,
307 | output_projection=None,
308 | feed_previous=False,
309 | dtype=None,
310 | scope=None):
311 | """Embedding RNN sequence-to-sequence model.
312 |
313 | This model first embeds encoder_inputs by a newly created embedding (of shape
314 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode
315 | embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs
316 | by another newly created embedding (of shape [num_decoder_symbols x
317 | input_size]). Then it runs RNN decoder, initialized with the last
318 | encoder state, on embedded decoder_inputs.
319 |
320 | Args:
321 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
322 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
323 | cell: core_rnn_cell.RNNCell defining the cell function and size.
324 | num_encoder_symbols: Integer; number of symbols on the encoder side.
325 | num_decoder_symbols: Integer; number of symbols on the decoder side.
326 | embedding_size: Integer, the length of the embedding vector for each symbol.
327 | output_projection: None or a pair (W, B) of output projection weights and
328 | biases; W has shape [output_size x num_decoder_symbols] and B has
329 | shape [num_decoder_symbols]; if provided and feed_previous=True, each
330 | fed previous output will first be multiplied by W and added B.
331 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
332 | of decoder_inputs will be used (the "GO" symbol), and all other decoder
333 | inputs will be taken from previous outputs (as in embedding_rnn_decoder).
334 | If False, decoder_inputs are used as given (the standard decoder case).
335 | dtype: The dtype of the initial state for both the encoder and encoder
336 | rnn cells (default: tf.float32).
337 | scope: VariableScope for the created subgraph; defaults to
338 | "embedding_rnn_seq2seq"
339 |
340 | Returns:
341 | A tuple of the form (outputs, state), where:
342 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The
343 | output is of shape [batch_size x cell.output_size] when
344 | output_projection is not None (and represents the dense representation
345 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
346 | when output_projection is None.
347 | state: The state of each decoder cell in each time-step. This is a list
348 | with length len(decoder_inputs) -- one item for each time-step.
349 | It is a 2D Tensor of shape [batch_size x cell.state_size].
350 | """
351 | with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
352 | if dtype is not None:
353 | scope.set_dtype(dtype)
354 | else:
355 | dtype = scope.dtype
356 |
357 | # Encoder.
358 | encoder_cell = copy.deepcopy(cell)
359 | encoder_cell = core_rnn_cell.EmbeddingWrapper(
360 | encoder_cell,
361 | embedding_classes=num_encoder_symbols,
362 | embedding_size=embedding_size)
363 | _, encoder_state = core_rnn.static_rnn(
364 | encoder_cell, encoder_inputs, dtype=dtype)
365 |
366 | # Decoder.
367 | if output_projection is None:
368 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
369 |
370 | if isinstance(feed_previous, bool):
371 | return embedding_rnn_decoder(
372 | decoder_inputs,
373 | encoder_state,
374 | cell,
375 | num_decoder_symbols,
376 | embedding_size,
377 | output_projection=output_projection,
378 | feed_previous=feed_previous)
379 |
380 | # If feed_previous is a Tensor, we construct 2 graphs and use cond.
381 | def decoder(feed_previous_bool):
382 | reuse = None if feed_previous_bool else True
383 | with variable_scope.variable_scope(
384 | variable_scope.get_variable_scope(), reuse=reuse) as scope:
385 | outputs, state = embedding_rnn_decoder(
386 | decoder_inputs,
387 | encoder_state,
388 | cell,
389 | num_decoder_symbols,
390 | embedding_size,
391 | output_projection=output_projection,
392 | feed_previous=feed_previous_bool,
393 | update_embedding_for_previous=False)
394 | state_list = [state]
395 | if nest.is_sequence(state):
396 | state_list = nest.flatten(state)
397 | return outputs + state_list
398 |
399 | outputs_and_state = control_flow_ops.cond(feed_previous,
400 | lambda: decoder(True),
401 | lambda: decoder(False))
402 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
403 | state_list = outputs_and_state[outputs_len:]
404 | state = state_list[0]
405 | if nest.is_sequence(encoder_state):
406 | state = nest.pack_sequence_as(
407 | structure=encoder_state, flat_sequence=state_list)
408 | return outputs_and_state[:outputs_len], state
409 |
410 |
411 | def embedding_tied_rnn_seq2seq(encoder_inputs,
412 | decoder_inputs,
413 | cell,
414 | num_symbols,
415 | embedding_size,
416 | num_decoder_symbols=None,
417 | output_projection=None,
418 | feed_previous=False,
419 | dtype=None,
420 | scope=None):
421 | """Embedding RNN sequence-to-sequence model with tied (shared) parameters.
422 |
423 | This model first embeds encoder_inputs by a newly created embedding (of shape
424 | [num_symbols x input_size]). Then it runs an RNN to encode embedded
425 | encoder_inputs into a state vector. Next, it embeds decoder_inputs using
426 | the same embedding. Then it runs RNN decoder, initialized with the last
427 | encoder state, on embedded decoder_inputs. The decoder output is over symbols
428 | from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it
429 | is over 0 to num_symbols - 1.
430 |
431 | Args:
432 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
433 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
434 | cell: core_rnn_cell.RNNCell defining the cell function and size.
435 | num_symbols: Integer; number of symbols for both encoder and decoder.
436 | embedding_size: Integer, the length of the embedding vector for each symbol.
437 | num_decoder_symbols: Integer; number of output symbols for decoder. If
438 | provided, the decoder output is over symbols 0 to num_decoder_symbols - 1.
439 | Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that
440 | this assumes that the vocabulary is set up such that the first
441 | num_decoder_symbols of num_symbols are part of decoding.
442 | output_projection: None or a pair (W, B) of output projection weights and
443 | biases; W has shape [output_size x num_symbols] and B has
444 | shape [num_symbols]; if provided and feed_previous=True, each
445 | fed previous output will first be multiplied by W and added B.
446 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
447 | of decoder_inputs will be used (the "GO" symbol), and all other decoder
448 | inputs will be taken from previous outputs (as in embedding_rnn_decoder).
449 | If False, decoder_inputs are used as given (the standard decoder case).
450 | dtype: The dtype to use for the initial RNN states (default: tf.float32).
451 | scope: VariableScope for the created subgraph; defaults to
452 | "embedding_tied_rnn_seq2seq".
453 |
454 | Returns:
455 | A tuple of the form (outputs, state), where:
456 | outputs: A list of the same length as decoder_inputs of 2D Tensors with
457 | shape [batch_size x output_symbols] containing the generated
458 | outputs where output_symbols = num_decoder_symbols if
459 | num_decoder_symbols is not None otherwise output_symbols = num_symbols.
460 | state: The state of each decoder cell at the final time-step.
461 | It is a 2D Tensor of shape [batch_size x cell.state_size].
462 |
463 | Raises:
464 | ValueError: When output_projection has the wrong shape.
465 | """
466 | with variable_scope.variable_scope(
467 | scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope:
468 | dtype = scope.dtype
469 |
470 | if output_projection is not None:
471 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
472 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
473 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
474 | proj_biases.get_shape().assert_is_compatible_with([num_symbols])
475 |
476 | embedding = variable_scope.get_variable(
477 | "embedding", [num_symbols, embedding_size], dtype=dtype)
478 |
479 | emb_encoder_inputs = [
480 | embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs
481 | ]
482 | emb_decoder_inputs = [
483 | embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs
484 | ]
485 |
486 | output_symbols = num_symbols
487 | if num_decoder_symbols is not None:
488 | output_symbols = num_decoder_symbols
489 | if output_projection is None:
490 | cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols)
491 |
492 | if isinstance(feed_previous, bool):
493 | loop_function = _extract_argmax_and_embed(embedding, output_projection,
494 | True) if feed_previous else None
495 | return tied_rnn_seq2seq(
496 | emb_encoder_inputs,
497 | emb_decoder_inputs,
498 | cell,
499 | loop_function=loop_function,
500 | dtype=dtype)
501 |
502 | # If feed_previous is a Tensor, we construct 2 graphs and use cond.
503 | def decoder(feed_previous_bool):
504 | loop_function = _extract_argmax_and_embed(
505 | embedding, output_projection, False) if feed_previous_bool else None
506 | reuse = None if feed_previous_bool else True
507 | with variable_scope.variable_scope(
508 | variable_scope.get_variable_scope(), reuse=reuse):
509 | outputs, state = tied_rnn_seq2seq(
510 | emb_encoder_inputs,
511 | emb_decoder_inputs,
512 | cell,
513 | loop_function=loop_function,
514 | dtype=dtype)
515 | state_list = [state]
516 | if nest.is_sequence(state):
517 | state_list = nest.flatten(state)
518 | return outputs + state_list
519 |
520 | outputs_and_state = control_flow_ops.cond(feed_previous,
521 | lambda: decoder(True),
522 | lambda: decoder(False))
523 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
524 | state_list = outputs_and_state[outputs_len:]
525 | state = state_list[0]
526 | # Calculate zero-state to know it's structure.
527 | static_batch_size = encoder_inputs[0].get_shape()[0]
528 | for inp in encoder_inputs[1:]:
529 | static_batch_size.merge_with(inp.get_shape()[0])
530 | batch_size = static_batch_size.value
531 | if batch_size is None:
532 | batch_size = array_ops.shape(encoder_inputs[0])[0]
533 | zero_state = cell.zero_state(batch_size, dtype)
534 | if nest.is_sequence(zero_state):
535 | state = nest.pack_sequence_as(
536 | structure=zero_state, flat_sequence=state_list)
537 | return outputs_and_state[:outputs_len], state
538 |
539 |
540 | def attention_decoder(decoder_inputs,
541 | initial_state,
542 | attention_states,
543 | cell,
544 | output_size=None,
545 | num_heads=1,
546 | loop_function=None,
547 | dtype=None,
548 | scope=None,
549 | initial_state_attention=False):
550 | """RNN decoder with attention for the sequence-to-sequence model.
551 |
552 | In this context "attention" means that, during decoding, the RNN can look up
553 | information in the additional tensor attention_states, and it does this by
554 | focusing on a few entries from the tensor. This model has proven to yield
555 | especially good results in a number of sequence-to-sequence tasks. This
556 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for
557 | details). It is recommended for complex sequence-to-sequence tasks.
558 |
559 | Args:
560 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
561 | initial_state: 2D Tensor [batch_size x cell.state_size].
562 | attention_states: 3D Tensor [batch_size x attn_length x attn_size].
563 | cell: core_rnn_cell.RNNCell defining the cell function and size.
564 | output_size: Size of the output vectors; if None, we use cell.output_size.
565 | num_heads: Number of attention heads that read from attention_states.
566 | loop_function: If not None, this function will be applied to i-th output
567 | in order to generate i+1-th input, and decoder_inputs will be ignored,
568 | except for the first element ("GO" symbol). This can be used for decoding,
569 | but also for training to emulate http://arxiv.org/abs/1506.03099.
570 | Signature -- loop_function(prev, i) = next
571 | * prev is a 2D Tensor of shape [batch_size x output_size],
572 | * i is an integer, the step number (when advanced control is needed),
573 | * next is a 2D Tensor of shape [batch_size x input_size].
574 | dtype: The dtype to use for the RNN initial state (default: tf.float32).
575 | scope: VariableScope for the created subgraph; default: "attention_decoder".
576 | initial_state_attention: If False (default), initial attentions are zero.
577 | If True, initialize the attentions from the initial state and attention
578 | states -- useful when we wish to resume decoding from a previously
579 | stored decoder state and attention states.
580 |
581 | Returns:
582 | A tuple of the form (outputs, state), where:
583 | outputs: A list of the same length as decoder_inputs of 2D Tensors of
584 | shape [batch_size x output_size]. These represent the generated outputs.
585 | Output i is computed from input i (which is either the i-th element
586 | of decoder_inputs or loop_function(output {i-1}, i)) as follows.
587 | First, we run the cell on a combination of the input and previous
588 | attention masks:
589 | cell_output, new_state = cell(linear(input, prev_attn), prev_state).
590 | Then, we calculate new attention masks:
591 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
592 | and then we calculate the output:
593 | output = linear(cell_output, new_attn).
594 | state: The state of each decoder cell the final time-step.
595 | It is a 2D Tensor of shape [batch_size x cell.state_size].
596 |
597 | Raises:
598 | ValueError: when num_heads is not positive, there are no inputs, shapes
599 | of attention_states are not set, or input size cannot be inferred
600 | from the input.
601 | """
602 | if not decoder_inputs:
603 | raise ValueError("Must provide at least 1 input to attention decoder.")
604 | if num_heads < 1:
605 | raise ValueError("With less than 1 heads, use a non-attention decoder.")
606 | if attention_states.get_shape()[2].value is None:
607 | raise ValueError("Shape[2] of attention_states must be known: %s" %
608 | attention_states.get_shape())
609 | if output_size is None:
610 | output_size = cell.output_size
611 |
612 | with variable_scope.variable_scope(
613 | scope or "attention_decoder", dtype=dtype) as scope:
614 | dtype = scope.dtype
615 |
616 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping.
617 | attn_length = attention_states.get_shape()[1].value
618 | if attn_length is None:
619 | attn_length = array_ops.shape(attention_states)[1]
620 | attn_size = attention_states.get_shape()[2].value
621 |
622 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
623 | hidden = array_ops.reshape(attention_states,
624 | [-1, attn_length, 1, attn_size])
625 | hidden_features = []
626 | v = []
627 | attention_vec_size = attn_size # Size of query vectors for attention.
628 | for a in xrange(num_heads):
629 | k = variable_scope.get_variable("AttnW_%d" % a,
630 | [1, 1, attn_size, attention_vec_size])
631 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
632 | v.append(
633 | variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
634 |
635 | state = initial_state
636 |
637 | def attention(query):
638 | """Put attention masks on hidden using hidden_features and query."""
639 | ds = [] # Results of attention reads will be stored here.
640 | if nest.is_sequence(query): # If the query is a tuple, flatten it.
641 | query_list = nest.flatten(query)
642 | for q in query_list: # Check that ndims == 2 if specified.
643 | ndims = q.get_shape().ndims
644 | if ndims:
645 | assert ndims == 2
646 | query = array_ops.concat(query_list, 1)
647 | for a in xrange(num_heads):
648 | with variable_scope.variable_scope("Attention_%d" % a):
649 | y = linear(query, attention_vec_size, True)
650 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
651 | # Attention mask is a softmax of v^T * tanh(...).
652 | s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
653 | [2, 3])
654 | a = nn_ops.softmax(s)
655 | # Now calculate the attention-weighted vector d.
656 | d = math_ops.reduce_sum(
657 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
658 | ds.append(array_ops.reshape(d, [-1, attn_size]))
659 | return ds
660 |
661 | outputs = []
662 | prev = None
663 | batch_attn_size = array_ops.stack([batch_size, attn_size])
664 | attns = [
665 | array_ops.zeros(
666 | batch_attn_size, dtype=dtype) for _ in xrange(num_heads)
667 | ]
668 | for a in attns: # Ensure the second shape of attention vectors is set.
669 | a.set_shape([None, attn_size])
670 | if initial_state_attention:
671 | attns = attention(initial_state)
672 | for i, inp in enumerate(decoder_inputs):
673 | if i > 0:
674 | variable_scope.get_variable_scope().reuse_variables()
675 | # If loop_function is set, we use it instead of decoder_inputs.
676 | if loop_function is not None and prev is not None:
677 | with variable_scope.variable_scope("loop_function", reuse=True):
678 | inp = loop_function(prev, i)
679 | # Merge input and previous attentions into one vector of the right size.
680 | input_size = inp.get_shape().with_rank(2)[1]
681 | if input_size.value is None:
682 | raise ValueError("Could not infer input size from input: %s" % inp.name)
683 | x = linear([inp] + attns, input_size, True)
684 | # Run the RNN.
685 | cell_output, state = cell(x, state)
686 | # Run the attention mechanism.
687 | if i == 0 and initial_state_attention:
688 | with variable_scope.variable_scope(
689 | variable_scope.get_variable_scope(), reuse=True):
690 | attns = attention(state)
691 | else:
692 | attns = attention(state)
693 |
694 | with variable_scope.variable_scope("AttnOutputProjection"):
695 | output = linear([cell_output] + attns, output_size, True)
696 | if loop_function is not None:
697 | prev = output
698 | outputs.append(output)
699 |
700 | return outputs, state
701 |
702 |
703 | def embedding_attention_decoder(decoder_inputs,
704 | initial_state,
705 | attention_states,
706 | cell,
707 | num_symbols,
708 | embedding_size,
709 | num_heads=1,
710 | output_size=None,
711 | output_projection=None,
712 | feed_previous=False,
713 | update_embedding_for_previous=True,
714 | dtype=None,
715 | scope=None,
716 | initial_state_attention=False):
717 | """RNN decoder with embedding and attention and a pure-decoding option.
718 |
719 | Args:
720 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
721 | initial_state: 2D Tensor [batch_size x cell.state_size].
722 | attention_states: 3D Tensor [batch_size x attn_length x attn_size].
723 | cell: core_rnn_cell.RNNCell defining the cell function.
724 | num_symbols: Integer, how many symbols come into the embedding.
725 | embedding_size: Integer, the length of the embedding vector for each symbol.
726 | num_heads: Number of attention heads that read from attention_states.
727 | output_size: Size of the output vectors; if None, use output_size.
728 | output_projection: None or a pair (W, B) of output projection weights and
729 | biases; W has shape [output_size x num_symbols] and B has shape
730 | [num_symbols]; if provided and feed_previous=True, each fed previous
731 | output will first be multiplied by W and added B.
732 | feed_previous: Boolean; if True, only the first of decoder_inputs will be
733 | used (the "GO" symbol), and all other decoder inputs will be generated by:
734 | next = embedding_lookup(embedding, argmax(previous_output)),
735 | In effect, this implements a greedy decoder. It can also be used
736 | during training to emulate http://arxiv.org/abs/1506.03099.
737 | If False, decoder_inputs are used as given (the standard decoder case).
738 | update_embedding_for_previous: Boolean; if False and feed_previous=True,
739 | only the embedding for the first symbol of decoder_inputs (the "GO"
740 | symbol) will be updated by back propagation. Embeddings for the symbols
741 | generated from the decoder itself remain unchanged. This parameter has
742 | no effect if feed_previous=False.
743 | dtype: The dtype to use for the RNN initial states (default: tf.float32).
744 | scope: VariableScope for the created subgraph; defaults to
745 | "embedding_attention_decoder".
746 | initial_state_attention: If False (default), initial attentions are zero.
747 | If True, initialize the attentions from the initial state and attention
748 | states -- useful when we wish to resume decoding from a previously
749 | stored decoder state and attention states.
750 |
751 | Returns:
752 | A tuple of the form (outputs, state), where:
753 | outputs: A list of the same length as decoder_inputs of 2D Tensors with
754 | shape [batch_size x output_size] containing the generated outputs.
755 | state: The state of each decoder cell at the final time-step.
756 | It is a 2D Tensor of shape [batch_size x cell.state_size].
757 |
758 | Raises:
759 | ValueError: When output_projection has the wrong shape.
760 | """
761 | if output_size is None:
762 | output_size = cell.output_size
763 | if output_projection is not None:
764 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
765 | proj_biases.get_shape().assert_is_compatible_with([num_symbols])
766 |
767 | with variable_scope.variable_scope(
768 | scope or "embedding_attention_decoder", dtype=dtype) as scope:
769 |
770 | embedding = variable_scope.get_variable("embedding",
771 | [num_symbols, embedding_size])
772 | loop_function = _extract_argmax_and_embed(
773 | embedding, output_projection,
774 | update_embedding_for_previous) if feed_previous else None
775 | emb_inp = [
776 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs
777 | ]
778 | return attention_decoder(
779 | emb_inp,
780 | initial_state,
781 | attention_states,
782 | cell,
783 | output_size=output_size,
784 | num_heads=num_heads,
785 | loop_function=loop_function,
786 | initial_state_attention=initial_state_attention)
787 |
788 |
789 | def embedding_attention_seq2seq(encoder_inputs,
790 | decoder_inputs,
791 | cell,
792 | num_encoder_symbols,
793 | num_decoder_symbols,
794 | embedding_size,
795 | num_heads=1,
796 | output_projection=None,
797 | feed_previous=False,
798 | dtype=None,
799 | scope=None,
800 | initial_state_attention=False):
801 | """Embedding sequence-to-sequence model with attention.
802 |
803 | This model first embeds encoder_inputs by a newly created embedding (of shape
804 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode
805 | embedded encoder_inputs into a state vector. It keeps the outputs of this
806 | RNN at every step to use for attention later. Next, it embeds decoder_inputs
807 | by another newly created embedding (of shape [num_decoder_symbols x
808 | input_size]). Then it runs attention decoder, initialized with the last
809 | encoder state, on embedded decoder_inputs and attending to encoder outputs.
810 |
811 | Warning: when output_projection is None, the size of the attention vectors
812 | and variables will be made proportional to num_decoder_symbols, can be large.
813 |
814 | Args:
815 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
816 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
817 | cell: core_rnn_cell.RNNCell defining the cell function and size.
818 | num_encoder_symbols: Integer; number of symbols on the encoder side.
819 | num_decoder_symbols: Integer; number of symbols on the decoder side.
820 | embedding_size: Integer, the length of the embedding vector for each symbol.
821 | num_heads: Number of attention heads that read from attention_states.
822 | output_projection: None or a pair (W, B) of output projection weights and
823 | biases; W has shape [output_size x num_decoder_symbols] and B has
824 | shape [num_decoder_symbols]; if provided and feed_previous=True, each
825 | fed previous output will first be multiplied by W and added B.
826 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
827 | of decoder_inputs will be used (the "GO" symbol), and all other decoder
828 | inputs will be taken from previous outputs (as in embedding_rnn_decoder).
829 | If False, decoder_inputs are used as given (the standard decoder case).
830 | dtype: The dtype of the initial RNN state (default: tf.float32).
831 | scope: VariableScope for the created subgraph; defaults to
832 | "embedding_attention_seq2seq".
833 | initial_state_attention: If False (default), initial attentions are zero.
834 | If True, initialize the attentions from the initial state and attention
835 | states.
836 |
837 | Returns:
838 | A tuple of the form (outputs, state), where:
839 | outputs: A list of the same length as decoder_inputs of 2D Tensors with
840 | shape [batch_size x num_decoder_symbols] containing the generated
841 | outputs.
842 | state: The state of each decoder cell at the final time-step.
843 | It is a 2D Tensor of shape [batch_size x cell.state_size].
844 | """
845 | with variable_scope.variable_scope(
846 | scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
847 | dtype = scope.dtype
848 | # Encoder.
849 | encoder_cell = copy.deepcopy(cell)
850 | encoder_cell = core_rnn_cell.EmbeddingWrapper(
851 | encoder_cell,
852 | embedding_classes=num_encoder_symbols,
853 | embedding_size=embedding_size)
854 | encoder_outputs, encoder_state = core_rnn.static_rnn(
855 | encoder_cell, encoder_inputs, dtype=dtype)
856 |
857 | # First calculate a concatenation of encoder outputs to put attention on.
858 | top_states = [
859 | array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs
860 | ]
861 | attention_states = array_ops.concat(top_states, 1)
862 |
863 | # Decoder.
864 | output_size = None
865 | if output_projection is None:
866 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
867 | output_size = num_decoder_symbols
868 |
869 | if isinstance(feed_previous, bool):
870 | return embedding_attention_decoder(
871 | decoder_inputs,
872 | encoder_state,
873 | attention_states,
874 | cell,
875 | num_decoder_symbols,
876 | embedding_size,
877 | num_heads=num_heads,
878 | output_size=output_size,
879 | output_projection=output_projection,
880 | feed_previous=feed_previous,
881 | initial_state_attention=initial_state_attention)
882 |
883 | # If feed_previous is a Tensor, we construct 2 graphs and use cond.
884 | def decoder(feed_previous_bool):
885 | reuse = None if feed_previous_bool else True
886 | with variable_scope.variable_scope(
887 | variable_scope.get_variable_scope(), reuse=reuse) as scope:
888 | outputs, state = embedding_attention_decoder(
889 | decoder_inputs,
890 | encoder_state,
891 | attention_states,
892 | cell,
893 | num_decoder_symbols,
894 | embedding_size,
895 | num_heads=num_heads,
896 | output_size=output_size,
897 | output_projection=output_projection,
898 | feed_previous=feed_previous_bool,
899 | update_embedding_for_previous=False,
900 | initial_state_attention=initial_state_attention)
901 | state_list = [state]
902 | if nest.is_sequence(state):
903 | state_list = nest.flatten(state)
904 | return outputs + state_list
905 |
906 | outputs_and_state = control_flow_ops.cond(feed_previous,
907 | lambda: decoder(True),
908 | lambda: decoder(False))
909 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
910 | state_list = outputs_and_state[outputs_len:]
911 | state = state_list[0]
912 | if nest.is_sequence(encoder_state):
913 | state = nest.pack_sequence_as(
914 | structure=encoder_state, flat_sequence=state_list)
915 | return outputs_and_state[:outputs_len], state
916 |
917 |
918 | def one2many_rnn_seq2seq(encoder_inputs,
919 | decoder_inputs_dict,
920 | enc_cell,
921 | dec_cells_dict,
922 | num_encoder_symbols,
923 | num_decoder_symbols_dict,
924 | embedding_size,
925 | feed_previous=False,
926 | dtype=None,
927 | scope=None):
928 | """One-to-many RNN sequence-to-sequence model (multi-task).
929 |
930 | This is a multi-task sequence-to-sequence model with one encoder and multiple
931 | decoders. Reference to multi-task sequence-to-sequence learning can be found
932 | here: http://arxiv.org/abs/1511.06114
933 |
934 | Args:
935 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
936 | decoder_inputs_dict: A dictionary mapping decoder name (string) to
937 | the corresponding decoder_inputs; each decoder_inputs is a list of 1D
938 | Tensors of shape [batch_size]; num_decoders is defined as
939 | len(decoder_inputs_dict).
940 | enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size.
941 | dec_cells_dict: A dictionary mapping encoder name (string) to an
942 | instance of core_rnn_cell.RNNCell.
943 | num_encoder_symbols: Integer; number of symbols on the encoder side.
944 | num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
945 | integer specifying number of symbols for the corresponding decoder;
946 | len(num_decoder_symbols_dict) must be equal to num_decoders.
947 | embedding_size: Integer, the length of the embedding vector for each symbol.
948 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of
949 | decoder_inputs will be used (the "GO" symbol), and all other decoder
950 | inputs will be taken from previous outputs (as in embedding_rnn_decoder).
951 | If False, decoder_inputs are used as given (the standard decoder case).
952 | dtype: The dtype of the initial state for both the encoder and encoder
953 | rnn cells (default: tf.float32).
954 | scope: VariableScope for the created subgraph; defaults to
955 | "one2many_rnn_seq2seq"
956 |
957 | Returns:
958 | A tuple of the form (outputs_dict, state_dict), where:
959 | outputs_dict: A mapping from decoder name (string) to a list of the same
960 | length as decoder_inputs_dict[name]; each element in the list is a 2D
961 | Tensors with shape [batch_size x num_decoder_symbol_list[name]]
962 | containing the generated outputs.
963 | state_dict: A mapping from decoder name (string) to the final state of the
964 | corresponding decoder RNN; it is a 2D Tensor of shape
965 | [batch_size x cell.state_size].
966 |
967 | Raises:
968 | TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell.
969 | ValueError: if len(dec_cells) != len(decoder_inputs_dict).
970 | """
971 | outputs_dict = {}
972 | state_dict = {}
973 |
974 | if not isinstance(enc_cell, core_rnn_cell.RNNCell):
975 | raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
976 | if set(dec_cells_dict) != set(decoder_inputs_dict):
977 | raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
978 | for dec_cell in dec_cells_dict.values():
979 | if not isinstance(dec_cell, core_rnn_cell.RNNCell):
980 | raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
981 |
982 | with variable_scope.variable_scope(
983 | scope or "one2many_rnn_seq2seq", dtype=dtype) as scope:
984 | dtype = scope.dtype
985 |
986 | # Encoder.
987 | enc_cell = core_rnn_cell.EmbeddingWrapper(
988 | enc_cell,
989 | embedding_classes=num_encoder_symbols,
990 | embedding_size=embedding_size)
991 | _, encoder_state = core_rnn.static_rnn(
992 | enc_cell, encoder_inputs, dtype=dtype)
993 |
994 | # Decoder.
995 | for name, decoder_inputs in decoder_inputs_dict.items():
996 | num_decoder_symbols = num_decoder_symbols_dict[name]
997 | dec_cell = dec_cells_dict[name]
998 |
999 | with variable_scope.variable_scope("one2many_decoder_" + str(
1000 | name)) as scope:
1001 | dec_cell = core_rnn_cell.OutputProjectionWrapper(
1002 | dec_cell, num_decoder_symbols)
1003 | if isinstance(feed_previous, bool):
1004 | outputs, state = embedding_rnn_decoder(
1005 | decoder_inputs,
1006 | encoder_state,
1007 | dec_cell,
1008 | num_decoder_symbols,
1009 | embedding_size,
1010 | feed_previous=feed_previous)
1011 | else:
1012 | # If feed_previous is a Tensor, we construct 2 graphs and use cond.
1013 | def filled_embedding_rnn_decoder(feed_previous):
1014 | """The current decoder with a fixed feed_previous parameter."""
1015 | # pylint: disable=cell-var-from-loop
1016 | reuse = None if feed_previous else True
1017 | vs = variable_scope.get_variable_scope()
1018 | with variable_scope.variable_scope(vs, reuse=reuse):
1019 | outputs, state = embedding_rnn_decoder(
1020 | decoder_inputs,
1021 | encoder_state,
1022 | dec_cell,
1023 | num_decoder_symbols,
1024 | embedding_size,
1025 | feed_previous=feed_previous)
1026 | # pylint: enable=cell-var-from-loop
1027 | state_list = [state]
1028 | if nest.is_sequence(state):
1029 | state_list = nest.flatten(state)
1030 | return outputs + state_list
1031 |
1032 | outputs_and_state = control_flow_ops.cond(
1033 | feed_previous, lambda: filled_embedding_rnn_decoder(True),
1034 | lambda: filled_embedding_rnn_decoder(False))
1035 | # Outputs length is the same as for decoder inputs.
1036 | outputs_len = len(decoder_inputs)
1037 | outputs = outputs_and_state[:outputs_len]
1038 | state_list = outputs_and_state[outputs_len:]
1039 | state = state_list[0]
1040 | if nest.is_sequence(encoder_state):
1041 | state = nest.pack_sequence_as(
1042 | structure=encoder_state, flat_sequence=state_list)
1043 | outputs_dict[name] = outputs
1044 | state_dict[name] = state
1045 |
1046 | return outputs_dict, state_dict
1047 |
1048 |
1049 | def sequence_loss_by_example(logits,
1050 | targets,
1051 | weights,
1052 | average_across_timesteps=True,
1053 | softmax_loss_function=None,
1054 | name=None):
1055 | """Weighted cross-entropy loss for a sequence of logits (per example).
1056 |
1057 | Args:
1058 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
1059 | targets: List of 1D batch-sized int32 Tensors of the same length as logits.
1060 | weights: List of 1D batch-sized float-Tensors of the same length as logits.
1061 | average_across_timesteps: If set, divide the returned cost by the total
1062 | label weight.
1063 | softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch
1064 | to be used instead of the standard softmax (the default if this is None).
1065 | name: Optional name for this operation, default: "sequence_loss_by_example".
1066 |
1067 | Returns:
1068 | 1D batch-sized float Tensor: The log-perplexity for each sequence.
1069 |
1070 | Raises:
1071 | ValueError: If len(logits) is different from len(targets) or len(weights).
1072 | """
1073 | if len(targets) != len(logits) or len(weights) != len(logits):
1074 | raise ValueError("Lengths of logits, weights, and targets must be the same "
1075 | "%d, %d, %d." % (len(logits), len(weights), len(targets)))
1076 | with ops.name_scope(name, "sequence_loss_by_example",
1077 | logits + targets + weights):
1078 | log_perp_list = []
1079 | for logit, target, weight in zip(logits, targets, weights):
1080 | if softmax_loss_function is None:
1081 | # TODO(irving,ebrevdo): This reshape is needed because
1082 | # sequence_loss_by_example is called with scalars sometimes, which
1083 | # violates our general scalar strictness policy.
1084 | target = array_ops.reshape(target, [-1])
1085 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
1086 | labels=target, logits=logit)
1087 | else:
1088 | crossent = softmax_loss_function(target, logit)
1089 | log_perp_list.append(crossent * weight)
1090 | log_perps = math_ops.add_n(log_perp_list)
1091 | if average_across_timesteps:
1092 | total_size = math_ops.add_n(weights)
1093 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
1094 | log_perps /= total_size
1095 | return log_perps
1096 |
1097 |
1098 | def sequence_loss(logits,
1099 | targets,
1100 | weights,
1101 | average_across_timesteps=True,
1102 | average_across_batch=True,
1103 | softmax_loss_function=None,
1104 | name=None):
1105 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
1106 |
1107 | Args:
1108 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
1109 | targets: List of 1D batch-sized int32 Tensors of the same length as logits.
1110 | weights: List of 1D batch-sized float-Tensors of the same length as logits.
1111 | average_across_timesteps: If set, divide the returned cost by the total
1112 | label weight.
1113 | average_across_batch: If set, divide the returned cost by the batch size.
1114 | softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch
1115 | to be used instead of the standard softmax (the default if this is None).
1116 | name: Optional name for this operation, defaults to "sequence_loss".
1117 |
1118 | Returns:
1119 | A scalar float Tensor: The average log-perplexity per symbol (weighted).
1120 |
1121 | Raises:
1122 | ValueError: If len(logits) is different from len(targets) or len(weights).
1123 | """
1124 | with ops.name_scope(name, "sequence_loss", logits + targets + weights):
1125 | cost = math_ops.reduce_sum(
1126 | sequence_loss_by_example(
1127 | logits,
1128 | targets,
1129 | weights,
1130 | average_across_timesteps=average_across_timesteps,
1131 | softmax_loss_function=softmax_loss_function))
1132 | if average_across_batch:
1133 | batch_size = array_ops.shape(targets[0])[0]
1134 | return cost / math_ops.cast(batch_size, cost.dtype)
1135 | else:
1136 | return cost
1137 |
1138 |
1139 | def model_with_buckets(encoder_inputs,
1140 | decoder_inputs,
1141 | targets,
1142 | weights,
1143 | buckets,
1144 | seq2seq,
1145 | softmax_loss_function=None,
1146 | per_example_loss=False,
1147 | name=None):
1148 | """Create a sequence-to-sequence model with support for bucketing.
1149 |
1150 | The seq2seq argument is a function that defines a sequence-to-sequence model,
1151 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
1152 | x, y, core_rnn_cell.GRUCell(24))
1153 |
1154 | Args:
1155 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
1156 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input.
1157 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence).
1158 | weights: List of 1D batch-sized float-Tensors to weight the targets.
1159 | buckets: A list of pairs of (input size, output size) for each bucket.
1160 | seq2seq: A sequence-to-sequence model function; it takes 2 input that
1161 | agree with encoder_inputs and decoder_inputs, and returns a pair
1162 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
1163 | softmax_loss_function: Function (labels-batch, inputs-batch) -> loss-batch
1164 | to be used instead of the standard softmax (the default if this is None).
1165 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized
1166 | tensor of losses for each sequence in the batch. If unset, it will be
1167 | a scalar with the averaged loss from all examples.
1168 | name: Optional name for this operation, defaults to "model_with_buckets".
1169 |
1170 | Returns:
1171 | A tuple of the form (outputs, losses), where:
1172 | outputs: The outputs for each bucket. Its j'th element consists of a list
1173 | of 2D Tensors. The shape of output tensors can be either
1174 | [batch_size x output_size] or [batch_size x num_decoder_symbols]
1175 | depending on the seq2seq model used.
1176 | losses: List of scalar Tensors, representing losses for each bucket, or,
1177 | if per_example_loss is set, a list of 1D batch-sized float Tensors.
1178 |
1179 | Raises:
1180 | ValueError: If length of encoder_inputs, targets, or weights is smaller
1181 | than the largest (last) bucket.
1182 | """
1183 | if len(encoder_inputs) < buckets[-1][0]:
1184 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
1185 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
1186 | if len(targets) < buckets[-1][1]:
1187 | raise ValueError("Length of targets (%d) must be at least that of last"
1188 | "bucket (%d)." % (len(targets), buckets[-1][1]))
1189 | if len(weights) < buckets[-1][1]:
1190 | raise ValueError("Length of weights (%d) must be at least that of last"
1191 | "bucket (%d)." % (len(weights), buckets[-1][1]))
1192 |
1193 | all_inputs = encoder_inputs + decoder_inputs + targets + weights
1194 | losses = []
1195 | outputs = []
1196 | with ops.name_scope(name, "model_with_buckets", all_inputs):
1197 | for j, bucket in enumerate(buckets):
1198 | with variable_scope.variable_scope(
1199 | variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
1200 | bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
1201 | decoder_inputs[:bucket[1]])
1202 | outputs.append(bucket_outputs)
1203 | if per_example_loss:
1204 | losses.append(
1205 | sequence_loss_by_example(
1206 | outputs[-1],
1207 | targets[:bucket[1]],
1208 | weights[:bucket[1]],
1209 | softmax_loss_function=softmax_loss_function))
1210 | else:
1211 | losses.append(
1212 | sequence_loss(
1213 | outputs[-1],
1214 | targets[:bucket[1]],
1215 | weights[:bucket[1]],
1216 | softmax_loss_function=softmax_loss_function))
1217 |
1218 | return outputs, losses
1219 |
--------------------------------------------------------------------------------