├── .gitignore ├── LICENSE ├── README.md ├── caption_infer.py ├── config.py ├── data ├── .gitignore ├── README.md ├── coco │ └── coco.names ├── coco_test.txt ├── coco_train.txt ├── coco_val.txt ├── glove_vocab.pkl ├── plural_words.json └── word_counts.txt ├── eval_all.py ├── im_caption_full.py ├── initialization ├── eval_obj2sen.py ├── gen_obj2sen_caption.py ├── im_caption.py ├── obj2sen.py ├── sentence_ae.py ├── sentence_gan.py ├── sentence_infer.py └── test_obj2sen.py ├── input_pipeline.py ├── misc_fn.py ├── preprocessing ├── crawl_descriptions.py ├── detect_objects.py ├── extract_descriptions.py ├── process_descriptions.py └── process_images.py ├── requirements.txt └── test_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yang Feng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Image Captioning 2 | by [Yang Feng](http://cs.rochester.edu/u/yfeng23/), Lin Ma, Wei Liu, and 3 | [Jiebo Luo](http://cs.rochester.edu/u/jluo) 4 | 5 | ### Introduction 6 | Most image captioning models are trained using paired image-sentence data, which 7 | are expensive to collect. We propose unsupervised image captioning to relax the 8 | reliance on paired data. For more details, please refer to our 9 | [paper](https://arxiv.org/abs/1811.10787). 10 | 11 | ![alt text](http://cs.rochester.edu/u/yfeng23/cvpr19_captioning/framework.png 12 | "Framework") 13 | 14 | ### Citation 15 | 16 | @InProceedings{feng2019unsupervised, 17 | author = {Feng, Yang and Ma, Lin and Liu, Wei and Luo, Jiebo}, 18 | title = {Unsupervised Image Captioning}, 19 | booktitle = {CVPR}, 20 | year = {2019} 21 | } 22 | 23 | ### Requirements 24 | ``` 25 | mkdir ~/workspace 26 | cd ~/workspace 27 | git clone https://github.com/tensorflow/models.git tf_models 28 | git clone https://github.com/tylin/coco-caption.git 29 | touch tf_models/research/im2txt/im2txt/__init__.py 30 | touch tf_models/research/im2txt/im2txt/data/__init__.py 31 | touch tf_models/research/im2txt/im2txt/inference_utils/__init__.py 32 | wget http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz 33 | mkdir ckpt 34 | tar zxvf inception_v4_2016_09_09.tar.gz -C ckpt 35 | git clone https://github.com/fengyang0317/unsupervised_captioning.git 36 | cd unsupervised_captioning 37 | pip install -r requirements.txt 38 | export PYTHONPATH=$PYTHONPATH:`pwd` 39 | ``` 40 | 41 | ### Dataset (Optional. The files generated below can be found at [Gdrive][1]). 42 | In case you do not have the access to Google, the files are also available at 43 | [One Drive][2]. 44 | 1. Crawl image descriptions. The descriptions used when conducting the 45 | experiments in the paper are available at 46 | [link](https://drive.google.com/file/d/1z8JwNxER-ORWoAmVKBqM7MyPozk6St4M). 47 | You may download the descriptions from the link and extract the files to 48 | data/coco. 49 | ``` 50 | pip3 install absl-py 51 | python3 preprocessing/crawl_descriptions.py 52 | ``` 53 | 54 | 2. Extract the descriptions. It seems that NLTK is changing constantly. So 55 | the number of the descriptions obtained may be different. 56 | ``` 57 | python -c "import nltk; nltk.download('punkt')" 58 | python preprocessing/extract_descriptions.py 59 | ``` 60 | 61 | 3. Preprocess the descriptions. You may need to change the vocab_size, start_id, 62 | and end_id in config.py if you generate a new dictionary. 63 | ``` 64 | python preprocessing/process_descriptions.py --word_counts_output_file \ 65 | data/word_counts.txt --new_dict 66 | ``` 67 | 68 | 4. Download the MSCOCO images from [link](http://cocodataset.org/) and put 69 | all the images into ~/dataset/mscoco/all_images. 70 | 71 | 5. Object detection for the training images. You need to first download the 72 | detection model from [here][detection_model] and then extract the model under 73 | tf_models/research/object_detection. 74 | ``` 75 | python preprocessing/detect_objects.py --image_path\ 76 | ~/dataset/mscoco/all_images --num_proc 2 --num_gpus 1 77 | ``` 78 | 79 | 6. Generate tfrecord files for images. 80 | ``` 81 | python preprocessing/process_images.py --image_path\ 82 | ~/dataset/mscoco/all_images 83 | ``` 84 | 85 | ### Training 86 | 7. Train the model without the intialization pipeline. 87 | ``` 88 | python im_caption_full.py --inc_ckpt ~/workspace/ckpt/inception_v4.ckpt\ 89 | --multi_gpu --batch_size 512 --save_checkpoint_steps 1000\ 90 | --gen_lr 0.001 --dis_lr 0.001 91 | ``` 92 | 93 | 8. Evaluate the model. The last element in the b34.json file is the best 94 | checkpoint. 95 | ``` 96 | CUDA_VISIBLE_DEVICES='0,1' python eval_all.py\ 97 | --inc_ckpt ~/workspace/ckpt/inception_v4.ckpt\ 98 | --data_dir ~/dataset/mscoco/all_images 99 | js-beautify saving/b34.json 100 | ``` 101 | 102 | 9. Evaluate the model on test set. Suppose the best validation checkpoint 103 | is 20000. 104 | ``` 105 | python test_model.py --inc_ckpt ~/workspace/ckpt/inception_v4.ckpt\ 106 | --data_dir ~/dataset/mscoco/all_images --job_dir saving/model.ckpt-20000 107 | ``` 108 | 109 | ### Initialization (Optional. The files can be found at [here][1]). 110 | 111 | 10. Train a object-to-sentence model, which is used to generate the 112 | pseudo-captions. 113 | ``` 114 | python initialization/obj2sen.py 115 | ``` 116 | 117 | 11. Find the best obj2sen model. 118 | ``` 119 | python initialization/eval_obj2sen.py --threads 8 120 | ``` 121 | 122 | 12. Generate pseudo-captions. Suppose the best validation checkpoint is 35000. 123 | ``` 124 | python initialization/gen_obj2sen_caption.py --num_proc 8\ 125 | --job_dir obj2sen/model.ckpt-35000 126 | ``` 127 | 128 | 13. Train a captioning using pseudo-pairs. 129 | ``` 130 | python initialization/im_caption.py --o2s_ckpt obj2sen/model.ckpt-35000\ 131 | --inc_ckpt ~/workspace/ckpt/inception_v4.ckpt 132 | ``` 133 | 134 | 14. Evaluate the model. 135 | ``` 136 | CUDA_VISIBLE_DEVICES='0,1' python eval_all.py\ 137 | --inc_ckpt ~/workspace/ckpt/inception_v4.ckpt\ 138 | --data_dir ~/dataset/mscoco/all_images --job_dir saving_imcap 139 | js-beautify saving_imcap/b34.json 140 | ``` 141 | 142 | 15. Train sentence auto-encoder, which is used to initialize sentence GAN. 143 | ``` 144 | python initialization/sentence_ae.py 145 | ``` 146 | 147 | 16. Train sentence GAN. 148 | ``` 149 | python initialization/sentence_gan.py 150 | ``` 151 | 152 | 17. Train the full model with initialization. Suppose the best imcap validation 153 | checkpoint is 18000. 154 | ``` 155 | python im_caption_full.py --inc_ckpt ~/workspace/ckpt/inception_v4.ckpt\ 156 | --imcap_ckpt saving_imcap/model.ckpt-18000\ 157 | --sae_ckpt sen_gan/model.ckpt-30000 --multi_gpu --batch_size 512\ 158 | --save_checkpoint_steps 1000 --gen_lr 0.001 --dis_lr 0.001 159 | ``` 160 | 161 | ### Credits 162 | Part of the code is from 163 | [coco-caption](https://github.com/tylin/coco-caption), 164 | [im2txt](https://github.com/tensorflow/models/tree/master/research/im2txt), 165 | [tfgan](https://github.com/tensorflow/models/tree/master/research/gan), 166 | [resnet](https://github.com/tensorflow/models/tree/master/official/resnet), 167 | [Tensorflow Object Detection API]( 168 | https://github.com/tensorflow/models/tree/master/research/object_detection) and 169 | [maskgan](https://github.com/tensorflow/models/tree/master/research/maskgan). 170 | 171 | [Xinpeng](https://github.com/chenxinpeng) told me the idea of self-critic, which 172 | is crucial to training. 173 | 174 | [1]: https://drive.google.com/drive/folders/1ol8gLj6hYgluldvdj9XFKm16TCqOr7EE 175 | [2]: https://uofr-my.sharepoint.com/:f:/g/personal/yfeng23_ur_rochester_edu/EgDosCuY5t9HmlBfFsVyxdAB4xGf6aTJ0DmQlYWASdjYsw?e=Rhc4nS 176 | [detection_model]: http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28.tar.gz -------------------------------------------------------------------------------- /caption_infer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import sys 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow.contrib.slim as slim 11 | 12 | from config import TF_MODELS_PATH 13 | 14 | sys.path.append(TF_MODELS_PATH + '/research/im2txt/im2txt') 15 | sys.path.append(TF_MODELS_PATH + '/research/slim') 16 | from inference_utils import vocabulary 17 | from inference_utils.caption_generator import Caption 18 | from inference_utils.caption_generator import TopN 19 | from nets import inception_v4 20 | 21 | FLAGS = tf.flags.FLAGS 22 | 23 | tf.flags.DEFINE_string('job_dir', 'saving', 'job dir') 24 | 25 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 26 | 27 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 28 | 29 | tf.flags.DEFINE_integer('batch_size', 1, 'batch size') 30 | 31 | tf.flags.DEFINE_string("vocab_file", "data/word_counts.txt", 32 | "Text file containing the vocabulary.") 33 | 34 | tf.flags.DEFINE_integer('beam_size', 3, 'beam size') 35 | 36 | tf.flags.DEFINE_integer('max_caption_length', 20, 'beam size') 37 | 38 | tf.flags.DEFINE_float('length_normalization_factor', 0.0, 'l n f') 39 | 40 | tf.flags.DEFINE_string('data_dir', None, 'path to all images') 41 | 42 | tf.flags.DEFINE_string('inc_ckpt', None, 'InceptionV4 checkpoint path') 43 | 44 | 45 | def _tower_fn(im, is_training=False): 46 | with slim.arg_scope(inception_v4.inception_v4_arg_scope()): 47 | net, _ = inception_v4.inception_v4(im, None, is_training=False) 48 | net = tf.squeeze(net, [1, 2]) 49 | 50 | with tf.variable_scope('Generator'): 51 | feat = slim.fully_connected(net, FLAGS.mem_dim, activation_fn=None) 52 | feat = tf.nn.l2_normalize(feat, axis=1) 53 | 54 | embedding = tf.get_variable( 55 | name='embedding', 56 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 57 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 58 | softmax_w = tf.matrix_transpose(embedding) 59 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 60 | 61 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 62 | if is_training: 63 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, 64 | FLAGS.keep_prob) 65 | zero_state = cell.zero_state(FLAGS.batch_size, tf.float32) 66 | _, state = cell(feat, zero_state) 67 | init_state = state 68 | tf.get_variable_scope().reuse_variables() 69 | 70 | state_feed = tf.placeholder(dtype=tf.float32, 71 | shape=[None, sum(cell.state_size)], 72 | name="state_feed") 73 | state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1) 74 | input_feed = tf.placeholder(dtype=tf.int64, 75 | shape=[None], # batch_size 76 | name="input_feed") 77 | inputs = tf.nn.embedding_lookup(embedding, input_feed) 78 | out, state_tuple = cell(inputs, state_tuple) 79 | tf.concat(axis=1, values=state_tuple, name="state") 80 | 81 | logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) 82 | tower_pred = tf.nn.softmax(logits, name="softmax") 83 | return tf.concat(init_state, axis=1, name='initial_state') 84 | 85 | 86 | def read_image(im): 87 | """Reads an image.""" 88 | filename = tf.string_join([FLAGS.data_dir, im]) 89 | image = tf.read_file(filename) 90 | image = tf.image.decode_jpeg(image, 3) 91 | image = tf.image.convert_image_dtype(image, tf.float32) 92 | image = tf.image.resize_images(image, [346, 346]) 93 | image = image[23:-24, 23:-24] 94 | image = image * 2 - 1 95 | return image 96 | 97 | 98 | class Infer: 99 | 100 | def __init__(self, job_dir=FLAGS.job_dir): 101 | im_inp = tf.placeholder(tf.string, []) 102 | im = read_image(im_inp) 103 | im = tf.expand_dims(im, 0) 104 | initial_state_op = _tower_fn(im) 105 | 106 | vocab = vocabulary.Vocabulary(FLAGS.vocab_file) 107 | self.saver = tf.train.Saver(tf.trainable_variables('Generator')) 108 | 109 | self.im_inp = im_inp 110 | self.init_state = initial_state_op 111 | self.vocab = vocab 112 | config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 113 | self.sess = tf.Session(config=config) 114 | 115 | inc_saver = tf.train.Saver(tf.global_variables('InceptionV4')) 116 | self.restore_fn(job_dir) 117 | inc_saver.restore(self.sess, FLAGS.inc_ckpt) 118 | 119 | def restore_fn(self, checkpoint_path): 120 | if tf.gfile.IsDirectory(checkpoint_path): 121 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) 122 | if checkpoint_path: 123 | self.saver.restore(self.sess, checkpoint_path) 124 | else: 125 | self.sess.run(tf.global_variables_initializer()) 126 | 127 | def infer(self, im): 128 | vocab = self.vocab 129 | sess = self.sess 130 | im_inp = self.im_inp 131 | initial_state_op = self.init_state 132 | 133 | initial_state = sess.run(initial_state_op, feed_dict={im_inp: im}) 134 | 135 | initial_beam = Caption( 136 | sentence=[vocab.start_id], 137 | state=initial_state[0], 138 | logprob=0.0, 139 | score=0.0, 140 | metadata=[""]) 141 | partial_captions = TopN(FLAGS.beam_size) 142 | partial_captions.push(initial_beam) 143 | complete_captions = TopN(FLAGS.beam_size) 144 | 145 | # Run beam search. 146 | for _ in range(FLAGS.max_caption_length - 1): 147 | partial_captions_list = partial_captions.extract() 148 | partial_captions.reset() 149 | input_feed = np.array([c.sentence[-1] for c in partial_captions_list]) 150 | state_feed = np.array([c.state for c in partial_captions_list]) 151 | 152 | softmax, new_states = sess.run( 153 | fetches=["Generator/softmax:0", "Generator/state:0"], 154 | feed_dict={ 155 | "Generator/input_feed:0": input_feed, 156 | "Generator/state_feed:0": state_feed, 157 | }) 158 | metadata = None 159 | 160 | for i, partial_caption in enumerate(partial_captions_list): 161 | word_probabilities = softmax[i] 162 | word_probabilities[-1] = 0 163 | state = new_states[i] 164 | # For this partial caption, get the beam_size most probable next words. 165 | words_and_probs = list(enumerate(word_probabilities)) 166 | words_and_probs.sort(key=lambda x: -x[1]) 167 | words_and_probs = words_and_probs[0:FLAGS.beam_size] 168 | # Each next word gives a new partial caption. 169 | for w, p in words_and_probs: 170 | if p < 1e-12: 171 | continue # Avoid log(0). 172 | sentence = partial_caption.sentence + [w] 173 | logprob = partial_caption.logprob + math.log(p) 174 | score = logprob 175 | if metadata: 176 | metadata_list = partial_caption.metadata + [metadata[i]] 177 | else: 178 | metadata_list = None 179 | if w == vocab.end_id: 180 | if FLAGS.length_normalization_factor > 0: 181 | score /= len(sentence) ** FLAGS.length_normalization_factor 182 | beam = Caption(sentence, state, logprob, score, metadata_list) 183 | complete_captions.push(beam) 184 | else: 185 | beam = Caption(sentence, state, logprob, score, metadata_list) 186 | partial_captions.push(beam) 187 | if partial_captions.size() == 0: 188 | # We have run out of partial candidates; happens when beam_size = 1. 189 | break 190 | 191 | # If we have no complete captions then fall back to the partial captions. 192 | # But never output a mixture of complete and partial captions because a 193 | # partial caption could have a higher score than all the complete captions. 194 | if not complete_captions.size(): 195 | complete_captions = partial_captions 196 | 197 | captions = complete_captions.extract(sort=True) 198 | ret = [] 199 | for i, caption in enumerate(captions): 200 | # Ignore begin and end words. 201 | sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]] 202 | sentence = " ".join(sentence) 203 | # print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) 204 | ret.append((sentence, math.exp(caption.logprob))) 205 | return ret 206 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from absl import flags 4 | 5 | flags.DEFINE_integer('vocab_size', 18669, 'vocab size') 6 | 7 | flags.DEFINE_integer('start_id', 0, 'SOS') 8 | 9 | flags.DEFINE_integer('end_id', 1, 'EOS') 10 | 11 | HOME = os.getenv('HOME') 12 | TF_MODELS_PATH = HOME + '/workspace/tf_models' 13 | COCO_PATH = HOME + '/workspace/coco-caption' 14 | 15 | NUM_DESCRIPTIONS = 2282457 16 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ###Note 2 | 3 | 1. There is a sos and eos in each sentence in sentence.tfrec. 4 | 5 | 2. There is no sos and eos in obj2sen_captions.tfrec. 6 | 7 | 3. In full model, the sentence auto-encoder input only contains eos. 8 | 9 | 4. We need to pad sos and eos in initialization/im_caption.py. 10 | -------------------------------------------------------------------------------- /data/coco/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /data/plural_words.json: -------------------------------------------------------------------------------- 1 | { 2 | "crocodile": "crocodiles", 3 | "snowmobile": "snowmobiles", 4 | "chair": "chairs", 5 | "dumbbell": "dumbbells", 6 | "milk": "milks", 7 | "grape": "grapes", 8 | "swan": "swans", 9 | "bike": "bikes", 10 | "dairy": "dairies", 11 | "trousers": "trousers", 12 | "melon": "melons", 13 | "dragonfly": "dragonflies", 14 | "woman": "women", 15 | "ladle": "ladles", 16 | "dessert": "desserts", 17 | "vase": "vases", 18 | "spoon": "spoons", 19 | "fan": "fans", 20 | "fireplace": "fireplaces", 21 | "segway": "segwaies", 22 | "parrot": "parrots", 23 | "asparagus": "asparaguses", 24 | "dinosaur": "dinosaurs", 25 | "saxophone": "saxophones", 26 | "mammal": "mammals", 27 | "rack": "racks", 28 | "carnivore": "carnivores", 29 | "bicycle": "bicycles", 30 | "tea": "teas", 31 | "lavender": "lavenders", 32 | "horn": "horns", 33 | "panda": "pandas", 34 | "clock": "clocks", 35 | "stool": "stools", 36 | "paddle": "paddles", 37 | "uniform": "uniforms", 38 | "unicycle": "unicycles", 39 | "supplies": "supplies", 40 | "bird": "birds", 41 | "body": "bodies", 42 | "leg": "legs", 43 | "kangaroo": "kangaroos", 44 | "mule": "mules", 45 | "bookcase": "bookcases", 46 | "sink": "sinks", 47 | "pancake": "pancakes", 48 | "goldfish": "goldfishes", 49 | "shellfish": "shellfishes", 50 | "box": "boxes", 51 | "boy": "boys", 52 | "tortoise": "tortoises", 53 | "drawers": "drawers", 54 | "hamster": "hamsters", 55 | "popcorn": "popcorns", 56 | "pillow": "pillows", 57 | "guacamole": "guacamoles", 58 | "straw": "straws", 59 | "radish": "radishes", 60 | "snowboard": "snowboards", 61 | "duck": "ducks", 62 | "tap": "taps", 63 | "eye": "eyes", 64 | "frog": "frogs", 65 | "camera": "cameras", 66 | "crab": "crabs", 67 | "vehicle": "vehicles", 68 | "limousine": "limousines", 69 | "ladybug": "ladybugs", 70 | "bidet": "bidets", 71 | "door": "doors", 72 | "porch": "porches", 73 | "envelope": "envelopes", 74 | "missile": "missiles", 75 | "phone": "'phones", 76 | "flag": "flags", 77 | "train": "trains", 78 | "stethoscope": "stethoscopes", 79 | "rabbit": "rabbits", 80 | "salad": "salads", 81 | "car": "cars", 82 | "cap": "caps", 83 | "worm": "worms", 84 | "cat": "cats", 85 | "can": "cans", 86 | "drill": "drills", 87 | "control": "controls", 88 | "shark": "sharks", 89 | "hamburger": "hamburgers", 90 | "dolphin": "dolphins", 91 | "parachute": "parachutes", 92 | "drawer": "drawers", 93 | "carrot": "carrots", 94 | "sunglasses": "sunglasses", 95 | "airplane": "airplanes", 96 | "woodpecker": "woodpeckers", 97 | "clothing": "clothings", 98 | "dress": "dresses", 99 | "magpie": "magpies", 100 | "machine": "machines", 101 | "lamp": "lamps", 102 | "animal": "animals", 103 | "elephant": "elephants", 104 | "cheetah": "cheetahs", 105 | "goat": "goat", 106 | "pizza": "pizzas", 107 | "plant": "plants", 108 | "sandwich": "sandwiches", 109 | "cupboard": "cupboards", 110 | "briefcase": "briefcases", 111 | "cocktail": "cocktails", 112 | "helmet": "helmets", 113 | "ladder": "ladders", 114 | "accessory": "accessories", 115 | "taco": "tacos", 116 | "cabbage": "cabbages", 117 | "castle": "castles", 118 | "man": "men", 119 | "tomato": "tomatoes", 120 | "oboe": "oboes", 121 | "light": "lights", 122 | "wheelchair": "wheelchairs", 123 | "volleyball": "volleyballs", 124 | "beehive": "beehives", 125 | "switch": "switches", 126 | "mango": "mangoes", 127 | "truck": "trucks", 128 | "chopsticks": "chopsticks", 129 | "ambulance": "ambulances", 130 | "necklace": "necklaces", 131 | "egg": "eggs", 132 | "hedgehog": "hedgehogs", 133 | "antelope": "antelopes", 134 | "office": "offices", 135 | "muffin": "muffins", 136 | "canary": "canaries", 137 | "jacuzzi": "jacuzzis", 138 | "artichoke": "artichokes", 139 | "oven": "ovens", 140 | "keyboard": "keyboards", 141 | "skirt": "skirts", 142 | "rhinoceros": "rhinoceroses", 143 | "monitor": "monitors", 144 | "accordion": "accordions", 145 | "willow": "willows", 146 | "houseplant": "houseplants", 147 | "backpack": "backpacks", 148 | "window": "windows", 149 | "fig": "figs", 150 | "orange": "oranges", 151 | "tiara": "tiaras", 152 | "coffee": "coffees", 153 | "sword": "swords", 154 | "food": "foods", 155 | "skates": "skates", 156 | "caterpillar": "caterpillars", 157 | "giraffe": "giraffes", 158 | "snake": "snakes", 159 | "foot": "feet", 160 | "bread": "breads", 161 | "arrow": "arrows", 162 | "invertebrates": "invertebrates", 163 | "tray": "trays", 164 | "glasses": "glasses", 165 | "turtle": "turtles", 166 | "house": "houses", 167 | "fish": "fishes", 168 | "fixture": "fixtures", 169 | "spider": "spiders", 170 | "waffle": "waffles", 171 | "goose": "geese", 172 | "zebra": "zebras", 173 | "beetle": "beetles", 174 | "girl": "girls", 175 | "harp": "harps", 176 | "flower": "flowers", 177 | "container": "containers", 178 | "lipstick": "lipsticks", 179 | "fountain": "fountains", 180 | "eagle": "eagles", 181 | "umbrella": "umbrellas", 182 | "utensil": "utensils", 183 | "ipod": "ipods", 184 | "cart": "carts", 185 | "cookie": "cookies", 186 | "van": "vans", 187 | "care": "cares", 188 | "dishwasher": "dishwashers", 189 | "seahorse": "seahorses", 190 | "headphones": "headphones", 191 | "skateboard": "skateboards", 192 | "ostrich": "ostriches", 193 | "butterflies": "butterflies", 194 | "footwear": "footwears", 195 | "lifejacket": "lifejackets", 196 | "scarf": "scarfs", 197 | "blind": "blinds", 198 | "cheese": "cheeses", 199 | "scissors": "scissors", 200 | "sockets": "sockets", 201 | "maple": "maples", 202 | "sheep": "sheep", 203 | "horse": "horses", 204 | "toy": "toys", 205 | "hydrant": "hydrants", 206 | "jellyfish": "jellyfishes", 207 | "banana": "bananas", 208 | "store": "stores", 209 | "beard": "beards", 210 | "starfish": "starfishes", 211 | "shorts": "shorts", 212 | "shelf": "shelves", 213 | "tool": "tools", 214 | "poster": "posters", 215 | "part": "parts", 216 | "rifle": "rifles", 217 | "wardrobe": "wardrobes", 218 | "sign": "signs", 219 | "stairs": "stairs", 220 | "television": "televisions", 221 | "tree": "trees", 222 | "bed": "beds", 223 | "bee": "bees", 224 | "shower": "showers", 225 | "croissant": "croissants", 226 | "sculpture": "sculptures", 227 | "seal": "seals", 228 | "glass": "glasses", 229 | "ant": "ants", 230 | "juice": "juices", 231 | "pastry": "pastries", 232 | "microphone": "microphones", 233 | "koala": "koalas", 234 | "couch": "couches", 235 | "bagel": "bagels", 236 | "equipment": "equipments", 237 | "miniskirt": "miniskirts", 238 | "banjo": "banjos", 239 | "lily": "lilies", 240 | "honeycomb": "honeycombs", 241 | "butterfly": "butterflies", 242 | "printer": "printers", 243 | "flute": "flutes", 244 | "invertebrate": "invertebrates", 245 | "basket": "baskets", 246 | "mouth": "mouths", 247 | "handbag": "handbags", 248 | "coin": "coins", 249 | "refrigerator": "refrigerators", 250 | "seafood": "seafoods", 251 | "dog": "dogs", 252 | "face": "faces", 253 | "pineapple": "pineapples", 254 | "barrel": "barrels", 255 | "wine": "wines", 256 | "taxi": "taxis", 257 | "loveseat": "loveseats", 258 | "kite": "kites", 259 | "sparrow": "sparrows", 260 | "stove": "stoves", 261 | "scoreboard": "scoreboards", 262 | "chicken": "chickens", 263 | "blender": "blenders", 264 | "snack": "snacks", 265 | "dice": "dices", 266 | "tire": "tires", 267 | "jay": "jays", 268 | "scorpion": "scorpions", 269 | "bust": "busts", 270 | "piano": "pianos", 271 | "otter": "otters", 272 | "pumpkin": "pumpkins", 273 | "porcupine": "porcupines", 274 | "plate": "plates", 275 | "handle": "handles", 276 | "cannon": "cannons", 277 | "watch": "watches", 278 | "bear": "bears", 279 | "beam": "beams", 280 | "watermelon": "watermelons", 281 | "handgun": "handguns", 282 | "bat": "bats", 283 | "jaguar": "jaguars", 284 | "organ": "organs", 285 | "pretzel": "pretzels", 286 | "perfume": "perfumes", 287 | "bag": "bags", 288 | "paper": "papers", 289 | "sushi": "sushis", 290 | "motorcycle": "motorcycles", 291 | "appliance": "appliances", 292 | "frame": "frames", 293 | "lion": "lions", 294 | "computer": "computers", 295 | "cattle": "cattle", 296 | "racket": "rackets", 297 | "arm": "arms", 298 | "closet": "closets", 299 | "sombrero": "sombreros", 300 | "swimwear": "swimwears", 301 | "mug": "mugs", 302 | "skyscraper": "skyscrapers", 303 | "wok": "woks", 304 | "tent": "tents", 305 | "cantaloupe": "cantaloupes", 306 | "suitcase": "suitcases", 307 | "guitar": "guitars", 308 | "aircraft": "aircraft", 309 | "toothbrush": "toothbrushes", 310 | "drum": "drums", 311 | "cup": "cups", 312 | "stretcher": "stretchers", 313 | "shirt": "shirts", 314 | "jeans": "jeans", 315 | "cabinetry": "cabinetries", 316 | "isopod": "isopods", 317 | "table": "tables", 318 | "trumpet": "trumpets", 319 | "boat": "boats", 320 | "coffeemaker": "coffeemakers", 321 | "turkey": "turkeys", 322 | "lighthouse": "lighthouses", 323 | "watercraft": "watercrafts", 324 | "gondola": "gondolas", 325 | "submarine": "submarines", 326 | "beer": "beers", 327 | "monkey": "monkeys", 328 | "shakers": "shakers", 329 | "helicopter": "helicopters", 330 | "cosmetics": "cosmetics", 331 | "squirrel": "squirrels", 332 | "mushroom": "mushrooms", 333 | "squash": "squashes", 334 | "tiger": "tigers", 335 | "bull": "bulls", 336 | "bulb": "bulbs", 337 | "teapot": "teapots", 338 | "tank": "tanks", 339 | "ski": "skies", 340 | "lizard": "lizards", 341 | "bathtub": "bathtubs", 342 | "canoe": "canoes", 343 | "calculator": "calculators", 344 | "squid": "squids", 345 | "telephone": "telephones", 346 | "nightstand": "nightstands", 347 | "heels": "heels", 348 | "violin": "violins", 349 | "mouse": "mouses", 350 | "bags": "bags", 351 | "fedora": "fedoras", 352 | "balloon": "balloons", 353 | "bowl": "bowls", 354 | "vegetable": "vegetables", 355 | "pan": "pans", 356 | "wheel": "wheels", 357 | "ball": "balls", 358 | "snail": "snails", 359 | "pomegranate": "pomegranates", 360 | "drink": "drink", 361 | "leopard": "leopards", 362 | "hand": "hands", 363 | "binoculars": "binoculars", 364 | "raccoon": "raccoons", 365 | "fruit": "fruits", 366 | "cucumber": "cucumbers", 367 | "whale": "whales", 368 | "broccoli": "broccolis", 369 | "burrito": "burritos", 370 | "surfboard": "surfboards", 371 | "shotgun": "shotguns", 372 | "weapon": "weapons", 373 | "person": "persons", 374 | "bottle": "bottles", 375 | "snowman": "snowmen", 376 | "rocket": "rockets", 377 | "camel": "camels", 378 | "laptop": "laptops", 379 | "goggles": "goggles", 380 | "apple": "apples", 381 | "flashlight": "flashlights", 382 | "sandal": "sandals", 383 | "sunflower": "sunflowers", 384 | "brassiere": "brassieres", 385 | "rose": "roses", 386 | "bench": "benches", 387 | "instrument": "instruments", 388 | "grapefruit": "grapefruits", 389 | "board": "boards", 390 | "hat": "hats", 391 | "raven": "ravens", 392 | "ruler": "rulers", 393 | "chainsaw": "chainsaws", 394 | "barge": "barges", 395 | "desk": "desks", 396 | "falcon": "falcons", 397 | "doughnut": "doughnuts", 398 | "furniture": "furnitures", 399 | "towel": "towels", 400 | "candy": "candies", 401 | "nose": "noses", 402 | "tower": "towers", 403 | "zucchini": "zucchinis", 404 | "doll": "dolls", 405 | "crown": "crowns", 406 | "hair": "hairs", 407 | "lobster": "lobsters", 408 | "glove": "gloves", 409 | "mirror": "mirrors", 410 | "billboard": "billboards", 411 | "candle": "candles", 412 | "saucer": "saucers", 413 | "scale": "scales", 414 | "kettle": "kettles", 415 | "fox": "foxes", 416 | "pen": "pens", 417 | "treadmill": "treadmills", 418 | "earrings": "earrings", 419 | "belt": "belts", 420 | "knife": "knives", 421 | "pepper": "peppers", 422 | "pasta": "pastas", 423 | "goods": "goods", 424 | "whiteboard": "whiteboards", 425 | "mixer": "mixers", 426 | "shrimp": "shrimps", 427 | "stand": "stands", 428 | "lantern": "lanterns", 429 | "processor": "processors", 430 | "owl": "owls", 431 | "reptile": "reptiles", 432 | "dagger": "daggers", 433 | "tableware": "tablewares", 434 | "syringe": "syringes", 435 | "alpaca": "alpacas", 436 | "snowplow": "snowplows", 437 | "sock": "socks", 438 | "harpsichord": "harpsichords", 439 | "oyster": "oysters", 440 | "suit": "suits", 441 | "fork": "forks", 442 | "head": "heads", 443 | "jug": "jugs", 444 | "bus": "buses", 445 | "pitcher": "pitchers", 446 | "tripod": "tripods", 447 | "cabinet": "cabinets", 448 | "curtain": "curtains", 449 | "ear": "ears", 450 | "penguin": "penguins", 451 | "skull": "skulls", 452 | "cooker": "cookers", 453 | "fries": "fries", 454 | "strawberry": "strawberries", 455 | "centipede": "centipedes", 456 | "cello": "'cellos", 457 | "lynx": "lynxes", 458 | "cake": "cakes", 459 | "toilet": "toilets", 460 | "deer": "deers", 461 | "pig": "pigs", 462 | "trombone": "trombones", 463 | "cream": "creams", 464 | "lemon": "lemons", 465 | "peach": "peaches", 466 | "countertop": "countertops", 467 | "boot": "boots", 468 | "book": "books", 469 | "tie": "ties", 470 | "hippopotamus": "hippopotamuses", 471 | "tart": "tarts", 472 | "football": "footballs", 473 | "pear": "pears", 474 | "insect": "insects", 475 | "pool": "pools", 476 | "building": "buildings", 477 | "coat": "coats", 478 | "potato": "potatoes", 479 | "flowerpot": "flowerpots", 480 | "jacket": "jackets", 481 | "platter": "platters" 482 | } -------------------------------------------------------------------------------- /eval_all.py: -------------------------------------------------------------------------------- 1 | """Evaluates the performance of all the checkpoints on validation set.""" 2 | import glob 3 | import json 4 | import multiprocessing 5 | import os 6 | import sys 7 | 8 | from absl import app 9 | from absl import flags 10 | 11 | from config import COCO_PATH 12 | 13 | flags.DEFINE_integer('threads', 1, 'num of threads') 14 | 15 | from caption_infer import Infer 16 | 17 | sys.path.insert(0, COCO_PATH) 18 | from pycocotools.coco import COCO 19 | from pycocoevalcap.eval import COCOEvalCap 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | 24 | def initializer(): 25 | """Decides which GPU is assigned to a worker. 26 | 27 | If your GPU memory is large enough, you may put several workers in one GPU. 28 | """ 29 | global tf, no_gpu 30 | no_gpu = False 31 | devices = os.getenv('CUDA_VISIBLE_DEVICES') 32 | if devices is None: 33 | print('Please set CUDA_VISIBLE_DEVICES') 34 | no_gpu = True 35 | return 36 | devices = devices.split(',') 37 | if len(devices) == 0: 38 | print('You should assign some gpus to this program.') 39 | no_gpu = True 40 | return 41 | current = multiprocessing.current_process() 42 | id = (current._identity[0] - 1) % len(devices) 43 | os.environ['CUDA_VISIBLE_DEVICES'] = devices[id] 44 | import tensorflow as tf 45 | 46 | 47 | def run(inp): 48 | if no_gpu: 49 | return 50 | out = FLAGS.job_dir + '/val_%s.json' % inp 51 | if not os.path.exists(out): 52 | with open(COCO_PATH + '/annotations/captions_val2014.json') as g: 53 | caption_data = json.load(g) 54 | name_to_id = [(x['file_name'], x['id']) for x in caption_data['images']] 55 | name_to_id = dict(name_to_id) 56 | 57 | ret = [] 58 | with tf.Graph().as_default(): 59 | infer = Infer(job_dir='%s/model.ckpt-%s' % (FLAGS.job_dir, inp)) 60 | with open('data/coco_val.txt', 'r') as g: 61 | for name in g: 62 | name = name.strip() 63 | sentences = infer.infer(name) 64 | cur = {} 65 | cur['image_id'] = name_to_id[name] 66 | cur['caption'] = sentences[0][0] 67 | ret.append(cur) 68 | with open(out, 'w') as g: 69 | json.dump(ret, g) 70 | 71 | coco = COCO(COCO_PATH + '/annotations/captions_val2014.json') 72 | cocoRes = coco.loadRes(out) 73 | # create cocoEval object by taking coco and cocoRes 74 | cocoEval = COCOEvalCap(coco, cocoRes) 75 | # evaluate on a subset of images by setting 76 | # cocoEval.params['image_id'] = cocoRes.getImgIds() 77 | # please remove this line when evaluating the full validation set 78 | cocoEval.params['image_id'] = cocoRes.getImgIds() 79 | # evaluate results 80 | cocoEval.evaluate() 81 | return (inp, cocoEval.eval['CIDEr'], cocoEval.eval['METEOR'], 82 | cocoEval.eval['Bleu_4'], cocoEval.eval['Bleu_3'], 83 | cocoEval.eval['Bleu_2']) 84 | 85 | 86 | def main(_): 87 | results = glob.glob(FLAGS.job_dir + '/model.ckpt-*') 88 | results = [os.path.splitext(i)[0] for i in results] 89 | results = set(results) 90 | gs_list = [i.split('-')[-1] for i in results] 91 | 92 | pool = multiprocessing.Pool(FLAGS.threads, initializer) 93 | ret = pool.map(run, gs_list) 94 | pool.close() 95 | pool.join() 96 | if not ret or ret[0] is None: 97 | return 98 | 99 | ret = sorted(ret, key=lambda x: x[1]) 100 | with open(FLAGS.job_dir + '/cider.json', 'w') as f: 101 | json.dump(ret, f) 102 | ret = sorted(ret, key=lambda x: x[2]) 103 | with open(FLAGS.job_dir + '/meteor.json', 'w') as f: 104 | json.dump(ret, f) 105 | ret = sorted(ret, key=lambda x: x[3]) 106 | with open(FLAGS.job_dir + '/b4.json', 'w') as f: 107 | json.dump(ret, f) 108 | ret = sorted(ret, key=lambda x: x[4]) 109 | with open(FLAGS.job_dir + '/b3.json', 'w') as f: 110 | json.dump(ret, f) 111 | ret = sorted(ret, key=lambda x: x[5]) 112 | with open(FLAGS.job_dir + '/b2.json', 'w') as f: 113 | json.dump(ret, f) 114 | ret = sorted(ret, key=lambda x: x[3] + x[4]) 115 | with open(FLAGS.job_dir + '/b34.json', 'w') as f: 116 | json.dump(ret, f) 117 | 118 | 119 | if __name__ == '__main__': 120 | app.run(main) 121 | -------------------------------------------------------------------------------- /im_caption_full.py: -------------------------------------------------------------------------------- 1 | """Train the full model. 2 | 3 | python im_caption_full.py --multi_gpu --batch_size 512 --save_checkpoint_steps\ 4 | 1000 --gen_lr 0.001 --dis_lr 0.001 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import functools 12 | import os 13 | import sys 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import tensorflow.contrib.gan as tfgan 18 | import tensorflow.contrib.slim as slim 19 | from tensorflow.contrib.framework import nest 20 | from tensorflow.contrib.gan.python.losses.python.losses_impl import modified_discriminator_loss 21 | from tensorflow.contrib.gan.python.train import get_sequential_train_hooks 22 | 23 | from config import TF_MODELS_PATH 24 | from input_pipeline import input_fn 25 | from misc_fn import crop_sentence 26 | from misc_fn import get_len 27 | from misc_fn import obj_rewards 28 | from misc_fn import transform_grads_fn 29 | from misc_fn import validate_batch_size_for_multi_gpu 30 | from misc_fn import variable_summaries 31 | 32 | sys.path.append(TF_MODELS_PATH + '/research/slim') 33 | from nets import inception_v4 34 | 35 | tf.logging.set_verbosity(tf.logging.INFO) 36 | 37 | tf.flags.DEFINE_integer('intra_op_parallelism_threads', 0, 'Number of threads') 38 | 39 | tf.flags.DEFINE_integer('inter_op_parallelism_threads', 0, 'Number of threads') 40 | 41 | tf.flags.DEFINE_bool('multi_gpu', False, 'use multi gpus') 42 | 43 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 44 | 45 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 46 | 47 | tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') 48 | 49 | tf.flags.DEFINE_string('job_dir', 'saving', 'job dir') 50 | 51 | tf.flags.DEFINE_integer('batch_size', 64, 'batch size') 52 | 53 | tf.flags.DEFINE_integer('max_steps', 1000000, 'maximum training steps') 54 | 55 | tf.flags.DEFINE_float('gen_lr', 0.0001, 'learning rate') 56 | 57 | tf.flags.DEFINE_float('dis_lr', 0.0001, 'learning rate') 58 | 59 | tf.flags.DEFINE_integer('save_summary_steps', 100, 'save summary steps') 60 | 61 | tf.flags.DEFINE_integer('save_checkpoint_steps', 5000, 'save ckpt') 62 | 63 | tf.flags.DEFINE_integer('max_caption_length', 20, 'max len') 64 | 65 | tf.flags.DEFINE_bool('wass', False, 'use wass') 66 | 67 | tf.flags.DEFINE_bool('use_pool', False, 'use pool') 68 | 69 | tf.flags.DEFINE_integer('pool_size', 512, 'pool size') 70 | 71 | tf.flags.DEFINE_string('inc_ckpt', None, 'path to InceptionV4 checkpoint') 72 | 73 | tf.flags.DEFINE_string('imcap_ckpt', None, 'initialization checkpoint') 74 | 75 | tf.flags.DEFINE_string('sae_ckpt', None, 'initialization checkpoint') 76 | 77 | tf.flags.DEFINE_float('w_obj', 10, 'object weight') 78 | 79 | tf.flags.DEFINE_float('w_mse', 100, 'object weight') 80 | 81 | FLAGS = tf.flags.FLAGS 82 | 83 | 84 | def generator(inputs, is_training=True): 85 | """The sentence generator.""" 86 | embedding = tf.get_variable( 87 | name='embedding', 88 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 89 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 90 | softmax_w = tf.matrix_transpose(embedding) 91 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 92 | 93 | inputs = inputs[0] 94 | feat = slim.fully_connected(inputs, FLAGS.mem_dim, activation_fn=None) 95 | feat = tf.nn.l2_normalize(feat, axis=1) 96 | 97 | batch_size = tf.shape(feat)[0] 98 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 99 | if is_training: 100 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, FLAGS.keep_prob) 101 | zero_state = cell.zero_state(batch_size, tf.float32) 102 | 103 | sequence, logits, log_probs, rnn_outs = [], [], [], [] 104 | 105 | _, state = cell(feat, zero_state) 106 | state_bl = state 107 | tf.get_variable_scope().reuse_variables() 108 | for t in range(FLAGS.max_caption_length): 109 | if t == 0: 110 | rnn_inp = tf.zeros([batch_size], tf.int32) + FLAGS.start_id 111 | rnn_inp = tf.nn.embedding_lookup(embedding, rnn_inp) 112 | rnn_out, state = cell(rnn_inp, state) 113 | rnn_outs.append(rnn_out) 114 | logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b) 115 | categorical = tf.contrib.distributions.Categorical(logits=logit) 116 | fake = categorical.sample() 117 | log_prob = categorical.log_prob(fake) 118 | sequence.append(fake) 119 | log_probs.append(log_prob) 120 | logits.append(logit) 121 | rnn_inp = fake 122 | sequence = tf.stack(sequence, axis=1) 123 | log_probs = tf.stack(log_probs, axis=1) 124 | logits = tf.stack(logits, axis=1) 125 | 126 | # Computes the baseline for self-critic. 127 | baseline = [] 128 | state = state_bl 129 | for t in range(FLAGS.max_caption_length): 130 | if t == 0: 131 | rnn_inp = tf.zeros([batch_size], tf.int32) + FLAGS.start_id 132 | rnn_inp = tf.nn.embedding_lookup(embedding, rnn_inp) 133 | rnn_out, state = cell(rnn_inp, state) 134 | logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b) 135 | fake = tf.argmax(logit, axis=1, output_type=tf.int32) 136 | baseline.append(fake) 137 | rnn_inp = fake 138 | baseline = tf.stack(baseline, axis=1) 139 | 140 | return sequence, logits, log_probs, baseline, feat 141 | 142 | 143 | def discriminator(generated_data, generator_inputs, is_training=True): 144 | """The discriminator.""" 145 | if type(generated_data) is tuple: 146 | # When the sentences are generated, we need to compute their length. 147 | sequence = generated_data[0] 148 | length = get_len(sequence, FLAGS.end_id) 149 | else: 150 | # We already know the length of the sentences from the input pipeline. 151 | sequence = generated_data 152 | length = generator_inputs[-1] 153 | embedding = tf.get_variable( 154 | name='embedding', 155 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 156 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 157 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 158 | if is_training: 159 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, FLAGS.keep_prob) 160 | 161 | rnn_inputs = tf.nn.embedding_lookup(embedding, sequence) 162 | rnn_out, state = tf.nn.dynamic_rnn(cell, rnn_inputs, length, dtype=tf.float32) 163 | pred = slim.fully_connected(rnn_out, 1, activation_fn=None, scope='fc') 164 | pred = tf.squeeze(pred, 2) 165 | mask = tf.sequence_mask(length, tf.shape(sequence)[1]) 166 | 167 | idx = tf.transpose(tf.stack([tf.range(tf.shape(length)[0]), length - 1])) 168 | state_h = tf.gather_nd(rnn_out, idx) 169 | feat = slim.fully_connected(state_h, FLAGS.mem_dim, activation_fn=None, 170 | scope='recon') 171 | feat = tf.nn.l2_normalize(feat, axis=1) 172 | return pred, mask, feat 173 | 174 | 175 | def rl_loss(gan_model, gan_loss, classes, scores, num, add_summaries): 176 | """Reinforcement learning loss.""" 177 | eps = 1e-7 178 | gamma = 0.9 179 | sequence, _, log_probs, seq_bl, pca = gan_model.generated_data 180 | 181 | with tf.variable_scope(gan_model.discriminator_scope, reuse=True): 182 | baselines, _, feat_bl = discriminator((seq_bl, None, None, None, pca), None) 183 | baselines, feat_bl = nest.map_structure( 184 | tf.stop_gradient, (baselines, feat_bl)) 185 | 186 | logits, mask, feat = gan_model.discriminator_gen_outputs 187 | 188 | dist = tf.reduce_mean(tf.squared_difference(pca, feat), axis=1, 189 | keepdims=True) * FLAGS.w_mse 190 | loss_mse = tf.reduce_mean(dist) 191 | l_rewards = -dist 192 | l_rewards = tf.tile(l_rewards, [1, sequence.shape[1]]) 193 | l_rewards = tf.where(mask, l_rewards, tf.zeros_like(l_rewards)) 194 | l_rewards_mat = l_rewards 195 | l_rewards = tf.unstack(l_rewards, axis=1) 196 | 197 | dis_predictions = tf.nn.sigmoid(logits) 198 | d_rewards = tf.log(dis_predictions + eps) 199 | o_rewards = obj_rewards(sequence, mask, classes, scores, num) * FLAGS.w_obj 200 | rewards = d_rewards + o_rewards 201 | rewards = tf.where(mask, rewards, tf.zeros_like(rewards)) 202 | 203 | l_bl = -tf.reduce_mean(tf.squared_difference(pca, feat_bl), axis=1, 204 | keepdims=True) * FLAGS.w_mse 205 | l_bl = tf.tile(l_bl, [1, seq_bl.shape[1]]) 206 | l_bl = tf.where(mask, l_bl, tf.zeros_like(l_bl)) 207 | l_bl = tf.unstack(l_bl, axis=1) 208 | baselines = tf.nn.sigmoid(baselines) 209 | baselines = tf.log(baselines + eps) 210 | baselines += obj_rewards(seq_bl, mask, classes, scores, num) * FLAGS.w_obj 211 | baselines = tf.where(mask, baselines, tf.zeros_like(baselines)) 212 | 213 | log_prob_list = tf.unstack(log_probs, axis=1) 214 | rewards_list = tf.unstack(rewards, axis=1) 215 | cumulative_rewards = [] 216 | baseline_list = tf.unstack(baselines, axis=1) 217 | cumulative_baseline = [] 218 | for t in range(FLAGS.max_caption_length): 219 | cum_value = l_rewards[t] 220 | for s in range(t, FLAGS.max_caption_length): 221 | cum_value += np.power(gamma, s - t) * rewards_list[s] 222 | cumulative_rewards.append(cum_value) 223 | 224 | cum_value = l_bl[t] 225 | for s in range(t, FLAGS.max_caption_length): 226 | cum_value += np.power(gamma, s - t) * baseline_list[s] 227 | cumulative_baseline.append(cum_value) 228 | c_rewards = tf.stack(cumulative_rewards, axis=1) 229 | c_baseline = tf.stack(cumulative_baseline, axis=1) 230 | 231 | advantages = [] 232 | final_gen_objective = [] 233 | for t in range(FLAGS.max_caption_length): 234 | log_probability = log_prob_list[t] 235 | cum_advantage = cumulative_rewards[t] - cumulative_baseline[t] 236 | cum_advantage = tf.clip_by_value(cum_advantage, -5.0, 5.0) 237 | advantages.append(cum_advantage) 238 | final_gen_objective.append( 239 | log_probability * tf.stop_gradient(cum_advantage)) 240 | final_gen_objective = tf.stack(final_gen_objective, axis=1) 241 | final_gen_objective = tf.losses.compute_weighted_loss(final_gen_objective, 242 | tf.to_float(mask)) 243 | final_gen_objective = -final_gen_objective 244 | advantages = tf.stack(advantages, axis=1) 245 | 246 | if add_summaries: 247 | tf.summary.scalar('losses/mse', loss_mse) 248 | tf.summary.scalar('losses/gen_obj', final_gen_objective) 249 | with tf.name_scope('rewards'): 250 | variable_summaries(c_rewards, mask, 'rewards') 251 | 252 | with tf.name_scope('advantages'): 253 | variable_summaries(advantages, mask, 'advantages') 254 | 255 | with tf.name_scope('baselines'): 256 | variable_summaries(c_baseline, mask, 'baselines') 257 | 258 | with tf.name_scope('log_probs'): 259 | variable_summaries(log_probs, mask, 'log_probs') 260 | 261 | with tf.name_scope('d_rewards'): 262 | variable_summaries(d_rewards, mask, 'd_rewards') 263 | 264 | with tf.name_scope('l_rewards'): 265 | variable_summaries(l_rewards_mat, mask, 'l_rewards') 266 | 267 | with tf.name_scope('o_rewards'): 268 | variable_summaries(o_rewards, mask, 'o_rewards') 269 | o_rewards = tf.where(mask, o_rewards, tf.zeros_like(o_rewards)) 270 | minimum = tf.minimum(tf.reduce_min(o_rewards, axis=1, keepdims=True), 0.0) 271 | o_rewards = tf.reduce_sum( 272 | tf.to_float(tf.logical_and(o_rewards > minimum, mask)), axis=1) 273 | o_rewards = tf.reduce_mean(o_rewards) 274 | tf.summary.scalar('mean_found_obj', o_rewards) 275 | 276 | return gan_loss._replace(generator_loss=final_gen_objective, 277 | discriminator_loss=gan_loss.discriminator_loss + loss_mse) 278 | 279 | 280 | def sentence_ae(gan_model, features, labels, add_summaries=True): 281 | """Sentence auto-encoder.""" 282 | with tf.variable_scope(gan_model.discriminator_scope, reuse=True): 283 | feat = discriminator(features['key'], [None, features['lk']])[2] 284 | with tf.variable_scope(gan_model.generator_scope, reuse=True): 285 | embedding = tf.get_variable( 286 | name='embedding', 287 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 288 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 289 | softmax_w = tf.matrix_transpose(embedding) 290 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 291 | 292 | sentence, ls = labels['sentence'], labels['len'] 293 | targets = sentence[:, 1:] 294 | sentence = sentence[:, :-1] 295 | ls -= 1 296 | sentence = tf.nn.embedding_lookup(embedding, sentence) 297 | 298 | batch_size = tf.shape(feat)[0] 299 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 300 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, FLAGS.keep_prob) 301 | zero_state = cell.zero_state(batch_size, tf.float32) 302 | _, state = cell(feat, zero_state) 303 | tf.get_variable_scope().reuse_variables() 304 | out, state = tf.nn.dynamic_rnn(cell, sentence, ls, state) 305 | out = tf.reshape(out, [-1, FLAGS.mem_dim]) 306 | logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) 307 | logits = tf.reshape(logits, [batch_size, -1, FLAGS.vocab_size]) 308 | 309 | mask = tf.sequence_mask(ls, tf.shape(sentence)[1]) 310 | targets = tf.boolean_mask(targets, mask) 311 | logits = tf.boolean_mask(logits, mask) 312 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, 313 | logits=logits) 314 | loss = tf.reduce_mean(loss) 315 | if add_summaries: 316 | tf.summary.scalar('losses/sentence_ae', loss) 317 | return loss 318 | 319 | 320 | def model_fn(features, labels, mode, params): 321 | """The full unsupervised captioning model.""" 322 | is_chief = not tf.get_variable_scope().reuse 323 | 324 | with slim.arg_scope(inception_v4.inception_v4_arg_scope()): 325 | net, _ = inception_v4.inception_v4(features['im'], None, is_training=False) 326 | net = tf.squeeze(net, [1, 2]) 327 | inc_saver = tf.train.Saver(tf.global_variables('InceptionV4')) 328 | 329 | gan_model = tfgan.gan_model( 330 | generator_fn=generator, 331 | discriminator_fn=discriminator, 332 | real_data=labels['sentence'][:, 1:], 333 | generator_inputs=(net, labels['len'] - 1), 334 | check_shapes=False) 335 | 336 | if is_chief: 337 | for variable in tf.trainable_variables(): 338 | tf.summary.histogram(variable.op.name, variable) 339 | tf.summary.histogram('logits/gen_logits', 340 | gan_model.discriminator_gen_outputs[0]) 341 | tf.summary.histogram('logits/real_logits', 342 | gan_model.discriminator_real_outputs[0]) 343 | 344 | def gen_loss_fn(gan_model, add_summaries): 345 | return 0 346 | 347 | def dis_loss_fn(gan_model, add_summaries): 348 | discriminator_real_outputs = gan_model.discriminator_real_outputs 349 | discriminator_gen_outputs = gan_model.discriminator_gen_outputs 350 | real_logits = tf.boolean_mask(discriminator_real_outputs[0], 351 | discriminator_real_outputs[1]) 352 | gen_logits = tf.boolean_mask(discriminator_gen_outputs[0], 353 | discriminator_gen_outputs[1]) 354 | return modified_discriminator_loss(real_logits, gen_logits, 355 | add_summaries=add_summaries) 356 | 357 | with tf.name_scope('losses'): 358 | pool_fn = functools.partial(tfgan.features.tensor_pool, 359 | pool_size=FLAGS.pool_size) 360 | gan_loss = tfgan.gan_loss( 361 | gan_model, 362 | generator_loss_fn=gen_loss_fn, 363 | discriminator_loss_fn=dis_loss_fn, 364 | gradient_penalty_weight=10 if FLAGS.wass else 0, 365 | tensor_pool_fn=pool_fn if FLAGS.use_pool else None, 366 | add_summaries=is_chief) 367 | if is_chief: 368 | tfgan.eval.add_regularization_loss_summaries(gan_model) 369 | gan_loss = rl_loss(gan_model, gan_loss, features['classes'], 370 | features['scores'], features['num'], 371 | add_summaries=is_chief) 372 | sen_ae_loss = sentence_ae(gan_model, features, labels, is_chief) 373 | loss = gan_loss.generator_loss + gan_loss.discriminator_loss + sen_ae_loss 374 | gan_loss = gan_loss._replace( 375 | generator_loss=gan_loss.generator_loss + sen_ae_loss) 376 | 377 | with tf.name_scope('train'): 378 | gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5) 379 | dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5) 380 | if params.multi_gpu: 381 | gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt) 382 | dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt) 383 | train_ops = tfgan.gan_train_ops( 384 | gan_model, 385 | gan_loss, 386 | generator_optimizer=gen_opt, 387 | discriminator_optimizer=dis_opt, 388 | transform_grads_fn=transform_grads_fn, 389 | summarize_gradients=is_chief, 390 | check_for_unused_update_ops=not FLAGS.use_pool, 391 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) 392 | train_op = train_ops.global_step_inc_op 393 | train_hooks = get_sequential_train_hooks()(train_ops) 394 | 395 | # Summary the generated caption on the fly. 396 | if is_chief: 397 | with open('data/word_counts.txt', 'r') as f: 398 | dic = list(f) 399 | dic = [i.split()[0] for i in dic] 400 | dic.append('') 401 | dic = tf.convert_to_tensor(dic) 402 | sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id) 403 | sentence = tf.gather(dic, sentence) 404 | real = crop_sentence(gan_model.real_data[0], FLAGS.end_id) 405 | real = tf.gather(dic, real) 406 | train_hooks.append( 407 | tf.train.LoggingTensorHook({'fake': sentence, 'real': real}, 408 | every_n_iter=100)) 409 | tf.summary.text('fake', sentence) 410 | tf.summary.image('im', features['im'][None, 0]) 411 | 412 | gen_saver = tf.train.Saver(tf.trainable_variables('Generator')) 413 | dis_var = [] 414 | dis_var.extend(tf.trainable_variables('Discriminator/rnn')) 415 | dis_var.extend(tf.trainable_variables('Discriminator/embedding')) 416 | dis_var.extend(tf.trainable_variables('Discriminator/fc')) 417 | dis_saver = tf.train.Saver(dis_var) 418 | 419 | def init_fn(scaffold, session): 420 | inc_saver.restore(session, FLAGS.inc_ckpt) 421 | if FLAGS.imcap_ckpt: 422 | gen_saver.restore(session, FLAGS.imcap_ckpt) 423 | if FLAGS.sae_ckpt: 424 | dis_saver.restore(session, FLAGS.sae_ckpt) 425 | 426 | scaffold = tf.train.Scaffold(init_fn=init_fn) 427 | 428 | return tf.estimator.EstimatorSpec( 429 | mode=mode, 430 | loss=loss, 431 | train_op=train_op, 432 | scaffold=scaffold, 433 | training_hooks=train_hooks) 434 | 435 | 436 | def main(_): 437 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' 438 | 439 | if FLAGS.multi_gpu: 440 | validate_batch_size_for_multi_gpu(FLAGS.batch_size) 441 | model_function = tf.contrib.estimator.replicate_model_fn( 442 | model_fn, 443 | loss_reduction=tf.losses.Reduction.MEAN) 444 | else: 445 | model_function = model_fn 446 | 447 | sess_config = tf.ConfigProto( 448 | allow_soft_placement=True, 449 | intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, 450 | inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, 451 | gpu_options=tf.GPUOptions(allow_growth=True)) 452 | 453 | run_config = tf.estimator.RunConfig( 454 | session_config=sess_config, 455 | save_checkpoints_steps=FLAGS.save_checkpoint_steps, 456 | save_summary_steps=FLAGS.save_summary_steps, 457 | keep_checkpoint_max=100) 458 | 459 | train_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size) 460 | 461 | estimator = tf.estimator.Estimator( 462 | model_fn=model_function, 463 | model_dir=FLAGS.job_dir, 464 | config=run_config, 465 | params=FLAGS) 466 | 467 | estimator.train(train_input_fn, max_steps=FLAGS.max_steps) 468 | 469 | 470 | if __name__ == '__main__': 471 | tf.app.run() 472 | -------------------------------------------------------------------------------- /initialization/eval_obj2sen.py: -------------------------------------------------------------------------------- 1 | """Evaluates the performance of all the checkpoints on validation set.""" 2 | import glob 3 | import json 4 | import multiprocessing 5 | import os 6 | import sys 7 | 8 | import tensorflow as tf 9 | from absl import app 10 | from absl import flags 11 | 12 | from config import COCO_PATH 13 | 14 | flags.DEFINE_integer('threads', 1, 'num of threads') 15 | 16 | from sentence_infer import Infer 17 | 18 | sys.path.insert(0, COCO_PATH) 19 | from pycocotools.coco import COCO 20 | from pycocoevalcap.eval import COCOEvalCap 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | 25 | def initializer(): 26 | """Decides which GPU is assigned to a worker. 27 | 28 | If your GPU memory is large enough, you may put several workers in one GPU. 29 | """ 30 | devices = os.getenv('CUDA_VISIBLE_DEVICES') 31 | if devices is None: 32 | devices = [] 33 | else: 34 | devices = devices.split(',') 35 | if len(devices) == 0: 36 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 37 | else: 38 | current = multiprocessing.current_process() 39 | id = (current._identity[0] - 1) % len(devices) 40 | os.environ['CUDA_VISIBLE_DEVICES'] = devices[id] 41 | 42 | 43 | def parse_image(serialized): 44 | """Parses a tensorflow.SequenceExample into an image and detected objects. 45 | 46 | Args: 47 | serialized: A scalar string Tensor; a single serialized SequenceExample. 48 | 49 | Returns: 50 | name: A scalar string Tensor containing the image name. 51 | classes: A 1-D int64 Tensor containing the detected objects. 52 | scores: A 1-D float32 Tensor containing the detection scores. 53 | """ 54 | context, sequence = tf.parse_single_sequence_example( 55 | serialized, 56 | context_features={ 57 | 'image/name': tf.FixedLenFeature([], dtype=tf.string) 58 | }, 59 | sequence_features={ 60 | 'classes': tf.FixedLenSequenceFeature([], dtype=tf.int64), 61 | 'scores': tf.FixedLenSequenceFeature([], dtype=tf.float32), 62 | }) 63 | 64 | name = context['image/name'] 65 | classes = tf.to_int32(sequence['classes']) 66 | scores = sequence['scores'] 67 | return name, classes, scores 68 | 69 | 70 | def run(inp): 71 | out = FLAGS.job_dir + '/val_%s.json' % inp 72 | if not os.path.exists(out): 73 | with open(COCO_PATH + '/annotations/captions_val2014.json') as g: 74 | caption_data = json.load(g) 75 | name_to_id = [(x['file_name'], x['id']) for x in caption_data['images']] 76 | name_to_id = dict(name_to_id) 77 | 78 | ret = [] 79 | with tf.Graph().as_default(), tf.Session() as sess: 80 | example = tf.placeholder(tf.string, []) 81 | name_op, class_op, _ = parse_image(example) 82 | infer = Infer(job_dir='%s/model.ckpt-%s' % (FLAGS.job_dir, inp)) 83 | for i in tf.io.tf_record_iterator('data/image_val.tfrec'): 84 | name, classes = sess.run([name_op, class_op], feed_dict={example: i}) 85 | sentences = infer.infer(classes[::-1]) 86 | cur = {} 87 | cur['image_id'] = name_to_id[name] 88 | cur['caption'] = sentences[0][0] 89 | ret.append(cur) 90 | with open(out, 'w') as g: 91 | json.dump(ret, g) 92 | 93 | coco = COCO(COCO_PATH + '/annotations/captions_val2014.json') 94 | cocoRes = coco.loadRes(out) 95 | # create cocoEval object by taking coco and cocoRes 96 | cocoEval = COCOEvalCap(coco, cocoRes) 97 | # evaluate on a subset of images by setting 98 | # cocoEval.params['image_id'] = cocoRes.getImgIds() 99 | # please remove this line when evaluating the full validation set 100 | cocoEval.params['image_id'] = cocoRes.getImgIds() 101 | # evaluate results 102 | cocoEval.evaluate() 103 | return (inp, cocoEval.eval['CIDEr'], cocoEval.eval['METEOR'], 104 | cocoEval.eval['Bleu_4'], cocoEval.eval['Bleu_3'], 105 | cocoEval.eval['Bleu_2']) 106 | 107 | 108 | def main(_): 109 | results = glob.glob(FLAGS.job_dir + '/model.ckpt-*') 110 | results = [os.path.splitext(i)[0] for i in results] 111 | results = set(results) 112 | gs_list = [i.split('-')[-1] for i in results] 113 | 114 | pool = multiprocessing.Pool(FLAGS.threads, initializer) 115 | ret = pool.map(run, gs_list) 116 | pool.close() 117 | pool.join() 118 | 119 | ret = sorted(ret, key=lambda x: x[1]) 120 | with open(FLAGS.job_dir + '/cider.json', 'w') as f: 121 | json.dump(ret, f) 122 | ret = sorted(ret, key=lambda x: x[2]) 123 | with open(FLAGS.job_dir + '/meteor.json', 'w') as f: 124 | json.dump(ret, f) 125 | ret = sorted(ret, key=lambda x: x[3]) 126 | with open(FLAGS.job_dir + '/b4.json', 'w') as f: 127 | json.dump(ret, f) 128 | ret = sorted(ret, key=lambda x: x[4]) 129 | with open(FLAGS.job_dir + '/b3.json', 'w') as f: 130 | json.dump(ret, f) 131 | ret = sorted(ret, key=lambda x: x[5]) 132 | with open(FLAGS.job_dir + '/b2.json', 'w') as f: 133 | json.dump(ret, f) 134 | ret = sorted(ret, key=lambda x: x[3] + x[4]) 135 | with open(FLAGS.job_dir + '/b34.json', 'w') as f: 136 | json.dump(ret, f) 137 | 138 | 139 | if __name__ == '__main__': 140 | app.run(main) 141 | -------------------------------------------------------------------------------- /initialization/gen_obj2sen_caption.py: -------------------------------------------------------------------------------- 1 | """Generate pseudo captions. 2 | 3 | python initialization/gen_obj2sen_caption.py --num_proc 64 4 | """ 5 | 6 | import multiprocessing 7 | import os 8 | from functools import partial 9 | 10 | from absl import app 11 | from absl import flags 12 | 13 | from misc_fn import _int64_feature_list 14 | 15 | flags.DEFINE_integer('num_proc', 1, 'number of processes') 16 | 17 | flags.DEFINE_integer('num_gpus', 1, 'number of gpus') 18 | 19 | from sentence_infer import Infer 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | 24 | def initializer(): 25 | if FLAGS.num_gpus > 0: 26 | current = multiprocessing.current_process() 27 | id = current._identity[0] - 1 28 | os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % (id % FLAGS.num_gpus) 29 | else: 30 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 31 | global infer 32 | infer = Infer() 33 | 34 | 35 | def run(classes): 36 | tf = infer.tf 37 | sentences = infer.infer(classes[::-1]) 38 | sentence = sentences[0][0].split() 39 | sentence = [infer.vocab.word_to_id(i) for i in sentence] 40 | context = tf.train.Features() 41 | feature_lists = tf.train.FeatureLists(feature_list={ 42 | 'sentence': _int64_feature_list(sentence) 43 | }) 44 | sequence_example = tf.train.SequenceExample( 45 | context=context, feature_lists=feature_lists) 46 | return sequence_example.SerializeToString() 47 | 48 | 49 | def parse_image(serialized, tf): 50 | """Parses a tensorflow.SequenceExample into an image and detected objects. 51 | 52 | Args: 53 | serialized: A scalar string Tensor; a single serialized SequenceExample. 54 | 55 | Returns: 56 | encoded_image: A scalar string Tensor containing a JPEG encoded image. 57 | classes: A 1-D int64 Tensor containing the detected objects. 58 | scores: A 1-D float32 Tensor containing the detection scores. 59 | """ 60 | context, sequence = tf.parse_single_sequence_example( 61 | serialized, 62 | sequence_features={ 63 | 'classes': tf.FixedLenSequenceFeature([], dtype=tf.int64), 64 | 'scores': tf.FixedLenSequenceFeature([], dtype=tf.float32), 65 | }) 66 | 67 | classes = tf.to_int32(sequence['classes']) 68 | scores = sequence['scores'] 69 | return classes, scores 70 | 71 | 72 | def image_generator(tf): 73 | ds = tf.data.TFRecordDataset('data/image_train.tfrec') 74 | ds = ds.map( 75 | partial(parse_image, tf=tf), 76 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 77 | for classes, scores in ds: 78 | yield classes.numpy() 79 | 80 | 81 | def main(_): 82 | pool = multiprocessing.Pool(FLAGS.num_proc, initializer=initializer) 83 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 84 | import tensorflow as tf 85 | tf.enable_eager_execution() 86 | with tf.python_io.TFRecordWriter('data/obj2sen_captions.tfrec') as writer: 87 | for i in pool.imap(run, image_generator(tf)): 88 | writer.write(i) 89 | 90 | 91 | if __name__ == '__main__': 92 | app.run(main) 93 | -------------------------------------------------------------------------------- /initialization/im_caption.py: -------------------------------------------------------------------------------- 1 | """Train sentence gan model. 2 | 3 | python initialization/im_caption.py --batch_size 512 --multi_gpu\ 4 | --save_checkpoint_steps 2000 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import functools 12 | import os 13 | import sys 14 | 15 | import tensorflow as tf 16 | import tensorflow.contrib.slim as slim 17 | 18 | from config import TF_MODELS_PATH 19 | from misc_fn import crop_sentence 20 | from misc_fn import transform_grads_fn 21 | from misc_fn import validate_batch_size_for_multi_gpu 22 | from input_pipeline import parse_image 23 | from input_pipeline import preprocess_image 24 | from input_pipeline import AUTOTUNE 25 | from input_pipeline import parse_sentence 26 | 27 | sys.path.append(TF_MODELS_PATH + '/research/slim') 28 | from nets import inception_v4 29 | 30 | tf.logging.set_verbosity(tf.logging.INFO) 31 | 32 | tf.flags.DEFINE_integer('intra_op_parallelism_threads', 0, 'Number of threads') 33 | 34 | tf.flags.DEFINE_integer('inter_op_parallelism_threads', 0, 'Number of threads') 35 | 36 | tf.flags.DEFINE_bool('multi_gpu', False, 'use multi gpus') 37 | 38 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 39 | 40 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 41 | 42 | tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') 43 | 44 | tf.flags.DEFINE_string('job_dir', 'saving_imcap', 'job dir') 45 | 46 | tf.flags.DEFINE_integer('batch_size', 512, 'batch size') 47 | 48 | tf.flags.DEFINE_integer('max_steps', 1000000, 'training steps') 49 | 50 | tf.flags.DEFINE_float('weight_decay', 0, 'weight decay') 51 | 52 | tf.flags.DEFINE_float('lr', 0.001, 'learning rate') 53 | 54 | tf.flags.DEFINE_integer('save_summary_steps', 100, 'save summary steps') 55 | 56 | tf.flags.DEFINE_integer('save_checkpoint_steps', 2000, 'save ckpt') 57 | 58 | tf.flags.DEFINE_string('inc_ckpt', None, 'InceptionV4 checkpoint') 59 | 60 | tf.flags.DEFINE_string('o2s_ckpt', None, 'ckpt') 61 | 62 | FLAGS = tf.flags.FLAGS 63 | 64 | 65 | def model_fn(features, labels, mode, params): 66 | is_chief = not tf.get_variable_scope().reuse 67 | is_training = mode == tf.estimator.ModeKeys.TRAIN 68 | batch_size = tf.shape(features)[0] 69 | 70 | with slim.arg_scope(inception_v4.inception_v4_arg_scope()): 71 | net, _ = inception_v4.inception_v4(features, None, is_training=False) 72 | net = tf.squeeze(net, [1, 2]) 73 | inc_saver = tf.train.Saver(tf.global_variables('InceptionV4')) 74 | 75 | with tf.variable_scope('Generator'): 76 | feat = slim.fully_connected(net, FLAGS.mem_dim, activation_fn=None) 77 | feat = tf.nn.l2_normalize(feat, axis=1) 78 | sentence, ls = labels['sentence'], labels['len'] 79 | targets = sentence[:, 1:] 80 | sentence = sentence[:, :-1] 81 | ls -= 1 82 | 83 | embedding = tf.get_variable( 84 | name='embedding', 85 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 86 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 87 | softmax_w = tf.matrix_transpose(embedding) 88 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 89 | sentence = tf.nn.embedding_lookup(embedding, sentence) 90 | 91 | cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) 92 | if is_training: 93 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, 94 | params.keep_prob) 95 | zero_state = cell.zero_state(batch_size, tf.float32) 96 | _, state = cell(feat, zero_state) 97 | tf.get_variable_scope().reuse_variables() 98 | out, state = tf.nn.dynamic_rnn(cell, sentence, ls, state) 99 | out = tf.reshape(out, [-1, FLAGS.mem_dim]) 100 | logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) 101 | logits = tf.reshape(logits, [batch_size, -1, FLAGS.vocab_size]) 102 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 103 | 104 | mask = tf.sequence_mask(ls, tf.shape(sentence)[1]) 105 | targets = tf.boolean_mask(targets, mask) 106 | logits = tf.boolean_mask(logits, mask) 107 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, 108 | logits=logits) 109 | loss = tf.reduce_mean(loss) 110 | 111 | opt = tf.train.AdamOptimizer(params.lr) 112 | if params.multi_gpu: 113 | opt = tf.contrib.estimator.TowerOptimizer(opt) 114 | grads = opt.compute_gradients(loss, tf.trainable_variables('Generator')) 115 | grads[2] = (tf.convert_to_tensor(grads[2][0]), grads[2][1]) 116 | for i in range(2, len(grads)): 117 | grads[i] = (grads[i][0] * 0.1, grads[i][1]) 118 | grads = transform_grads_fn(grads) 119 | train_op = opt.apply_gradients(grads, global_step=tf.train.get_global_step()) 120 | 121 | train_hooks = None 122 | if is_chief: 123 | with open('data/word_counts.txt', 'r') as f: 124 | dic = list(f) 125 | dic = [i.split()[0] for i in dic] 126 | end_id = dic.index('') 127 | dic.append('') 128 | dic = tf.convert_to_tensor(dic) 129 | sentence = crop_sentence(predictions[0], end_id) 130 | sentence = tf.gather(dic, sentence) 131 | tf.summary.text('fake', sentence) 132 | tf.summary.image('im', features[None, 0]) 133 | for variable in tf.trainable_variables(): 134 | tf.summary.histogram(variable.op.name, variable) 135 | 136 | predictions = tf.boolean_mask(predictions, mask) 137 | metrics = { 138 | 'acc': tf.metrics.accuracy(targets, predictions) 139 | } 140 | 141 | gen_var = tf.trainable_variables('Generator')[2:] 142 | gen_saver = tf.train.Saver(gen_var) 143 | 144 | def init_fn(scaffold, session): 145 | inc_saver.restore(session, FLAGS.inc_ckpt) 146 | if FLAGS.o2s_ckpt: 147 | gen_saver.restore(session, FLAGS.o2s_ckpt) 148 | 149 | scaffold = tf.train.Scaffold(init_fn=init_fn) 150 | 151 | return tf.estimator.EstimatorSpec( 152 | mode=mode, 153 | loss=loss, 154 | train_op=train_op, 155 | scaffold=scaffold, 156 | training_hooks=train_hooks, 157 | eval_metric_ops=metrics) 158 | 159 | 160 | def batching_func(x, batch_size): 161 | return x.padded_batch( 162 | batch_size, 163 | padded_shapes=( 164 | tf.TensorShape([299, 299, 3]), 165 | tf.TensorShape([None]), 166 | tf.TensorShape([]))) 167 | 168 | 169 | def take(image, sentence): 170 | sentence = tf.concat([[FLAGS.start_id], sentence[2], [FLAGS.end_id]], axis=0) 171 | return image[0], sentence, tf.shape(sentence)[0] 172 | 173 | 174 | def input_fn(batch_size): 175 | image_ds = tf.data.TFRecordDataset('data/image_train.tfrec') 176 | image_ds = image_ds.map(parse_image, num_parallel_calls=AUTOTUNE) 177 | image_ds = image_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE) 178 | 179 | sentence_ds = tf.data.TFRecordDataset('data/obj2sen_captions.tfrec') 180 | sentence_ds = sentence_ds.map(parse_sentence, num_parallel_calls=AUTOTUNE) 181 | 182 | dataset = tf.data.Dataset.zip((image_ds, sentence_ds)) 183 | dataset = dataset.filter(lambda im, sen: tf.not_equal(im[3], 0)) 184 | dataset = dataset.map(take) 185 | dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(4096)) 186 | dataset = batching_func(dataset, batch_size) 187 | dataset = dataset.prefetch(AUTOTUNE) 188 | iterator = dataset.make_one_shot_iterator() 189 | im, sentence, ls = iterator.get_next() 190 | return im, {'sentence': sentence, 'len': ls} 191 | 192 | 193 | def main(_): 194 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' 195 | 196 | if FLAGS.multi_gpu: 197 | validate_batch_size_for_multi_gpu(FLAGS.batch_size) 198 | model_function = tf.contrib.estimator.replicate_model_fn( 199 | model_fn, 200 | loss_reduction=tf.losses.Reduction.MEAN) 201 | else: 202 | model_function = model_fn 203 | 204 | sess_config = tf.ConfigProto( 205 | allow_soft_placement=True, 206 | intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, 207 | inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, 208 | gpu_options=tf.GPUOptions(allow_growth=True)) 209 | 210 | run_config = tf.estimator.RunConfig( 211 | session_config=sess_config, 212 | save_checkpoints_steps=FLAGS.save_checkpoint_steps, 213 | save_summary_steps=FLAGS.save_summary_steps, 214 | keep_checkpoint_max=100) 215 | 216 | train_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size) 217 | 218 | estimator = tf.estimator.Estimator( 219 | model_fn=model_function, 220 | model_dir=FLAGS.job_dir, 221 | config=run_config, 222 | params=FLAGS) 223 | 224 | estimator.train(train_input_fn, max_steps=FLAGS.max_steps) 225 | 226 | 227 | if __name__ == '__main__': 228 | tf.app.run() 229 | -------------------------------------------------------------------------------- /initialization/obj2sen.py: -------------------------------------------------------------------------------- 1 | """Train object-to-sentence model. 2 | 3 | python initialization/obj2sen.py --batch_size 512 --save_checkpoint_steps 5000 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import functools 11 | import os 12 | 13 | import tensorflow as tf 14 | 15 | from config import NUM_DESCRIPTIONS 16 | from misc_fn import crop_sentence 17 | from misc_fn import transform_grads_fn 18 | from misc_fn import validate_batch_size_for_multi_gpu 19 | from input_pipeline import AUTOTUNE 20 | 21 | tf.logging.set_verbosity(tf.logging.INFO) 22 | 23 | tf.flags.DEFINE_integer('intra_op_parallelism_threads', 0, 'Number of threads') 24 | 25 | tf.flags.DEFINE_integer('inter_op_parallelism_threads', 0, 'Number of threads') 26 | 27 | tf.flags.DEFINE_bool('multi_gpu', False, 'use multi gpus') 28 | 29 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 30 | 31 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 32 | 33 | tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') 34 | 35 | tf.flags.DEFINE_string('job_dir', 'obj2sen', 'job dir') 36 | 37 | tf.flags.DEFINE_integer('batch_size', 512, 'batch size') 38 | 39 | tf.flags.DEFINE_integer('max_steps', 1000000, 'training steps') 40 | 41 | tf.flags.DEFINE_float('weight_decay', 0, 'weight decay') 42 | 43 | tf.flags.DEFINE_float('lr', 0.001, 'learning rate') 44 | 45 | tf.flags.DEFINE_integer('save_summary_steps', 100, 'save summary steps') 46 | 47 | tf.flags.DEFINE_integer('save_checkpoint_steps', 5000, 'save ckpt') 48 | 49 | FLAGS = tf.flags.FLAGS 50 | 51 | 52 | def model_fn(features, labels, mode, params): 53 | is_training = mode == tf.estimator.ModeKeys.TRAIN 54 | 55 | with tf.variable_scope('Discriminator'): 56 | embedding = tf.get_variable( 57 | name='embedding', 58 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 59 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 60 | 61 | key, lk = features['key'], features['len'] 62 | key = tf.nn.embedding_lookup(embedding, key) 63 | sentence, ls = labels['sentence'], labels['len'] 64 | targets = sentence[:, 1:] 65 | sentence = sentence[:, :-1] 66 | ls -= 1 67 | sentence = tf.nn.embedding_lookup(embedding, sentence) 68 | 69 | cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) 70 | if is_training: 71 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, 72 | params.keep_prob) 73 | out, initial_state = tf.nn.dynamic_rnn(cell, key, lk, dtype=tf.float32) 74 | 75 | feat = tf.nn.l2_normalize(initial_state[1], axis=1) 76 | batch_size = tf.shape(feat)[0] 77 | 78 | with tf.variable_scope('Generator'): 79 | embedding = tf.get_variable( 80 | name='embedding', 81 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 82 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 83 | softmax_w = tf.matrix_transpose(embedding) 84 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 85 | 86 | cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) 87 | if is_training: 88 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, 89 | params.keep_prob) 90 | zero_state = cell.zero_state(batch_size, tf.float32) 91 | _, state = cell(feat, zero_state) 92 | tf.get_variable_scope().reuse_variables() 93 | out, state = tf.nn.dynamic_rnn(cell, sentence, ls, state) 94 | out = tf.reshape(out, [-1, FLAGS.mem_dim]) 95 | logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) 96 | logits = tf.reshape(logits, [batch_size, -1, FLAGS.vocab_size]) 97 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 98 | 99 | mask = tf.sequence_mask(ls, tf.shape(sentence)[1]) 100 | targets = tf.boolean_mask(targets, mask) 101 | logits = tf.boolean_mask(logits, mask) 102 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, 103 | logits=logits) 104 | loss = tf.reduce_mean(loss) 105 | 106 | opt = tf.train.AdamOptimizer(params.lr) 107 | if params.multi_gpu: 108 | opt = tf.contrib.estimator.TowerOptimizer(opt) 109 | grads = opt.compute_gradients(loss) 110 | grads = transform_grads_fn(grads) 111 | train_op = opt.apply_gradients(grads, global_step=tf.train.get_global_step()) 112 | 113 | train_hooks = None 114 | if not FLAGS.multi_gpu or opt._graph_state().is_the_last_tower: 115 | with open('data/word_counts.txt', 'r') as f: 116 | dic = list(f) 117 | dic = [i.split()[0] for i in dic] 118 | end_id = dic.index('') 119 | dic.append('') 120 | dic = tf.convert_to_tensor(dic) 121 | noise = features['key'][0] 122 | m = tf.sequence_mask(features['len'][0], tf.shape(noise)[0]) 123 | noise = tf.boolean_mask(noise, m) 124 | noise = tf.gather(dic, noise) 125 | sentence = crop_sentence(labels['sentence'][0], end_id) 126 | sentence = tf.gather(dic, sentence) 127 | pred = crop_sentence(predictions[0], end_id) 128 | pred = tf.gather(dic, pred) 129 | train_hooks = [tf.train.LoggingTensorHook( 130 | {'sentence': sentence, 'noise': noise, 'pred': pred}, every_n_iter=100)] 131 | for variable in tf.trainable_variables(): 132 | tf.summary.histogram(variable.op.name, variable) 133 | 134 | predictions = tf.boolean_mask(predictions, mask) 135 | metrics = { 136 | 'acc': tf.metrics.accuracy(targets, predictions) 137 | } 138 | 139 | return tf.estimator.EstimatorSpec( 140 | mode=mode, 141 | loss=loss, 142 | train_op=train_op, 143 | training_hooks=train_hooks, 144 | eval_metric_ops=metrics) 145 | 146 | 147 | def batching_func(x, batch_size): 148 | return x.padded_batch( 149 | batch_size, 150 | padded_shapes=( 151 | tf.TensorShape([None]), 152 | tf.TensorShape([]), 153 | tf.TensorShape([None]), 154 | tf.TensorShape([]))) 155 | 156 | 157 | def parse_sentence(serialized): 158 | """Parses a tensorflow.SequenceExample into an caption. 159 | 160 | Args: 161 | serialized: A scalar string Tensor; a single serialized SequenceExample. 162 | 163 | Returns: 164 | key: The keywords in a sentence. 165 | num_key: The number of keywords. 166 | sentence: A description. 167 | sentence_length: The length of the description. 168 | """ 169 | context, sequence = tf.parse_single_sequence_example( 170 | serialized, 171 | context_features={}, 172 | sequence_features={ 173 | 'key': tf.FixedLenSequenceFeature([], dtype=tf.int64), 174 | 'sentence': tf.FixedLenSequenceFeature([], dtype=tf.int64), 175 | }) 176 | key = tf.to_int32(sequence['key']) 177 | key = tf.random_shuffle(key) 178 | sentence = tf.to_int32(sequence['sentence']) 179 | return key, tf.shape(key)[0], sentence, tf.shape(sentence)[0] 180 | 181 | 182 | def input_fn(batch_size, subset='train'): 183 | sentence_ds = tf.data.TFRecordDataset('data/sentence.tfrec') 184 | num_val = NUM_DESCRIPTIONS // 50 185 | if subset == 'train': 186 | sentence_ds = sentence_ds.skip(num_val) 187 | else: 188 | sentence_ds = sentence_ds.take(num_val) 189 | sentence_ds = sentence_ds.map(parse_sentence, num_parallel_calls=AUTOTUNE) 190 | 191 | sentence_ds = sentence_ds.filter(lambda k, lk, s, ls: tf.not_equal(lk, 0)) 192 | if subset == 'train': 193 | sentence_ds = sentence_ds.apply(tf.contrib.data.shuffle_and_repeat(65536)) 194 | sentence_ds = batching_func(sentence_ds, batch_size) 195 | sentence_ds = sentence_ds.prefetch(AUTOTUNE) 196 | iterator = sentence_ds.make_one_shot_iterator() 197 | key, lk, sentence, ls = iterator.get_next() 198 | return {'key': key, 'len': lk}, {'sentence': sentence, 'len': ls} 199 | 200 | 201 | def main(_): 202 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' 203 | 204 | if FLAGS.multi_gpu: 205 | validate_batch_size_for_multi_gpu(FLAGS.batch_size) 206 | model_function = tf.contrib.estimator.replicate_model_fn( 207 | model_fn, 208 | loss_reduction=tf.losses.Reduction.MEAN) 209 | else: 210 | model_function = model_fn 211 | 212 | sess_config = tf.ConfigProto( 213 | allow_soft_placement=True, 214 | intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, 215 | inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, 216 | gpu_options=tf.GPUOptions(allow_growth=True)) 217 | 218 | run_config = tf.estimator.RunConfig( 219 | session_config=sess_config, 220 | save_checkpoints_steps=FLAGS.save_checkpoint_steps, 221 | save_summary_steps=FLAGS.save_summary_steps, 222 | keep_checkpoint_max=100) 223 | 224 | train_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size) 225 | 226 | eval_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size, 227 | subset='val') 228 | 229 | estimator = tf.estimator.Estimator( 230 | model_fn=model_function, 231 | model_dir=FLAGS.job_dir, 232 | config=run_config, 233 | params=FLAGS) 234 | 235 | train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, 236 | max_steps=FLAGS.max_steps) 237 | eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None) 238 | 239 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 240 | 241 | 242 | if __name__ == '__main__': 243 | tf.app.run() 244 | -------------------------------------------------------------------------------- /initialization/sentence_ae.py: -------------------------------------------------------------------------------- 1 | """Train sentence autoencoder model. 2 | 3 | python initialization/sentence_ae.py --batch_size 512\ 4 | --save_checkpoint_steps 5000 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import functools 12 | import os 13 | 14 | import tensorflow as tf 15 | 16 | from config import NUM_DESCRIPTIONS 17 | from input_pipeline import AUTOTUNE 18 | from input_pipeline import parse_sentence 19 | from misc_fn import crop_sentence 20 | from misc_fn import transform_grads_fn 21 | from misc_fn import validate_batch_size_for_multi_gpu 22 | 23 | tf.logging.set_verbosity(tf.logging.INFO) 24 | 25 | tf.flags.DEFINE_integer('intra_op_parallelism_threads', 0, 'Number of threads') 26 | 27 | tf.flags.DEFINE_integer('inter_op_parallelism_threads', 0, 'Number of threads') 28 | 29 | tf.flags.DEFINE_bool('multi_gpu', False, 'use multi gpus') 30 | 31 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 32 | 33 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 34 | 35 | tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') 36 | 37 | tf.flags.DEFINE_string('job_dir', 'sen_ae', 'job dir') 38 | 39 | tf.flags.DEFINE_integer('batch_size', 512, 'batch size') 40 | 41 | tf.flags.DEFINE_integer('max_steps', 1000000, 'training steps') 42 | 43 | tf.flags.DEFINE_float('weight_decay', 0, 'weight decay') 44 | 45 | tf.flags.DEFINE_float('lr', 0.001, 'learning rate') 46 | 47 | tf.flags.DEFINE_integer('save_summary_steps', 100, 'save summary steps') 48 | 49 | tf.flags.DEFINE_integer('save_checkpoint_steps', 5000, 'save ckpt') 50 | 51 | FLAGS = tf.flags.FLAGS 52 | 53 | 54 | def model_fn(features, labels, mode, params): 55 | is_training = mode == tf.estimator.ModeKeys.TRAIN 56 | 57 | with tf.variable_scope('Discriminator'): 58 | embedding = tf.get_variable( 59 | name='embedding', 60 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 61 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 62 | 63 | noisy_sentence, lns = features['noisy_sentence'], features['len'] 64 | noisy_sentence = tf.nn.embedding_lookup(embedding, noisy_sentence) 65 | 66 | cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) 67 | if is_training: 68 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, 69 | params.keep_prob) 70 | out, initial_state = tf.nn.dynamic_rnn(cell, noisy_sentence, lns, 71 | dtype=tf.float32) 72 | 73 | feat = tf.nn.l2_normalize(initial_state[1], axis=1) 74 | batch_size = tf.shape(feat)[0] 75 | 76 | with tf.variable_scope('Generator'): 77 | embedding = tf.get_variable( 78 | name='embedding', 79 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 80 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 81 | softmax_w = tf.matrix_transpose(embedding) 82 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 83 | 84 | sentence, ls = labels['sentence'], labels['len'] 85 | targets = sentence[:, 1:] 86 | sentence = sentence[:, :-1] 87 | ls -= 1 88 | sentence = tf.nn.embedding_lookup(embedding, sentence) 89 | 90 | cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) 91 | if is_training: 92 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, 93 | params.keep_prob) 94 | zero_state = cell.zero_state(batch_size, tf.float32) 95 | _, state = cell(feat, zero_state) 96 | tf.get_variable_scope().reuse_variables() 97 | out, state = tf.nn.dynamic_rnn(cell, sentence, ls, state) 98 | out = tf.reshape(out, [-1, FLAGS.mem_dim]) 99 | logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) 100 | logits = tf.reshape(logits, [batch_size, -1, FLAGS.vocab_size]) 101 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 102 | 103 | mask = tf.sequence_mask(ls, tf.shape(sentence)[1]) 104 | targets = tf.boolean_mask(targets, mask) 105 | logits = tf.boolean_mask(logits, mask) 106 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, 107 | logits=logits) 108 | loss = tf.reduce_mean(loss) 109 | 110 | opt = tf.train.AdamOptimizer(params.lr) 111 | if params.multi_gpu: 112 | opt = tf.contrib.estimator.TowerOptimizer(opt) 113 | grads = opt.compute_gradients(loss) 114 | grads = transform_grads_fn(grads) 115 | train_op = opt.apply_gradients(grads, global_step=tf.train.get_global_step()) 116 | 117 | train_hooks = None 118 | if not FLAGS.multi_gpu or opt._graph_state().is_the_last_tower: 119 | with open('data/word_counts.txt', 'r') as f: 120 | dic = list(f) 121 | dic = [i.split()[0] for i in dic] 122 | end_id = dic.index('') 123 | dic.append('') 124 | dic = tf.convert_to_tensor(dic) 125 | noise = crop_sentence(features['noisy_sentence'][0], end_id) 126 | noise = tf.gather(dic, noise) 127 | sentence = crop_sentence(labels['sentence'][0], end_id) 128 | sentence = tf.gather(dic, sentence) 129 | pred = crop_sentence(predictions[0], end_id) 130 | pred = tf.gather(dic, pred) 131 | train_hooks = [tf.train.LoggingTensorHook( 132 | {'sentence': sentence, 'noise': noise, 'pred': pred}, every_n_iter=100)] 133 | for variable in tf.trainable_variables(): 134 | tf.summary.histogram(variable.op.name, variable) 135 | 136 | predictions = tf.boolean_mask(predictions, mask) 137 | metrics = { 138 | 'acc': tf.metrics.accuracy(targets, predictions) 139 | } 140 | 141 | return tf.estimator.EstimatorSpec( 142 | mode=mode, 143 | loss=loss, 144 | train_op=train_op, 145 | training_hooks=train_hooks, 146 | eval_metric_ops=metrics) 147 | 148 | 149 | def batching_func(x, batch_size): 150 | return x.padded_batch( 151 | batch_size, 152 | padded_shapes=( 153 | tf.TensorShape([None]), 154 | tf.TensorShape([]), 155 | tf.TensorShape([None]), 156 | tf.TensorShape([]))) 157 | 158 | 159 | def input_fn(batch_size, subset='train'): 160 | sentence_ds = tf.data.TFRecordDataset('data/sentence.tfrec') 161 | num_val = NUM_DESCRIPTIONS // 50 162 | if subset == 'train': 163 | sentence_ds = sentence_ds.skip(num_val) 164 | else: 165 | sentence_ds = sentence_ds.take(num_val) 166 | sentence_ds = sentence_ds.map(parse_sentence, num_parallel_calls=AUTOTUNE) 167 | 168 | if subset == 'train': 169 | sentence_ds = sentence_ds.apply(tf.contrib.data.shuffle_and_repeat(65536)) 170 | sentence_ds = batching_func(sentence_ds, batch_size) 171 | sentence_ds = sentence_ds.prefetch(AUTOTUNE) 172 | iterator = sentence_ds.make_one_shot_iterator() 173 | key, lk, sentence, ls = iterator.get_next() 174 | return {'noisy_sentence': key, 'len': lk}, {'sentence': sentence, 'len': ls} 175 | 176 | 177 | def main(_): 178 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' 179 | 180 | if FLAGS.multi_gpu: 181 | validate_batch_size_for_multi_gpu(FLAGS.batch_size) 182 | model_function = tf.contrib.estimator.replicate_model_fn( 183 | model_fn, 184 | loss_reduction=tf.losses.Reduction.MEAN) 185 | else: 186 | model_function = model_fn 187 | 188 | sess_config = tf.ConfigProto( 189 | allow_soft_placement=True, 190 | intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, 191 | inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, 192 | gpu_options=tf.GPUOptions(allow_growth=True)) 193 | 194 | run_config = tf.estimator.RunConfig( 195 | session_config=sess_config, 196 | save_checkpoints_steps=FLAGS.save_checkpoint_steps, 197 | save_summary_steps=FLAGS.save_summary_steps, 198 | keep_checkpoint_max=100) 199 | 200 | train_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size) 201 | 202 | eval_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size, 203 | subset='val') 204 | 205 | estimator = tf.estimator.Estimator( 206 | model_fn=model_function, 207 | model_dir=FLAGS.job_dir, 208 | config=run_config, 209 | params=FLAGS) 210 | 211 | train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, 212 | max_steps=FLAGS.max_steps) 213 | eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None) 214 | 215 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 216 | 217 | 218 | if __name__ == '__main__': 219 | tf.app.run() 220 | -------------------------------------------------------------------------------- /initialization/sentence_gan.py: -------------------------------------------------------------------------------- 1 | """Train sentence gan model. 2 | 3 | python initialization/sentence_gan.py --batch_size 512 4 | --save_checkpoint_steps 2000 --gen_lr 0.0001 --dis_lr 0.0001 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import functools 12 | import os 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | import tensorflow.contrib.gan as tfgan 17 | import tensorflow.contrib.slim as slim 18 | from tensorflow.contrib.gan.python.losses.python.losses_impl import modified_discriminator_loss 19 | from tensorflow.contrib.gan.python.train import get_sequential_train_hooks 20 | 21 | import config 22 | from input_pipeline import AUTOTUNE 23 | from input_pipeline import parse_sentence 24 | from misc_fn import crop_sentence 25 | from misc_fn import get_len 26 | from misc_fn import transform_grads_fn 27 | from misc_fn import validate_batch_size_for_multi_gpu 28 | from misc_fn import variable_summaries 29 | 30 | tf.logging.set_verbosity(tf.logging.INFO) 31 | 32 | tf.flags.DEFINE_integer('intra_op_parallelism_threads', 0, 'Number of threads') 33 | 34 | tf.flags.DEFINE_integer('inter_op_parallelism_threads', 0, 'Number of threads') 35 | 36 | tf.flags.DEFINE_bool('multi_gpu', False, 'use multi gpus') 37 | 38 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 39 | 40 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 41 | 42 | tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') 43 | 44 | tf.flags.DEFINE_string('job_dir', 'sen_gan', 'job dir') 45 | 46 | tf.flags.DEFINE_integer('batch_size', 512, 'batch size') 47 | 48 | tf.flags.DEFINE_integer('max_steps', 1000000, 'training steps') 49 | 50 | tf.flags.DEFINE_float('weight_decay', 0, 'weight decay') 51 | 52 | tf.flags.DEFINE_float('gen_lr', 0.0001, 'learning rate') 53 | 54 | tf.flags.DEFINE_float('dis_lr', 0.0001, 'learning rate') 55 | 56 | tf.flags.DEFINE_integer('save_summary_steps', 100, 'save summary steps') 57 | 58 | tf.flags.DEFINE_integer('save_checkpoint_steps', 2000, 'save ckpt') 59 | 60 | tf.flags.DEFINE_integer('max_caption_length', 20, 'max len') 61 | 62 | tf.flags.DEFINE_bool('wass', False, 'use wass') 63 | 64 | tf.flags.DEFINE_string('sae_ckpt', 'sen_ae/model.ckpt-65000', 'ckpt') 65 | 66 | FLAGS = tf.flags.FLAGS 67 | 68 | 69 | def generator(inputs, is_training=True): 70 | feat, _ = inputs 71 | embedding = tf.get_variable( 72 | name='embedding', 73 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 74 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 75 | softmax_w = tf.matrix_transpose(embedding) 76 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 77 | 78 | batch_size = tf.shape(feat)[0] 79 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 80 | if is_training: 81 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, FLAGS.keep_prob) 82 | zero_state = cell.zero_state(batch_size, tf.float32) 83 | 84 | sequence, logits, log_probs, rnn_outs = [], [], [], [] 85 | 86 | _, state = cell(feat, zero_state) 87 | state_bl = state 88 | tf.get_variable_scope().reuse_variables() 89 | for t in range(FLAGS.max_caption_length): 90 | if t == 0: 91 | rnn_inp = tf.zeros([batch_size], tf.int32) + FLAGS.start_id 92 | rnn_inp = tf.nn.embedding_lookup(embedding, rnn_inp) 93 | rnn_out, state = cell(rnn_inp, state) 94 | rnn_outs.append(rnn_out) 95 | logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b) 96 | categorical = tf.contrib.distributions.Categorical(logits=logit) 97 | fake = categorical.sample() 98 | log_prob = categorical.log_prob(fake) 99 | sequence.append(fake) 100 | log_probs.append(log_prob) 101 | logits.append(logit) 102 | rnn_inp = fake 103 | sequence = tf.stack(sequence, axis=1) 104 | log_probs = tf.stack(log_probs, axis=1) 105 | logits = tf.stack(logits, axis=1) 106 | 107 | baseline = [] 108 | state = state_bl 109 | for t in range(FLAGS.max_caption_length): 110 | if t == 0: 111 | rnn_inp = tf.zeros([batch_size], tf.int32) + FLAGS.start_id 112 | rnn_inp = tf.nn.embedding_lookup(embedding, rnn_inp) 113 | rnn_out, state = cell(rnn_inp, state) 114 | logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b) 115 | fake = tf.argmax(logit, axis=1, output_type=tf.int32) 116 | baseline.append(fake) 117 | rnn_inp = fake 118 | baseline = tf.stack(baseline, axis=1) 119 | 120 | return sequence, logits, log_probs, baseline 121 | 122 | 123 | def discriminator(generated_data, generator_inputs, is_training=True): 124 | if type(generated_data) is tuple: 125 | sequence = generated_data[0] 126 | length = get_len(sequence, FLAGS.end_id) 127 | else: 128 | sequence = generated_data 129 | length = generator_inputs[-1] 130 | embedding = tf.get_variable( 131 | name='embedding', 132 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 133 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 134 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 135 | if is_training: 136 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, FLAGS.keep_prob) 137 | 138 | rnn_inputs = tf.nn.embedding_lookup(embedding, sequence) 139 | rnn_out, state = tf.nn.dynamic_rnn(cell, rnn_inputs, length, dtype=tf.float32) 140 | pred = slim.fully_connected(rnn_out, 1, activation_fn=None, scope='fc') 141 | pred = tf.squeeze(pred, 2) 142 | mask = tf.sequence_mask(length, tf.shape(sequence)[1]) 143 | return pred, mask 144 | 145 | 146 | def rl_loss(gan_model, gan_loss, add_summaries): 147 | eps = 1e-7 148 | gamma = 0.9 149 | sequence, _, log_probs, seq_bl = gan_model.generated_data 150 | 151 | with tf.variable_scope(gan_model.discriminator_scope, reuse=True): 152 | baselines, _ = discriminator((seq_bl, None, None, None), None) 153 | baselines = tf.stop_gradient(baselines) 154 | 155 | logits, mask = gan_model.discriminator_gen_outputs 156 | 157 | dis_predictions = tf.nn.sigmoid(logits) 158 | rewards = tf.log(dis_predictions + eps) 159 | rewards = tf.where(mask, rewards, tf.zeros_like(rewards)) 160 | 161 | baselines = tf.nn.sigmoid(baselines) 162 | baselines = tf.log(baselines + eps) 163 | baselines = tf.where(mask, baselines, tf.zeros_like(baselines)) 164 | 165 | log_prob_list = tf.unstack(log_probs, axis=1) 166 | rewards_list = tf.unstack(rewards, axis=1) 167 | cumulative_rewards = [] 168 | baseline_list = tf.unstack(baselines, axis=1) 169 | cumulative_baseline = [] 170 | for t in range(FLAGS.max_caption_length): 171 | cum_value = tf.zeros_like(rewards_list[0]) 172 | for s in range(t, FLAGS.max_caption_length): 173 | cum_value += np.power(gamma, s - t) * rewards_list[s] 174 | cumulative_rewards.append(cum_value) 175 | 176 | cum_value = tf.zeros_like(baseline_list[0]) 177 | for s in range(t, FLAGS.max_caption_length): 178 | cum_value += np.power(gamma, s - t) * baseline_list[s] 179 | cumulative_baseline.append(cum_value) 180 | c_rewards = tf.stack(cumulative_rewards, axis=1) 181 | c_baseline = tf.stack(cumulative_baseline, axis=1) 182 | 183 | advantages = [] 184 | final_gen_objective = [] 185 | for t in range(FLAGS.max_caption_length): 186 | log_probability = log_prob_list[t] 187 | cum_advantage = cumulative_rewards[t] - cumulative_baseline[t] 188 | cum_advantage = tf.clip_by_value(cum_advantage, -5.0, 5.0) 189 | advantages.append(cum_advantage) 190 | final_gen_objective.append( 191 | log_probability * tf.stop_gradient(cum_advantage)) 192 | final_gen_objective = tf.stack(final_gen_objective, axis=1) 193 | final_gen_objective = tf.losses.compute_weighted_loss(final_gen_objective, 194 | tf.to_float(mask)) 195 | final_gen_objective = -final_gen_objective 196 | advantages = tf.stack(advantages, axis=1) 197 | 198 | if add_summaries: 199 | tf.summary.scalar('gen_obj', final_gen_objective) 200 | 201 | with tf.name_scope('rewards'): 202 | variable_summaries(c_rewards, mask, 'rewards') 203 | 204 | with tf.name_scope('advantages'): 205 | variable_summaries(advantages, mask, 'advantages') 206 | 207 | with tf.name_scope('baselines'): 208 | variable_summaries(c_baseline, mask, 'baselines') 209 | 210 | with tf.name_scope('log_probs'): 211 | variable_summaries(log_probs, mask, 'log_probs') 212 | 213 | with tf.name_scope('d_rewards'): 214 | variable_summaries(rewards, mask, 'd_rewards') 215 | 216 | return gan_loss._replace(generator_loss=final_gen_objective) 217 | 218 | 219 | def model_fn(features, labels, mode, params): 220 | is_chief = not tf.get_variable_scope().reuse 221 | 222 | batch_size = tf.shape(labels)[0] 223 | noise = tf.random_normal([batch_size, FLAGS.emb_dim]) 224 | noise = tf.nn.l2_normalize(noise, axis=1) 225 | gan_model = tfgan.gan_model( 226 | generator_fn=generator, 227 | discriminator_fn=discriminator, 228 | real_data=features[:, 1:], 229 | generator_inputs=(noise, labels - 1), 230 | check_shapes=False) 231 | if is_chief: 232 | for variable in tf.trainable_variables(): 233 | tf.summary.histogram(variable.op.name, variable) 234 | tf.summary.histogram('logits/gen_logits', 235 | gan_model.discriminator_gen_outputs[0]) 236 | tf.summary.histogram('logits/real_logits', 237 | gan_model.discriminator_real_outputs[0]) 238 | 239 | def gen_loss_fn(gan_model, add_summaries): 240 | return 0 241 | 242 | def dis_loss_fn(gan_model, add_summaries): 243 | discriminator_real_outputs = gan_model.discriminator_real_outputs 244 | discriminator_gen_outputs = gan_model.discriminator_gen_outputs 245 | real_logits = tf.boolean_mask(discriminator_real_outputs[0], 246 | discriminator_real_outputs[1]) 247 | gen_logits = tf.boolean_mask(discriminator_gen_outputs[0], 248 | discriminator_gen_outputs[1]) 249 | return modified_discriminator_loss(real_logits, gen_logits, 250 | add_summaries=add_summaries) 251 | 252 | with tf.name_scope('losses'): 253 | gan_loss = tfgan.gan_loss( 254 | gan_model, 255 | generator_loss_fn=gen_loss_fn, 256 | discriminator_loss_fn=dis_loss_fn, 257 | gradient_penalty_weight=10 if FLAGS.wass else 0, 258 | add_summaries=is_chief) 259 | if is_chief: 260 | tfgan.eval.add_regularization_loss_summaries(gan_model) 261 | gan_loss = rl_loss(gan_model, gan_loss, add_summaries=is_chief) 262 | loss = gan_loss.generator_loss + gan_loss.discriminator_loss 263 | 264 | with tf.name_scope('train'): 265 | gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5) 266 | dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5) 267 | if params.multi_gpu: 268 | gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt) 269 | dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt) 270 | train_ops = tfgan.gan_train_ops( 271 | gan_model, 272 | gan_loss, 273 | generator_optimizer=gen_opt, 274 | discriminator_optimizer=dis_opt, 275 | transform_grads_fn=transform_grads_fn, 276 | summarize_gradients=is_chief, 277 | check_for_unused_update_ops=True, 278 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) 279 | train_op = train_ops.global_step_inc_op 280 | train_hooks = get_sequential_train_hooks()(train_ops) 281 | 282 | if is_chief: 283 | with open('data/word_counts.txt', 'r') as f: 284 | dic = list(f) 285 | dic = [i.split()[0] for i in dic] 286 | dic.append('') 287 | dic = tf.convert_to_tensor(dic) 288 | sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id) 289 | sentence = tf.gather(dic, sentence) 290 | real = crop_sentence(gan_model.real_data[0], FLAGS.end_id) 291 | real = tf.gather(dic, real) 292 | train_hooks.append( 293 | tf.train.LoggingTensorHook({'fake': sentence, 'real': real}, 294 | every_n_iter=100)) 295 | tf.summary.text('fake', sentence) 296 | 297 | gen_var = tf.trainable_variables('Generator') 298 | dis_var = [] 299 | dis_var.extend(tf.trainable_variables('Discriminator/rnn')) 300 | dis_var.extend(tf.trainable_variables('Discriminator/embedding')) 301 | saver = tf.train.Saver(gen_var + dis_var) 302 | 303 | def init_fn(scaffold, session): 304 | saver.restore(session, FLAGS.sae_ckpt) 305 | pass 306 | 307 | scaffold = tf.train.Scaffold(init_fn=init_fn) 308 | 309 | return tf.estimator.EstimatorSpec( 310 | mode=mode, 311 | loss=loss, 312 | train_op=train_op, 313 | scaffold=scaffold, 314 | training_hooks=train_hooks) 315 | 316 | 317 | def batching_func(x, batch_size): 318 | return x.padded_batch( 319 | batch_size, 320 | padded_shapes=( 321 | tf.TensorShape([None]), 322 | tf.TensorShape([])), 323 | drop_remainder=True) 324 | 325 | 326 | def take(key, lk, sentence, ls): 327 | return sentence, ls 328 | 329 | 330 | def input_fn(batch_size): 331 | sentence_ds = tf.data.TFRecordDataset('data/sentence.tfrec') 332 | sentence_ds = sentence_ds.map(parse_sentence, num_parallel_calls=AUTOTUNE) 333 | sentence_ds = sentence_ds.map(take, num_parallel_calls=AUTOTUNE) 334 | sentence_ds = sentence_ds.apply(tf.contrib.data.shuffle_and_repeat(65536)) 335 | sentence_ds = batching_func(sentence_ds, batch_size) 336 | sentence_ds = sentence_ds.prefetch(AUTOTUNE) 337 | iterator = sentence_ds.make_one_shot_iterator() 338 | sentence, ls = iterator.get_next() 339 | return sentence, ls 340 | 341 | 342 | def main(_): 343 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' 344 | 345 | if FLAGS.multi_gpu: 346 | validate_batch_size_for_multi_gpu(FLAGS.batch_size) 347 | model_function = tf.contrib.estimator.replicate_model_fn( 348 | model_fn, 349 | loss_reduction=tf.losses.Reduction.MEAN) 350 | else: 351 | model_function = model_fn 352 | 353 | sess_config = tf.ConfigProto( 354 | allow_soft_placement=True, 355 | intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, 356 | inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, 357 | gpu_options=tf.GPUOptions(allow_growth=True)) 358 | 359 | run_config = tf.estimator.RunConfig( 360 | session_config=sess_config, 361 | save_checkpoints_steps=FLAGS.save_checkpoint_steps, 362 | save_summary_steps=FLAGS.save_summary_steps, 363 | keep_checkpoint_max=100) 364 | 365 | train_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size) 366 | 367 | estimator = tf.estimator.Estimator( 368 | model_fn=model_function, 369 | model_dir=FLAGS.job_dir, 370 | config=run_config, 371 | params=FLAGS) 372 | 373 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.max_steps) 374 | 375 | 376 | if __name__ == '__main__': 377 | tf.app.run() 378 | -------------------------------------------------------------------------------- /initialization/sentence_infer.py: -------------------------------------------------------------------------------- 1 | """Given some object words, infers a whole sentence.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import math 8 | import sys 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from config import TF_MODELS_PATH 14 | 15 | sys.path.append(TF_MODELS_PATH + '/research/im2txt/') 16 | from im2txt.inference_utils.caption_generator import Caption 17 | from im2txt.inference_utils.caption_generator import TopN 18 | from im2txt.inference_utils import vocabulary 19 | 20 | FLAGS = tf.flags.FLAGS 21 | 22 | tf.flags.DEFINE_string('job_dir', 'obj2sen', 'job dir') 23 | 24 | tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') 25 | 26 | tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') 27 | 28 | tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') 29 | 30 | tf.flags.DEFINE_integer('batch_size', 1, 'batch size') 31 | 32 | tf.flags.DEFINE_string("vocab_file", "data/word_counts.txt", 33 | "Text file containing the vocabulary.") 34 | 35 | tf.flags.DEFINE_integer('beam_size', 3, 'beam size') 36 | 37 | tf.flags.DEFINE_integer('max_caption_length', 20, 'beam size') 38 | 39 | tf.flags.DEFINE_float('length_normalization_factor', 0.0, 'l n f') 40 | 41 | 42 | def _tower_fn(key, lk, is_training=False): 43 | with tf.variable_scope('Discriminator'): 44 | embedding = tf.get_variable( 45 | name='embedding', 46 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 47 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 48 | 49 | key = tf.nn.embedding_lookup(embedding, key) 50 | 51 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 52 | if is_training: 53 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, 54 | FLAGS.keep_prob) 55 | out, initial_state = tf.nn.dynamic_rnn(cell, key, lk, dtype=tf.float32) 56 | 57 | feat = tf.nn.l2_normalize(initial_state[1], axis=1) 58 | 59 | with tf.variable_scope('Generator'): 60 | w = tf.get_variable( 61 | name='embedding', 62 | shape=[FLAGS.vocab_size, FLAGS.emb_dim], 63 | initializer=tf.random_uniform_initializer(-0.08, 0.08)) 64 | softmax_w = tf.matrix_transpose(w) 65 | softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) 66 | 67 | cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) 68 | if is_training: 69 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, 70 | FLAGS.keep_prob) 71 | zero_state = cell.zero_state(FLAGS.batch_size, tf.float32) 72 | _, state = cell(feat, zero_state) 73 | init_state = state 74 | tf.get_variable_scope().reuse_variables() 75 | 76 | state_feed = tf.placeholder(dtype=tf.float32, 77 | shape=[None, sum(cell.state_size)], 78 | name="state_feed") 79 | state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1) 80 | input_feed = tf.placeholder(dtype=tf.int64, 81 | shape=[None], # batch_size 82 | name="input_feed") 83 | inputs = tf.nn.embedding_lookup(embedding, input_feed) 84 | out, state_tuple = cell(inputs, state_tuple) 85 | tf.concat(axis=1, values=state_tuple, name="state") 86 | 87 | logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) 88 | tower_pred = tf.nn.softmax(logits, name="softmax") 89 | return tf.concat(init_state, axis=1, name='initial_state') 90 | 91 | 92 | class Infer: 93 | 94 | def __init__(self, job_dir=FLAGS.job_dir): 95 | key_inp = tf.placeholder(tf.int32, [None]) 96 | lk = tf.shape(key_inp)[0] 97 | key = tf.expand_dims(key_inp, axis=0) 98 | lk = tf.expand_dims(lk, axis=0) 99 | initial_state_op = _tower_fn(key, lk) 100 | 101 | vocab = vocabulary.Vocabulary(FLAGS.vocab_file) 102 | self.saver = tf.train.Saver() 103 | 104 | self.key_inp = key_inp 105 | self.init_state = initial_state_op 106 | self.vocab = vocab 107 | config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 108 | self.sess = tf.Session(config=config) 109 | 110 | self.restore_fn(job_dir) 111 | self.tf = tf 112 | 113 | def restore_fn(self, checkpoint_path): 114 | if tf.gfile.IsDirectory(checkpoint_path): 115 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) 116 | if checkpoint_path: 117 | self.saver.restore(self.sess, checkpoint_path) 118 | else: 119 | self.sess.run(tf.global_variables_initializer()) 120 | 121 | def infer(self, key_words): 122 | vocab = self.vocab 123 | sess = self.sess 124 | key_inp = self.key_inp 125 | initial_state_op = self.init_state 126 | if key_words.size > 5: 127 | key_words = key_words[-5:] 128 | 129 | initial_state = sess.run(initial_state_op, feed_dict={key_inp: key_words}) 130 | 131 | initial_beam = Caption( 132 | sentence=[vocab.start_id], 133 | state=initial_state[0], 134 | logprob=0.0, 135 | score=0.0, 136 | metadata=[""]) 137 | partial_captions = TopN(FLAGS.beam_size) 138 | partial_captions.push(initial_beam) 139 | complete_captions = TopN(FLAGS.beam_size) 140 | 141 | # Run beam search. 142 | for _ in range(FLAGS.max_caption_length - 1): 143 | partial_captions_list = partial_captions.extract() 144 | partial_captions.reset() 145 | input_feed = np.array([c.sentence[-1] for c in partial_captions_list]) 146 | state_feed = np.array([c.state for c in partial_captions_list]) 147 | 148 | softmax, new_states = sess.run( 149 | fetches=["Generator/softmax:0", "Generator/state:0"], 150 | feed_dict={ 151 | "Generator/input_feed:0": input_feed, 152 | "Generator/state_feed:0": state_feed, 153 | }) 154 | metadata = None 155 | 156 | for i, partial_caption in enumerate(partial_captions_list): 157 | word_probabilities = softmax[i] 158 | state = new_states[i] 159 | # For this partial caption, get the beam_size most probable next words. 160 | words_and_probs = list(enumerate(word_probabilities)) 161 | words_and_probs.sort(key=lambda x: -x[1]) 162 | words_and_probs = words_and_probs[0:FLAGS.beam_size] 163 | # Each next word gives a new partial caption. 164 | for w, p in words_and_probs: 165 | if p < 1e-12: 166 | continue # Avoid log(0). 167 | sentence = partial_caption.sentence + [w] 168 | logprob = partial_caption.logprob + math.log(p) 169 | score = logprob 170 | if metadata: 171 | metadata_list = partial_caption.metadata + [metadata[i]] 172 | else: 173 | metadata_list = None 174 | if w == vocab.end_id: 175 | if FLAGS.length_normalization_factor > 0: 176 | score /= len(sentence) ** FLAGS.length_normalization_factor 177 | beam = Caption(sentence, state, logprob, score, metadata_list) 178 | complete_captions.push(beam) 179 | else: 180 | beam = Caption(sentence, state, logprob, score, metadata_list) 181 | partial_captions.push(beam) 182 | if partial_captions.size() == 0: 183 | # We have run out of partial candidates; happens when beam_size = 1. 184 | break 185 | 186 | # If we have no complete captions then fall back to the partial captions. 187 | # But never output a mixture of complete and partial captions because a 188 | # partial caption could have a higher score than all the complete captions. 189 | if not complete_captions.size(): 190 | complete_captions = partial_captions 191 | 192 | captions = complete_captions.extract(sort=True) 193 | ret = [] 194 | for i, caption in enumerate(captions): 195 | # Ignore begin and end words. 196 | sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]] 197 | sentence = " ".join(sentence) 198 | # print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) 199 | ret.append((sentence, math.exp(caption.logprob))) 200 | return ret 201 | -------------------------------------------------------------------------------- /initialization/test_obj2sen.py: -------------------------------------------------------------------------------- 1 | """Evaluate the performance on test split.""" 2 | import json 3 | import os 4 | import sys 5 | 6 | import tensorflow as tf 7 | from absl import app 8 | from absl import flags 9 | from tqdm import tqdm 10 | 11 | from config import COCO_PATH 12 | from eval_obj2sen import parse_image 13 | from sentence_infer import Infer 14 | 15 | sys.path.insert(0, COCO_PATH) 16 | from pycocotools.coco import COCO 17 | from pycocoevalcap.eval import COCOEvalCap 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | 22 | def main(_): 23 | infer = Infer() 24 | 25 | with open(COCO_PATH + '/annotations/captions_val2014.json') as g: 26 | caption_data = json.load(g) 27 | name_to_id = [(x['file_name'], x['id']) for x in caption_data['images']] 28 | name_to_id = dict(name_to_id) 29 | 30 | ret = [] 31 | with tf.Graph().as_default(), tf.Session() as sess: 32 | example = tf.placeholder(tf.string, []) 33 | name_op, class_op, _ = parse_image(example) 34 | for i in tqdm(tf.io.tf_record_iterator('data/image_test.tfrec'), 35 | total=5000): 36 | name, classes = sess.run([name_op, class_op], feed_dict={example: i}) 37 | sentences = infer.infer(classes[::-1]) 38 | cur = {} 39 | cur['image_id'] = name_to_id[name] 40 | cur['caption'] = sentences[0][0] 41 | ret.append(cur) 42 | 43 | if os.path.isdir(FLAGS.job_dir): 44 | out_dir = FLAGS.job_dir 45 | else: 46 | out_dir = os.path.split(FLAGS.job_dir)[0] 47 | out = out_dir + '/test.json' 48 | with open(out, 'w') as g: 49 | json.dump(ret, g) 50 | 51 | coco = COCO(COCO_PATH + '/annotations/captions_val2014.json') 52 | cocoRes = coco.loadRes(out) 53 | 54 | # create cocoEval object by taking coco and cocoRes 55 | cocoEval = COCOEvalCap(coco, cocoRes) 56 | 57 | # evaluate on a subset of images by setting 58 | # cocoEval.params['image_id'] = cocoRes.getImgIds() 59 | # please remove this line when evaluating the full validation set 60 | cocoEval.params['image_id'] = cocoRes.getImgIds() 61 | 62 | # evaluate results 63 | cocoEval.evaluate() 64 | 65 | # print output evaluation scores 66 | for metric, score in cocoEval.eval.items(): 67 | print('%s: %.3f' % (metric, score)) 68 | 69 | 70 | if __name__ == '__main__': 71 | app.run(main) 72 | -------------------------------------------------------------------------------- /input_pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from misc_fn import controlled_shuffle 3 | from misc_fn import random_drop 4 | 5 | FLAGS = tf.flags.FLAGS 6 | 7 | AUTOTUNE = tf.data.experimental.AUTOTUNE 8 | 9 | 10 | def batching_func(x, batch_size): 11 | """Forms a batch with dynamic padding.""" 12 | return x.padded_batch( 13 | batch_size, 14 | padded_shapes=(( 15 | tf.TensorShape([299, 299, 3]), 16 | tf.TensorShape([None]), 17 | tf.TensorShape([None]), 18 | tf.TensorShape([])), 19 | (tf.TensorShape([None]), 20 | tf.TensorShape([]), 21 | tf.TensorShape([None]), 22 | tf.TensorShape([]))), 23 | drop_remainder=True) 24 | 25 | 26 | def preprocess_image(encoded_image, classes, scores): 27 | """Decodes an image.""" 28 | image = tf.image.decode_jpeg(encoded_image, 3) 29 | image = tf.image.convert_image_dtype(image, tf.float32) 30 | image = tf.image.resize_images(image, [346, 346]) 31 | image = tf.random_crop(image, [299, 299, 3]) 32 | image = image * 2 - 1 33 | return image, classes, scores, tf.shape(classes)[0] 34 | 35 | 36 | def parse_image(serialized): 37 | """Parses a tensorflow.SequenceExample into an image and detected objects. 38 | 39 | Args: 40 | serialized: A scalar string Tensor; a single serialized SequenceExample. 41 | 42 | Returns: 43 | encoded_image: A scalar string Tensor containing a JPEG encoded image. 44 | classes: A 1-D int64 Tensor containing the detected objects. 45 | scores: A 1-D float32 Tensor containing the detection scores. 46 | """ 47 | context, sequence = tf.parse_single_sequence_example( 48 | serialized, 49 | context_features={ 50 | 'image/data': tf.FixedLenFeature([], dtype=tf.string) 51 | }, 52 | sequence_features={ 53 | 'classes': tf.FixedLenSequenceFeature([], dtype=tf.int64), 54 | 'scores': tf.FixedLenSequenceFeature([], dtype=tf.float32), 55 | }) 56 | 57 | encoded_image = context['image/data'] 58 | classes = tf.to_int32(sequence['classes']) 59 | scores = sequence['scores'] 60 | return encoded_image, classes, scores 61 | 62 | 63 | def parse_sentence(serialized): 64 | """Parses a tensorflow.SequenceExample into an caption. 65 | 66 | Args: 67 | serialized: A scalar string Tensor; a single serialized SequenceExample. 68 | 69 | Returns: 70 | key: The keywords in a sentence. 71 | num_key: The number of keywords. 72 | sentence: A description. 73 | sentence_length: The length of the description. 74 | """ 75 | context, sequence = tf.parse_single_sequence_example( 76 | serialized, 77 | context_features={}, 78 | sequence_features={ 79 | 'sentence': tf.FixedLenSequenceFeature([], dtype=tf.int64), 80 | }) 81 | sentence = tf.to_int32(sequence['sentence']) 82 | key = controlled_shuffle(sentence[1:-1]) 83 | key = random_drop(key) 84 | key = tf.concat([key, [FLAGS.end_id]], axis=0) 85 | return key, tf.shape(key)[0], sentence, tf.shape(sentence)[0] 86 | 87 | 88 | def input_fn(batch_size): 89 | """Input function.""" 90 | image_ds = tf.data.TFRecordDataset('data/image_train.tfrec') 91 | image_ds = image_ds.map(parse_image, num_parallel_calls=AUTOTUNE) 92 | image_ds = image_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE) 93 | image_ds = image_ds.shuffle(8192).repeat() 94 | 95 | sentence_ds = tf.data.TFRecordDataset('data/sentence.tfrec') 96 | sentence_ds = sentence_ds.map(parse_sentence, num_parallel_calls=AUTOTUNE) 97 | sentence_ds = sentence_ds.shuffle(65536).repeat() 98 | 99 | dataset = tf.data.Dataset.zip((image_ds, sentence_ds)) 100 | 101 | dataset = batching_func(dataset, batch_size) 102 | dataset = dataset.prefetch(AUTOTUNE) 103 | iterator = dataset.make_one_shot_iterator() 104 | image, sentence = iterator.get_next() 105 | im, classes, scores, num = image 106 | key, lk, sentence, ls = sentence 107 | return {'im': im, 'classes': classes, 'scores': scores, 'num': num, 108 | 'key': key, 'lk': lk}, {'sentence': sentence, 'len': ls} 109 | -------------------------------------------------------------------------------- /misc_fn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def get_len(sequence, end): 5 | """Gets the length of a generated caption. 6 | 7 | Args: 8 | sequence: A tensor of size [batch, max_length]. 9 | end: The token. 10 | 11 | Returns: 12 | length: The length of each caption. 13 | """ 14 | 15 | def body(x): 16 | idx = tf.to_int32(tf.where(tf.equal(x, end))) 17 | idx = tf.cond(tf.shape(idx)[0] > 0, lambda: idx[0] + 1, lambda: tf.shape(x)) 18 | return idx[0] 19 | 20 | length = tf.map_fn(body, sequence, tf.int32) 21 | return length 22 | 23 | 24 | def variable_summaries(var, mask, name): 25 | """Attaches a lot of summaries to a Tensor. 26 | 27 | Args: 28 | var: A tensor to summary. 29 | mask: The mask indicating the valid elements in var. 30 | name: The name of the tensor in summary. 31 | """ 32 | var = tf.boolean_mask(var, mask) 33 | mean = tf.reduce_mean(var) 34 | tf.summary.scalar('mean/' + name, mean) 35 | with tf.name_scope('stddev'): 36 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) 37 | tf.summary.scalar('sttdev/' + name, stddev) 38 | tf.summary.scalar('max/' + name, tf.reduce_max(var)) 39 | tf.summary.scalar('min/' + name, tf.reduce_min(var)) 40 | tf.summary.histogram(name, var) 41 | 42 | 43 | def transform_grads_fn(grads): 44 | """Gradient clip.""" 45 | grads, vars = zip(*grads) 46 | grads, _ = tf.clip_by_global_norm(grads, 10) 47 | return list(zip(grads, vars)) 48 | 49 | 50 | def crop_sentence(sentence, end): 51 | """Sentence cropping for logging. Remove the tokens after .""" 52 | idx = tf.to_int32(tf.where(tf.equal(sentence, end))) 53 | idx = tf.cond(tf.shape(idx)[0] > 0, lambda: idx[0] + 1, 54 | lambda: tf.shape(sentence)) 55 | sentence = sentence[:idx[0]] 56 | return sentence 57 | 58 | 59 | def validate_batch_size_for_multi_gpu(batch_size): 60 | """For multi-gpu, batch-size must be a multiple of the number of GPUs. 61 | 62 | Note that this should eventually be handled by replicate_model_fn 63 | directly. Multi-GPU support is currently experimental, however, 64 | so doing the work here until that feature is in place. 65 | 66 | Args: 67 | batch_size: the number of examples processed in each training batch. 68 | 69 | Raises: 70 | ValueError: if no GPUs are found, or selected batch_size is invalid. 71 | """ 72 | from tensorflow.python.client import \ 73 | device_lib # pylint: disable=g-import-not-at-top 74 | 75 | local_device_protos = device_lib.list_local_devices() 76 | num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU']) 77 | if not num_gpus: 78 | raise ValueError('Multi-GPU mode was specified, but no GPUs ' 79 | 'were found. To use CPU, run without --multi_gpu.') 80 | 81 | remainder = batch_size % num_gpus 82 | if remainder: 83 | err = ('When running with multiple GPUs, batch size ' 84 | 'must be a multiple of the number of available GPUs. ' 85 | 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.' 86 | ).format(num_gpus, batch_size, batch_size - remainder) 87 | raise ValueError(err) 88 | 89 | 90 | def find_obj(sentence, s_mask, classes, scores, num): 91 | """Computes the object reward for one sentence.""" 92 | shape = tf.shape(sentence) 93 | sentence = tf.boolean_mask(sentence, s_mask) 94 | 95 | def body(x): 96 | idx = tf.to_int32(tf.where(tf.equal(sentence, x))) 97 | idx = tf.cond(tf.shape(idx)[0] > 0, lambda: idx[0, 0], 98 | lambda: tf.constant(999, tf.int32)) 99 | return idx 100 | 101 | classes = classes[:num] 102 | scores = scores[:num] 103 | ind = tf.map_fn(body, classes, tf.int32) 104 | mask = tf.not_equal(ind, 999) 105 | miss, detected = tf.dynamic_partition(scores, tf.to_int32(mask), 2) 106 | ind = tf.boolean_mask(ind, mask) 107 | ret = tf.scatter_nd(tf.expand_dims(ind, 1), detected, shape) 108 | return ret 109 | 110 | 111 | def obj_rewards(sentence, mask, classes, scores, num): 112 | """Computes the object reward. 113 | 114 | Args: 115 | sentence: A tensor of size [batch, max_length]. 116 | mask: The mask indicating the valid elements in sentence. 117 | classes: [batch, padded_size] int32 tensor of detected objects. 118 | scores: [batch, padded_size] float32 tensor of detection scores. 119 | num: [batch] int32 tensor of number of detections. 120 | 121 | Returns: 122 | rewards: [batch, max_length] float32 tensor of rewards. 123 | """ 124 | 125 | def body(x): 126 | ret = find_obj(x[0], x[1], x[2], x[3], x[4]) 127 | return ret 128 | 129 | rewards = tf.map_fn(body, [sentence, mask, classes, scores, num], tf.float32) 130 | return rewards 131 | 132 | 133 | def random_drop(sentence): 134 | """Randomly drops some tokens.""" 135 | length = tf.shape(sentence)[0] 136 | rnd = tf.random_uniform([length]) + 0.9 137 | mask = tf.cast(tf.floor(rnd), tf.bool) 138 | sentence = tf.boolean_mask(sentence, mask) 139 | return sentence 140 | 141 | 142 | def controlled_shuffle(sentence, d=3.0): 143 | """Shuffles the sentence as described in https://arxiv.org/abs/1711.00043""" 144 | length = tf.shape(sentence)[0] 145 | rnd = tf.random_uniform([length]) * (d + 1) + tf.to_float(tf.range(length)) 146 | _, idx = tf.nn.top_k(rnd, length) 147 | idx = tf.reverse(idx, axis=[0]) 148 | sentence = tf.gather(sentence, idx) 149 | return sentence 150 | 151 | 152 | def _int64_feature(value): 153 | """Wrapper for inserting an int64 Feature into a SequenceExample proto.""" 154 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 155 | 156 | 157 | def _int64_feature_list(values): 158 | """Wrapper for inserting an int64 FeatureList into a SequenceExample proto.""" 159 | return tf.train.FeatureList(feature=[_int64_feature(v) for v in values]) 160 | 161 | 162 | def _float_feature(value): 163 | """Wrapper for inserting an float Feature into a SequenceExample proto.""" 164 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 165 | 166 | 167 | def _float_feature_list(values): 168 | """Wrapper for inserting an float FeatureList into a SequenceExample proto.""" 169 | return tf.train.FeatureList(feature=[_float_feature(v) for v in values]) 170 | 171 | 172 | def _bytes_feature(value): 173 | """Wrapper for inserting a bytes Feature into a SequenceExample proto.""" 174 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)])) 175 | -------------------------------------------------------------------------------- /preprocessing/crawl_descriptions.py: -------------------------------------------------------------------------------- 1 | """Crawl Shutterstock image descriptions.""" 2 | import json 3 | import os 4 | import re 5 | import time 6 | import urllib.request 7 | from multiprocessing import Pool 8 | 9 | from absl import app 10 | from absl import flags 11 | 12 | flags.DEFINE_string('data_dir', 'data/coco', 'data directory') 13 | 14 | flags.DEFINE_integer('num_pages', 1000, 'number of images') 15 | 16 | flags.DEFINE_integer('num_processes', 16, 'number of processes') 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | url = ('https://www.shutterstock.com/search?language=en&image_type=photo&' 21 | 'searchterm=%s&page=%d') 22 | pattern = '' 23 | 24 | 25 | class Downloader(object): 26 | 27 | def __init__(self, label): 28 | self.label = label 29 | 30 | def __call__(self, page_id): 31 | attempt = 0 32 | while attempt < 5: 33 | req = urllib.request.Request(url % (self.label, page_id), 34 | headers={'User-Agent': "Magic Browser"}) 35 | with urllib.request.urlopen(req) as f: 36 | page = f.read() 37 | page = page.decode('utf-8') 38 | obj = re.search(pattern, page) 39 | if obj is None: 40 | time.sleep(5) 41 | attempt += 1 42 | else: 43 | break 44 | if obj is None: 45 | images = [] 46 | else: 47 | ret = obj.group(1) 48 | images = eval(ret) 49 | return page_id, images 50 | 51 | 52 | def get_num_pages(label): 53 | req = urllib.request.Request( 54 | url % (label, 1), 55 | headers={'User-Agent': 'Magic Browser'}) 56 | with urllib.request.urlopen(req) as f: 57 | page = f.read() 58 | page = page.decode('utf-8') 59 | obj = re.search('data-max="(\d*)"', page) 60 | num_pages = int(obj.group(1)) 61 | return num_pages 62 | 63 | 64 | def download(data_dir, num_pages, id, label): 65 | output = data_dir + '/%04d.json' % id 66 | if os.path.exists(output): 67 | with open(output, 'r') as f: 68 | images = json.load(f) 69 | print(label, len(images)) 70 | # print empty pages 71 | page_nums = [int(k) for k, v in images.items() if len(v) == 0] 72 | page_nums.sort() 73 | print(len(page_nums), page_nums) 74 | else: 75 | images = {} 76 | all_pages = get_num_pages(label) 77 | print(label, all_pages, 'pages available.') 78 | page_nums = list(range(1, min(num_pages, all_pages) + 1)) 79 | 80 | pool = Pool(FLAGS.num_processes) 81 | pages = pool.map(Downloader(label), page_nums) 82 | pages = [(str(i[0]), i[1]) for i in pages] 83 | pool.close() 84 | pool.join() 85 | if len(pages) > 0: 86 | images.update(dict(pages)) 87 | with open(output, 'w') as f: 88 | json.dump(images, f) 89 | 90 | 91 | def main(_): 92 | with open(FLAGS.data_dir + '/coco.names', 'r') as f: 93 | classes = list(f) 94 | classes = [i.strip().replace(' ', '+') for i in classes] 95 | for i, c in enumerate(classes): 96 | download(FLAGS.data_dir, FLAGS.num_pages, i, c) 97 | 98 | 99 | if __name__ == '__main__': 100 | app.run(main) 101 | -------------------------------------------------------------------------------- /preprocessing/detect_objects.py: -------------------------------------------------------------------------------- 1 | """Detect objects using a model pretrained on OpenImage.""" 2 | import multiprocessing 3 | import os 4 | 5 | import h5py 6 | import numpy as np 7 | from PIL import Image 8 | from absl import app 9 | from absl import flags 10 | from tqdm import tqdm 11 | 12 | from config import TF_MODELS_PATH 13 | 14 | flags.DEFINE_string('image_path', None, 'data dir') 15 | 16 | flags.DEFINE_integer('num_proc', 1, 'number of process') 17 | 18 | flags.DEFINE_integer('num_gpus', 4, 'number of gpus to use') 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | def load_image_into_numpy_array(image): 24 | if image.mode != 'RGB': 25 | image = image.convert('RGB') 26 | (im_width, im_height) = image.size 27 | return np.array(image.getdata()).reshape( 28 | (im_height, im_width, 3)).astype(np.uint8) 29 | 30 | 31 | def initializer(): 32 | import tensorflow as tf 33 | current = multiprocessing.current_process() 34 | id = current._identity[0] - 1 35 | os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % (id % FLAGS.num_gpus) 36 | 37 | model_name = 'faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28' 38 | path_to_ckpt = (TF_MODELS_PATH + '/research/object_detection/' + model_name 39 | + '/frozen_inference_graph.pb') 40 | 41 | detection_graph = tf.Graph() 42 | with detection_graph.as_default(): 43 | od_graph_def = tf.GraphDef() 44 | with tf.gfile.GFile(path_to_ckpt, 'rb') as fid: 45 | od_graph_def.ParseFromString(fid.read()) 46 | tf.import_graph_def(od_graph_def, name='') 47 | 48 | global sess, tensor_dict, image_tensor 49 | with detection_graph.as_default(): 50 | sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( 51 | allow_growth=True))) 52 | ops = tf.get_default_graph().get_operations() 53 | all_tensor_names = {output.name for op in ops for output in op.outputs} 54 | tensor_dict = {} 55 | for key in [ 56 | 'num_detections', 'detection_boxes', 'detection_scores', 57 | 'detection_classes', 'detection_masks' 58 | ]: 59 | tensor_name = key + ':0' 60 | if tensor_name in all_tensor_names: 61 | tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( 62 | tensor_name) 63 | image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') 64 | 65 | 66 | def run(i): 67 | global sess, tensor_dict, image_tensor 68 | image_path = FLAGS.image_path + '/' + i.strip() 69 | image = Image.open(image_path) 70 | image = load_image_into_numpy_array(image) 71 | image_np_expanded = np.expand_dims(image, axis=0) 72 | output_dict = sess.run(tensor_dict, 73 | feed_dict={image_tensor: image_np_expanded}) 74 | return i, output_dict 75 | 76 | 77 | def main(_): 78 | pool = multiprocessing.Pool(FLAGS.num_proc, initializer) 79 | with open('data/coco_train.txt', 'r') as f: 80 | train_images = list(f) 81 | with open('data/coco_val.txt', 'r') as f: 82 | val_images = list(f) 83 | with open('data/coco_test.txt', 'r') as f: 84 | test_images = list(f) 85 | all_images = train_images + val_images + test_images 86 | with h5py.File('data/object.hdf5', 'w') as f: 87 | for ret in tqdm(pool.imap_unordered(run, all_images), 88 | total=len(all_images)): 89 | name = os.path.splitext(ret[0])[0] 90 | g = f.create_group(name) 91 | output_dict = ret[1] 92 | n = int(output_dict['num_detections']) 93 | del output_dict['num_detections'] 94 | for k, v in output_dict.items(): 95 | g.create_dataset(k, data=v[0, :n]) 96 | 97 | 98 | if __name__ == '__main__': 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /preprocessing/extract_descriptions.py: -------------------------------------------------------------------------------- 1 | """Extract image descriptions from the downloaded files.""" 2 | import cPickle as pkl 3 | import glob 4 | import json 5 | import sys 6 | from multiprocessing import Pool 7 | from unicodedata import normalize 8 | 9 | from absl import app 10 | from absl import flags 11 | from tqdm import tqdm 12 | 13 | from config import TF_MODELS_PATH 14 | 15 | sys.path.insert(0, TF_MODELS_PATH + '/research/im2txt/im2txt') 16 | from data.build_mscoco_data import _process_caption 17 | 18 | flags.DEFINE_string('data_dir', 'data/coco', 'data directory') 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | def main(_): 24 | s = set() 25 | files = glob.glob(FLAGS.data_dir + '/*.json') 26 | files.sort() 27 | for i in tqdm(files): 28 | with open(i, 'r') as g: 29 | data = json.load(g) 30 | for k, v in data.items(): 31 | for j in v: 32 | if 'description' in j: 33 | c = normalize('NFKD', j['description']).encode('ascii', 'ignore') 34 | c = c.split('\n') 35 | s.update(c) 36 | 37 | pool = Pool() 38 | captions = pool.map(_process_caption, list(s)) 39 | pool.close() 40 | pool.join() 41 | # There is a sos and eos in each caption, so the actual length is at least 8. 42 | captions = [i for i in captions if len(i) >= 10] 43 | print('%s captions parsed' % len(captions)) 44 | with open('data/sentences.pkl', 'w') as f: 45 | pkl.dump(captions, f) 46 | 47 | 48 | if __name__ == '__main__': 49 | app.run(main) 50 | -------------------------------------------------------------------------------- /preprocessing/process_descriptions.py: -------------------------------------------------------------------------------- 1 | """Convert the descriptions to tfrecords.""" 2 | import cPickle as pkl 3 | import json 4 | import os 5 | import random 6 | import re 7 | import sys 8 | from urllib2 import Request 9 | from urllib2 import urlopen 10 | 11 | import tensorflow as tf 12 | from absl import app 13 | from absl import flags 14 | from tqdm import tqdm 15 | 16 | from config import TF_MODELS_PATH 17 | from misc_fn import _int64_feature_list 18 | 19 | sys.path.insert(0, TF_MODELS_PATH + '/research/im2txt/im2txt') 20 | sys.path.append(TF_MODELS_PATH + '/research') 21 | sys.path.append(TF_MODELS_PATH + '/research/object_detection') 22 | from data.build_mscoco_data import _create_vocab 23 | from inference_utils import vocabulary 24 | from utils import label_map_util 25 | 26 | tf.enable_eager_execution() 27 | 28 | flags.DEFINE_bool('new_dict', False, 'generate a new dict') 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def get_plural(word): 34 | c = re.compile('Noun

\(.*plural ([^\)]+)\)') 35 | req = Request('https://www.yourdictionary.com/' + word, headers={ 36 | 'User-Agent': 'Magic Browser'}) 37 | f = urlopen(req) 38 | html = f.read() 39 | f.close() 40 | html = html.decode('utf-8') 41 | plural_word = c.findall(html) 42 | if plural_word: 43 | plural_word = plural_word[0] 44 | plural_word = plural_word.lower() 45 | elif 'Noun

