├── .gitignore
├── ByteNet
├── __init__.py
├── generator.py
├── ops.py
└── translator.py
├── Data
└── generator_training_data
│ └── shakespeare.txt
├── LICENSE
├── README.md
├── data_loader.py
├── generate.py
├── model_config.py
├── train_generator.py
├── train_translator.py
├── translate.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea*
3 | *.pdf
4 | *.jpg
5 | *.png
6 | *.pyc
7 | *.py.bak
8 | *.pem
9 | *.ckpt
10 | awstransfer.sh
11 | Data/Models/*
12 | logs/*
13 | Data/MachineTranslation
14 | source.txt
15 | data_loader_old.py
16 | connect.sh
17 | sample.sh
18 | tensorboard.sh
19 | Data/tb_summaries/*
20 | Data/*.txt
21 | dump/*
22 | Data/translator_model/*
--------------------------------------------------------------------------------
/ByteNet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/paarthneekhara/byteNet-tensorflow/9bba9352e5f8a89a32ab14e9546a158750cdbfaf/ByteNet/__init__.py
--------------------------------------------------------------------------------
/ByteNet/generator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import ops
3 |
4 | class ByteNet_Generator:
5 | def __init__(self, options):
6 | self.options = options
7 | source_embedding_channels = 2 * options['residual_channels']
8 | self.w_sentence_embedding = tf.get_variable('w_sentence_embedding',
9 | [options['vocab_size'], source_embedding_channels],
10 | initializer=tf.truncated_normal_initializer(stddev=0.02))
11 |
12 | def build_model(self):
13 | options = self.options
14 | self.t_sentence = tf.placeholder('int32',
15 | [None, None], name = 't_sentence')
16 |
17 | source_sentence = self.t_sentence[:,0:-1]
18 | target_sentence = self.t_sentence[:,1:]
19 |
20 | source_embedding = tf.nn.embedding_lookup(self.w_sentence_embedding,
21 | source_sentence, name = "source_embedding")
22 |
23 | curr_input = source_embedding
24 | for layer_no, dilation in enumerate(options['dilations']):
25 | curr_input = ops.byetenet_residual_block(curr_input, dilation,
26 | layer_no, options['residual_channels'],
27 | options['filter_width'], causal = True, train = True)
28 |
29 | logits = ops.conv1d(tf.nn.relu(curr_input),
30 | options['vocab_size'], name = 'logits')
31 |
32 | logits_flat = tf.reshape(logits, [-1, options['vocab_size']])
33 | target_flat = tf.reshape(target_sentence, [-1])
34 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = target_flat, logits = logits_flat)
35 | self.loss = tf.reduce_mean(loss)
36 |
37 | self.arg_max_prediction = tf.argmax(logits_flat, 1)
38 |
39 | tf.summary.scalar('loss', self.loss)
40 |
41 | def build_generator(self, reuse = False):
42 | if reuse:
43 | tf.get_variable_scope().reuse_variables()
44 |
45 | options = self.options
46 | self.seed_sentence = tf.placeholder('int32',
47 | [None, None], name = 'seed_sentence')
48 |
49 | source_embedding = tf.nn.embedding_lookup(self.w_sentence_embedding,
50 | self.seed_sentence, name = "source_embedding")
51 |
52 | curr_input = source_embedding
53 | for layer_no, dilation in enumerate(options['dilations']):
54 | curr_input = ops.byetenet_residual_block(curr_input, dilation,
55 | layer_no, options['residual_channels'],
56 | options['filter_width'], causal = True, train = False)
57 |
58 | logits = ops.conv1d(tf.nn.relu(curr_input),
59 | options['vocab_size'], name = 'logits')
60 | logits_flat = tf.reshape(logits, [-1, options['vocab_size']])
61 | probs_flat = tf.nn.softmax(logits_flat)
62 |
63 | self.g_probs = tf.reshape(probs_flat, [-1, tf.shape(self.seed_sentence)[1], options['vocab_size']])
64 |
65 |
66 | def main():
67 | options = {
68 | 'vocab_size' : 250,
69 | 'residual_channels' : 512,
70 | 'dilations' : [ 1,2,4,8,16,
71 | 1,2,4,8,16
72 | ],
73 | 'filter_width' : 3
74 | }
75 |
76 | model = ByteNet_Generator(options)
77 | model.build_model()
78 | model.build_generator(reuse = True)
79 |
80 | if __name__ == '__main__':
81 | main()
--------------------------------------------------------------------------------
/ByteNet/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import math
3 |
4 | def fully_connected(input_, output_nodes, name, stddev=0.02):
5 | with tf.variable_scope(name):
6 | input_shape = input_.get_shape()
7 | input_nodes = input_shape[-1]
8 | w = tf.get_variable('w', [input_nodes, output_nodes],
9 | initializer=tf.truncated_normal_initializer(stddev=0.02))
10 | biases = tf.get_variable('b', [output_nodes],
11 | initializer=tf.constant_initializer(0.0))
12 | res = tf.matmul(input_, w) + biases
13 | return res
14 |
15 |
16 | # 1d CONVOLUTION WITH DILATION
17 | def conv1d(input_, output_channels,
18 | dilation = 1, filter_width = 1, causal = False,
19 | name = "dilated_conv"):
20 | with tf.variable_scope(name):
21 | w = tf.get_variable('w', [1, filter_width, input_.get_shape()[-1], output_channels ],
22 | initializer=tf.truncated_normal_initializer(stddev=0.02))
23 | b = tf.get_variable('b', [output_channels ],
24 | initializer=tf.constant_initializer(0.0))
25 |
26 | if causal:
27 | padding = [[0, 0], [(filter_width - 1) * dilation, 0], [0, 0]]
28 | padded = tf.pad(input_, padding)
29 | input_expanded = tf.expand_dims(padded, dim = 1)
30 | out = tf.nn.atrous_conv2d(input_expanded, w, rate = dilation, padding = 'VALID') + b
31 | else:
32 | input_expanded = tf.expand_dims(input_, dim = 1)
33 | out = tf.nn.atrous_conv2d(input_expanded, w, rate = dilation, padding = 'SAME') + b
34 |
35 | return tf.squeeze(out, [1])
36 |
37 | def layer_normalization(x, name, epsilon=1e-8, trainable = True):
38 | with tf.variable_scope(name):
39 | shape = x.get_shape()
40 | beta = tf.get_variable('beta', [ int(shape[-1])],
41 | initializer=tf.constant_initializer(0), trainable=trainable)
42 | gamma = tf.get_variable('gamma', [ int(shape[-1])],
43 | initializer=tf.constant_initializer(1), trainable=trainable)
44 |
45 | mean, variance = tf.nn.moments(x, axes=[len(shape) - 1], keep_dims=True)
46 |
47 | x = (x - mean) / tf.sqrt(variance + epsilon)
48 |
49 | return gamma * x + beta
50 |
51 | def byetenet_residual_block(input_, dilation, layer_no,
52 | residual_channels, filter_width,
53 | causal = True, train = True):
54 | block_type = "decoder" if causal else "encoder"
55 | block_name = "bytenet_{}_layer_{}_{}".format(block_type, layer_no, dilation)
56 | with tf.variable_scope(block_name):
57 | input_ln = layer_normalization(input_, name="ln1", trainable = train)
58 | relu1 = tf.nn.relu(input_ln)
59 | conv1 = conv1d(relu1, residual_channels, name = "conv1d_1")
60 | conv1 = layer_normalization(conv1, name="ln2", trainable = train)
61 | relu2 = tf.nn.relu(conv1)
62 |
63 | dilated_conv = conv1d(relu2, residual_channels,
64 | dilation, filter_width,
65 | causal = causal,
66 | name = "dilated_conv"
67 | )
68 | print dilated_conv
69 | dilated_conv = layer_normalization(dilated_conv, name="ln3", trainable = train)
70 | relu3 = tf.nn.relu(dilated_conv)
71 | conv2 = conv1d(relu3, 2 * residual_channels, name = 'conv1d_2')
72 | return input_ + conv2
73 |
74 | def init_weight(dim_in, dim_out, name=None, stddev=1.0):
75 | return tf.Variable(tf.truncated_normal([dim_in, dim_out], stddev=stddev/math.sqrt(float(dim_in))), name=name)
76 |
77 | def init_bias(dim_out, name=None):
78 | return tf.Variable(tf.zeros([dim_out]), name=name)
--------------------------------------------------------------------------------
/ByteNet/translator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import ops
3 |
4 | class ByteNet_Translator:
5 | def __init__(self, options):
6 | self.options = options
7 | embedding_channels = 2 * options['residual_channels']
8 |
9 | self.w_source_embedding = tf.get_variable('w_source_embedding',
10 | [options['source_vocab_size'], embedding_channels],
11 | initializer=tf.truncated_normal_initializer(stddev=0.02))
12 |
13 | self.w_target_embedding = tf.get_variable('w_target_embedding',
14 | [options['target_vocab_size'], embedding_channels],
15 | initializer=tf.truncated_normal_initializer(stddev=0.02))
16 |
17 | def build_model(self):
18 | options = self.options
19 | self.source_sentence = tf.placeholder('int32',
20 | [None, None], name = 'source_sentence')
21 | self.target_sentence = tf.placeholder('int32',
22 | [None, None], name = 'target_sentence')
23 |
24 | target_1 = self.target_sentence[:,0:-1]
25 | target_2 = self.target_sentence[:,1:]
26 |
27 | source_embedding = tf.nn.embedding_lookup(self.w_source_embedding,
28 | self.source_sentence, name = "source_embedding")
29 | target_1_embedding = tf.nn.embedding_lookup(self.w_target_embedding,
30 | target_1, name = "target_1_embedding")
31 |
32 |
33 | curr_input = source_embedding
34 | for layer_no, dilation in enumerate(options['encoder_dilations']):
35 | curr_input = ops.byetenet_residual_block(curr_input, dilation,
36 | layer_no, options['residual_channels'],
37 | options['encoder_filter_width'], causal = False, train = True)
38 |
39 | encoder_output = curr_input
40 | combined_embedding = target_1_embedding + encoder_output
41 | curr_input = combined_embedding
42 | for layer_no, dilation in enumerate(options['decoder_dilations']):
43 | curr_input = ops.byetenet_residual_block(curr_input, dilation,
44 | layer_no, options['residual_channels'],
45 | options['decoder_filter_width'], causal = True, train = True)
46 |
47 | logits = ops.conv1d(tf.nn.relu(curr_input),
48 | options['target_vocab_size'], name = 'logits')
49 | print "logits", logits
50 | logits_flat = tf.reshape(logits, [-1, options['target_vocab_size']])
51 | target_flat = tf.reshape(target_2, [-1])
52 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
53 | labels = target_flat, logits = logits_flat)
54 |
55 | self.loss = tf.reduce_mean(loss)
56 | self.arg_max_prediction = tf.argmax(logits_flat, 1)
57 | tf.summary.scalar('loss', self.loss)
58 |
59 | def build_translator(self, reuse = False):
60 | if reuse:
61 | tf.get_variable_scope().reuse_variables()
62 |
63 | options = self.options
64 | self.t_source_sentence = tf.placeholder('int32',
65 | [None, None], name = 'source_sentence')
66 | self.t_target_sentence = tf.placeholder('int32',
67 | [None, None], name = 'target_sentence')
68 |
69 | source_embedding = tf.nn.embedding_lookup(self.w_source_embedding,
70 | self.t_source_sentence, name = "source_embedding")
71 | target_embedding = tf.nn.embedding_lookup(self.w_target_embedding,
72 | self.t_target_sentence, name = "target_embedding")
73 |
74 | curr_input = source_embedding
75 | for layer_no, dilation in enumerate(options['encoder_dilations']):
76 | curr_input = ops.byetenet_residual_block(curr_input, dilation,
77 | layer_no, options['residual_channels'],
78 | options['encoder_filter_width'], causal = False, train = False)
79 |
80 | encoder_output = curr_input[:,0:tf.shape(self.t_target_sentence)[1],:]
81 |
82 | combined_embedding = target_embedding + encoder_output
83 | curr_input = combined_embedding
84 | for layer_no, dilation in enumerate(options['decoder_dilations']):
85 | curr_input = ops.byetenet_residual_block(curr_input, dilation,
86 | layer_no, options['residual_channels'],
87 | options['decoder_filter_width'], causal = True, train = False)
88 |
89 | logits = ops.conv1d(tf.nn.relu(curr_input),
90 | options['target_vocab_size'], name = 'logits')
91 | logits_flat = tf.reshape(logits, [-1, options['target_vocab_size']])
92 | probs_flat = tf.nn.softmax(logits_flat)
93 |
94 | self.t_probs = tf.reshape(probs_flat,
95 | [-1, tf.shape(logits)[1], options['target_vocab_size']])
96 |
97 | def main():
98 | options = {
99 | 'source_vocab_size' : 250,
100 | 'target_vocab_size' : 250,
101 | 'residual_channels' : 512,
102 | 'encoder_dilations' : [ 1,2,4,8,16,
103 | 1,2,4,8,16
104 | ],
105 | 'decoder_dilations' : [ 1,2,4,8,16,
106 | 1,2,4,8,16
107 | ],
108 | 'encoder_filter_width' : 3,
109 | 'decoder_filter_width' : 3
110 | }
111 | md = ByteNet_Translator(options)
112 | md.build_model()
113 | md.build_translator(reuse = True)
114 |
115 | if __name__ == '__main__':
116 | main()
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2014 Paarth Neekhara
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # byteNet-tensorflow
2 |
3 | [](https://gitter.im/byteNet-tensorflow/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
4 |
5 | This is a tensorflow implementation of the byte-net model from DeepMind's paper [Neural Machine Translation in Linear Time][1].
6 |
7 | From the abstract
8 | >The ByteNet decoder attains state-of-the-art performance on character-level language modeling and outperforms the previous best results obtained with recurrent neural networks. The ByteNet also achieves a performance on raw character-level machine translation that approaches that of the best neural translation models that run in quadratic time. The implicit structure learnt by the ByteNet mirrors the expected alignments between the sequences.
9 |
10 | ## ByteNet Encoder-Decoder Model:
11 | 
12 |
13 | Image Source - [Neural Machine Translation in Linear Time][1] paper
14 |
15 | The model applies dilated 1d convolutions on the sequential data, layer by layer to obain the source encoding. The decoder then applies masked 1d convolutions on the target sequence (conditioned by the encoder output) to obtain the next character in the target sequence.The character generation model is just the byteNet decoder, while the machine translation model is the combined encoder and decoder.
16 |
17 | ## Implementation Notes
18 | 1. The character generation model is defined in ```ByteNet/generator.py``` and the translation model is defined in ```ByteNet/translator.py```. ```ByteNet/ops.py``` contains the bytenet residual block, dilated conv1d and layer normalization.
19 | 2. The model can be configured by editing model_config.py.
20 | 5. Number of residual channels 512 (Configurable in model_config.py).
21 |
22 | ## Requirements
23 | - Python 2.7.6
24 | - Tensorflow 1.2.0
25 |
26 | ## Datasets
27 | - The character generation model has been trained on [Shakespeare text][4]. I have included the text file in the repository ```Data/generator_training_data/shakespeare.txt```.
28 | - The machine translation model has been trained for german to english translation. You may download the news commentary dataset from here [http://www.statmt.org/wmt16/translation-task.html][5]
29 |
30 | ## Training
31 | Create the following directories ```Data/tb_summaries/translator_model```, ```Data/tb_summaries/generator_model```, ```Data/Models/generation_model```, ```Data/Models/translation_model```.
32 |
33 | - Text Generation
34 | * Configure the model by editing ```model_config.py```.
35 | * Save the text files to train on, in ```Data/generator_training_data```. A sample shakespeare.txt is included in the repo.
36 | * Train the model by : ```python train_generator.py --text_dir="Data/generator_training_data"```
37 | * ```python train_generator.py --help``` for more options.
38 |
39 | - Machine Translation
40 | * Configure the model by editing ```model_config.py```.
41 | * Save the source and target sentences in separate files in ```Data/MachineTranslation```. You may download the new commentary training corpus using [this link][6].
42 | * The model is trained on buckets of sentence pairs of length in mutpiples of a configurable parameter ```bucket_quant```. The sentences are padded with a special character beyond the actual length.
43 | * Train translation model using:
44 | - ```python train_translator.py --source_file= --target_file= --bucket_quant=50```
45 | - ```python train_translator.py``` --help for more options.
46 |
47 |
48 |
49 | ## Generating Samples
50 | - Generate new samples using :
51 | * ```python generate.py --seed="SOME_TEXT_TO_START_WITH" --sample_size=```
52 | - You can test sample translations from the dataset using ```python translate.py```.
53 | * This will pick random source sentences from the dataset and translate them.
54 |
55 | #### Sample Generations
56 |
57 | ```
58 | ANTONIO:
59 | What say you to this part of this to thee?
60 |
61 | KING PHILIP:
62 | What say these faith, madam?
63 |
64 | First Citizen:
65 | The king of England, the will of the state,
66 | That thou dost speak to me, and the thing that shall
67 | In this the son of this devil to the storm,
68 | That thou dost speak to thee to the world,
69 | That thou dost see the bear that was the foot,
70 |
71 | ```
72 |
73 | #### Translation Results to be updated
74 |
75 | ## TODO
76 | - Evaluating the translation Model
77 | - Implement beam search - Contributors welcomed. Currently the model samples from the probability distribution from the top k most probable predictions.
78 | ## References
79 | - [Neural Machine Translation in Linear Time][1] paper
80 | - [Tensorflow Wavenet][2] code
81 | - [Sugar Tensor Source Code][7] For implementing some ops.
82 |
83 | [1]:https://arxiv.org/abs/1610.10099
84 | [2]:https://github.com/ibab/tensorflow-wavenet
85 | [3]:https://drive.google.com/file/d/0B30fmeZ1slbBYWVSWnMyc3hXQVU/view?usp=sharing
86 | [4]:http://cs.stanford.edu/people/karpathy/char-rnn/
87 | [5]:http://www.statmt.org/wmt16/translation-task.html
88 | [6]:http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz
89 | [7]:https://github.com/buriburisuri/sugartensor
90 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import listdir
3 | from os.path import isfile, join
4 | import numpy as np
5 |
6 | class Data_Loader:
7 | def __init__(self, options):
8 | if options['model_type'] == 'translation':
9 | source_file = options['source_file']
10 | target_file = options['target_file']
11 |
12 | self.max_sentences = None
13 | if 'max_sentences' in options:
14 | self.max_sentences = options['max_sentences']
15 |
16 | with open(source_file) as f:
17 | self.source_lines = f.read().decode("utf-8").split('\n')
18 | with open(target_file) as f:
19 | self.target_lines = f.read().decode("utf-8").split('\n')
20 |
21 | if self.max_sentences:
22 | self.source_lines = self.source_lines[0:self.max_sentences]
23 | self.target_lines = self.target_lines[0:self.max_sentences]
24 |
25 | print "Source Sentences", len(self.source_lines)
26 | print "Target Sentences", len(self.target_lines)
27 |
28 | self.bucket_quant = options['bucket_quant']
29 | self.source_vocab = self.build_vocab(self.source_lines)
30 | self.target_vocab = self.build_vocab(self.target_lines)
31 |
32 | print "SOURCE VOCAB SIZE", len(self.source_vocab)
33 | print "TARGET VOCAB SIZE", len(self.target_vocab)
34 |
35 | elif options['model_type'] == 'generator':
36 | dir_name = options['dir_name']
37 | files = [ join(dir_name, f) for f in listdir(dir_name) if ( isfile(join(dir_name, f)) and ('.txt' in f) ) ]
38 | text = []
39 | for f in files:
40 | text += list(open(f).read())
41 |
42 | vocab = {ch : True for ch in text}
43 | print "Bool vocab", len(vocab)
44 | self.vocab_list = [ch for ch in vocab]
45 | print "vocab list", len(self.vocab_list)
46 | self.vocab_indexed = {ch : i for i, ch in enumerate(self.vocab_list)}
47 | print "vocab_indexed", len(self.vocab_indexed)
48 |
49 | for index, item in enumerate(text):
50 | text[index] = self.vocab_indexed[item]
51 | self.text = np.array(text, dtype='int32')
52 |
53 | def load_generator_data(self, sample_size):
54 | text = self.text
55 | mod_size = len(text) - len(text)%sample_size
56 | text = text[0:mod_size]
57 | text = text.reshape(-1, sample_size)
58 | return text, self.vocab_indexed
59 |
60 |
61 | def load_translation_data(self):
62 | source_lines = []
63 | target_lines = []
64 | for i in range(len(self.source_lines)):
65 | source_lines.append( self.string_to_indices(self.source_lines[i], self.source_vocab) )
66 | target_lines.append( self.string_to_indices(self.target_lines[i], self.target_vocab) )
67 |
68 | buckets = self.create_buckets(source_lines, target_lines)
69 |
70 | # frequent_keys = [ (-len(buckets[key]), key) for key in buckets ]
71 | # frequent_keys.sort()
72 |
73 | # print "Source", self.inidices_to_string( buckets[ frequent_keys[3][1] ][5][0], self.source_vocab)
74 | # print "Target", self.inidices_to_string( buckets[ frequent_keys[3][1] ][5][1], self.target_vocab)
75 |
76 | return buckets, self.source_vocab, self.target_vocab
77 |
78 |
79 |
80 | def create_buckets(self, source_lines, target_lines):
81 |
82 | bucket_quant = self.bucket_quant
83 | source_vocab = self.source_vocab
84 | target_vocab = self.target_vocab
85 |
86 | buckets = {}
87 | for i in xrange(len(source_lines)):
88 |
89 | source_lines[i] = np.concatenate( (source_lines[i], [source_vocab['eol']]) )
90 | target_lines[i] = np.concatenate( ([target_vocab['init']], target_lines[i], [target_vocab['eol']]) )
91 |
92 | sl = len(source_lines[i])
93 | tl = len(target_lines[i])
94 |
95 |
96 | new_length = max(sl, tl)
97 | if new_length % bucket_quant > 0:
98 | new_length = ((new_length/bucket_quant) + 1 ) * bucket_quant
99 |
100 | s_padding = np.array( [source_vocab['padding'] for ctr in xrange(sl, new_length) ] )
101 |
102 | # NEED EXTRA PADDING FOR TRAINING..
103 | t_padding = np.array( [target_vocab['padding'] for ctr in xrange(tl, new_length + 1) ] )
104 |
105 | source_lines[i] = np.concatenate( [ source_lines[i], s_padding ] )
106 | target_lines[i] = np.concatenate( [ target_lines[i], t_padding ] )
107 |
108 | if new_length in buckets:
109 | buckets[new_length].append( (source_lines[i], target_lines[i]) )
110 | else:
111 | buckets[new_length] = [(source_lines[i], target_lines[i])]
112 |
113 | if i%1000 == 0:
114 | print "Loading", i
115 |
116 | return buckets
117 |
118 | def build_vocab(self, sentences):
119 | vocab = {}
120 | ctr = 0
121 | for st in sentences:
122 | for ch in st:
123 | if ch not in vocab:
124 | vocab[ch] = ctr
125 | ctr += 1
126 |
127 | # SOME SPECIAL CHARACTERS
128 | vocab['eol'] = ctr
129 | vocab['padding'] = ctr + 1
130 | vocab['init'] = ctr + 2
131 |
132 | return vocab
133 |
134 | def string_to_indices(self, sentence, vocab):
135 | indices = [ vocab[s] for s in sentence ]
136 | return indices
137 |
138 | def inidices_to_string(self, sentence, vocab):
139 | id_ch = { vocab[ch] : ch for ch in vocab }
140 | sent = []
141 | for c in sentence:
142 | if id_ch[c] == 'eol':
143 | break
144 | sent += id_ch[c]
145 |
146 | return "".join(sent)
147 |
148 | def get_batch_from_pairs(self, pair_list):
149 | source_sentences = []
150 | target_sentences = []
151 | for s, t in pair_list:
152 | source_sentences.append(s)
153 | target_sentences.append(t)
154 |
155 | return np.array(source_sentences, dtype = 'int32'), np.array(target_sentences, dtype = 'int32')
156 |
157 |
158 | def main():
159 | # FOR TESTING ONLY
160 | trans_options = {
161 | 'model_type' : 'translation',
162 | 'source_file' : 'Data/MachineTranslation/news-commentary-v11.de-en.de',
163 | 'target_file' : 'Data/MachineTranslation/news-commentary-v11.de-en.en',
164 | 'bucket_quant' : 25,
165 | }
166 | gen_options = {
167 | 'model_type' : 'generator',
168 | 'dir_name' : 'Data',
169 | }
170 |
171 | dl = Data_Loader(gen_options)
172 | text_samples, vocab = dl.load_generator_data( 1000 )
173 | print dl.inidices_to_string(text_samples[1], vocab)
174 | print text_samples.shape
175 | print np.max(text_samples)
176 | # buckets, source_vocab, target_vocab = dl.load_translation_data()
177 |
178 | if __name__ == '__main__':
179 | main()
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import argparse
4 | import model_config
5 | import data_loader
6 | from ByteNet import generator
7 | import utils
8 | import shutil
9 | import time
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument('--sample_size', type=int, default=300,
15 | help='Sampled output size')
16 | parser.add_argument('--top_k', type=int, default=5,
17 | help='Sample from top k predictions')
18 | parser.add_argument('--model_path', type=str, default=None,
19 | help='Pre-Trained Model Path, to resume from')
20 | parser.add_argument('--text_dir', type=str, default='Data/generator_training_data',
21 | help='Directory containing text files')
22 | parser.add_argument('--data_dir', type=str, default='Data',
23 | help='Data Directory')
24 | parser.add_argument('--seed', type=str, default='All',
25 | help='Seed for text generation')
26 |
27 |
28 |
29 | args = parser.parse_args()
30 |
31 | # model_config = json.loads( open('model_config.json').read() )
32 | config = model_config.predictor_config
33 |
34 | dl = data_loader.Data_Loader({'model_type' : 'generator', 'dir_name' : args.text_dir})
35 | _, vocab = dl.load_generator_data(config['sample_size'])
36 |
37 |
38 | model_options = {
39 | 'vocab_size' : len(vocab),
40 | 'residual_channels' : config['residual_channels'],
41 | 'dilations' : config['dilations'],
42 | 'filter_width' : config['filter_width'],
43 | }
44 |
45 | generator_model = generator.ByteNet_Generator( model_options )
46 | generator_model.build_generator()
47 |
48 |
49 | sess = tf.InteractiveSession()
50 | tf.initialize_all_variables().run()
51 | saver = tf.train.Saver()
52 |
53 | if args.model_path:
54 | saver.restore(sess, args.model_path)
55 |
56 | seed_sentence = np.array([dl.string_to_indices(args.seed, vocab)], dtype = 'int32' )
57 |
58 | for col in range(args.sample_size):
59 | [probs] = sess.run([generator_model.g_probs],
60 | feed_dict = {
61 | generator_model.seed_sentence :seed_sentence
62 | })
63 |
64 | curr_preds = []
65 | for bi in range(probs.shape[0]):
66 | pred_word = utils.sample_top(probs[bi][-1], top_k = args.top_k )
67 | curr_preds.append(pred_word)
68 |
69 | seed_sentence = np.insert(seed_sentence, seed_sentence.shape[1], curr_preds, axis = 1)
70 | print col, dl.inidices_to_string(seed_sentence[0], vocab)
71 |
72 | f = open('Data/generator_sample.txt', 'wb')
73 | f.write(dl.inidices_to_string(seed_sentence[0], vocab))
74 | f.close()
75 |
76 | if __name__ == '__main__':
77 | main()
--------------------------------------------------------------------------------
/model_config.py:
--------------------------------------------------------------------------------
1 | predictor_config = {
2 | "filter_width": 3,
3 | "dilations": [1, 2, 4, 8, 16,
4 | 1, 2, 4, 8, 16,
5 | 1, 2, 4, 8, 16,
6 | 1, 2, 4, 8, 16,
7 | 1, 2, 4, 8, 16,
8 | ],
9 | "residual_channels": 512,
10 | "n_target_quant": 256,
11 | "n_source_quant": 256,
12 | "sample_size" : 1000
13 | }
14 |
15 | translator_config = {
16 | "decoder_filter_width": 3,
17 | "encoder_filter_width" : 5,
18 | "encoder_dilations": [1, 2, 4, 8, 16,
19 | 1, 2, 4, 8, 16,
20 | 1, 2, 4, 8, 16,
21 | ],
22 | "decoder_dilations": [1, 2, 4, 8, 16,
23 | 1, 2, 4, 8, 16,
24 | 1, 2, 4, 8, 16,
25 | ],
26 | "residual_channels": 512,
27 | }
--------------------------------------------------------------------------------
/train_generator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import argparse
4 | import model_config
5 | import data_loader
6 | from ByteNet import generator
7 | import utils
8 | import shutil
9 | import time
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--learning_rate', type=float, default=0.001,
14 | help='Learning Rate')
15 | parser.add_argument('--batch_size', type=int, default=1,
16 | help='Learning Rate')
17 | parser.add_argument('--sample_every', type=int, default=500,
18 | help='Sample generator output evry x steps')
19 | parser.add_argument('--summary_every', type=int, default=50,
20 | help='Sample generator output evry x steps')
21 | parser.add_argument('--save_model_every', type=int, default=1500,
22 | help='Save model every')
23 | parser.add_argument('--sample_size', type=int, default=300,
24 | help='Sampled output size')
25 | parser.add_argument('--top_k', type=int, default=5,
26 | help='Sample from top k predictions')
27 | parser.add_argument('--max_epochs', type=int, default=1000,
28 | help='Max Epochs')
29 | parser.add_argument('--beta1', type=float, default=0.5,
30 | help='Momentum for Adam Update')
31 | parser.add_argument('--resume_model', type=str, default=None,
32 | help='Pre-Trained Model Path, to resume from')
33 | parser.add_argument('--text_dir', type=str, default='Data/generator_training_data',
34 | help='Directory containing text files')
35 | parser.add_argument('--data_dir', type=str, default='Data',
36 | help='Data Directory')
37 | parser.add_argument('--seed', type=str, default='All',
38 | help='Seed for text generation')
39 |
40 |
41 |
42 | args = parser.parse_args()
43 |
44 | # model_config = json.loads( open('model_config.json').read() )
45 | config = model_config.predictor_config
46 |
47 | dl = data_loader.Data_Loader({'model_type' : 'generator', 'dir_name' : args.text_dir})
48 | text_samples, vocab = dl.load_generator_data(config['sample_size'])
49 | print text_samples.shape
50 |
51 | model_options = {
52 | 'vocab_size' : len(vocab),
53 | 'residual_channels' : config['residual_channels'],
54 | 'dilations' : config['dilations'],
55 | 'filter_width' : config['filter_width'],
56 | }
57 |
58 | generator_model = generator.ByteNet_Generator( model_options )
59 | generator_model.build_model()
60 |
61 | optim = tf.train.AdamOptimizer(
62 | args.learning_rate,
63 | beta1 = args.beta1).minimize(generator_model.loss)
64 |
65 | generator_model.build_generator(reuse = True)
66 | merged_summary = tf.summary.merge_all()
67 |
68 | sess = tf.InteractiveSession()
69 | tf.initialize_all_variables().run()
70 | saver = tf.train.Saver()
71 |
72 | if args.resume_model:
73 | saver.restore(sess, args.resume_model)
74 |
75 | shutil.rmtree('Data/tb_summaries/generator_model')
76 | train_writer = tf.summary.FileWriter('Data/tb_summaries/generator_model', sess.graph)
77 |
78 | step = 0
79 | for epoch in range(args.max_epochs):
80 | batch_no = 0
81 | batch_size = args.batch_size
82 | while (batch_no+1) * batch_size < text_samples.shape[0]:
83 |
84 | start = time.clock()
85 |
86 | text_batch = text_samples[batch_no*batch_size : (batch_no + 1)*batch_size, :]
87 | _, loss, prediction = sess.run(
88 | [optim, generator_model.loss,
89 | generator_model.arg_max_prediction],
90 | feed_dict = {
91 | generator_model.t_sentence : text_batch
92 | })
93 | end = time.clock()
94 | print "-------------------------------------------------------"
95 | print "LOSS: {}\tEPOCH: {}\tBATCH_NO: {}\t STEP:{}\t total_batches:{}".format(
96 | loss, epoch, batch_no, step, text_samples.shape[0]/args.batch_size)
97 | print "TIME FOR BATCH", end - start
98 | print "TIME FOR EPOCH (mins)", (end - start) * (text_samples.shape[0]/args.batch_size)/60.0
99 |
100 | batch_no += 1
101 | step += 1
102 |
103 | if step % args.summary_every == 0:
104 | [summary] = sess.run([merged_summary], feed_dict = {
105 | generator_model.t_sentence : text_batch
106 | })
107 | train_writer.add_summary(summary, step)
108 | print dl.inidices_to_string(prediction, vocab)
109 |
110 | print "********************************************************"
111 |
112 | if step % args.sample_every == 0:
113 | seed_sentence = np.array([dl.string_to_indices(args.seed, vocab)], dtype = 'int32' )
114 |
115 | for col in range(args.sample_size):
116 | [probs] = sess.run([generator_model.g_probs],
117 | feed_dict = {
118 | generator_model.seed_sentence :seed_sentence
119 | })
120 |
121 | curr_preds = []
122 | for bi in range(probs.shape[0]):
123 | pred_word = utils.sample_top(probs[bi][-1], top_k = args.top_k )
124 | curr_preds.append(pred_word)
125 |
126 | seed_sentence = np.insert(seed_sentence, seed_sentence.shape[1], curr_preds, axis = 1)
127 | print col, dl.inidices_to_string(seed_sentence[0], vocab)
128 |
129 | f = open('Data/generator_sample.txt', 'wb')
130 | f.write(dl.inidices_to_string(seed_sentence[0], vocab))
131 | f.close()
132 |
133 | if step % args.save_model_every == 0:
134 | save_path = saver.save(sess, "Data/Models/generation_model/model_epoch_{}_{}.ckpt".format(epoch, step))
135 |
136 | if __name__ == '__main__':
137 | main()
--------------------------------------------------------------------------------
/train_translator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import argparse
4 | import model_config
5 | import data_loader
6 | from ByteNet import translator
7 | import utils
8 | import shutil
9 | import time
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--learning_rate', type=float, default=0.001,
14 | help='Learning Rate')
15 | parser.add_argument('--batch_size', type=int, default=8,
16 | help='Learning Rate')
17 | parser.add_argument('--bucket_quant', type=int, default=50,
18 | help='Learning Rate')
19 | parser.add_argument('--max_epochs', type=int, default=1000,
20 | help='Max Epochs')
21 | parser.add_argument('--beta1', type=float, default=0.5,
22 | help='Momentum for Adam Update')
23 | parser.add_argument('--resume_model', type=str, default=None,
24 | help='Pre-Trained Model Path, to resume from')
25 | parser.add_argument('--source_file', type=str, default='Data/MachineTranslation/news-commentary-v11.de-en.de',
26 | help='Source File')
27 | parser.add_argument('--target_file', type=str, default='Data/MachineTranslation/news-commentary-v11.de-en.en',
28 | help='Target File')
29 | parser.add_argument('--sample_every', type=int, default=500,
30 | help='Sample generator output evry x steps')
31 | parser.add_argument('--summary_every', type=int, default=50,
32 | help='Sample generator output evry x steps')
33 | parser.add_argument('--top_k', type=int, default=5,
34 | help='Sample from top k predictions')
35 | parser.add_argument('--resume_from_bucket', type=int, default=0,
36 | help='Resume From Bucket')
37 | args = parser.parse_args()
38 |
39 | data_loader_options = {
40 | 'model_type' : 'translation',
41 | 'source_file' : args.source_file,
42 | 'target_file' : args.target_file,
43 | 'bucket_quant' : args.bucket_quant,
44 | }
45 |
46 | dl = data_loader.Data_Loader(data_loader_options)
47 | buckets, source_vocab, target_vocab = dl.load_translation_data()
48 | print "Number Of Buckets", len(buckets)
49 |
50 | config = model_config.translator_config
51 | model_options = {
52 | 'source_vocab_size' : len(source_vocab),
53 | 'target_vocab_size' : len(target_vocab),
54 | 'residual_channels' : config['residual_channels'],
55 | 'decoder_dilations' : config['decoder_dilations'],
56 | 'encoder_dilations' : config['encoder_dilations'],
57 | 'decoder_filter_width' : config['decoder_filter_width'],
58 | 'encoder_filter_width' : config['encoder_filter_width'],
59 | }
60 |
61 | translator_model = translator.ByteNet_Translator( model_options )
62 | translator_model.build_model()
63 |
64 | optim = tf.train.AdamOptimizer(
65 | args.learning_rate,
66 | beta1 = args.beta1).minimize(translator_model.loss)
67 |
68 | translator_model.build_translator(reuse = True)
69 | merged_summary = tf.summary.merge_all()
70 |
71 | sess = tf.InteractiveSession()
72 | tf.initialize_all_variables().run()
73 | saver = tf.train.Saver()
74 |
75 | if args.resume_model:
76 | saver.restore(sess, args.resume_model)
77 |
78 | shutil.rmtree('Data/tb_summaries/translator_model')
79 | train_writer = tf.summary.FileWriter('Data/tb_summaries/translator_model', sess.graph)
80 |
81 | bucket_sizes = [bucket_size for bucket_size in buckets]
82 | bucket_sizes.sort()
83 |
84 | step = 0
85 | batch_size = args.batch_size
86 | for epoch in range(args.max_epochs):
87 | for bucket_size in bucket_sizes:
88 | if epoch == 0 and bucket_size < args.resume_from_bucket:
89 | continue
90 |
91 | batch_no = 0
92 | while (batch_no + 1) * batch_size < len(buckets[bucket_size]):
93 | start = time.clock()
94 | source, target = dl.get_batch_from_pairs(
95 | buckets[bucket_size][batch_no * batch_size : (batch_no+1) * batch_size]
96 | )
97 |
98 | _, loss, prediction = sess.run(
99 | [optim, translator_model.loss, translator_model.arg_max_prediction],
100 |
101 | feed_dict = {
102 | translator_model.source_sentence : source,
103 | translator_model.target_sentence : target,
104 | })
105 | end = time.clock()
106 |
107 | print "LOSS: {}\tEPOCH: {}\tBATCH_NO: {}\t STEP:{}\t total_batches:{}\t bucket_size:{}".format(
108 | loss, epoch, batch_no, step, len(buckets[bucket_size])/args.batch_size, bucket_size)
109 | print "TIME FOR BATCH", end - start
110 | print "TIME FOR BUCKET (mins)", (end - start) * (len(buckets[bucket_size])/args.batch_size)/60.0
111 |
112 | batch_no += 1
113 | step += 1
114 |
115 | if step % args.summary_every == 0:
116 | [summary] = sess.run([merged_summary], feed_dict = {
117 | translator_model.source_sentence : source,
118 | translator_model.target_sentence : target,
119 | })
120 | train_writer.add_summary(summary, step)
121 |
122 | print "******"
123 | print "Source ", dl.inidices_to_string(source[0], source_vocab)
124 | print "---------"
125 | print "Target ", dl.inidices_to_string(target[0], target_vocab)
126 | print "----------"
127 | print "Prediction ",dl.inidices_to_string(prediction[0:bucket_size], target_vocab)
128 | print "******"
129 |
130 | if step % args.sample_every == 0:
131 | log_file = open('Data/translator_sample.txt', 'wb')
132 | generated_target = target[:,0:1]
133 | for col in range(bucket_size):
134 | [probs] = sess.run([translator_model.t_probs],
135 | feed_dict = {
136 | translator_model.t_source_sentence : source,
137 | translator_model.t_target_sentence : generated_target,
138 | })
139 |
140 | curr_preds = []
141 | for bi in range(probs.shape[0]):
142 | pred_word = utils.sample_top(probs[bi][-1], top_k = args.top_k )
143 | curr_preds.append(pred_word)
144 |
145 | generated_target = np.insert(generated_target, generated_target.shape[1], curr_preds, axis = 1)
146 |
147 |
148 | for bi in range(probs.shape[0]):
149 |
150 | print col, dl.inidices_to_string(generated_target[bi], target_vocab)
151 | print col, dl.inidices_to_string(target[bi], target_vocab)
152 | print "***************"
153 |
154 | if col == bucket_size - 1:
155 | try:
156 | log_file.write("Predicted: " + dl.inidices_to_string(generated_target[bi], target_vocab) + '\n')
157 | log_file.write("Actual Target: " + dl.inidices_to_string(target[bi], target_vocab) + '\n')
158 | log_file.write("Actual Source: " + dl.inidices_to_string(source[bi], source_vocab) + '\n *******')
159 | except:
160 | pass
161 | print "***************"
162 | log_file.close()
163 |
164 | save_path = saver.save(sess, "Data/Models/translation_model/model_epoch_{}_{}.ckpt".format(epoch, bucket_size))
165 |
166 |
167 |
168 | if __name__ == '__main__':
169 | main()
170 |
171 |
172 |
173 |
--------------------------------------------------------------------------------
/translate.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import argparse
4 | import model_config
5 | import data_loader
6 | from ByteNet import translator
7 | import utils
8 | import shutil
9 | import time
10 | import random
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--bucket_quant', type=int, default=50,
16 | help='Learning Rate')
17 | parser.add_argument('--model_path', type=str, default=None,
18 | help='Pre-Trained Model Path, to resume from')
19 | parser.add_argument('--source_file', type=str, default='Data/MachineTranslation/news-commentary-v11.de-en.de',
20 | help='Source File')
21 | parser.add_argument('--target_file', type=str, default='Data/MachineTranslation/news-commentary-v11.de-en.en',
22 | help='Target File')
23 | parser.add_argument('--top_k', type=int, default=5,
24 | help='Sample from top k predictions')
25 | parser.add_argument('--batch_size', type=int, default=16,
26 | help='Batch Size')
27 | parser.add_argument('--bucket_size', type=int, default=None,
28 | help='Bucket Size')
29 | args = parser.parse_args()
30 |
31 | data_loader_options = {
32 | 'model_type' : 'translation',
33 | 'source_file' : args.source_file,
34 | 'target_file' : args.target_file,
35 | 'bucket_quant' : args.bucket_quant,
36 | }
37 |
38 | dl = data_loader.Data_Loader(data_loader_options)
39 | buckets, source_vocab, target_vocab = dl.load_translation_data()
40 | print "Number Of Buckets", len(buckets)
41 |
42 | config = model_config.translator_config
43 | model_options = {
44 | 'source_vocab_size' : len(source_vocab),
45 | 'target_vocab_size' : len(target_vocab),
46 | 'residual_channels' : config['residual_channels'],
47 | 'decoder_dilations' : config['decoder_dilations'],
48 | 'encoder_dilations' : config['encoder_dilations'],
49 | 'decoder_filter_width' : config['decoder_filter_width'],
50 | 'encoder_filter_width' : config['encoder_filter_width'],
51 | }
52 |
53 | translator_model = translator.ByteNet_Translator( model_options )
54 | translator_model.build_translator()
55 |
56 | sess = tf.InteractiveSession()
57 | tf.initialize_all_variables().run()
58 | saver = tf.train.Saver()
59 |
60 | if args.model_path:
61 | saver.restore(sess, args.model_path)
62 |
63 |
64 |
65 | bucket_sizes = [bucket_size for bucket_size in buckets]
66 | bucket_sizes.sort()
67 |
68 | if not args.bucket_size:
69 | bucket_size = random.choice(bucket_sizes)
70 | else:
71 | bucket_size = args.bucket_size
72 |
73 | source, target = dl.get_batch_from_pairs(
74 | random.sample(buckets[bucket_size], args.batch_size)
75 | )
76 |
77 | log_file = open('Data/translator_sample.txt', 'wb')
78 | generated_target = target[:,0:1]
79 | for col in range(bucket_size):
80 | [probs] = sess.run([translator_model.t_probs],
81 | feed_dict = {
82 | translator_model.t_source_sentence : source,
83 | translator_model.t_target_sentence : generated_target,
84 | })
85 |
86 | curr_preds = []
87 | for bi in range(probs.shape[0]):
88 | pred_word = utils.sample_top(probs[bi][-1], top_k = args.top_k )
89 | curr_preds.append(pred_word)
90 |
91 | generated_target = np.insert(generated_target, generated_target.shape[1], curr_preds, axis = 1)
92 |
93 |
94 | for bi in range(probs.shape[0]):
95 |
96 | print col, dl.inidices_to_string(generated_target[bi], target_vocab)
97 | print col, dl.inidices_to_string(target[bi], target_vocab)
98 | print "***************"
99 |
100 | if col == bucket_size - 1:
101 | try:
102 | log_file.write("Predicted: " + dl.inidices_to_string(generated_target[bi], target_vocab) + '\n')
103 | log_file.write("Actual Target: " + dl.inidices_to_string(target[bi], target_vocab) + '\n')
104 | log_file.write("Actual Source: " + dl.inidices_to_string(source[bi], source_vocab) + '\n *******')
105 | except:
106 | pass
107 |
108 | log_file.close()
109 |
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def sample_top(a=[], top_k=10):
4 | idx = np.argsort(a)[::-1]
5 | idx = idx[:top_k]
6 | probs = a[idx]
7 | probs = probs / np.sum(probs)
8 | choice = np.random.choice(idx, p=probs)
9 | return choice
10 |
--------------------------------------------------------------------------------