├── .gitignore ├── License.md ├── README.md ├── run.sh ├── src ├── __init__.py ├── data_util │ ├── __init__.py │ ├── bucketdata.py │ └── data_gen.py ├── exp_config.py ├── launcher.py └── model │ ├── __init__.py │ ├── cnn.py │ ├── model.py │ ├── seq2seq.py │ └── seq2seq_model.py ├── test_demo.sh ├── tmp.py └── train_demo.sh /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | ### JetBrains template 61 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio 62 | 63 | *.iml 64 | 65 | ## Directory-based project format: 66 | .idea/ 67 | # if you remove the above rule, at least ignore the following: 68 | 69 | # User-specific stuff: 70 | # .idea/workspace.xml 71 | # .idea/tasks.xml 72 | # .idea/dictionaries 73 | 74 | # Sensitive or high-churn files: 75 | # .idea/dataSources.ids 76 | # .idea/dataSources.xml 77 | # .idea/sqlDataSources.xml 78 | # .idea/dynamic.xml 79 | # .idea/uiDesigner.xml 80 | 81 | # Gradle: 82 | # .idea/gradle.xml 83 | # .idea/libraries 84 | 85 | # Mongo Explorer plugin: 86 | # .idea/mongoSettings.xml 87 | 88 | ## File-based project format: 89 | *.ipr 90 | *.iws 91 | 92 | ## Plugin-specific files: 93 | 94 | # IntelliJ 95 | /out/ 96 | 97 | # mpeltonen/sbt-idea plugin 98 | .idea_modules/ 99 | 100 | # JIRA plugin 101 | atlassian-ide-plugin.xml 102 | 103 | # Crashlytics plugin (for Android Studio and IntelliJ) 104 | com_crashlytics_export_strings.xml 105 | crashlytics.properties 106 | crashlytics-build.properties 107 | 108 | # Created by .ignore support plugin (hsz.mobi) 109 | misc/ 110 | data/evaluation_data -------------------------------------------------------------------------------- /License.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Qi Guo and Yuntian Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention-OCR 2 | Authours: [Qi Guo](http://qiguo.ml) and [Yuntian Deng](https://github.com/da03) 3 | 4 | Visual Attention based OCR. The model first runs a sliding CNN on the image (images are resized to height 32 while preserving aspect ratio). Then an LSTM is stacked on top of the CNN. Finally, an attention model is used as a decoder for producing the final outputs. 5 | 6 | ![example image 0](http://cs.cmu.edu/~yuntiand/OCR-2.jpg) 7 | 8 | # Prerequsites 9 | Most of our code is written based on Tensorflow, but we also use Keras for the convolution part of our model. Besides, we use python package distance to calculate edit distance for evaluation. (However, that is not mandatory, if distance is not installed, we will do exact match). 10 | 11 | ### Tensorflow: [Installation Instructions](https://www.tensorflow.org/get_started/os_setup#download-and-setup) (tested on 0.12.1) 12 | 13 | ### Distance (Optional): 14 | 15 | ``` 16 | wget http://www.cs.cmu.edu/~yuntiand/Distance-0.1.3.tar.gz 17 | ``` 18 | 19 | ``` 20 | tar zxf Distance-0.1.3.tar.gz 21 | ``` 22 | 23 | ``` 24 | cd distance; sudo python setup.py install 25 | ``` 26 | 27 | # Usage: 28 | 29 | Note: We assume that the working directory is `Attention-OCR`. 30 | 31 | ## Train 32 | 33 | ### Data Preparation 34 | We need a file (specified by parameter `data-path`) containing the path of images and the corresponding characters, e.g.: 35 | 36 | ``` 37 | path/to/image1 abc 38 | path/to/image2 def 39 | ``` 40 | 41 | And we also need to specify a `data-base-dir` parameter such that we read the images from path `data-base-dir/path/to/image`. If `data-path` contains absolute path of images, then `data-base-dir` needs to be set to `/`. 42 | 43 | ### A Toy Example 44 | 45 | For a toy example, we have prepared a training dataset of the specified format, which is a subset of [Synth 90k](http://www.robots.ox.ac.uk/~vgg/data/text/) 46 | 47 | ``` 48 | wget http://www.cs.cmu.edu/~yuntiand/sample.tgz 49 | ``` 50 | 51 | ``` 52 | tar zxf sample.tgz 53 | ``` 54 | 55 | ``` 56 | python src/launcher.py --phase=train --data-path=sample/sample.txt --data-base-dir=sample --log-path=log.txt --no-load-model 57 | ``` 58 | 59 | After a while, you will see something like the following output in `log.txt`: 60 | 61 | ``` 62 | ... 63 | 2016-06-08 20:47:22,335 root INFO Created model with fresh parameters. 64 | 2016-06-08 20:47:52,852 root INFO current_step: 0 65 | 2016-06-08 20:48:01,253 root INFO step_time: 8.400597, step perplexity: 38.998714 66 | 2016-06-08 20:48:01,385 root INFO current_step: 1 67 | 2016-06-08 20:48:07,166 root INFO step_time: 5.781749, step perplexity: 38.998445 68 | 2016-06-08 20:48:07,337 root INFO current_step: 2 69 | 2016-06-08 20:48:12,322 root INFO step_time: 4.984972, step perplexity: 39.006730 70 | 2016-06-08 20:48:12,347 root INFO current_step: 3 71 | 2016-06-08 20:48:16,821 root INFO step_time: 4.473902, step perplexity: 39.000267 72 | 2016-06-08 20:48:16,859 root INFO current_step: 4 73 | 2016-06-08 20:48:21,452 root INFO step_time: 4.593249, step perplexity: 39.009864 74 | 2016-06-08 20:48:21,530 root INFO current_step: 5 75 | 2016-06-08 20:48:25,878 root INFO step_time: 4.348195, step perplexity: 38.987707 76 | 2016-06-08 20:48:26,016 root INFO current_step: 6 77 | 2016-06-08 20:48:30,851 root INFO step_time: 4.835423, step perplexity: 39.022887 78 | ``` 79 | 80 | Note that it takes quite a long time to reach convergence, since we are training the CNN and attention model simultaneously. 81 | 82 | ## Test and visualize attention results 83 | 84 | The test data format shall be the same as training data format. We have also prepared a test dataset of the specified format, which includes ICDAR03, ICDAR13, IIIT5k and SVT. 85 | 86 | ``` 87 | wget http://www.cs.cmu.edu/~yuntiand/evaluation_data.tgz 88 | ``` 89 | 90 | ``` 91 | tar zxf evaluation_data.tgz 92 | ``` 93 | 94 | We also provide a trained model on Synth 90K: 95 | 96 | ``` 97 | wget http://www.cs.cmu.edu/~yuntiand/model.tgz 98 | ``` 99 | 100 | ``` 101 | tar zxf model.tgz 102 | ``` 103 | 104 | ``` 105 | python src/launcher.py --phase=test --visualize --data-path=evaluation_data/svt/test.txt --data-base-dir=evaluation_data/svt --log-path=log.txt --load-model --model-dir=model --output-dir=results 106 | ``` 107 | 108 | After a while, you will see something like the following output in `log.txt`: 109 | 110 | ``` 111 | 2016-06-08 22:36:31,638 root INFO Reading model parameters from model/translate.ckpt-47200 112 | 2016-06-08 22:36:40,529 root INFO Compare word based on edit distance. 113 | 2016-06-08 22:36:41,652 root INFO step_time: 1.119277, step perplexity: 1.056626 114 | 2016-06-08 22:36:41,660 root INFO 1.000000 out of 1 correct 115 | 2016-06-08 22:36:42,358 root INFO step_time: 0.696687, step perplexity: 2.003350 116 | 2016-06-08 22:36:42,363 root INFO 1.666667 out of 2 correct 117 | 2016-06-08 22:36:42,831 root INFO step_time: 0.466550, step perplexity: 1.501963 118 | 2016-06-08 22:36:42,835 root INFO 2.466667 out of 3 correct 119 | 2016-06-08 22:36:43,402 root INFO step_time: 0.562091, step perplexity: 1.269991 120 | 2016-06-08 22:36:43,418 root INFO 3.366667 out of 4 correct 121 | 2016-06-08 22:36:43,897 root INFO step_time: 0.477545, step perplexity: 1.072437 122 | 2016-06-08 22:36:43,905 root INFO 4.366667 out of 5 correct 123 | 2016-06-08 22:36:44,107 root INFO step_time: 0.195361, step perplexity: 2.071796 124 | 2016-06-08 22:36:44,127 root INFO 5.144444 out of 6 correct 125 | 126 | ``` 127 | 128 | Example output images in `results/correct` (the output directory is set via parameter `output-dir` and the default is `results`): (Look closer to see it clearly.) 129 | 130 | Format: Image `index` (`predicted`/`ground truth`) `Image file` 131 | 132 | Image 0 (j/j): ![example image 0](http://cs.cmu.edu/~yuntiand/2evaluation_data_icdar13_images_word_370.png/image_0.jpg) 133 | 134 | Image 1 (u/u): ![example image 1](http://cs.cmu.edu/~yuntiand/2evaluation_data_icdar13_images_word_370.png/image_1.jpg) 135 | 136 | Image 2 (n/n): ![example image 2](http://cs.cmu.edu/~yuntiand/2evaluation_data_icdar13_images_word_370.png/image_2.jpg) 137 | 138 | Image 3 (g/g): ![example image 3](http://cs.cmu.edu/~yuntiand/2evaluation_data_icdar13_images_word_370.png/image_3.jpg) 139 | 140 | Image 4 (l/l): ![example image 4](http://cs.cmu.edu/~yuntiand/2evaluation_data_icdar13_images_word_370.png/image_4.jpg) 141 | 142 | Image 5 (e/e): ![example image 5](http://cs.cmu.edu/~yuntiand/2evaluation_data_icdar13_images_word_370.png/image_5.jpg) 143 | 144 | 145 | # Parameters: 146 | 147 | - Control 148 | * `phase`: Determine whether to train or test. 149 | * `visualize`: Valid if `phase` is set to test. Output the attention maps on the original image. 150 | * `load-model`: Load model from `model-dir` or not. 151 | 152 | - Input and output 153 | * `data-base-dir`: The base directory of the image path in `data-path`. If the image path in `data-path` is absolute path, set it to `/`. 154 | * `data-path`: The path containing data file names and labels. Format per line: `image_path characters`. 155 | * `model-dir`: The directory for saving and loading model parameters (structure is not stored). 156 | * `log-path`: The path to put log. 157 | * `output-dir`: The path to put visualization results if `visualize` is set to True. 158 | * `steps-per-checkpoint`: Checkpointing (print perplexity, save model) per how many steps 159 | 160 | - Optimization 161 | * `num-epoch`: The number of whole data passes. 162 | * `batch-size`: Batch size. Only valid if `phase` is set to train. 163 | * `initial-learning-rate`: Initial learning rate, note the we use AdaDelta, so the initial value doe not matter much. 164 | 165 | - Network 166 | * `target-embedding-size`: Embedding dimension for each target. 167 | * `attn-use-lstm`: Whether or not use LSTM attention decoder cell. 168 | * `attn-num-hidden`: Number of hidden units in attention decoder cell. 169 | * `attn-num-layers`: Number of layers in attention decoder cell. (Encoder number of hidden units will be `attn-num-hidden`*`attn-num-layers`). 170 | * `target-vocab-size`: Target vocabulary size. Default is = 26+10+3 # 0: PADDING, 1: GO, 2: EOS, >2: 0-9, a-z 171 | 172 | 173 | # References 174 | 175 | [Convert a formula to its LaTex source](https://github.com/harvardnlp/im2markup) 176 | 177 | [What You Get Is What You See: A Visual Markup Decompiler](https://arxiv.org/pdf/1609.04938.pdf) 178 | 179 | [Torch attention OCR](https://github.com/da03/torch-Attention-OCR) 180 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # train on iam (handwritten) 4 | python src/launcher.py --data-base-dir=/ --data-path=/home/sivankeret/wolf_dir/Dev2/Datasets/iam-words/images/tmp_images_lists/trainset.txt --model-dir=Workplace --log-path=Workplace/log.txt --steps-per-checkpoint=200 --phase=train 5 | 6 | # train on Synth90k subset toy example 7 | python src/launcher.py --data-base-dir=data/sample --data-path=data/sample/sample.txt --model-dir=Workplace/model --log-path=Workplace/model_log.txt --steps-per-checkpoint=200 --phase=train --no-load-model 8 | 9 | # train with load model 10 | python src/launcher.py --data-base-dir=data/sample --data-path=data/sample/sample.txt --model-dir=Workplace --log-path=Workplace/log.txt --phase=train --load-model 11 | 12 | python src/train.py --phase=train --train-data-path=data/sample/sample.txt --val-data-path=data/sample/sample.txt --train-data-base-dir=data/sample --val-data-base-dir=data/sample --log-path=Workplace/log_test.txt --model-dir=Workplace 13 | 14 | 15 | # test on same subset toy example 16 | python src/launcher.py --phase=test --data-path=data/sample/sample.txt --data-base-dir=data/sample --log-path=Workplace/log_test.txt --load-model --model-dir=Workplace --output-dir=Workplace/results 17 | 18 | 19 | 20 | python src/test.py --phase=test --data-path=data/sample/sample.txt --data-base-dir=data/sample --log-path=Workplace/log_test.txt --model-dir=Workplace --output-dir=Workplace/results 21 | 22 | 23 | python src/launcher.py --phase=train --data-path=/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px/annotation_train_words.txt --data-base-dir=/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px --log-path=Workplace/log_before_refactor.txt --model-dir=Workplace 24 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | -------------------------------------------------------------------------------- /src/data_util/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | -------------------------------------------------------------------------------- /src/data_util/bucketdata.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | # from keras.preprocessing.sequence import pad_sequences 7 | from collections import Counter 8 | import pickle as cPickle 9 | import random 10 | import math 11 | 12 | class BucketData(object): 13 | def __init__(self): 14 | self.max_width = 0 15 | self.max_label_len = 0 16 | self.data_list = [] 17 | self.data_len_list = [] 18 | self.label_list = [] 19 | self.file_list = [] 20 | 21 | def append(self, datum, label, filename): 22 | self.data_list.append(datum) 23 | self.data_len_list.append(int(math.floor(datum.shape[-1] / 4)) - 1) 24 | self.label_list.append(label) 25 | self.file_list.append(filename) 26 | 27 | self.max_width = max(datum.shape[-1], self.max_width) 28 | self.max_label_len = max(len(label), self.max_label_len) 29 | 30 | return len(self.data_list) 31 | 32 | def flush_out(self, bucket_specs, valid_target_length=float('inf'), 33 | go_shift=1): 34 | # print self.max_width, self.max_label_len 35 | res = dict(bucket_id=None, 36 | data=None, zero_paddings=None, encoder_mask=None, 37 | decoder_inputs=None, target_weights=None) 38 | 39 | def get_bucket_id(): 40 | for idx in range(0, len(bucket_specs)): 41 | if bucket_specs[idx][0] >= self.max_width / 4 - 1 \ 42 | and bucket_specs[idx][1] >= self.max_label_len: 43 | return idx 44 | return None 45 | 46 | res['bucket_id'] = get_bucket_id() 47 | if res['bucket_id'] is None: 48 | self.data_list, self.data_len_list, self.label_list = [], [], [] 49 | self.max_width, self.max_label_len = 0, 0 50 | return None 51 | 52 | encoder_input_len, decoder_input_len = bucket_specs[res['bucket_id']] 53 | 54 | # ENCODER PART 55 | res['data_len'] = [a.astype(np.int32) for a in 56 | np.array(self.data_len_list)] 57 | res['data'] = np.array(self.data_list) 58 | real_len = max(int(math.floor(self.max_width / 4)) - 1, 0) 59 | padd_len = int(encoder_input_len) - real_len 60 | res['zero_paddings'] = np.zeros([len(self.data_list), padd_len, 512], 61 | dtype=np.float32) 62 | encoder_mask = np.concatenate( 63 | (np.ones([len(self.data_list), real_len], dtype=np.float32), 64 | np.zeros([len(self.data_list), padd_len], dtype=np.float32)), 65 | axis=1) 66 | res['encoder_mask'] = [a[:, np.newaxis] for a in encoder_mask.T] # 32, (100, ) 67 | res['real_len'] = self.max_width 68 | 69 | # DECODER PART 70 | target_weights = [] 71 | for l_idx in range(len(self.label_list)): 72 | label_len = len(self.label_list[l_idx]) 73 | if label_len <= decoder_input_len: 74 | self.label_list[l_idx] = np.concatenate(( 75 | self.label_list[l_idx], 76 | np.zeros(decoder_input_len - label_len, dtype=np.int32))) 77 | one_mask_len = min(label_len - go_shift, valid_target_length) 78 | target_weights.append(np.concatenate(( 79 | np.ones(one_mask_len, dtype=np.float32), 80 | np.zeros(decoder_input_len - one_mask_len, 81 | dtype=np.float32)))) 82 | else: 83 | raise NotImplementedError 84 | # self.label_list[l_idx] = \ 85 | # self.label_list[l_idx][:decoder_input_len] 86 | # target_weights.append([1]*decoder_input_len) 87 | 88 | res['decoder_inputs'] = [a.astype(np.int32) for a in 89 | np.array(self.label_list).T] 90 | res['target_weights'] = [a.astype(np.float32) for a in 91 | np.array(target_weights).T] 92 | #print (res['decoder_inputs'][0]) 93 | #assert False 94 | assert len(res['decoder_inputs']) == len(res['target_weights']) 95 | res['filenames'] = self.file_list 96 | 97 | self.data_list, self.label_list, self.file_list = [], [], [] 98 | self.max_width, self.max_label_len = 0, 0 99 | 100 | return res 101 | 102 | def __len__(self): 103 | return len(self.data_list) 104 | 105 | def __iadd__(self, other): 106 | self.data_list += other.data_list 107 | self.label_list += other.label_list 108 | self.max_label_len = max(self.max_label_len, other.max_label_len) 109 | self.max_width = max(self.max_width, other.max_width) 110 | 111 | def __add__(self, other): 112 | res = BucketData() 113 | res.data_list = self.data_list + other.data_list 114 | res.label_list = self.label_list + other.label_list 115 | res.max_width = max(self.max_width, other.max_width) 116 | res.max_label_len = max((self.max_label_len, other.max_label_len)) 117 | return res 118 | -------------------------------------------------------------------------------- /src/data_util/data_gen.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | from collections import Counter 7 | import pickle as cPickle 8 | import random, math 9 | from data_util.bucketdata import BucketData 10 | 11 | 12 | 13 | class DataGen(object): 14 | GO = 1 15 | EOS = 2 16 | 17 | def __init__(self, 18 | data_root, annotation_fn, 19 | evaluate = False, 20 | valid_target_len = float('inf'), 21 | img_width_range = (12, 320), 22 | word_len = 30): 23 | """ 24 | :param data_root: 25 | :param annotation_fn: 26 | :param lexicon_fn: 27 | :param img_width_range: only needed for training set 28 | :return: 29 | """ 30 | 31 | img_height = 32 32 | self.data_root = data_root 33 | if os.path.exists(annotation_fn): 34 | self.annotation_path = annotation_fn 35 | else: 36 | self.annotation_path = os.path.join(data_root, annotation_fn) 37 | 38 | if evaluate: 39 | self.bucket_specs = [(int(math.floor(64 / 4)), int(word_len + 2)), (int(math.floor(108 / 4)), int(word_len + 2)), 40 | (int(math.floor(140 / 4)), int(word_len + 2)), (int(math.floor(256 / 4)), int(word_len + 2)), 41 | (int(math.floor(img_width_range[1] / 4)), int(word_len + 2))] 42 | else: 43 | self.bucket_specs = [(int(64 / 4), 9 + 2), (int(108 / 4), 15 + 2), 44 | (int(140 / 4), 17 + 2), (int(256 / 4), 20 + 2), 45 | (int(math.ceil(img_width_range[1] / 4)), word_len + 2)] 46 | 47 | self.bucket_min_width, self.bucket_max_width = img_width_range 48 | self.image_height = img_height 49 | self.valid_target_len = valid_target_len 50 | 51 | self.bucket_data = {i: BucketData() 52 | for i in range(self.bucket_max_width + 1)} 53 | 54 | def clear(self): 55 | self.bucket_data = {i: BucketData() 56 | for i in range(self.bucket_max_width + 1)} 57 | 58 | def get_size(self): 59 | with open(self.annotation_path, 'r') as ann_file: 60 | return len(ann_file.readlines()) 61 | 62 | def gen(self, batch_size): 63 | valid_target_len = self.valid_target_len 64 | with open(self.annotation_path, 'r') as ann_file: 65 | lines = ann_file.readlines() 66 | random.shuffle(lines) 67 | for l in lines: 68 | img_path, lex = l.strip().split() 69 | try: 70 | img_bw, word = self.read_data(img_path, lex) 71 | if valid_target_len < float('inf'): 72 | word = word[:valid_target_len + 1] 73 | width = img_bw.shape[-1] 74 | 75 | # TODO:resize if > 320 76 | b_idx = min(width, self.bucket_max_width) 77 | bs = self.bucket_data[b_idx].append(img_bw, word, os.path.join(self.data_root,img_path)) 78 | if bs >= batch_size: 79 | b = self.bucket_data[b_idx].flush_out( 80 | self.bucket_specs, 81 | valid_target_length=valid_target_len, 82 | go_shift=1) 83 | if b is not None: 84 | yield b 85 | else: 86 | assert False, 'no valid bucket of width %d'%width 87 | except IOError: 88 | pass # ignore error images 89 | #with open('error_img.txt', 'a') as ef: 90 | # ef.write(img_path + '\n') 91 | self.clear() 92 | 93 | def read_data(self, img_path, lex): 94 | assert 0 < len(lex) < self.bucket_specs[-1][1] 95 | # L = R * 299/1000 + G * 587/1000 + B * 114/1000 96 | with open(os.path.join(self.data_root, img_path), 'rb') as img_file: 97 | img = Image.open(img_file) 98 | w, h = img.size 99 | aspect_ratio = float(w) / float(h) 100 | if aspect_ratio < float(self.bucket_min_width) / self.image_height: 101 | img = img.resize( 102 | (self.bucket_min_width, self.image_height), 103 | Image.ANTIALIAS) 104 | elif aspect_ratio > float( 105 | self.bucket_max_width) / self.image_height: 106 | img = img.resize( 107 | (self.bucket_max_width, self.image_height), 108 | Image.ANTIALIAS) 109 | elif h != self.image_height: 110 | img = img.resize( 111 | (int(aspect_ratio * self.image_height), self.image_height), 112 | Image.ANTIALIAS) 113 | 114 | img_bw = img.convert('L') 115 | img_bw = np.asarray(img_bw, dtype=np.uint8) 116 | img_bw = img_bw[np.newaxis, :] 117 | 118 | # 'a':97, '0':48 119 | word = [self.GO] 120 | for c in lex: 121 | assert 96 < ord(c) < 123 or 47 < ord(c) < 58 122 | word.append( 123 | ord(c) - 97 + 13 if ord(c) > 96 else ord(c) - 48 + 3) 124 | word.append(self.EOS) 125 | word = np.array(word, dtype=np.int32) 126 | # word = np.array( [self.GO] + 127 | # [ord(c) - 97 + 13 if ord(c) > 96 else ord(c) - 48 + 3 128 | # for c in lex] + [self.EOS], dtype=np.int32) 129 | 130 | return img_bw, word 131 | 132 | 133 | def test_gen(): 134 | print('testing gen_valid') 135 | # s_gen = EvalGen('../../data/evaluation_data/svt', 'test.txt') 136 | # s_gen = EvalGen('../../data/evaluation_data/iiit5k', 'test.txt') 137 | # s_gen = EvalGen('../../data/evaluation_data/icdar03', 'test.txt') 138 | s_gen = EvalGen('../../data/evaluation_data/icdar13', 'test.txt') 139 | count = 0 140 | for batch in s_gen.gen(1): 141 | count += 1 142 | print(str(batch['bucket_id']) + ' ' + str(batch['data'].shape[2:])) 143 | assert batch['data'].shape[2] == img_height 144 | print(count) 145 | 146 | 147 | if __name__ == '__main__': 148 | test_gen() 149 | -------------------------------------------------------------------------------- /src/exp_config.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | """ 4 | Default paramters for experiemnt 5 | """ 6 | 7 | 8 | class ExpConfig: 9 | 10 | GPU_ID = 0 11 | # phase 12 | PHASE = 'test' 13 | VISUALIZE = True 14 | 15 | # input and output 16 | DATA_BASE_DIR = '/mnt/90kDICT32px' 17 | DATA_PATH = '/mnt/train_shuffled_words.txt' # path containing data file names and labels. Format: 18 | MODEL_DIR = 'train' # the directory for saving and loading model parameters (structure is not stored) 19 | LOG_PATH = 'log.txt' 20 | OUTPUT_DIR = 'results' # output directory 21 | STEPS_PER_CHECKPOINT = 500 # checkpointing (print perplexity, save model) per how many steps 22 | 23 | # Optimization 24 | NUM_EPOCH = 1000 25 | BATCH_SIZE = 64 26 | INITIAL_LEARNING_RATE = 1.0 # initial learning rate, note the we use AdaDelta, so the initial value doe not matter much 27 | 28 | # Network parameters 29 | CLIP_GRADIENTS = True # whether to perform gradient clipping 30 | MAX_GRADIENT_NORM = 5.0 # Clip gradients to this norm 31 | TARGET_EMBEDDING_SIZE = 10 # embedding dimension for each target 32 | ATTN_USE_LSTM = True # whether or not use LSTM attention decoder cell 33 | ATTN_NUM_HIDDEN=128 # number of hidden units in attention decoder cell 34 | ATTN_NUM_LAYERS = 2 # number of layers in attention decoder cell 35 | # (Encoder number of hidden units will be ATTN_NUM_HIDDEN*ATTN_NUM_LAYERS) 36 | LOAD_MODEL = False 37 | OLD_MODEL_VERSION = False 38 | TARGET_VOCAB_SIZE = 26+10+3 # 0: PADDING, 1: GO, 2: EOS, >2: 0-9, a-z 39 | -------------------------------------------------------------------------------- /src/launcher.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | 3 | import sys, argparse, logging 4 | 5 | import numpy as np 6 | from PIL import Image 7 | import tensorflow as tf 8 | tf.logging.set_verbosity(tf.logging.ERROR) 9 | 10 | 11 | from model.model import Model 12 | import exp_config 13 | 14 | def process_args(args, defaults): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--gpu-id', dest="gpu_id", 18 | type=int, default=defaults.GPU_ID) 19 | 20 | parser.add_argument('--use-gru', dest='use_gru', action='store_true') 21 | 22 | parser.add_argument('--phase', dest="phase", 23 | type=str, default=defaults.PHASE, 24 | choices=['train', 'test'], 25 | help=('Phase of experiment, can be either' 26 | ' train or test, default=%s'%(defaults.PHASE))) 27 | parser.add_argument('--data-path', dest="data_path", 28 | type=str, default=defaults.DATA_PATH, 29 | help=('Path of file containing the path and labels' 30 | ' of training or testing data, default=%s' 31 | %(defaults.DATA_PATH))) 32 | parser.add_argument('--data-base-dir', dest="data_base_dir", 33 | type=str, default=defaults.DATA_BASE_DIR, 34 | help=('The base directory of the paths in the file ' 35 | 'containing the path and labels, default=%s' 36 | %(defaults.DATA_PATH))) 37 | parser.add_argument('--visualize', dest='visualize', action='store_true', 38 | help=('Visualize attentions or not' 39 | ', default=%s' %(defaults.VISUALIZE))) 40 | parser.add_argument('--no-visualize', dest='visualize', action='store_false') 41 | parser.set_defaults(visualize=defaults.VISUALIZE) 42 | parser.add_argument('--batch-size', dest="batch_size", 43 | type=int, default=defaults.BATCH_SIZE, 44 | help=('Batch size, default = %s' 45 | %(defaults.BATCH_SIZE))) 46 | parser.add_argument('--initial-learning-rate', dest="initial_learning_rate", 47 | type=float, default=defaults.INITIAL_LEARNING_RATE, 48 | help=('Initial learning rate, default = %s' 49 | %(defaults.INITIAL_LEARNING_RATE))) 50 | parser.add_argument('--num-epoch', dest="num_epoch", 51 | type=int, default=defaults.NUM_EPOCH, 52 | help=('Number of epochs, default = %s' 53 | %(defaults.NUM_EPOCH))) 54 | parser.add_argument('--steps-per-checkpoint', dest="steps_per_checkpoint", 55 | type=int, default=defaults.STEPS_PER_CHECKPOINT, 56 | help=('Checkpointing (print perplexity, save model) per' 57 | ' how many steps, default = %s' 58 | %(defaults.STEPS_PER_CHECKPOINT))) 59 | parser.add_argument('--target-vocab-size', dest="target_vocab_size", 60 | type=int, default=defaults.TARGET_VOCAB_SIZE, 61 | help=('Target vocabulary size, default=%s' 62 | %(defaults.TARGET_VOCAB_SIZE))) 63 | parser.add_argument('--model-dir', dest="model_dir", 64 | type=str, default=defaults.MODEL_DIR, 65 | help=('The directory for saving and loading model ' 66 | '(structure is not stored), ' 67 | 'default=%s' %(defaults.MODEL_DIR))) 68 | parser.add_argument('--target-embedding-size', dest="target_embedding_size", 69 | type=int, default=defaults.TARGET_EMBEDDING_SIZE, 70 | help=('Embedding dimension for each target, default=%s' 71 | %(defaults.TARGET_EMBEDDING_SIZE))) 72 | parser.add_argument('--attn-num-hidden', dest="attn_num_hidden", 73 | type=int, default=defaults.ATTN_NUM_HIDDEN, 74 | help=('number of hidden units in attention decoder cell' 75 | ', default=%s' 76 | %(defaults.ATTN_NUM_HIDDEN))) 77 | parser.add_argument('--attn-num-layers', dest="attn_num_layers", 78 | type=int, default=defaults.ATTN_NUM_LAYERS, 79 | help=('number of hidden layers in attention decoder cell' 80 | ', default=%s' 81 | %(defaults.ATTN_NUM_LAYERS))) 82 | parser.add_argument('--load-model', dest='load_model', action='store_true', 83 | help=('Load model from model-dir or not' 84 | ', default=%s' %(defaults.LOAD_MODEL))) 85 | parser.add_argument('--no-load-model', dest='load_model', action='store_false') 86 | parser.set_defaults(load_model=defaults.LOAD_MODEL) 87 | parser.add_argument('--log-path', dest="log_path", 88 | type=str, default=defaults.LOG_PATH, 89 | help=('Log file path, default=%s' 90 | %(defaults.LOG_PATH))) 91 | parser.add_argument('--output-dir', dest="output_dir", 92 | type=str, default=defaults.OUTPUT_DIR, 93 | help=('Output directory, default=%s' 94 | %(defaults.OUTPUT_DIR))) 95 | parser.add_argument('--max_gradient_norm', dest="max_gradient_norm", 96 | type=int, default=defaults.MAX_GRADIENT_NORM, 97 | help=('Clip gradients to this norm.' 98 | ', default=%s' 99 | % (defaults.MAX_GRADIENT_NORM))) 100 | parser.add_argument('--no-gradient_clipping', dest='clip_gradients', action='store_false', 101 | help=('Do not perform gradient clipping, difault for clip_gradients is %s' % 102 | (defaults.CLIP_GRADIENTS))) 103 | parser.set_defaults(clip_gradients=defaults.CLIP_GRADIENTS) 104 | 105 | parameters = parser.parse_args(args) 106 | return parameters 107 | 108 | def main(args, defaults): 109 | parameters = process_args(args, defaults) 110 | logging.basicConfig( 111 | level=logging.DEBUG, 112 | format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s', 113 | filename=parameters.log_path) 114 | console = logging.StreamHandler() 115 | console.setLevel(logging.INFO) 116 | formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') 117 | console.setFormatter(formatter) 118 | logging.getLogger('').addHandler(console) 119 | 120 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 121 | model = Model( 122 | phase = parameters.phase, 123 | visualize = parameters.visualize, 124 | data_path = parameters.data_path, 125 | data_base_dir = parameters.data_base_dir, 126 | output_dir = parameters.output_dir, 127 | batch_size = parameters.batch_size, 128 | initial_learning_rate = parameters.initial_learning_rate, 129 | num_epoch = parameters.num_epoch, 130 | steps_per_checkpoint = parameters.steps_per_checkpoint, 131 | target_vocab_size = parameters.target_vocab_size, 132 | model_dir = parameters.model_dir, 133 | target_embedding_size = parameters.target_embedding_size, 134 | attn_num_hidden = parameters.attn_num_hidden, 135 | attn_num_layers = parameters.attn_num_layers, 136 | clip_gradients = parameters.clip_gradients, 137 | max_gradient_norm = parameters.max_gradient_norm, 138 | load_model = parameters.load_model, 139 | valid_target_length = float('inf'), 140 | gpu_id=parameters.gpu_id, 141 | use_gru=parameters.use_gru, 142 | session = sess) 143 | model.launch() 144 | 145 | if __name__ == "__main__": 146 | main(sys.argv[1:], exp_config.ExpConfig) 147 | 148 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | -------------------------------------------------------------------------------- /src/model/cnn.py: -------------------------------------------------------------------------------- 1 | __author__ = 'moonkey' 2 | 3 | #from keras import models, layers 4 | import logging 5 | import numpy as np 6 | # from src.data_util.synth_prepare import SynthGen 7 | 8 | #import keras.backend as K 9 | import tensorflow as tf 10 | 11 | 12 | def var_random(name, shape, regularizable=False): 13 | ''' 14 | Initialize a random variable using xavier initialization. 15 | Add regularization if regularizable=True 16 | :param name: 17 | :param shape: 18 | :param regularizable: 19 | :return: 20 | ''' 21 | v = tf.get_variable(name, shape=shape, initializer=tf.contrib.layers.xavier_initializer()) 22 | if regularizable: 23 | with tf.name_scope(name + '/Regularizer/'): 24 | tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, tf.nn.l2_loss(v)) 25 | return v 26 | 27 | def max_2x2pool(incoming, name): 28 | ''' 29 | max pooling on 2 dims. 30 | :param incoming: 31 | :param name: 32 | :return: 33 | ''' 34 | with tf.variable_scope(name): 35 | return tf.nn.max_pool(incoming, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='VALID') 36 | 37 | def max_2x1pool(incoming, name): 38 | ''' 39 | max pooling only on image width 40 | :param incoming: 41 | :param name: 42 | :return: 43 | ''' 44 | with tf.variable_scope(name): 45 | return tf.nn.max_pool(incoming, ksize=(1, 2, 1, 1), strides=(1, 2, 1, 1), padding='VALID') 46 | 47 | def ConvRelu(incoming, num_filters, filter_size, name): 48 | ''' 49 | Add a convolution layer followed by a Relu layer. 50 | :param incoming: 51 | :param num_filters: 52 | :param filter_size: 53 | :param name: 54 | :return: 55 | ''' 56 | num_filters_from = incoming.get_shape().as_list()[3] 57 | with tf.variable_scope(name): 58 | conv_W = var_random('W', tuple(filter_size) + (num_filters_from, num_filters), regularizable=True) 59 | 60 | after_conv = tf.nn.conv2d(incoming, conv_W, strides=(1,1,1,1), padding='SAME') 61 | 62 | return tf.nn.relu(after_conv) 63 | 64 | 65 | def batch_norm(incoming, is_training): 66 | ''' 67 | batch normalization 68 | :param incoming: 69 | :param is_training: 70 | :return: 71 | ''' 72 | return tf.contrib.layers.batch_norm(incoming, is_training=is_training, scale=True, decay=0.99) 73 | 74 | 75 | def ConvReluBN(incoming, num_filters, filter_size, name, is_training, padding_type = 'SAME'): 76 | ''' 77 | Convolution -> Batch normalization -> Relu 78 | :param incoming: 79 | :param num_filters: 80 | :param filter_size: 81 | :param name: 82 | :param is_training: 83 | :param padding_type: 84 | :return: 85 | ''' 86 | num_filters_from = incoming.get_shape().as_list()[3] 87 | with tf.variable_scope(name): 88 | conv_W = var_random('W', tuple(filter_size) + (num_filters_from, num_filters), regularizable=True) 89 | 90 | after_conv = tf.nn.conv2d(incoming, conv_W, strides=(1,1,1,1), padding=padding_type) 91 | 92 | after_bn = batch_norm(after_conv, is_training) 93 | 94 | return tf.nn.relu(after_bn) 95 | 96 | def dropout(incoming, is_training, keep_prob=0.5): 97 | return tf.contrib.layers.dropout(incoming, keep_prob=keep_prob, is_training=is_training) 98 | 99 | def tf_create_attention_map(incoming): 100 | ''' 101 | flatten hight and width into one dimention of size attn_length 102 | :param incoming: 3D Tensor [batch_size x cur_h x cur_w x num_channels] 103 | :return: attention_map: 3D Tensor [batch_size x attn_length x attn_size]. 104 | ''' 105 | shape = incoming.get_shape().as_list() 106 | print("shape of incoming is: {}".format(incoming.get_shape())) 107 | print(shape) 108 | return tf.reshape(incoming, (-1, np.prod(shape[1:3]), shape[3])) 109 | 110 | class CNN(object): 111 | """ 112 | Usage for tf tensor output: 113 | o = CNN(x).tf_output() 114 | 115 | """ 116 | 117 | def __init__(self, input_tensor, is_training): 118 | self._build_network(input_tensor, is_training) 119 | 120 | def _build_network(self, input_tensor, is_training): 121 | """ 122 | https://github.com/bgshih/crnn/blob/master/model/crnn_demo/config.lua 123 | :return: 124 | """ 125 | print('input_tensor dim: {}'.format(input_tensor.get_shape())) 126 | net = tf.transpose(input_tensor, perm=[0, 2, 3, 1]) 127 | net = tf.add(net, (-128.0)) 128 | net = tf.multiply(net, (1/128.0)) 129 | 130 | net = ConvRelu(net, 64, (3, 3), 'conv_conv1') 131 | net = max_2x2pool(net, 'conv_pool1') 132 | 133 | net = ConvRelu(net, 128, (3, 3), 'conv_conv2') 134 | net = max_2x2pool(net, 'conv_pool2') 135 | 136 | net = ConvReluBN(net, 256, (3, 3), 'conv_conv3', is_training) 137 | net = ConvRelu(net, 256, (3, 3), 'conv_conv4') 138 | net = max_2x1pool(net, 'conv_pool3') 139 | 140 | net = ConvReluBN(net, 512, (3, 3), 'conv_conv5', is_training) 141 | net = ConvRelu(net, 512, (3, 3), 'conv_conv6') 142 | net = max_2x1pool(net, 'conv_pool4') 143 | 144 | net = ConvReluBN(net, 512, (2, 2), 'conv_conv7', is_training, "VALID") 145 | net = dropout(net, is_training) 146 | 147 | print('CNN outdim before squeeze: {}'.format(net.get_shape())) # 1x32x100 -> 24x512 148 | 149 | net = tf.squeeze(net,axis=1) 150 | 151 | print('CNN outdim: {}'.format(net.get_shape())) 152 | self.model = net 153 | 154 | def tf_output(self): 155 | # if self.input_tensor is not None: 156 | return self.model 157 | ''' 158 | def __call__(self, input_tensor): 159 | return self.model(input_tensor) 160 | ''' 161 | def save(self): 162 | pass 163 | 164 | 165 | -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | """Visual Attention Based OCR Model.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import random, time, os, shutil, math, sys, logging 8 | #import ipdb 9 | import numpy as np 10 | from six.moves import xrange # pylint: disable=redefined-builtin 11 | from PIL import Image 12 | import tensorflow as tf 13 | #import keras.backend as K 14 | #from tensorflow.models.rnn.translate import data_utils 15 | 16 | from .cnn import CNN 17 | from .seq2seq_model import Seq2SeqModel 18 | from data_util.data_gen import DataGen 19 | from tqdm import tqdm 20 | 21 | try: 22 | import distance 23 | distance_loaded = True 24 | except ImportError: 25 | distance_loaded = False 26 | 27 | class Model(object): 28 | 29 | def __init__(self, 30 | phase, 31 | visualize, 32 | data_path, 33 | data_base_dir, 34 | output_dir, 35 | batch_size, 36 | initial_learning_rate, 37 | num_epoch, 38 | steps_per_checkpoint, 39 | target_vocab_size, 40 | model_dir, 41 | target_embedding_size, 42 | attn_num_hidden, 43 | attn_num_layers, 44 | clip_gradients, 45 | max_gradient_norm, 46 | session, 47 | load_model, 48 | gpu_id, 49 | use_gru, 50 | evaluate=False, 51 | valid_target_length=float('inf'), 52 | reg_val = 0 ): 53 | 54 | gpu_device_id = '/gpu:' + str(gpu_id) 55 | if not os.path.exists(model_dir): 56 | os.makedirs(model_dir) 57 | logging.info('loading data') 58 | # load data 59 | if phase == 'train': 60 | self.s_gen = DataGen( 61 | data_base_dir, data_path, valid_target_len=valid_target_length, evaluate=False) 62 | else: 63 | batch_size = 1 64 | self.s_gen = DataGen( 65 | data_base_dir, data_path, evaluate=True) 66 | 67 | 68 | #logging.info('valid_target_length: %s' %(str(valid_target_length))) 69 | logging.info('phase: %s' % phase) 70 | logging.info('model_dir: %s' % (model_dir)) 71 | logging.info('load_model: %s' % (load_model)) 72 | logging.info('output_dir: %s' % (output_dir)) 73 | logging.info('steps_per_checkpoint: %d' % (steps_per_checkpoint)) 74 | logging.info('batch_size: %d' %(batch_size)) 75 | logging.info('num_epoch: %d' %num_epoch) 76 | logging.info('learning_rate: %d' % initial_learning_rate) 77 | logging.info('reg_val: %d' % (reg_val)) 78 | logging.info('max_gradient_norm: %f' % max_gradient_norm) 79 | logging.info('clip_gradients: %s' % clip_gradients) 80 | logging.info('valid_target_length %f' %valid_target_length) 81 | logging.info('target_vocab_size: %d' %target_vocab_size) 82 | logging.info('target_embedding_size: %f' % target_embedding_size) 83 | logging.info('attn_num_hidden: %d' % attn_num_hidden) 84 | logging.info('attn_num_layers: %d' % attn_num_layers) 85 | logging.info('visualize: %s' % visualize) 86 | 87 | buckets = self.s_gen.bucket_specs 88 | logging.info('buckets') 89 | logging.info(buckets) 90 | if use_gru: 91 | logging.info('ues GRU in the decoder.') 92 | 93 | # variables 94 | self.img_data = tf.placeholder(tf.float32, shape=(None, 1, 32, None), name='img_data') 95 | self.zero_paddings = tf.placeholder(tf.float32, shape=(None, None, 512), name='zero_paddings') 96 | 97 | self.decoder_inputs = [] 98 | self.encoder_masks = [] 99 | self.target_weights = [] 100 | for i in xrange(int(buckets[-1][0] + 1)): 101 | self.encoder_masks.append(tf.placeholder(tf.float32, shape=[None, 1], 102 | name="encoder_mask{0}".format(i))) 103 | for i in xrange(buckets[-1][1] + 1): 104 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 105 | name="decoder{0}".format(i))) 106 | self.target_weights.append(tf.placeholder(tf.float32, shape=[None], 107 | name="weight{0}".format(i))) 108 | 109 | self.reg_val = reg_val 110 | self.sess = session 111 | self.evaluate = evaluate 112 | self.steps_per_checkpoint = steps_per_checkpoint 113 | self.model_dir = model_dir 114 | self.output_dir = output_dir 115 | self.buckets = buckets 116 | self.batch_size = batch_size 117 | self.num_epoch = num_epoch 118 | self.global_step = tf.Variable(0, trainable=False) 119 | self.valid_target_length = valid_target_length 120 | self.phase = phase 121 | self.visualize = visualize 122 | self.learning_rate = initial_learning_rate 123 | self.clip_gradients = clip_gradients 124 | 125 | if phase == 'train': 126 | self.forward_only = False 127 | elif phase == 'test': 128 | self.forward_only = True 129 | else: 130 | assert False, phase 131 | 132 | with tf.device(gpu_device_id): 133 | cnn_model = CNN(self.img_data, True) #(not self.forward_only)) 134 | self.conv_output = cnn_model.tf_output() 135 | self.concat_conv_output = tf.concat(axis=1, values=[self.conv_output, self.zero_paddings]) 136 | 137 | self.perm_conv_output = tf.transpose(self.concat_conv_output, perm=[1, 0, 2]) 138 | 139 | with tf.device(gpu_device_id): 140 | self.attention_decoder_model = Seq2SeqModel( 141 | encoder_masks = self.encoder_masks, 142 | encoder_inputs_tensor = self.perm_conv_output, 143 | decoder_inputs = self.decoder_inputs, 144 | target_weights = self.target_weights, 145 | target_vocab_size = target_vocab_size, 146 | buckets = buckets, 147 | target_embedding_size = target_embedding_size, 148 | attn_num_layers = attn_num_layers, 149 | attn_num_hidden = attn_num_hidden, 150 | forward_only = self.forward_only, 151 | use_gru = use_gru) 152 | 153 | 154 | 155 | 156 | if not self.forward_only: 157 | 158 | self.updates = [] 159 | self.summaries_by_bucket = [] 160 | with tf.device(gpu_device_id): 161 | params = tf.trainable_variables() 162 | # Gradients and SGD update operation for training the model. 163 | opt = tf.train.AdadeltaOptimizer(learning_rate=initial_learning_rate) 164 | for b in xrange(len(buckets)): 165 | if self.reg_val > 0: 166 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 167 | logging.info('Adding %s regularization losses', len(reg_losses)) 168 | logging.debug('REGULARIZATION_LOSSES: %s', reg_losses) 169 | loss_op = self.reg_val * tf.reduce_sum(reg_losses) + self.attention_decoder_model.losses[b] 170 | else: 171 | loss_op = self.attention_decoder_model.losses[b] 172 | 173 | gradients, params = zip(*opt.compute_gradients(loss_op, params)) 174 | if self.clip_gradients: 175 | gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm) 176 | # Add summaries for loss, variables, gradients, gradient norms and total gradient norm. 177 | summaries = [] 178 | ''' 179 | for gradient, variable in gradients: 180 | if isinstance(gradient, tf.IndexedSlices): 181 | grad_values = gradient.values 182 | else: 183 | grad_values = gradient 184 | summaries.append(tf.summary.histogram(variable.name, variable)) 185 | summaries.append(tf.summary.histogram(variable.name + "/gradients", grad_values)) 186 | summaries.append(tf.summary.scalar(variable.name + "/gradient_norm", 187 | tf.global_norm([grad_values]))) 188 | ''' 189 | summaries.append(tf.summary.scalar("loss", loss_op)) 190 | summaries.append(tf.summary.scalar("total_gradient_norm", tf.global_norm(gradients))) 191 | all_summaries = tf.summary.merge(summaries) 192 | self.summaries_by_bucket.append(all_summaries) 193 | # update op - apply gradients 194 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 195 | with tf.control_dependencies(update_ops): 196 | self.updates.append(opt.apply_gradients(zip(gradients, params), global_step=self.global_step)) 197 | 198 | self.saver_all = tf.train.Saver(tf.all_variables()) 199 | 200 | ckpt = tf.train.get_checkpoint_state(model_dir) 201 | if ckpt and load_model: 202 | logging.info("Reading model parameters from %s" % ckpt.model_checkpoint_path) 203 | #self.saver.restore(self.sess, ckpt.model_checkpoint_path) 204 | self.saver_all.restore(self.sess, ckpt.model_checkpoint_path) 205 | else: 206 | logging.info("Created model with fresh parameters.") 207 | self.sess.run(tf.initialize_all_variables()) 208 | #self.sess.run(init_new_vars_op) 209 | 210 | 211 | # train or test as specified by phase 212 | def launch(self): 213 | step_time, loss = 0.0, 0.0 214 | current_step = 0 215 | previous_losses = [] 216 | writer = tf.summary.FileWriter(self.model_dir, self.sess.graph) 217 | if self.phase == 'test': 218 | if not distance_loaded: 219 | logging.info('Warning: distance module not installed. Do whole sequence comparison instead.') 220 | else: 221 | logging.info('Compare word based on edit distance.') 222 | num_correct = 0 223 | num_total = 0 224 | for batch in self.s_gen.gen(self.batch_size): 225 | # Get a batch and make a step. 226 | start_time = time.time() 227 | bucket_id = batch['bucket_id'] 228 | img_data = batch['data'] 229 | zero_paddings = batch['zero_paddings'] 230 | decoder_inputs = batch['decoder_inputs'] 231 | target_weights = batch['target_weights'] 232 | encoder_masks = batch['encoder_mask'] 233 | file_list = batch['filenames'] 234 | real_len = batch['real_len'] 235 | 236 | grounds = [a for a in np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()] 237 | _, step_loss, step_logits, step_attns = self.step(encoder_masks, img_data, zero_paddings, decoder_inputs, target_weights, bucket_id, self.forward_only) 238 | curr_step_time = (time.time() - start_time) 239 | step_time += curr_step_time / self.steps_per_checkpoint 240 | logging.info('step_time: %f, loss: %f, step perplexity: %f'%(curr_step_time, step_loss, math.exp(step_loss) if step_loss < 300 else float('inf'))) 241 | loss += step_loss / self.steps_per_checkpoint 242 | current_step += 1 243 | step_outputs = [b for b in np.array([np.argmax(logit, axis=1).tolist() for logit in step_logits]).transpose()] 244 | if self.visualize: 245 | step_attns = np.array([[a.tolist() for a in step_attn] for step_attn in step_attns]).transpose([1, 0, 2]) 246 | #print (step_attns) 247 | 248 | for idx, output, ground in zip(range(len(grounds)), step_outputs, grounds): 249 | flag_ground,flag_out = True, True 250 | num_total += 1 251 | output_valid = [] 252 | ground_valid = [] 253 | for j in range(1,len(ground)): 254 | s1 = output[j-1] 255 | s2 = ground[j] 256 | if s2 != 2 and flag_ground: 257 | ground_valid.append(s2) 258 | else: 259 | flag_ground = False 260 | if s1 != 2 and flag_out: 261 | output_valid.append(s1) 262 | else: 263 | flag_out = False 264 | if distance_loaded: 265 | num_incorrect = distance.levenshtein(output_valid, ground_valid) 266 | if self.visualize: 267 | self.visualize_attention(file_list[idx], step_attns[idx], output_valid, ground_valid, num_incorrect>0, real_len) 268 | num_incorrect = float(num_incorrect) / len(ground_valid) 269 | num_incorrect = min(1.0, num_incorrect) 270 | else: 271 | if output_valid == ground_valid: 272 | num_incorrect = 0 273 | else: 274 | num_incorrect = 1 275 | if self.visualize: 276 | self.visualize_attention(file_list[idx], step_attns[idx], output_valid, ground_valid, num_incorrect>0, real_len) 277 | num_correct += 1. - num_incorrect 278 | logging.info('%f out of %d correct' %(num_correct, num_total)) 279 | elif self.phase == 'train': 280 | total = (self.s_gen.get_size() // self.batch_size) 281 | with tqdm(desc='Train: ', total=total) as pbar: 282 | for epoch in range(self.num_epoch): 283 | 284 | logging.info('Generating first batch)') 285 | for i,batch in enumerate(self.s_gen.gen(self.batch_size)): 286 | # Get a batch and make a step. 287 | num_total = 0 288 | num_correct = 0 289 | start_time = time.time() 290 | batch_len = batch['real_len'] 291 | bucket_id = batch['bucket_id'] 292 | img_data = batch['data'] 293 | zero_paddings = batch['zero_paddings'] 294 | decoder_inputs = batch['decoder_inputs'] 295 | target_weights = batch['target_weights'] 296 | encoder_masks = batch['encoder_mask'] 297 | #logging.info('current_step: %d'%current_step) 298 | #logging.info(np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()[0]) 299 | #print (np.array([target_weight.tolist() for target_weight in target_weights]).transpose()[0]) 300 | summaries, step_loss, step_logits, _ = self.step(encoder_masks, img_data, zero_paddings, decoder_inputs, target_weights, bucket_id, self.forward_only) 301 | 302 | grounds = [a for a in 303 | np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()] 304 | step_outputs = [b for b in 305 | np.array( 306 | [np.argmax(logit, axis=1).tolist() for logit in step_logits]).transpose()] 307 | 308 | for idx, output, ground in zip(range(len(grounds)), step_outputs, grounds): 309 | flag_ground, flag_out = True, True 310 | num_total += 1 311 | output_valid = [] 312 | ground_valid = [] 313 | for j in range(1, len(ground)): 314 | s1 = output[j - 1] 315 | s2 = ground[j] 316 | if s2 != 2 and flag_ground: 317 | ground_valid.append(s2) 318 | else: 319 | flag_ground = False 320 | if s1 != 2 and flag_out: 321 | output_valid.append(s1) 322 | else: 323 | flag_out = False 324 | if distance_loaded: 325 | num_incorrect = distance.levenshtein(output_valid, ground_valid) 326 | num_incorrect = float(num_incorrect) / len(ground_valid) 327 | num_incorrect = min(1.0, num_incorrect) 328 | else: 329 | if output_valid == ground_valid: 330 | num_incorrect = 0 331 | else: 332 | num_incorrect = 1 333 | num_correct += 1. - num_incorrect 334 | 335 | writer.add_summary(summaries, current_step) 336 | curr_step_time = (time.time() - start_time) 337 | step_time += curr_step_time / self.steps_per_checkpoint 338 | precision = num_correct / num_total 339 | logging.info('step %f - time: %f, loss: %f, perplexity: %f, precision: %f, batch_len: %f'%(current_step, curr_step_time, step_loss, math.exp(step_loss) if step_loss < 300 else float('inf'), precision, batch_len)) 340 | loss += step_loss / self.steps_per_checkpoint 341 | pbar.set_description('Train, loss={:.8f}'.format(step_loss)) 342 | pbar.update() 343 | current_step += 1 344 | # If there is an EOS symbol in outputs, cut them at that point. 345 | #if data_utils.EOS_ID in step_outputs: 346 | # step_outputs = step_outputs[:step_outputs.index(data_utils.EOS_ID)] 347 | #if data_utils.PAD_ID in decoder_inputs: 348 | #decoder_inputs = decoder_inputs[:decoder_inputs.index(data_utils.PAD_ID)] 349 | # print (step_outputs[0]) 350 | 351 | # Once in a while, we save checkpoint, print statistics, and run evals. 352 | if current_step % self.steps_per_checkpoint == 0: 353 | # Print statistics for the previous epoch. 354 | perplexity = math.exp(loss) if loss < 300 else float('inf') 355 | logging.info("global step %d step-time %.2f loss %f perplexity " 356 | "%.2f" % (self.global_step.eval(), step_time, loss, perplexity)) 357 | previous_losses.append(loss) 358 | # Save checkpoint and zero timer and loss. 359 | if not self.forward_only: 360 | checkpoint_path = os.path.join(self.model_dir, "translate.ckpt") 361 | logging.info("Saving model, current_step: %d"%current_step) 362 | self.saver_all.save(self.sess, checkpoint_path, global_step=self.global_step) 363 | step_time, loss = 0.0, 0.0 364 | #sys.stdout.flush() 365 | 366 | # step, read one batch, generate gradients 367 | def step(self, encoder_masks, img_data, zero_paddings, decoder_inputs, target_weights, 368 | bucket_id, forward_only): 369 | # Check if the sizes match. 370 | encoder_size, decoder_size = self.buckets[bucket_id] 371 | if len(decoder_inputs) != decoder_size: 372 | raise ValueError("Decoder length must be equal to the one in bucket," 373 | " %d != %d." % (len(decoder_inputs), decoder_size)) 374 | if len(target_weights) != decoder_size: 375 | raise ValueError("Weights length must be equal to the one in bucket," 376 | " %d != %d." % (len(target_weights), decoder_size)) 377 | 378 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 379 | input_feed = {} 380 | input_feed[self.img_data.name] = img_data 381 | input_feed[self.zero_paddings.name] = zero_paddings 382 | for l in xrange(decoder_size): 383 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 384 | input_feed[self.target_weights[l].name] = target_weights[l] 385 | for l in xrange(int(encoder_size)): 386 | try: 387 | input_feed[self.encoder_masks[l].name] = encoder_masks[l] 388 | except Exception as e: 389 | pass 390 | #ipdb.set_trace() 391 | 392 | # Since our targets are decoder inputs shifted by one, we need one more. 393 | last_target = self.decoder_inputs[decoder_size].name 394 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 395 | 396 | # Output feed: depends on whether we do a backward step or not. 397 | if not forward_only: 398 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 399 | #self.gradient_norms[bucket_id], # Gradient norm. 400 | self.attention_decoder_model.losses[bucket_id], 401 | self.summaries_by_bucket[bucket_id]] 402 | for l in xrange(decoder_size): # Output logits. 403 | output_feed.append(self.attention_decoder_model.outputs[bucket_id][l]) 404 | else: 405 | output_feed = [self.attention_decoder_model.losses[bucket_id]] # Loss for this batch. 406 | for l in xrange(decoder_size): # Output logits. 407 | output_feed.append(self.attention_decoder_model.outputs[bucket_id][l]) 408 | if self.visualize: 409 | output_feed += self.attention_decoder_model.attention_weights_histories[bucket_id] 410 | 411 | outputs = self.sess.run(output_feed, input_feed) 412 | if not forward_only: 413 | return outputs[2], outputs[1], outputs[3:(3+self.buckets[bucket_id][1])], None # Gradient norm summary, loss, no outputs, no attentions. 414 | else: 415 | return None, outputs[0], outputs[1:(1+self.buckets[bucket_id][1])], outputs[(1+self.buckets[bucket_id][1]):] # No gradient norm, loss, outputs, attentions. 416 | 417 | 418 | def visualize_attention(self, filename, attentions, output_valid, ground_valid, flag_incorrect, real_len): 419 | if flag_incorrect: 420 | output_dir = os.path.join(self.output_dir, 'incorrect') 421 | else: 422 | output_dir = os.path.join(self.output_dir, 'correct') 423 | output_dir = os.path.join(output_dir, filename.replace('/', '_')) 424 | if not os.path.exists(output_dir): 425 | os.makedirs(output_dir) 426 | with open(os.path.join(output_dir, 'word.txt'), 'w') as fword: 427 | fword.write(' '.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in ground_valid])+'\n') 428 | fword.write(' '.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in output_valid])) 429 | with open(filename, 'rb') as img_file: 430 | img = Image.open(img_file) 431 | w, h = img.size 432 | h = 32 433 | img = img.resize( 434 | (real_len, h), 435 | Image.ANTIALIAS) 436 | img_data = np.asarray(img, dtype=np.uint8) 437 | for idx in range(len(output_valid)): 438 | output_filename = os.path.join(output_dir, 'image_%d.jpg'%(idx)) 439 | attention = attentions[idx][:(int(real_len/4)-1)] 440 | 441 | # I have got the attention_orig here, which is of size 32*len(ground_truth), the only thing left is to visualize it and save it to output_filename 442 | # TODO here 443 | attention_orig = np.zeros(real_len) 444 | for i in range(real_len): 445 | if 0 < i/4-1 and i/4-1 < len(attention): 446 | attention_orig[i] = attention[int(i/4)-1] 447 | attention_orig = np.convolve(attention_orig, [0.199547,0.200226,0.200454,0.200226,0.199547], mode='same') 448 | attention_orig = np.maximum(attention_orig, 0.3) 449 | attention_out = np.zeros((h, real_len)) 450 | for i in range(real_len): 451 | attention_out[:,i] = attention_orig[i] 452 | if len(img_data.shape) == 3: 453 | attention_out = attention_out[:,:,np.newaxis] 454 | img_out_data = img_data * attention_out 455 | img_out = Image.fromarray(img_out_data.astype(np.uint8)) 456 | img_out.save(output_filename) 457 | #print (output_filename) 458 | #assert False 459 | -------------------------------------------------------------------------------- /src/model/seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. # 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | 15 | """Library for creating sequence-to-sequence models in TensorFlow. 16 | 17 | Sequence-to-sequence recurrent neural networks can learn complex functions 18 | that map input sequences to output sequences. These models yield very good 19 | results on a number of tasks, such as speech recognition, parsing, machine 20 | translation, or even constructing automated replies to emails. 21 | 22 | Before using this module, it is recommended to read the TensorFlow tutorial 23 | on sequence-to-sequence models. It explains the basic concepts of this module 24 | and shows an end-to-end example of how to build a translation model. 25 | https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 26 | 27 | Here is an overview of functions available in this module. They all use 28 | a very similar interface, so after reading the above tutorial and using 29 | one of them, others should be easy to substitute. 30 | 31 | * Full sequence-to-sequence models. 32 | - basic_rnn_seq2seq: The most basic RNN-RNN model. 33 | - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights. 34 | - embedding_rnn_seq2seq: The basic model with input embedding. 35 | - embedding_tied_rnn_seq2seq: The tied model with input embedding. 36 | - embedding_attention_seq2seq: Advanced model with input embedding and 37 | the neural attention mechanism; recommended for complex tasks. 38 | 39 | * Multi-task sequence-to-sequence models. 40 | - one2many_rnn_seq2seq: The embedding model with multiple decoders. 41 | 42 | * Decoders (when you write your own encoder, you can use these to decode; 43 | e.g., if you want to write a model that generates captions for images). 44 | - rnn_decoder: The basic decoder based on a pure RNN. 45 | - attention_decoder: A decoder that uses the attention mechanism. 46 | 47 | * Losses. 48 | - sequence_loss: Loss for a sequence model returning average log-perplexity. 49 | - sequence_loss_by_example: As above, but not averaging over all examples. 50 | 51 | * model_with_buckets: A convenience function to create models with bucketing 52 | (see the tutorial above for an explanation of why and how to use it). 53 | """ 54 | 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | # We disable pylint because we need python3 compatibility. 60 | from six.moves import xrange # pylint: disable=redefined-builtin 61 | from six.moves import zip # pylint: disable=redefined-builtin 62 | 63 | import tensorflow as tf 64 | from tensorflow.python.framework import dtypes 65 | from tensorflow.python.framework import ops 66 | from tensorflow.python.ops import array_ops 67 | from tensorflow.python.ops import control_flow_ops 68 | from tensorflow.python.ops import embedding_ops 69 | from tensorflow.python.ops import math_ops 70 | from tensorflow.python.ops import nn_ops 71 | from tensorflow.contrib.rnn.python.ops import rnn, rnn_cell 72 | from tensorflow.python.ops import variable_scope 73 | linear = rnn_cell._linear # pylint: disable=protected-access 74 | 75 | def _extract_argmax_and_embed(embedding, output_projection=None, 76 | update_embedding=True): 77 | """Get a loop_function that extracts the previous symbol and embeds it. 78 | 79 | Args: 80 | embedding: embedding tensor for symbols. 81 | output_projection: None or a pair (W, B). If provided, each fed previous 82 | output will first be multiplied by W and added B. 83 | update_embedding: Boolean; if False, the gradients will not propagate 84 | through the embeddings. 85 | 86 | Returns: 87 | A loop function. 88 | """ 89 | def loop_function(prev, _): 90 | if output_projection is not None: 91 | prev = nn_ops.xw_plus_b( 92 | prev, output_projection[0], output_projection[1]) 93 | prev_symbol = math_ops.argmax(prev, 1) 94 | # Note that gradients will not propagate through the second parameter of 95 | # embedding_lookup. 96 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 97 | if not update_embedding: 98 | emb_prev = array_ops.stop_gradient(emb_prev) 99 | return emb_prev 100 | return loop_function 101 | 102 | 103 | def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, 104 | scope=None): 105 | """RNN decoder for the sequence-to-sequence model. 106 | 107 | Args: 108 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 109 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 110 | cell: rnn_cell.RNNCell defining the cell function and size. 111 | loop_function: If not None, this function will be applied to the i-th output 112 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 113 | except for the first element ("GO" symbol). This can be used for decoding, 114 | but also for training to emulate http://arxiv.org/abs/1506.03099. 115 | Signature -- loop_function(prev, i) = next 116 | * prev is a 2D Tensor of shape [batch_size x output_size], 117 | * i is an integer, the step number (when advanced control is needed), 118 | * next is a 2D Tensor of shape [batch_size x input_size]. 119 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 120 | 121 | Returns: 122 | A tuple of the form (outputs, state), where: 123 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 124 | shape [batch_size x output_size] containing generated outputs. 125 | state: The state of each cell at the final time-step. 126 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 127 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 128 | states can be the same. They are different for LSTM cells though.) 129 | """ 130 | with variable_scope.variable_scope(scope or "rnn_decoder"): 131 | state = initial_state 132 | outputs = [] 133 | prev = None 134 | for i, inp in enumerate(decoder_inputs): 135 | if loop_function is not None and prev is not None: 136 | with variable_scope.variable_scope("loop_function", reuse=True): 137 | inp = loop_function(prev, i) 138 | if i > 0: 139 | variable_scope.get_variable_scope().reuse_variables() 140 | output, state = cell(inp, state) 141 | outputs.append(output) 142 | if loop_function is not None: 143 | prev = output 144 | return outputs, state 145 | 146 | 147 | def basic_rnn_seq2seq( 148 | encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None): 149 | """Basic RNN sequence-to-sequence model. 150 | 151 | This model first runs an RNN to encode encoder_inputs into a state vector, 152 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 153 | Encoder and decoder use the same RNN cell type, but don't share parameters. 154 | 155 | Args: 156 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 157 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 158 | cell: rnn_cell.RNNCell defining the cell function and size. 159 | dtype: The dtype of the initial state of the RNN cell (default: tf.float32). 160 | scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". 161 | 162 | Returns: 163 | A tuple of the form (outputs, state), where: 164 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 165 | shape [batch_size x output_size] containing the generated outputs. 166 | state: The state of each decoder cell in the final time-step. 167 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 168 | """ 169 | with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): 170 | _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype) 171 | return rnn_decoder(decoder_inputs, enc_state, cell) 172 | 173 | 174 | def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, 175 | loop_function=None, dtype=dtypes.float32, scope=None): 176 | """RNN sequence-to-sequence model with tied encoder and decoder parameters. 177 | 178 | This model first runs an RNN to encode encoder_inputs into a state vector, and 179 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 180 | Encoder and decoder use the same RNN cell and share parameters. 181 | 182 | Args: 183 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 184 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 185 | cell: rnn_cell.RNNCell defining the cell function and size. 186 | loop_function: If not None, this function will be applied to i-th output 187 | in order to generate i+1-th input, and decoder_inputs will be ignored, 188 | except for the first element ("GO" symbol), see rnn_decoder for details. 189 | dtype: The dtype of the initial state of the rnn cell (default: tf.float32). 190 | scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". 191 | 192 | Returns: 193 | A tuple of the form (outputs, state), where: 194 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 195 | shape [batch_size x output_size] containing the generated outputs. 196 | state: The state of each decoder cell in each time-step. This is a list 197 | with length len(decoder_inputs) -- one item for each time-step. 198 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 199 | """ 200 | with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): 201 | scope = scope or "tied_rnn_seq2seq" 202 | _, enc_state = rnn.rnn( 203 | cell, encoder_inputs, dtype=dtype, scope=scope) 204 | variable_scope.get_variable_scope().reuse_variables() 205 | return rnn_decoder(decoder_inputs, enc_state, cell, 206 | loop_function=loop_function, scope=scope) 207 | 208 | 209 | def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, 210 | embedding_size, output_projection=None, 211 | feed_previous=False, 212 | update_embedding_for_previous=True, scope=None): 213 | """RNN decoder with embedding and a pure-decoding option. 214 | 215 | Args: 216 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 217 | initial_state: 2D Tensor [batch_size x cell.state_size]. 218 | cell: rnn_cell.RNNCell defining the cell function. 219 | num_symbols: Integer, how many symbols come into the embedding. 220 | embedding_size: Integer, the length of the embedding vector for each symbol. 221 | output_projection: None or a pair (W, B) of output projection weights and 222 | biases; W has shape [output_size x num_symbols] and B has 223 | shape [num_symbols]; if provided and feed_previous=True, each fed 224 | previous output will first be multiplied by W and added B. 225 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 226 | used (the "GO" symbol), and all other decoder inputs will be generated by: 227 | next = embedding_lookup(embedding, argmax(previous_output)), 228 | In effect, this implements a greedy decoder. It can also be used 229 | during training to emulate http://arxiv.org/abs/1506.03099. 230 | If False, decoder_inputs are used as given (the standard decoder case). 231 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 232 | only the embedding for the first symbol of decoder_inputs (the "GO" 233 | symbol) will be updated by back propagation. Embeddings for the symbols 234 | generated from the decoder itself remain unchanged. This parameter has 235 | no effect if feed_previous=False. 236 | scope: VariableScope for the created subgraph; defaults to 237 | "embedding_rnn_decoder". 238 | 239 | Returns: 240 | A tuple of the form (outputs, state), where: 241 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 242 | shape [batch_size x output_size] containing the generated outputs. 243 | state: The state of each decoder cell in each time-step. This is a list 244 | with length len(decoder_inputs) -- one item for each time-step. 245 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 246 | 247 | Raises: 248 | ValueError: When output_projection has the wrong shape. 249 | """ 250 | if output_projection is not None: 251 | proj_weights = ops.convert_to_tensor(output_projection[0], 252 | dtype=dtypes.float32) 253 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 254 | proj_biases = ops.convert_to_tensor( 255 | output_projection[1], dtype=dtypes.float32) 256 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 257 | 258 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder"): 259 | with ops.device("/cpu:0"): 260 | embedding = variable_scope.get_variable("embedding", 261 | [num_symbols, embedding_size]) 262 | loop_function = _extract_argmax_and_embed( 263 | embedding, output_projection, 264 | update_embedding_for_previous) if feed_previous else None 265 | emb_inp = ( 266 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs) 267 | return rnn_decoder(emb_inp, initial_state, cell, 268 | loop_function=loop_function) 269 | 270 | 271 | def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, 272 | num_encoder_symbols, num_decoder_symbols, 273 | embedding_size, output_projection=None, 274 | feed_previous=False, dtype=dtypes.float32, 275 | scope=None): 276 | """Embedding RNN sequence-to-sequence model. 277 | 278 | This model first embeds encoder_inputs by a newly created embedding (of shape 279 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 280 | embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs 281 | by another newly created embedding (of shape [num_decoder_symbols x 282 | input_size]). Then it runs RNN decoder, initialized with the last 283 | encoder state, on embedded decoder_inputs. 284 | 285 | Args: 286 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 287 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 288 | cell: rnn_cell.RNNCell defining the cell function and size. 289 | num_encoder_symbols: Integer; number of symbols on the encoder side. 290 | num_decoder_symbols: Integer; number of symbols on the decoder side. 291 | embedding_size: Integer, the length of the embedding vector for each symbol. 292 | output_projection: None or a pair (W, B) of output projection weights and 293 | biases; W has shape [output_size x num_decoder_symbols] and B has 294 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 295 | fed previous output will first be multiplied by W and added B. 296 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 297 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 298 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 299 | If False, decoder_inputs are used as given (the standard decoder case). 300 | dtype: The dtype of the initial state for both the encoder and encoder 301 | rnn cells (default: tf.float32). 302 | scope: VariableScope for the created subgraph; defaults to 303 | "embedding_rnn_seq2seq" 304 | 305 | Returns: 306 | A tuple of the form (outputs, state), where: 307 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 308 | shape [batch_size x num_decoder_symbols] containing the generated 309 | outputs. 310 | state: The state of each decoder cell in each time-step. This is a list 311 | with length len(decoder_inputs) -- one item for each time-step. 312 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 313 | """ 314 | with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq"): 315 | # Encoder. 316 | encoder_cell = rnn_cell.EmbeddingWrapper( 317 | cell, embedding_classes=num_encoder_symbols, 318 | embedding_size=embedding_size) 319 | _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) 320 | 321 | # Decoder. 322 | if output_projection is None: 323 | cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 324 | 325 | if isinstance(feed_previous, bool): 326 | return embedding_rnn_decoder( 327 | decoder_inputs, encoder_state, cell, num_decoder_symbols, 328 | embedding_size, output_projection=output_projection, 329 | feed_previous=feed_previous) 330 | 331 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 332 | def decoder(feed_previous_bool): 333 | reuse = None if feed_previous_bool else True 334 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 335 | reuse=reuse): 336 | outputs, state = embedding_rnn_decoder( 337 | decoder_inputs, encoder_state, cell, num_decoder_symbols, 338 | embedding_size, output_projection=output_projection, 339 | feed_previous=feed_previous_bool, 340 | update_embedding_for_previous=False) 341 | return outputs + [state] 342 | 343 | outputs_and_state = control_flow_ops.cond(feed_previous, 344 | lambda: decoder(True), 345 | lambda: decoder(False)) 346 | return outputs_and_state[:-1], outputs_and_state[-1] 347 | 348 | 349 | def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, 350 | num_symbols, embedding_size, 351 | output_projection=None, feed_previous=False, 352 | dtype=dtypes.float32, scope=None): 353 | """Embedding RNN sequence-to-sequence model with tied (shared) parameters. 354 | 355 | This model first embeds encoder_inputs by a newly created embedding (of shape 356 | [num_symbols x input_size]). Then it runs an RNN to encode embedded 357 | encoder_inputs into a state vector. Next, it embeds decoder_inputs using 358 | the same embedding. Then it runs RNN decoder, initialized with the last 359 | encoder state, on embedded decoder_inputs. 360 | 361 | Args: 362 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 363 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 364 | cell: rnn_cell.RNNCell defining the cell function and size. 365 | num_symbols: Integer; number of symbols for both encoder and decoder. 366 | embedding_size: Integer, the length of the embedding vector for each symbol. 367 | output_projection: None or a pair (W, B) of output projection weights and 368 | biases; W has shape [output_size x num_symbols] and B has 369 | shape [num_symbols]; if provided and feed_previous=True, each 370 | fed previous output will first be multiplied by W and added B. 371 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 372 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 373 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 374 | If False, decoder_inputs are used as given (the standard decoder case). 375 | dtype: The dtype to use for the initial RNN states (default: tf.float32). 376 | scope: VariableScope for the created subgraph; defaults to 377 | "embedding_tied_rnn_seq2seq". 378 | 379 | Returns: 380 | A tuple of the form (outputs, state), where: 381 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 382 | shape [batch_size x num_decoder_symbols] containing the generated 383 | outputs. 384 | state: The state of each decoder cell at the final time-step. 385 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 386 | 387 | Raises: 388 | ValueError: When output_projection has the wrong shape. 389 | """ 390 | if output_projection is not None: 391 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 392 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 393 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 394 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 395 | 396 | with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"): 397 | with ops.device("/cpu:0"): 398 | embedding = variable_scope.get_variable("embedding", 399 | [num_symbols, embedding_size]) 400 | 401 | emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x) 402 | for x in encoder_inputs] 403 | emb_decoder_inputs = [embedding_ops.embedding_lookup(embedding, x) 404 | for x in decoder_inputs] 405 | 406 | if output_projection is None: 407 | cell = rnn_cell.OutputProjectionWrapper(cell, num_symbols) 408 | 409 | if isinstance(feed_previous, bool): 410 | loop_function = _extract_argmax_and_embed( 411 | embedding, output_projection, True) if feed_previous else None 412 | return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, 413 | loop_function=loop_function, dtype=dtype) 414 | 415 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 416 | def decoder(feed_previous_bool): 417 | loop_function = _extract_argmax_and_embed( 418 | embedding, output_projection, False) if feed_previous_bool else None 419 | reuse = None if feed_previous_bool else True 420 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 421 | reuse=reuse): 422 | outputs, state = tied_rnn_seq2seq( 423 | emb_encoder_inputs, emb_decoder_inputs, cell, 424 | loop_function=loop_function, dtype=dtype) 425 | return outputs + [state] 426 | 427 | outputs_and_state = control_flow_ops.cond(feed_previous, 428 | lambda: decoder(True), 429 | lambda: decoder(False)) 430 | return outputs_and_state[:-1], outputs_and_state[-1] 431 | 432 | 433 | def attention_decoder(decoder_inputs, initial_state, attention_states, cell, 434 | output_size=None, num_heads=1, loop_function=None, 435 | dtype=dtypes.float32, scope=None, 436 | initial_state_attention=False, attn_num_hidden=128): 437 | """RNN decoder with attention for the sequence-to-sequence model. 438 | 439 | In this context "attention" means that, during decoding, the RNN can look up 440 | information in the additional tensor attention_states, and it does this by 441 | focusing on a few entries from the tensor. This model has proven to yield 442 | especially good results in a number of sequence-to-sequence tasks. This 443 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for 444 | details). It is recommended for complex sequence-to-sequence tasks. 445 | 446 | Args: 447 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 448 | initial_state: 2D Tensor [batch_size x cell.state_size]. 449 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 450 | cell: rnn_cell.RNNCell defining the cell function and size. 451 | output_size: Size of the output vectors; if None, we use cell.output_size. 452 | num_heads: Number of attention heads that read from attention_states. 453 | loop_function: If not None, this function will be applied to i-th output 454 | in order to generate i+1-th input, and decoder_inputs will be ignored, 455 | except for the first element ("GO" symbol). This can be used for decoding, 456 | but also for training to emulate http://arxiv.org/abs/1506.03099. 457 | Signature -- loop_function(prev, i) = next 458 | * prev is a 2D Tensor of shape [batch_size x output_size], 459 | * i is an integer, the step number (when advanced control is needed), 460 | * next is a 2D Tensor of shape [batch_size x input_size]. 461 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 462 | scope: VariableScope for the created subgraph; default: "attention_decoder". 463 | initial_state_attention: If False (default), initial attentions are zero. 464 | If True, initialize the attentions from the initial state and attention 465 | states -- useful when we wish to resume decoding from a previously 466 | stored decoder state and attention states. 467 | 468 | Returns: 469 | A tuple of the form (outputs, state), where: 470 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 471 | shape [batch_size x output_size]. These represent the generated outputs. 472 | Output i is computed from input i (which is either the i-th element 473 | of decoder_inputs or loop_function(output {i-1}, i)) as follows. 474 | First, we run the cell on a combination of the input and previous 475 | attention masks: 476 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 477 | Then, we calculate new attention masks: 478 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 479 | and then we calculate the output: 480 | output = linear(cell_output, new_attn). 481 | state: The state of each decoder cell the final time-step. 482 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 483 | 484 | Raises: 485 | ValueError: when num_heads is not positive, there are no inputs, or shapes 486 | of attention_states are not set. 487 | """ 488 | # MODIFIED ADD START 489 | assert num_heads == 1, 'We only consider the case where num_heads=1!' 490 | # MODIFIED ADD END 491 | if not decoder_inputs: 492 | raise ValueError("Must provide at least 1 input to attention decoder.") 493 | if num_heads < 1: 494 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 495 | if not attention_states.get_shape()[1:2].is_fully_defined(): 496 | raise ValueError("Shape[1] and [2] of attention_states must be known: %s" 497 | % attention_states.get_shape()) 498 | if output_size is None: 499 | output_size = cell.output_size 500 | 501 | with variable_scope.variable_scope(scope or "attention_decoder"): 502 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 503 | attn_length = attention_states.get_shape()[1].value 504 | attn_size = attention_states.get_shape()[2].value 505 | 506 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 507 | hidden = array_ops.reshape( 508 | attention_states, [-1, attn_length, 1, attn_size]) 509 | hidden_features = [] 510 | v = [] 511 | attention_vec_size = attn_size # Size of query vectors for attention. 512 | for a in xrange(num_heads): 513 | k = variable_scope.get_variable("AttnW_%d" % a, 514 | [1, 1, attn_size, attention_vec_size]) 515 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 516 | v.append(variable_scope.get_variable("AttnV_%d" % a, 517 | [attention_vec_size])) 518 | 519 | state = initial_state 520 | 521 | # MODIFIED: return both context vector and attention weights 522 | def attention(query): 523 | """Put attention masks on hidden using hidden_features and query.""" 524 | # MODIFIED ADD START 525 | ss = None # record attention weights 526 | # MODIFIED ADD END 527 | ds = [] # Results of attention reads will be stored here. 528 | for a in xrange(num_heads): 529 | with variable_scope.variable_scope("Attention_%d" % a): 530 | y = linear(query, attention_vec_size, True) 531 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 532 | # Attention mask is a softmax of v^T * tanh(...). 533 | s = math_ops.reduce_sum( 534 | v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 535 | a = nn_ops.softmax(s) 536 | ss = a 537 | #a = tf.Print(a, [a], message="a: ",summarize=30) 538 | # Now calculate the attention-weighted vector d. 539 | d = math_ops.reduce_sum( 540 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, 541 | [1, 2]) 542 | ds.append(array_ops.reshape(d, [-1, attn_size])) 543 | # MODIFIED DELETED return ds 544 | # MODIFIED ADD START 545 | return ds, ss 546 | # MODIFIED ADD END 547 | 548 | outputs = [] 549 | # MODIFIED ADD START 550 | attention_weights_history = [] 551 | # MODIFIED ADD END 552 | prev = None 553 | batch_attn_size = array_ops.stack([batch_size, attn_size]) 554 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) 555 | for _ in xrange(num_heads)] 556 | for a in attns: # Ensure the second shape of attention vectors is set. 557 | a.set_shape([None, attn_size]) 558 | if initial_state_attention: 559 | # MODIFIED DELETED attns = attention(initial_state) 560 | # MODIFIED ADD START 561 | attns, attn_weights = attention(initial_state) 562 | attention_weights_history.append(attn_weights) 563 | # MODIFIED ADD END 564 | for i, inp in enumerate(decoder_inputs): 565 | if i > 0: 566 | variable_scope.get_variable_scope().reuse_variables() 567 | # If loop_function is set, we use it instead of decoder_inputs. 568 | if loop_function is not None and prev is not None: 569 | with variable_scope.variable_scope("loop_function", reuse=True): 570 | inp = loop_function(prev, i) 571 | # Merge input and previous attentions into one vector of the right size. 572 | #input_size = inp.get_shape().with_rank(2)[1] 573 | # TODO: use input_size 574 | input_size = attn_num_hidden 575 | x = linear([inp] + attns, input_size, True) 576 | # Run the RNN. 577 | cell_output, state = cell(x, state) 578 | # Run the attention mechanism. 579 | if i == 0 and initial_state_attention: 580 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 581 | reuse=True): 582 | # MODIFIED DELETED attns = attention(state) 583 | # MODIFIED ADD START 584 | attns, attn_weights = attention(state) 585 | # MODIFIED ADD END 586 | else: 587 | # MODIFIED DELETED attns = attention(state) 588 | # MODIFIED ADD START 589 | attns, attn_weights = attention(state) 590 | attention_weights_history.append(attn_weights) 591 | # MODIFIED ADD END 592 | 593 | with variable_scope.variable_scope("AttnOutputProjection"): 594 | output = linear([cell_output] + attns, output_size, True) 595 | if loop_function is not None: 596 | prev = output 597 | outputs.append(output) 598 | 599 | # MODIFIED DELETED return outputs, state 600 | # MODIFIED ADD START 601 | return outputs, state, attention_weights_history 602 | # MODIFIED ADD END 603 | 604 | 605 | def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, 606 | cell, num_symbols, embedding_size, num_heads=1, 607 | output_size=None, output_projection=None, 608 | feed_previous=False, 609 | update_embedding_for_previous=True, 610 | dtype=dtypes.float32, scope=None, 611 | initial_state_attention=False, 612 | attn_num_hidden=128): 613 | """RNN decoder with embedding and attention and a pure-decoding option. 614 | 615 | Args: 616 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 617 | initial_state: 2D Tensor [batch_size x cell.state_size]. 618 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 619 | cell: rnn_cell.RNNCell defining the cell function. 620 | num_symbols: Integer, how many symbols come into the embedding. 621 | embedding_size: Integer, the length of the embedding vector for each symbol. 622 | num_heads: Number of attention heads that read from attention_states. 623 | output_size: Size of the output vectors; if None, use output_size. 624 | output_projection: None or a pair (W, B) of output projection weights and 625 | biases; W has shape [output_size x num_symbols] and B has shape 626 | [num_symbols]; if provided and feed_previous=True, each fed previous 627 | output will first be multiplied by W and added B. 628 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 629 | used (the "GO" symbol), and all other decoder inputs will be generated by: 630 | next = embedding_lookup(embedding, argmax(previous_output)), 631 | In effect, this implements a greedy decoder. It can also be used 632 | during training to emulate http://arxiv.org/abs/1506.03099. 633 | If False, decoder_inputs are used as given (the standard decoder case). 634 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 635 | only the embedding for the first symbol of decoder_inputs (the "GO" 636 | symbol) will be updated by back propagation. Embeddings for the symbols 637 | generated from the decoder itself remain unchanged. This parameter has 638 | no effect if feed_previous=False. 639 | dtype: The dtype to use for the RNN initial states (default: tf.float32). 640 | scope: VariableScope for the created subgraph; defaults to 641 | "embedding_attention_decoder". 642 | initial_state_attention: If False (default), initial attentions are zero. 643 | If True, initialize the attentions from the initial state and attention 644 | states -- useful when we wish to resume decoding from a previously 645 | stored decoder state and attention states. 646 | 647 | Returns: 648 | A tuple of the form (outputs, state), where: 649 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 650 | shape [batch_size x output_size] containing the generated outputs. 651 | state: The state of each decoder cell at the final time-step. 652 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 653 | 654 | Raises: 655 | ValueError: When output_projection has the wrong shape. 656 | """ 657 | if output_size is None: 658 | output_size = cell.output_size 659 | if output_projection is not None: 660 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 661 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 662 | 663 | with variable_scope.variable_scope(scope or "embedding_attention_decoder"): 664 | with ops.device("/cpu:0"): 665 | embedding = variable_scope.get_variable("embedding", 666 | [num_symbols, embedding_size]) 667 | loop_function = _extract_argmax_and_embed( 668 | embedding, output_projection, 669 | update_embedding_for_previous) if feed_previous else None 670 | emb_inp = [ 671 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 672 | return attention_decoder( 673 | emb_inp, initial_state, attention_states, cell, output_size=output_size, 674 | num_heads=num_heads, loop_function=loop_function, 675 | initial_state_attention=initial_state_attention, attn_num_hidden=attn_num_hidden) 676 | 677 | 678 | def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, 679 | num_encoder_symbols, num_decoder_symbols, 680 | embedding_size, 681 | num_heads=1, output_projection=None, 682 | feed_previous=False, dtype=dtypes.float32, 683 | scope=None, initial_state_attention=False): 684 | """Embedding sequence-to-sequence model with attention. 685 | 686 | This model first embeds encoder_inputs by a newly created embedding (of shape 687 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 688 | embedded encoder_inputs into a state vector. It keeps the outputs of this 689 | RNN at every step to use for attention later. Next, it embeds decoder_inputs 690 | by another newly created embedding (of shape [num_decoder_symbols x 691 | input_size]). Then it runs attention decoder, initialized with the last 692 | encoder state, on embedded decoder_inputs and attending to encoder outputs. 693 | 694 | Args: 695 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 696 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 697 | cell: rnn_cell.RNNCell defining the cell function and size. 698 | num_encoder_symbols: Integer; number of symbols on the encoder side. 699 | num_decoder_symbols: Integer; number of symbols on the decoder side. 700 | embedding_size: Integer, the length of the embedding vector for each symbol. 701 | num_heads: Number of attention heads that read from attention_states. 702 | output_projection: None or a pair (W, B) of output projection weights and 703 | biases; W has shape [output_size x num_decoder_symbols] and B has 704 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 705 | fed previous output will first be multiplied by W and added B. 706 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 707 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 708 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 709 | If False, decoder_inputs are used as given (the standard decoder case). 710 | dtype: The dtype of the initial RNN state (default: tf.float32). 711 | scope: VariableScope for the created subgraph; defaults to 712 | "embedding_attention_seq2seq". 713 | initial_state_attention: If False (default), initial attentions are zero. 714 | If True, initialize the attentions from the initial state and attention 715 | states. 716 | 717 | Returns: 718 | A tuple of the form (outputs, state), where: 719 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 720 | shape [batch_size x num_decoder_symbols] containing the generated 721 | outputs. 722 | state: The state of each decoder cell at the final time-step. 723 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 724 | """ 725 | with variable_scope.variable_scope(scope or "embedding_attention_seq2seq"): 726 | # Encoder. 727 | encoder_cell = rnn_cell.EmbeddingWrapper( 728 | cell, embedding_classes=num_encoder_symbols, 729 | embedding_size=embedding_size) 730 | encoder_outputs, encoder_state = rnn.rnn( 731 | encoder_cell, encoder_inputs, dtype=dtype) 732 | 733 | # First calculate a concatenation of encoder outputs to put attention on. 734 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) 735 | for e in encoder_outputs] 736 | attention_states = array_ops.concat(1, top_states) 737 | 738 | # Decoder. 739 | output_size = None 740 | if output_projection is None: 741 | cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 742 | output_size = num_decoder_symbols 743 | 744 | if isinstance(feed_previous, bool): 745 | return embedding_attention_decoder( 746 | decoder_inputs, encoder_state, attention_states, cell, 747 | num_decoder_symbols, embedding_size, num_heads=num_heads, 748 | output_size=output_size, output_projection=output_projection, 749 | feed_previous=feed_previous, 750 | initial_state_attention=initial_state_attention) 751 | 752 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 753 | def decoder(feed_previous_bool): 754 | reuse = None if feed_previous_bool else True 755 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 756 | reuse=reuse): 757 | outputs, state = embedding_attention_decoder( 758 | decoder_inputs, encoder_state, attention_states, cell, 759 | num_decoder_symbols, embedding_size, num_heads=num_heads, 760 | output_size=output_size, output_projection=output_projection, 761 | feed_previous=feed_previous_bool, 762 | update_embedding_for_previous=False, 763 | initial_state_attention=initial_state_attention) 764 | return outputs + [state] 765 | 766 | outputs_and_state = control_flow_ops.cond(feed_previous, 767 | lambda: decoder(True), 768 | lambda: decoder(False)) 769 | return outputs_and_state[:-1], outputs_and_state[-1] 770 | 771 | 772 | def one2many_rnn_seq2seq(encoder_inputs, decoder_inputs_dict, cell, 773 | num_encoder_symbols, num_decoder_symbols_dict, 774 | embedding_size, feed_previous=False, 775 | dtype=dtypes.float32, scope=None): 776 | """One-to-many RNN sequence-to-sequence model (multi-task). 777 | 778 | This is a multi-task sequence-to-sequence model with one encoder and multiple 779 | decoders. Reference to multi-task sequence-to-sequence learning can be found 780 | here: http://arxiv.org/abs/1511.06114 781 | 782 | Args: 783 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 784 | decoder_inputs_dict: A dictionany mapping decoder name (string) to 785 | the corresponding decoder_inputs; each decoder_inputs is a list of 1D 786 | Tensors of shape [batch_size]; num_decoders is defined as 787 | len(decoder_inputs_dict). 788 | cell: rnn_cell.RNNCell defining the cell function and size. 789 | num_encoder_symbols: Integer; number of symbols on the encoder side. 790 | num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an 791 | integer specifying number of symbols for the corresponding decoder; 792 | len(num_decoder_symbols_dict) must be equal to num_decoders. 793 | embedding_size: Integer, the length of the embedding vector for each symbol. 794 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of 795 | decoder_inputs will be used (the "GO" symbol), and all other decoder 796 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 797 | If False, decoder_inputs are used as given (the standard decoder case). 798 | dtype: The dtype of the initial state for both the encoder and encoder 799 | rnn cells (default: tf.float32). 800 | scope: VariableScope for the created subgraph; defaults to 801 | "one2many_rnn_seq2seq" 802 | 803 | Returns: 804 | A tuple of the form (outputs_dict, state_dict), where: 805 | outputs_dict: A mapping from decoder name (string) to a list of the same 806 | length as decoder_inputs_dict[name]; each element in the list is a 2D 807 | Tensors with shape [batch_size x num_decoder_symbol_list[name]] 808 | containing the generated outputs. 809 | state_dict: A mapping from decoder name (string) to the final state of the 810 | corresponding decoder RNN; it is a 2D Tensor of shape 811 | [batch_size x cell.state_size]. 812 | """ 813 | outputs_dict = {} 814 | state_dict = {} 815 | 816 | with variable_scope.variable_scope(scope or "one2many_rnn_seq2seq"): 817 | # Encoder. 818 | encoder_cell = rnn_cell.EmbeddingWrapper( 819 | cell, embedding_classes=num_encoder_symbols, 820 | embedding_size=embedding_size) 821 | _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) 822 | 823 | # Decoder. 824 | for name, decoder_inputs in decoder_inputs_dict.items(): 825 | num_decoder_symbols = num_decoder_symbols_dict[name] 826 | 827 | with variable_scope.variable_scope("one2many_decoder_" + str(name)): 828 | decoder_cell = rnn_cell.OutputProjectionWrapper(cell, 829 | num_decoder_symbols) 830 | if isinstance(feed_previous, bool): 831 | outputs, state = embedding_rnn_decoder( 832 | decoder_inputs, encoder_state, decoder_cell, num_decoder_symbols, 833 | embedding_size, feed_previous=feed_previous) 834 | else: 835 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 836 | def filled_embedding_rnn_decoder(feed_previous): 837 | # pylint: disable=cell-var-from-loop 838 | reuse = None if feed_previous else True 839 | vs = variable_scope.get_variable_scope() 840 | with variable_scope.variable_scope(vs, reuse=reuse): 841 | outputs, state = embedding_rnn_decoder( 842 | decoder_inputs, encoder_state, decoder_cell, 843 | num_decoder_symbols, embedding_size, 844 | feed_previous=feed_previous) 845 | # pylint: enable=cell-var-from-loop 846 | return outputs + [state] 847 | outputs_and_state = control_flow_ops.cond( 848 | feed_previous, 849 | lambda: filled_embedding_rnn_decoder(True), 850 | lambda: filled_embedding_rnn_decoder(False)) 851 | outputs = outputs_and_state[:-1] 852 | state = outputs_and_state[-1] 853 | 854 | outputs_dict[name] = outputs 855 | state_dict[name] = state 856 | 857 | return outputs_dict, state_dict 858 | 859 | 860 | def sequence_loss_by_example(logits, targets, weights, 861 | average_across_timesteps=True, 862 | softmax_loss_function=None, name=None): 863 | """Weighted cross-entropy loss for a sequence of logits (per example). 864 | 865 | Args: 866 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 867 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 868 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 869 | average_across_timesteps: If set, divide the returned cost by the total 870 | label weight. 871 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 872 | to be used instead of the standard softmax (the default if this is None). 873 | name: Optional name for this operation, default: "sequence_loss_by_example". 874 | 875 | Returns: 876 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 877 | 878 | Raises: 879 | ValueError: If len(logits) is different from len(targets) or len(weights). 880 | """ 881 | if len(targets) != len(logits) or len(weights) != len(logits): 882 | raise ValueError("Lengths of logits, weights, and targets must be the same " 883 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 884 | with ops.name_scope(name, "sequence_loss_by_example", 885 | logits + targets + weights): 886 | log_perp_list = [] 887 | for logit, target, weight in zip(logits, targets, weights): 888 | if softmax_loss_function is None: 889 | # TODO(irving,ebrevdo): This reshape is needed because 890 | # sequence_loss_by_example is called with scalars sometimes, which 891 | # violates our general scalar strictness policy. 892 | target = array_ops.reshape(target, [-1]) 893 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 894 | logits=logit, labels=target) 895 | else: 896 | crossent = softmax_loss_function(logits=logit, labels=target) 897 | log_perp_list.append(crossent * weight) 898 | log_perps = math_ops.add_n(log_perp_list) 899 | if average_across_timesteps: 900 | total_size = math_ops.add_n(weights) 901 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 902 | log_perps /= total_size 903 | return log_perps 904 | 905 | 906 | def sequence_loss(logits, targets, weights, 907 | average_across_timesteps=True, average_across_batch=True, 908 | softmax_loss_function=None, name=None): 909 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 910 | 911 | Args: 912 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 913 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 914 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 915 | average_across_timesteps: If set, divide the returned cost by the total 916 | label weight. 917 | average_across_batch: If set, divide the returned cost by the batch size. 918 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 919 | to be used instead of the standard softmax (the default if this is None). 920 | name: Optional name for this operation, defaults to "sequence_loss". 921 | 922 | Returns: 923 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 924 | 925 | Raises: 926 | ValueError: If len(logits) is different from len(targets) or len(weights). 927 | """ 928 | with ops.name_scope(name, "sequence_loss", logits + targets + weights): 929 | cost = math_ops.reduce_sum(sequence_loss_by_example( 930 | logits, targets, weights, 931 | average_across_timesteps=average_across_timesteps, 932 | softmax_loss_function=softmax_loss_function)) 933 | if average_across_batch: 934 | batch_size = array_ops.shape(targets[0])[0] 935 | return cost / math_ops.cast(batch_size, dtypes.float32) 936 | else: 937 | return cost 938 | 939 | 940 | def model_with_buckets(encoder_inputs_tensor, decoder_inputs, targets, weights, 941 | buckets, seq2seq, softmax_loss_function=None, 942 | per_example_loss=False, name=None): 943 | """Create a sequence-to-sequence model with support for bucketing. 944 | 945 | The seq2seq argument is a function that defines a sequence-to-sequence model, 946 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 947 | 948 | Args: 949 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 950 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 951 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 952 | weights: List of 1D batch-sized float-Tensors to weight the targets. 953 | buckets: A list of pairs of (input size, output size) for each bucket. 954 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 955 | agree with encoder_inputs and decoder_inputs, and returns a pair 956 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 957 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 958 | to be used instead of the standard softmax (the default if this is None). 959 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 960 | tensor of losses for each sequence in the batch. If unset, it will be 961 | a scalar with the averaged loss from all examples. 962 | name: Optional name for this operation, defaults to "model_with_buckets". 963 | 964 | Returns: 965 | A tuple of the form (outputs, losses), where: 966 | outputs: The outputs for each bucket. Its j'th element consists of a list 967 | of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs). 968 | losses: List of scalar Tensors, representing losses for each bucket, or, 969 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 970 | 971 | Raises: 972 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 973 | than the largest (last) bucket. 974 | """ 975 | if len(targets) < buckets[-1][1]: 976 | raise ValueError("Length of targets (%d) must be at least that of last" 977 | "bucket (%d)." % (len(targets), buckets[-1][1])) 978 | if len(weights) < buckets[-1][1]: 979 | raise ValueError("Length of weights (%d) must be at least that of last" 980 | "bucket (%d)." % (len(weights), buckets[-1][1])) 981 | 982 | all_inputs = [encoder_inputs_tensor] + decoder_inputs + targets + weights 983 | losses = [] 984 | outputs = [] 985 | attention_weights_histories = [] 986 | with ops.name_scope(name, "model_with_buckets", all_inputs): 987 | for j, bucket in enumerate(buckets): 988 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 989 | reuse=True if j > 0 else None): 990 | encoder_inputs = tf.split(encoder_inputs_tensor, bucket[0], 0) 991 | encoder_inputs = [tf.squeeze(encoder_input,squeeze_dims=[0]) for encoder_input in encoder_inputs] 992 | bucket_outputs, attention_weights_history = seq2seq(encoder_inputs[:int(bucket[0])], 993 | decoder_inputs[:int(bucket[1])], int(bucket[0])) 994 | #bucket_outputs[0] = tf.Print(bucket_outputs[0], [bucket_outputs[0]], message="This is a: ",summarize=30) 995 | outputs.append(bucket_outputs) 996 | attention_weights_histories.append(attention_weights_history) 997 | if per_example_loss: 998 | losses.append(sequence_loss_by_example( 999 | outputs[-1], targets[:int(bucket[1])], weights[:int(bucket[1])], 1000 | average_across_timesteps=True, 1001 | softmax_loss_function=softmax_loss_function)) 1002 | else: 1003 | losses.append(sequence_loss( 1004 | outputs[-1], targets[:int(bucket[1])], weights[:int(bucket[1])], 1005 | average_across_timesteps=True, 1006 | softmax_loss_function=softmax_loss_function)) 1007 | #losses[0] = tf.Print(losses[0], [losses[0]], message="This is b: ",summarize=3) 1008 | 1009 | return outputs, losses, attention_weights_histories 1010 | -------------------------------------------------------------------------------- /src/model/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sequence-to-sequence model with an attention mechanism.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import random 23 | 24 | import numpy as np 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | #from tensorflow.models.rnn.translate import data_utils 29 | #from tensorflow.nn import rnn, rnn_cell 30 | from tensorflow.python.ops import array_ops 31 | from tensorflow.python.ops import variable_scope 32 | 33 | from .seq2seq import model_with_buckets 34 | from .seq2seq import embedding_attention_decoder 35 | 36 | class Seq2SeqModel(object): 37 | """Sequence-to-sequence model with attention and for multiple buckets. 38 | This class implements a multi-layer recurrent neural network as encoder, 39 | and an attention-based decoder. This is the same as the model described in 40 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details, 41 | or into the seq2seq library for complete model implementation. 42 | This class also allows to use GRU cells in addition to LSTM cells, and 43 | sampled softmax to handle large output vocabulary size. A single-layer 44 | version of this model, but with bi-directional encoder, was presented in 45 | http://arxiv.org/abs/1409.0473 46 | and sampled softmax is described in Section 3 of the following paper. 47 | http://arxiv.org/abs/1412.2007 48 | """ 49 | 50 | def __init__(self, encoder_masks, encoder_inputs_tensor, 51 | decoder_inputs, 52 | target_weights, 53 | target_vocab_size, 54 | buckets, 55 | target_embedding_size, 56 | attn_num_layers, 57 | attn_num_hidden, 58 | forward_only, 59 | use_gru): 60 | """Create the model. 61 | 62 | Args: 63 | source_vocab_size: size of the source vocabulary. 64 | target_vocab_size: size of the target vocabulary. 65 | buckets: a list of pairs (I, O), where I specifies maximum input length 66 | that will be processed in that bucket, and O specifies maximum output 67 | length. Training instances that have inputs longer than I or outputs 68 | longer than O will be pushed to the next bucket and padded accordingly. 69 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. 70 | size: number of units in each layer of the model. 71 | num_layers: number of layers in the model. 72 | max_gradient_norm: gradients will be clipped to maximally this norm. 73 | learning_rate: learning rate to start with. 74 | learning_rate_decay_factor: decay learning rate by this much when needed. 75 | use_lstm: if true, we use LSTM cells instead of GRU cells. 76 | num_samples: number of samples for sampled softmax. 77 | forward_only: if set, we do not construct the backward pass in the model. 78 | """ 79 | self.encoder_inputs_tensor = encoder_inputs_tensor 80 | self.decoder_inputs = decoder_inputs 81 | self.target_weights = target_weights 82 | self.target_vocab_size = target_vocab_size 83 | self.buckets = buckets 84 | self.encoder_masks = encoder_masks 85 | 86 | # Create the internal multi-layer cell for our RNN. 87 | single_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(attn_num_hidden, forget_bias=0.0, state_is_tuple=False) 88 | if use_gru: 89 | print("using GRU CELL in decoder") 90 | single_cell = tf.contrib.rnn.core_rnn_cell.GRUCell(attn_num_hidden) 91 | cell = single_cell 92 | 93 | if attn_num_layers > 1: 94 | cell = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([single_cell] * attn_num_layers, state_is_tuple=False) 95 | 96 | # The seq2seq function: we use embedding for the input and attention. 97 | def seq2seq_f(lstm_inputs, decoder_inputs, seq_length, do_decode): 98 | 99 | num_hidden = attn_num_layers * attn_num_hidden 100 | lstm_fw_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False) 101 | # Backward direction cell 102 | lstm_bw_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False) 103 | 104 | pre_encoder_inputs, output_state_fw, output_state_bw = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs, 105 | initial_state_fw=None, initial_state_bw=None, 106 | dtype=tf.float32, sequence_length=None, scope=None) 107 | 108 | encoder_inputs = [e*f for e,f in zip(pre_encoder_inputs,encoder_masks[:seq_length])] 109 | top_states = [array_ops.reshape(e, [-1, 1, num_hidden*2]) 110 | for e in encoder_inputs] 111 | attention_states = array_ops.concat(top_states, 1) 112 | initial_state = tf.concat(axis=1, values=[output_state_fw, output_state_bw]) 113 | outputs, _, attention_weights_history = embedding_attention_decoder( 114 | decoder_inputs, initial_state, attention_states, cell, 115 | num_symbols=target_vocab_size, 116 | embedding_size=target_embedding_size, 117 | num_heads=1, 118 | output_size=target_vocab_size, 119 | output_projection=None, 120 | feed_previous=do_decode, 121 | initial_state_attention=False, 122 | attn_num_hidden = attn_num_hidden) 123 | return outputs, attention_weights_history 124 | 125 | # Our targets are decoder inputs shifted by one. 126 | targets = [decoder_inputs[i + 1] 127 | for i in xrange(len(decoder_inputs) - 1)] 128 | 129 | softmax_loss_function = None # default to tf.nn.sparse_softmax_cross_entropy_with_logits 130 | 131 | # Training outputs and losses. 132 | if forward_only: 133 | self.outputs, self.losses, self.attention_weights_histories = model_with_buckets( 134 | encoder_inputs_tensor, decoder_inputs, targets, 135 | self.target_weights, buckets, lambda x, y, z: seq2seq_f(x, y, z, True), 136 | softmax_loss_function=softmax_loss_function) 137 | else: 138 | self.outputs, self.losses, self.attention_weights_histories = model_with_buckets( 139 | encoder_inputs_tensor, decoder_inputs, targets, 140 | self.target_weights, buckets, lambda x, y, z: seq2seq_f(x, y, z, False), 141 | softmax_loss_function=softmax_loss_function) 142 | -------------------------------------------------------------------------------- /test_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python src/launcher.py \ 4 | --phase=test \ 5 | --data-path=/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px/annotation_train_words.txt \ 6 | --data-base-dir=/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px \ 7 | --log-path=log_01_16_test.txt \ 8 | --attn-num-hidden 256 \ 9 | --batch-size 64 \ 10 | --model-dir=model_01_16 \ 11 | --load-model \ 12 | --num-epoch=3 \ 13 | --gpu-id=1 \ 14 | --output-dir=model_01_16/synth90 \ 15 | --use-gru \ 16 | --target-embedding-size=10 17 | -------------------------------------------------------------------------------- /tmp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | input_path = '/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px/annotation_train.txt' 4 | lex_path = '/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px/lexicon.txt' 5 | output_path = '/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px/annotation_train_words.txt' 6 | 7 | with open(lex_path,'r') as lex_f: 8 | all_words = lex_f.readlines() 9 | word_dict = dict(enumerate(all_words)) 10 | 11 | with open(input_path,'r') as input_f: 12 | all_lines = input_f.readlines() 13 | 14 | new_lines = [line.split(' ')[0] + ' ' + word_dict[int(line.split(' ')[1])] for line in all_lines] 15 | 16 | with open(output_path, 'w') as out_f: 17 | out_f.writelines(new_lines) 18 | -------------------------------------------------------------------------------- /train_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python src/launcher.py \ 4 | --phase=train \ 5 | --data-path=/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px/annotation_train_words.txt \ 6 | --data-base-dir=/media/data2/sivankeret/Datasets/mnt/ramdisk/max/90kDICT32px \ 7 | --log-path=log_01_16.txt \ 8 | --attn-num-hidden 256 \ 9 | --batch-size 64 \ 10 | --model-dir=model_01_16 \ 11 | --initial-learning-rate=1.0 \ 12 | --no-load-model \ 13 | --num-epoch=3 \ 14 | --gpu-id=0 \ 15 | --use-gru \ 16 | --steps-per-checkpoint=2000 \ 17 | --target-embedding-size=10 18 | --------------------------------------------------------------------------------