(plural only)' in html: 46 | plural_word = word 47 | else: 48 | plural_word = word 49 | if word[-1] != 's': 50 | plural_word += 's' 51 | return plural_word 52 | 53 | 54 | def get_open_image_categories(): 55 | path_to_labels = (TF_MODELS_PATH + '/research/object_detection/data/' 56 | 'oid_bbox_trainable_label_map.pbtxt') 57 | category_index = label_map_util.create_category_index_from_labelmap( 58 | path_to_labels, 59 | use_display_name=True) 60 | categories = dict([(v['id'], str(v['name'].lower()).split()[-1]) for k, v in 61 | category_index.items()]) 62 | category_name = list(set(categories.values())) 63 | category_name.sort() 64 | plural_file = 'data/plural_words.json' 65 | if os.path.exists(plural_file): 66 | with open(plural_file, 'r') as f: 67 | plural_dict = json.load(f) 68 | plural_name = [plural_dict[i] for i in category_name] 69 | else: 70 | plural_name = [] 71 | for i in tqdm(category_name): 72 | plural_name.append(get_plural(i)) 73 | with open(plural_file, 'w') as f: 74 | json.dump(dict(zip(category_name, plural_name)), f) 75 | return category_name, plural_name, categories 76 | 77 | 78 | def parse_key_words(caption, dic): 79 | key_words = dic.intersection(caption) 80 | return key_words 81 | 82 | 83 | def sentence_generator(): 84 | category_name, plural_name, categories = get_open_image_categories() 85 | replace = dict(zip(plural_name, category_name)) 86 | category_set = set(category_name) 87 | 88 | with open('data/sentences.pkl', 'r') as f: 89 | captions = pkl.load(f) 90 | 91 | if FLAGS.new_dict: 92 | _create_vocab(captions) 93 | with open('data/glove_vocab.pkl', 'r') as f: 94 | glove = pkl.load(f) 95 | glove.append('') 96 | glove.append('') 97 | glove = set(glove) 98 | with open(FLAGS.word_counts_output_file, 'r') as f: 99 | vocab = list(f) 100 | vocab = [i.strip() for i in vocab] 101 | vocab = [i.split() for i in vocab] 102 | vocab = [(i, int(j)) for i, j in vocab if i in glove] 103 | word_counts = [i for i in vocab if i[0] in category_set or i[1] >= 40] 104 | words = set([i[0] for i in word_counts]) 105 | for i in category_name: 106 | if i not in words: 107 | word_counts.append((i, 0)) 108 | with open(FLAGS.word_counts_output_file, 'w') as f: 109 | f.write('\n'.join(['%s %d' % (w, c) for w, c in word_counts])) 110 | 111 | vocab = vocabulary.Vocabulary(FLAGS.word_counts_output_file) 112 | 113 | all_ids = dict([(k, vocab.word_to_id(v)) for k, v in categories.items()]) 114 | with open('data/all_ids.pkl', 'w') as f: 115 | pkl.dump(all_ids, f) 116 | 117 | context = tf.train.Features() 118 | random.shuffle(captions) 119 | for c in captions: 120 | for i, w in enumerate(c): 121 | if w in replace: 122 | c[i] = replace[w] 123 | k = parse_key_words(c, category_set) 124 | c = [vocab.word_to_id(word) for word in c] 125 | if c.count(vocab.unk_id) > len(c) * 0.15: 126 | continue 127 | k = [vocab.word_to_id(i) for i in k] 128 | feature_lists = tf.train.FeatureLists(feature_list={ 129 | 'key': _int64_feature_list(k), 130 | 'sentence': _int64_feature_list(c) 131 | }) 132 | sequence_example = tf.train.SequenceExample( 133 | context=context, feature_lists=feature_lists) 134 | yield sequence_example.SerializeToString() 135 | 136 | 137 | def main(_): 138 | ds = tf.data.Dataset.from_generator(sentence_generator, 139 | output_types=tf.string, output_shapes=()) 140 | tfrec = tf.data.experimental.TFRecordWriter('data/sentence.tfrec') 141 | tfrec.write(ds) 142 | 143 | 144 | if __name__ == '__main__': 145 | app.run(main) 146 | -------------------------------------------------------------------------------- /preprocessing/process_images.py: -------------------------------------------------------------------------------- 1 | """Convert the images and the detected object labels to tfrecords.""" 2 | import os 3 | import pickle as pkl 4 | import random 5 | 6 | import h5py 7 | import numpy as np 8 | import tensorflow as tf 9 | from absl import app 10 | from absl import flags 11 | from functools import partial 12 | 13 | from misc_fn import _bytes_feature 14 | from misc_fn import _float_feature_list 15 | from misc_fn import _int64_feature_list 16 | 17 | tf.enable_eager_execution() 18 | 19 | flags.DEFINE_string('image_path', None, 'Path to all coco images.') 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | 24 | def image_generator(split): 25 | with open('data/coco_%s.txt' % split, 'r') as f: 26 | filename = list(f) 27 | filename = [i.strip() for i in filename] 28 | if split == 'train': 29 | random.shuffle(filename) 30 | with open('data/all_ids.pkl', 'r') as f: 31 | all_ids = pkl.load(f) 32 | with h5py.File('data/object.hdf5', 'r') as f: 33 | for i in filename: 34 | name = os.path.splitext(i)[0] 35 | detection_classes = f[name + '/detection_classes'][:].astype(np.int32) 36 | detection_scores = f[name + '/detection_scores'][:] 37 | detection_classes, ind = np.unique(detection_classes, return_index=True) 38 | detection_scores = detection_scores[ind] 39 | detection_classes = [all_ids[j] for j in detection_classes] 40 | image_path = FLAGS.image_path + '/' + i 41 | with tf.gfile.FastGFile(image_path, 'r') as g: 42 | image = g.read() 43 | context = tf.train.Features(feature={ 44 | 'image/name': _bytes_feature(i), 45 | 'image/data': _bytes_feature(image), 46 | }) 47 | feature_lists = tf.train.FeatureLists(feature_list={ 48 | 'classes': _int64_feature_list(detection_classes), 49 | 'scores': _float_feature_list(detection_scores) 50 | }) 51 | sequence_example = tf.train.SequenceExample( 52 | context=context, feature_lists=feature_lists) 53 | 54 | yield sequence_example.SerializeToString() 55 | 56 | 57 | def gen_tfrec(split): 58 | ds = tf.data.Dataset.from_generator(partial(image_generator, split=split), 59 | output_types=tf.string, output_shapes=()) 60 | tfrec = tf.data.experimental.TFRecordWriter('data/image_%s.tfrec' % split) 61 | tfrec.write(ds) 62 | 63 | 64 | def main(_): 65 | for i in ['train', 'val', 'test']: 66 | gen_tfrec(i) 67 | 68 | 69 | if __name__ == '__main__': 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsbeautifier 2 | matplotlib 3 | nltk 4 | scikit-image 5 | tensorflow-gpu==1.13.1 6 | tqdm 7 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | """Evaluate the performance on test split.""" 2 | import json 3 | import os 4 | import sys 5 | 6 | import cv2 7 | from absl import app 8 | from absl import flags 9 | from tqdm import tqdm 10 | 11 | from caption_infer import Infer 12 | from config import COCO_PATH 13 | 14 | sys.path.insert(0, COCO_PATH) 15 | from pycocotools.coco import COCO 16 | from pycocoevalcap.eval import COCOEvalCap 17 | 18 | flags.DEFINE_bool('vis', False, 'visulaize') 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | def main(_): 24 | infer = Infer() 25 | 26 | with open(COCO_PATH + '/annotations/captions_val2014.json') as g: 27 | caption_data = json.load(g) 28 | name_to_id = [(x['file_name'], x['id']) for x in caption_data['images']] 29 | name_to_id = dict(name_to_id) 30 | 31 | with open('data/coco_test.txt', 'r') as g: 32 | ret = [] 33 | for name in tqdm(g, total=5000): 34 | name = name.strip() 35 | sentences = infer.infer(name) 36 | cur = {} 37 | cur['image_id'] = name_to_id[name] 38 | cur['caption'] = sentences[0][0] 39 | ret.append(cur) 40 | if FLAGS.vis: 41 | im = cv2.imread(FLAGS.data_dir + name) 42 | print(sentences[0][0]) 43 | cv2.imshow('a', im) 44 | k = cv2.waitKey() 45 | if k & 0xff == 27: 46 | return 47 | 48 | if os.path.isdir(FLAGS.job_dir): 49 | out_dir = FLAGS.job_dir 50 | else: 51 | out_dir = os.path.split(FLAGS.job_dir)[0] 52 | out = out_dir + '/test.json' 53 | with open(out, 'w') as g: 54 | json.dump(ret, g) 55 | 56 | coco = COCO(COCO_PATH + '/annotations/captions_val2014.json') 57 | cocoRes = coco.loadRes(out) 58 | 59 | # create cocoEval object by taking coco and cocoRes 60 | cocoEval = COCOEvalCap(coco, cocoRes) 61 | 62 | # evaluate on a subset of images by setting 63 | # cocoEval.params['image_id'] = cocoRes.getImgIds() 64 | # please remove this line when evaluating the full validation set 65 | cocoEval.params['image_id'] = cocoRes.getImgIds() 66 | 67 | # evaluate results 68 | cocoEval.evaluate() 69 | 70 | # print output evaluation scores 71 | for metric, score in cocoEval.eval.items(): 72 | print('%s: %.3f' % (metric, score)) 73 | 74 | 75 | if __name__ == '__main__': 76 | app.run(main) 77 | --------------------------------------------------------------------------------