├── .gitignore ├── LICENSE ├── README.md ├── analyse_communication.ipynb ├── binary_vectors.py ├── misc.py ├── model.py ├── requirements.txt ├── sparks.py └── utils ├── build_datasets.sh ├── descriptions.csv ├── descriptions_mammals.csv ├── download_data.py ├── imagenet.synset ├── imgs └── .gitkeep └── package_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | utils/downloaded 2 | *.hdf5 3 | bin_vec 4 | conf_mat 5 | logs 6 | *.pyc 7 | urls 8 | .ipynb_checkpoints 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, New York University (Kyunghyun Cho) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultimodalGame 2 | 3 | Source code for [Emergent Communication in a Multi-Modal, Multi-Step Referential Game](https://arxiv.org/abs/1705.10369). 4 | 5 | ## Dependencies 6 | 7 | - Python2.7 8 | - Pytorch 9 | 10 | You should install Pytorch using instructions from [here](http://pytorch.org/). Otherwise, can install dependencies using pip: `pip install -r requirements.txt` 11 | 12 | ## Building the Datasets 13 | 14 | This model requires an hdf5 file containing image features and csv file containing class descriptions. To build such a dataset using images from Imagenet, you can simply run the following script: 15 | 16 | ``` 17 | cd ./utils 18 | bash build_datasets.sh 19 | ``` 20 | 21 | This will download image urls from Imagenet (~300mb compressed), save urls from 30 classes, split them into train/dev/test, download the relevant images, extract the necessary features using a pretrained ResNet-34, and build the descriptions file. 22 | 23 | This model also depends on pretrained word embeddings. We recommend using the `6B.100d` GloVe embeddings availabe [here](https://nlp.stanford.edu/projects/glove/). 24 | 25 | ## Running the Code 26 | 27 | Here is an example command for running the agents in an "Adaptive" setting, where the Receiver has the option to terminate the conversation and make a prediction before the maximum number of exchange steps have been exhausted. 28 | 29 | ``` 30 | python model.py \ 31 | -experiment_name demo \ # used to save various log files 32 | -exchange_samples 5 \ # print samples of the communication 33 | -model_type Adaptive \ # the receiver will determine when to stop the conversation 34 | -max_exchange 10 \ # max number of exchange steps in the agents' conversation 35 | -batch_size 64 \ 36 | -rec_w_dim 32 \ # message dimension of the receiver (this should match the sender) 37 | -sender_out_dim 32 \ # message dimension of the sender (this should match the receiver) 38 | -img_h_dim 256 \ # hidden dimension of the sender 39 | -rec_hidden 64 \ # hidden dimension of the receiver 40 | -learning_rate 1e-4 \ # learning rate for gradient descent 41 | -entropy_rec 0.01 \ # regularize the receiver's messages 42 | -entropy_sen 0.01 \ # regularize the sender's messages 43 | -entropy_s 0.08 \ # regularize the stop bit 44 | -use_binary \ # specify binary communication (continuous values are also an option) 45 | -max_epoch 500 \ # number of epochs to train 46 | -top_k_dev 6 \ # specify tok-k for dev 47 | -top_k_train 6 \ # specify top-k for train 48 | -descr_train ./utils/descriptions.csv \ 49 | -descr_dev ./utils/descriptions.csv \ 50 | -train_file ./utils/train.hdf5 \ 51 | -dev_file ./utils/dev.hdf5 \ 52 | -wv_dim 100 \ # dimension of word vector 53 | -glove_path ~/data/glove/glove.6B.100d.txt 54 | ``` 55 | 56 | ## Message Analysis 57 | 58 | After training a model, it's desirable to examine the binary messages used in the communication between the Sender and Receiver. These can be retrieved with a command along the lines of the following: 59 | 60 | ``` 61 | EXPERIMENT_NAME="demo"; \ 62 | python model.py \ 63 | -log_load ./logs/${EXPERIMENT_NAME}.json \ # load model configuration from here 64 | -binary_only \ # specify to only extract binary messages (`eval_only` is also an option) 65 | -experiment_name demo-binary \ # write output to a log file different from training 66 | -checkpoint ./logs/${EXPERIMENT_NAME}.pt_best \ # load this checkpoint 67 | -binary_output ./logs/${EXPERIMENT_NAME}.bv.hdf5 \ # save messages as an hdf5 68 | -fixed_exchange # use `fixed_exchange` since the adaptive length can be determined with the stop bits 69 | ``` 70 | 71 | We've included a notebook with a couple examples for how you might want to analyse the binary messages [here](https://github.com/nyu-dl/MultimodalGame/blob/master/analyse_communication.ipynb). 72 | -------------------------------------------------------------------------------- /binary_vectors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import sys 4 | import h5py 5 | import numpy as np 6 | 7 | import gflags 8 | 9 | FLAGS = gflags.FLAGS 10 | 11 | 12 | def extract_binary(FLAGS, load_hdf5, exchange, dev_file, batch_size, epoch, shuffle, cuda, top_k, 13 | sender, receiver, desc_dict, map_labels, file_name): 14 | sender_out_dim = FLAGS.sender_out_dim 15 | output_path = FLAGS.binary_output 16 | 17 | desc = desc_dict["desc"] 18 | desc_set = desc_dict.get("desc_set", None) 19 | desc_set_lens = desc_dict.get("desc_set_lens", None) 20 | 21 | # Create hdf5 binary vectors file 22 | bin_vec_file = h5py.File(output_path, "w") 23 | 24 | bin_vec_format = np.dtype([('ExampleId', np.str_, 50), 25 | ('AgentId', np.str_, 1), 26 | ('Index', 'i'), 27 | ('Target', 'i'), 28 | ('Rank', 'i'), 29 | ('BinaryProb', np.float32, (sender_out_dim, )), 30 | ('BinaryVec', np.float32, (sender_out_dim, ))]) 31 | bin_vec_communication = bin_vec_file.create_dataset("Communication", 32 | (0, ), maxshape=(None, ), dtype=bin_vec_format) 33 | 34 | # Create hdf5 predictions file 35 | preds_format = np.dtype([('ExampleId', np.str_, 50), 36 | ('AgentId', np.str_, 1), 37 | ('Index', 'i'), 38 | ('Target', 'i'), 39 | ('Rank', 'i'), 40 | ('Predictions', np.float32, (len(desc), )), 41 | ('StopProb', np.float32, (1, )), 42 | ('StopVec', np.float32, (1, )), 43 | ('StopMask', np.float32, (1, )), 44 | ]) 45 | preds_communication = bin_vec_file.create_dataset("Predictions", 46 | (0, ), maxshape=(None, ), dtype=preds_format) 47 | 48 | # Load development images 49 | dev_loader = load_hdf5(dev_file, batch_size, epoch, shuffle, 50 | truncate_final_batch=True, map_labels=map_labels) 51 | 52 | for batch in dev_loader: 53 | # Extract images and targets 54 | 55 | target = batch["target"] 56 | data = batch[FLAGS.img_feat] 57 | example_ids = batch["example_ids"] 58 | batch_size = target.size(0) 59 | 60 | # GPU support 61 | if cuda: 62 | data = data.cuda() 63 | target = target.cuda() 64 | desc = desc.cuda() 65 | 66 | exchange_args = dict() 67 | exchange_args["data"] = data 68 | if FLAGS.attn_extra_context: 69 | exchange_args["data_context"] = batch[FLAGS.data_context] 70 | exchange_args["target"] = target 71 | exchange_args["desc"] = desc 72 | exchange_args["desc_set"] = desc_set 73 | exchange_args["desc_set_lens"] = desc_set_lens 74 | exchange_args["train"] = False 75 | exchange_args["break_early"] = not FLAGS.fixed_exchange 76 | 77 | s, sen_w, rec_w, y, bs, br = exchange( 78 | sender, receiver, None, None, exchange_args) 79 | 80 | s_masks, s_feats, s_probs = s 81 | sen_feats, sen_probs = sen_w 82 | rec_feats, rec_probs = rec_w 83 | 84 | # TODO: Use masks. This can be tricky! 85 | timesteps = zip(sen_feats, sen_probs, rec_feats, 86 | rec_probs, y, s_feats, s_probs, s_masks) 87 | 88 | for i_exchange, (_z_binary, _z_probs, _w_binary, _w_probs, _y, _s_feats, _s_probs, _s_masks) in enumerate(timesteps): 89 | 90 | i_exchange_batch = np.full(batch_size, i_exchange, dtype=int) 91 | 92 | # Extract predictions and rank of target class. 93 | np_preds = _y.data.cpu().numpy() 94 | nclasses = np_preds.shape[1] 95 | target_set = set(target.tolist()) 96 | assert len( 97 | target_set) == 1, "Rank only works if there is one target" 98 | single_target = target[0] 99 | np_rank = np.abs(np_preds.argsort(1) - nclasses)[:, single_target] 100 | 101 | # Store Sender binary features and probabilities locally 102 | np_agent_ids = np.full(batch_size, 'S', dtype=np.dtype('S1')) 103 | np_index_sen = i_exchange_batch * 2 104 | np_target = target.cpu().numpy() 105 | np_probs = _z_probs.data.cpu().numpy() 106 | np_bin_vec = _z_binary.data.cpu().numpy() 107 | zipped = zip(example_ids, np_agent_ids, np_index_sen, 108 | np_target, np_rank, np_probs, np_bin_vec) 109 | bin_vec_communication.resize( 110 | bin_vec_communication.shape[0] + batch_size, axis=0) 111 | try: 112 | bin_vec_communication[-batch_size:] = zipped 113 | except: 114 | import ipdb 115 | ipdb.set_trace() 116 | 117 | # Store Receiver binary features and probabilities locally 118 | np_agent_ids = np.full(batch_size, 'R', dtype=np.dtype('S1')) 119 | np_index_rec = i_exchange_batch * 2 + 1 120 | np_probs = _w_probs.data.cpu().numpy() 121 | np_bin_vec = _w_binary.data.cpu().numpy() 122 | np_s_feats = _s_feats.data.cpu().numpy() 123 | np_s_probs = _s_probs.data.cpu().numpy() 124 | np_s_masks = _s_masks.data.cpu().numpy() 125 | zipped = zip(example_ids, np_agent_ids, np_index_rec, 126 | np_target, np_rank, np_probs, np_bin_vec) 127 | bin_vec_communication.resize( 128 | bin_vec_communication.shape[0] + batch_size, axis=0) 129 | bin_vec_communication[-batch_size:] = zipped 130 | # Store Receiver's prediction scores locally 131 | zipped = zip(example_ids, np_agent_ids, np_index_rec, np_target, 132 | np_rank, np_preds, np_s_probs, np_s_feats, np_s_masks) 133 | preds_communication.resize( 134 | preds_communication.shape[0] + batch_size, axis=0) 135 | preds_communication[-batch_size:] = zipped 136 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import datetime 5 | import os 6 | import sys 7 | import json 8 | import h5py 9 | import random 10 | from nltk.tokenize import word_tokenize 11 | from nltk.corpus import stopwords 12 | import string 13 | import itertools 14 | 15 | try: 16 | from visdom import Visdom 17 | except: 18 | pass 19 | 20 | 21 | """ 22 | Notes 23 | 24 | A. Loading Description File and Assigning Labels 25 | 26 | Description File should be in the CSV format, 27 | 28 | label_id,label,description 29 | 30 | Concretely, 31 | 32 | 3,aardvark,nocturnal burrowing mammal of the grasslands of Africa that feeds on termites; sole extant representative of the order Tubulidentata 33 | 11,armadillo,burrowing chiefly nocturnal mammal with body covered with strong horny plates 34 | 35 | Note that the label_id need not be ordered nor within any predefined range. Should simply match 36 | the "Target" attribute of the data.hdf5. Once the dataset has been loaded, the label_ids will be 37 | converted into range(len(classes)), and there will be a mapping from the label_ids to this range. 38 | 39 | """ 40 | 41 | 42 | def recursively_set_device(inp, gpu): 43 | if hasattr(inp, 'keys'): 44 | for k in inp.keys(): 45 | inp[k] = recursively_set_device(inp[k], gpu) 46 | elif isinstance(inp, list): 47 | return [recursively_set_device(ii, gpu) for ii in inp] 48 | elif isinstance(inp, tuple): 49 | return (recursively_set_device(ii, gpu) for ii in inp) 50 | elif hasattr(inp, 'cpu'): 51 | if gpu >= 0: 52 | inp = inp.cuda() 53 | else: 54 | inp = inp.cpu() 55 | return inp 56 | 57 | 58 | def torch_save(filename, data, models_dict, optimizers_dict, gpu): 59 | models_to_save = {k: recursively_set_device( 60 | v.state_dict(), gpu=-1) for k, v in models_dict.items()} 61 | optimizers_to_save = {k: recursively_set_device( 62 | v.state_dict(), gpu=-1) for k, v in optimizers_dict.items()} 63 | 64 | # Always sends Tensors to CPU. 65 | torch.save({ 66 | 'data': data, 67 | 'optimizers': optimizers_to_save, 68 | 'models': models_to_save, 69 | }, filename) 70 | 71 | if gpu >= 0: 72 | for m in models_dict.values(): 73 | recursively_set_device(m.state_dict(), gpu=gpu) 74 | for o in optimizers_dict.values(): 75 | recursively_set_device(o.state_dict(), gpu=gpu) 76 | 77 | 78 | def torch_load(filename, models_dict, optimizers_dict): 79 | filename = os.path.expanduser(filename) 80 | 81 | if not os.path.exists(filename): 82 | raise Exception("File does not exist: " + filename) 83 | 84 | checkpoint = torch.load(filename) 85 | 86 | for k, v in models_dict.items(): 87 | v.load_state_dict(checkpoint['models'][k]) 88 | 89 | for k, v in optimizers_dict.items(): 90 | v.load_state_dict(checkpoint['optimizers'][k]) 91 | 92 | return checkpoint['data'] 93 | 94 | 95 | class VisdomLogger(object): 96 | """ 97 | Logs data to visdom 98 | 99 | """ 100 | 101 | def __init__(self, env, experiment_name, minimum=2, enabled=False): 102 | self.enabled = enabled 103 | self.experiment_name = experiment_name 104 | self.env = env 105 | self.minimum = minimum 106 | 107 | self.q = dict() 108 | 109 | if enabled: 110 | self.viz = Visdom() 111 | 112 | def get_metrics(self, key, val, step): 113 | metric = self.q.setdefault(key, []) 114 | metric.append((step, val)) 115 | if len(metric) >= self.minimum: 116 | del self.q[key] 117 | return metric 118 | return None 119 | 120 | def viz_success(self, win): 121 | if win == "win does not exist": 122 | return False 123 | return True 124 | 125 | def log(self, key, val, step): 126 | if not self.enabled: 127 | return 128 | 129 | metrics = self.get_metrics(key, val, step) 130 | 131 | # Visdom requires 2+ data points to be written. 132 | if metrics is None: 133 | return 134 | 135 | steps, vals = zip(*metrics) 136 | steps = np.array(steps, dtype=np.int32) 137 | vals = np.array(vals, dtype=np.float32) 138 | 139 | viz = self.viz 140 | experiment_name = self.experiment_name 141 | env = self.env 142 | 143 | win = viz.updateTrace(X=steps, Y=vals, 144 | name=experiment_name, win=key, env=env, 145 | append=True) 146 | 147 | if not self.viz_success(win): 148 | viz.line(X=steps, Y=vals, 149 | win=key, env=env, 150 | opts={"legend": [experiment_name], "title": key}) 151 | 152 | 153 | class FileLogger(object): 154 | # A logging alternative that doesn't leave logs open between writes, 155 | # so as to allow AFS synchronization. 156 | 157 | # Level constants 158 | DEBUG = 0 159 | INFO = 1 160 | WARNING = 2 161 | ERROR = 3 162 | 163 | def __init__(self, log_path=None, json_log_path=None, min_print_level=0, min_file_level=0): 164 | # log_path: The full path for the log file to write. The file will be appended 165 | # to if it exists. 166 | # min_print_level: Only messages with level above this level will be printed to stderr. 167 | # min_file_level: Only messages with level above this level will be 168 | # written to disk. 169 | self.log_path = log_path 170 | self.json_log_path = json_log_path 171 | self.min_print_level = min_print_level 172 | self.min_file_level = min_file_level 173 | 174 | def Log(self, message, level=INFO): 175 | if level >= self.min_print_level: 176 | # Write to STDERR 177 | sys.stderr.write("[%i] %s\n" % (level, message)) 178 | if self.log_path and level >= self.min_file_level: 179 | # Write to the log file then close it 180 | with open(self.log_path, 'a') as f: 181 | datetime_string = datetime.datetime.now().strftime( 182 | "%y-%m-%d %H:%M:%S") 183 | f.write("%s [%i] %s\n" % (datetime_string, level, message)) 184 | 185 | def LogJSON(self, message_obj, level=INFO): 186 | if self.json_log_path and level >= self.min_file_level: 187 | with open(self.json_log_path, 'w') as f: 188 | print >>f, json.dumps(message_obj) 189 | else: 190 | sys.stderr.write('WARNING: No JSON log filename.') 191 | 192 | 193 | def read_log_load(filename, last=True): 194 | ret = None 195 | cur = None 196 | reading = False 197 | begin = "Flag Values" 198 | end = "}" 199 | 200 | with open(filename) as f: 201 | for line in f: 202 | if begin in line and not reading: 203 | cur = "" 204 | reading = True 205 | continue 206 | 207 | if reading: 208 | cur += line.strip() 209 | 210 | if end in line: 211 | ret = json.loads(cur) 212 | reading = False 213 | 214 | if not last: 215 | return ret 216 | 217 | return ret 218 | 219 | 220 | def clean_desc(desc): 221 | words = word_tokenize(desc.lower()) # lowercase and tokenize 222 | words = list(set(words)) # remove duplicates 223 | words = [w for w in words if w not in stopwords.words( 224 | 'english')] # remove stopwords 225 | words = [w for w in words if w not in string.punctuation] 226 | return words 227 | 228 | 229 | def read_data(input_descr): 230 | descr = {} 231 | word_dict = {} 232 | dict_size = 0 233 | num_descr = 0 234 | label_id_to_idx = {} 235 | idx_to_label = {} 236 | with open(input_descr, "r") as f: 237 | for i, line in enumerate(f): 238 | line = line.strip() 239 | parts = line.split(",") 240 | label_id, label = parts[:2] 241 | desc = line[len(label_id) + len(label) + 2:] 242 | desc = clean_desc(desc) 243 | # print label, sorted(desc) 244 | for w in desc: 245 | if w not in word_dict: 246 | dict_size += 1 247 | word_dict[w] = {"id": dict_size} 248 | descr[num_descr] = {"name": label, "desc": desc} 249 | num_descr += 1 250 | label_id_to_idx[int(label_id)] = i 251 | idx_to_label[i] = label 252 | _desc = set([w for ii in descr.keys() for w in descr[ii]['desc']]) 253 | # print sorted(_desc) 254 | return descr, word_dict, dict_size, label_id_to_idx, idx_to_label 255 | 256 | 257 | def load_hdf5(hdf5_file, batch_size, random_seed, shuffle, truncate_final_batch=False, map_labels=int): 258 | """ 259 | Reads images into random batches 260 | """ 261 | # Read data 262 | f = h5py.File(os.path.expanduser(hdf5_file), "r") 263 | target = f["Target"] 264 | dataset_size = target.shape[0] 265 | f.close() 266 | order = range(dataset_size) 267 | 268 | # Shuffle 269 | if shuffle: 270 | random.seed(11 + random_seed) 271 | random.shuffle(order) 272 | 273 | # Generate batches 274 | num_batches = dataset_size // batch_size 275 | 276 | if truncate_final_batch: 277 | if dataset_size - (num_batches * batch_size) > 0: 278 | num_batches = num_batches + 1 279 | 280 | for i in range(num_batches): 281 | 282 | batch_indices = sorted(order[i * batch_size:(i + 1) * batch_size]) 283 | 284 | f = h5py.File(os.path.expanduser(hdf5_file), "r") 285 | 286 | batch = dict() 287 | 288 | # TODO: We probably need to map the label_ids some way. 289 | batch['target'] = torch.LongTensor( 290 | map(map_labels, f["Target"][batch_indices])) 291 | batch['example_ids'] = f["Location"][batch_indices] 292 | 293 | batch['layer4_2'] = torch.from_numpy( 294 | f["layer4_2"][batch_indices]).float().squeeze() 295 | batch['avgpool_512'] = torch.from_numpy( 296 | f["avgpool_512"][batch_indices]).float().squeeze() 297 | batch['fc'] = torch.from_numpy( 298 | f["fc"][batch_indices]).float().squeeze() 299 | 300 | f.close() 301 | 302 | yield batch 303 | 304 | 305 | # Function returning word embeddings from GloVe 306 | def embed(word_dict, emb): 307 | glove = {} 308 | print("Vocab Size: {}".format(len(word_dict.keys()))) 309 | with open(emb, "r") as f: 310 | for line in f: 311 | word = line.strip() 312 | word = word.split(" ") 313 | if word[0] in word_dict: 314 | embed = torch.Tensor([float(s) for s in word[1:]]) 315 | glove[word[0]] = embed 316 | print("Found {} in glove.".format(len(glove.keys()))) 317 | for k in word_dict: 318 | embed = glove.get(k, None) 319 | word_dict[k]["emb"] = embed 320 | return word_dict 321 | 322 | 323 | # Function computing CBOW for each description 324 | def cbow(descr, word_dict): 325 | # TODO: Faster summing please! 326 | emb_size = len(word_dict.values()[0]["emb"]) 327 | for mammal in descr: 328 | num_w = 0 329 | desc_len = len(descr[mammal]["desc"]) 330 | desc_set = torch.FloatTensor(desc_len, emb_size).fill_(0) 331 | for i_w, w in enumerate(descr[mammal]["desc"]): 332 | if word_dict[w]["emb"] is not None: 333 | desc_set[i_w] = word_dict[w]["emb"] 334 | num_w += 1 335 | desc_cbow = desc_set.clone().sum(0).squeeze() 336 | if num_w > 0: 337 | desc_cbow = desc_cbow / num_w 338 | descr[mammal]["cbow"] = desc_cbow 339 | descr[mammal]["set"] = desc_set 340 | return descr 341 | 342 | 343 | """ 344 | Initialization Schemes 345 | Source: https://github.com/alykhantejani/nninit/blob/master/nninit.py 346 | """ 347 | 348 | 349 | def _calculate_fan_in_and_fan_out(tensor): 350 | if tensor.ndimension() < 2: 351 | raise ValueError( 352 | "fan in and fan out can not be computed for tensor of size ", tensor.size()) 353 | 354 | if tensor.ndimension() == 2: # Linear 355 | fan_in = tensor.size(1) 356 | fan_out = tensor.size(0) 357 | else: 358 | num_input_fmaps = tensor.size(1) 359 | num_output_fmaps = tensor.size(0) 360 | receptive_field_size = np.prod(tensor.numpy().shape[2:]) 361 | fan_in = num_input_fmaps * receptive_field_size 362 | fan_out = num_output_fmaps * receptive_field_size 363 | 364 | return fan_in, fan_out 365 | 366 | 367 | def xavier_normal(tensor, gain=1): 368 | """Fills the input Tensor or Variable with values according to the method described in "Understanding the difficulty of training 369 | deep feedforward neural networks" - Glorot, X. and Bengio, Y., using a normal distribution. 370 | The resulting tensor will have values sampled from normal distribution with mean=0 and 371 | std = gain * sqrt(2/(fan_in + fan_out)) 372 | Args: 373 | tensor: a n-dimension torch.Tensor 374 | gain: an optional scaling factor to be applied 375 | Examples: 376 | >>> w = torch.Tensor(3, 5) 377 | >>> nninit.xavier_normal(w, gain=np.sqrt(2.0)) 378 | """ 379 | if isinstance(tensor, Variable): 380 | xavier_normal(tensor.data, gain=gain) 381 | return tensor 382 | else: 383 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 384 | std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 385 | return tensor.normal_(0, std) 386 | 387 | 388 | def build_mask(region_str, size): 389 | # Read input string 390 | regions = region_str.split(',') 391 | regions = [r.split(':') for r in regions] 392 | regions = [[int(r[0])] if len(r) == 1 else 393 | list(range(int(r[0]), int(r[1]))) for r in regions] # python style indexing 394 | 395 | # Flattens the list of lists 396 | index = torch.LongTensor(list(itertools.chain(*regions))) 397 | 398 | # Generate mask 399 | mask = torch.FloatTensor(size, 1).fill_(0) 400 | mask[index] = 1 401 | 402 | return mask 403 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import numpy as np 6 | import random 7 | import h5py 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable as _Variable 13 | import torch.optim as optim 14 | from torch.nn.parameter import Parameter 15 | 16 | import torchvision.models as models 17 | import torchvision.datasets as dset 18 | import torchvision.transforms as transforms 19 | 20 | from sklearn.metrics import confusion_matrix 21 | 22 | from misc import recursively_set_device, torch_save, torch_load 23 | from misc import VisdomLogger as Logger 24 | from misc import FileLogger 25 | from misc import read_log_load 26 | from misc import load_hdf5 27 | from misc import read_data 28 | from misc import embed 29 | from misc import cbow 30 | from misc import xavier_normal 31 | from misc import build_mask 32 | 33 | from sparks import sparks 34 | 35 | from binary_vectors import extract_binary 36 | 37 | import gflags 38 | 39 | FLAGS = gflags.FLAGS 40 | 41 | 42 | def Variable(*args, **kwargs): 43 | var = _Variable(*args, **kwargs) 44 | if FLAGS.cuda: 45 | var = var.cuda() 46 | return var 47 | 48 | 49 | class Sender(nn.Module): 50 | """Agent 1 Network: Sender 51 | """ 52 | 53 | def __init__(self, feature_type, feat_dim, h_dim, w_dim, bin_dim_out, use_binary, 54 | use_attn, attn_dim, attn_extra_context, attn_context_dim): 55 | super(Sender, self).__init__() 56 | self.feature_type = feature_type 57 | self.feat_dim = feat_dim 58 | self.h_dim = h_dim 59 | self.w_dim = w_dim 60 | self.bin_dim_out = bin_dim_out 61 | self.use_binary = use_binary 62 | self.use_attn = use_attn 63 | self.attn_dim = attn_dim 64 | self.attn_extra_context = attn_extra_context 65 | self.attn_context_dim = attn_context_dim 66 | 67 | self.image_layer = nn.Linear(self.feat_dim, self.h_dim) 68 | self.code_layer = nn.Linear(self.w_dim, self.h_dim) 69 | self.code_bias = Parameter(torch.Tensor(self.bin_dim_out)) 70 | # Layer for binary vector 71 | if FLAGS.sender_mix == "mou": 72 | self.binary_layer = nn.Linear(self.h_dim * 4, self.bin_dim_out) 73 | if FLAGS.ignore_code: 74 | self.code_bias_mou = Parameter(torch.Tensor(self.bin_dim_out)) 75 | else: 76 | self.binary_layer = nn.Linear(self.h_dim, self.bin_dim_out) 77 | 78 | # self.binary_layer.bias.data.fill_(-2.) 79 | 80 | if self.use_attn: 81 | self.attn_W_x = nn.Linear(self.feat_dim, self.attn_dim) 82 | self.attn_W_w = nn.Linear(self.w_dim, self.attn_dim) 83 | self.attn_U = nn.Linear(self.attn_dim, 1) 84 | 85 | if FLAGS.attn_extra_context: 86 | self.attn_W_g = nn.Linear(self.attn_context_dim, self.attn_dim) 87 | 88 | self.reset_parameters() 89 | 90 | def reset_parameters(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Linear): 93 | m.weight.data.set_(xavier_normal(m.weight.data)) 94 | if m.bias is not None: 95 | m.bias.data.zero_() 96 | if hasattr(self, 'code_bias'): 97 | self.code_bias.data.normal_() 98 | 99 | def reset_state(self): 100 | """Initialize state for Sender. 101 | 102 | The Sender is stateless in its decisions, but some computation 103 | can be reused at each time step. 104 | 105 | """ 106 | # Cached computation. 107 | self.h_x_attn_flat = None 108 | self.h_g_flat = None 109 | self.fn_x = None 110 | 111 | # Used for debugging. 112 | self.attn_scores = [] 113 | 114 | def attention_func(self, w, x, g): 115 | batch_size, n_feats, channels = x.size() 116 | 117 | h_w_attn = self.attn_W_w(w) 118 | h_w_attn_broadcast = h_w_attn.contiguous().unsqueeze( 119 | 1).expand(batch_size, n_feats, self.attn_dim) 120 | h_w_attn_flat = h_w_attn_broadcast.contiguous().view( 121 | batch_size * n_feats, self.attn_dim) 122 | 123 | if not self.h_x_attn_flat: 124 | x_flat = x.contiguous().view(batch_size * n_feats, channels) 125 | self.h_x_attn_flat = self.attn_W_x(x_flat) 126 | 127 | if self.attn_extra_context: 128 | if not self.h_g_flat: 129 | h_g = self.attn_W_g(g) 130 | hg_broadcast = h_g.contiguous().unsqueeze( 131 | 1).expand(batch_size, n_feats, self.attn_dim) 132 | self.h_g_flat = hg_broadcast.contiguous().view( 133 | batch_size * n_feats, self.attn_dim) 134 | 135 | if self.attn_extra_context: 136 | attn_U_inp = nn.Tanh()(h_w_attn_flat + self.h_x_attn_flat + self.h_g_flat) 137 | else: 138 | attn_U_inp = nn.Tanh()(h_w_attn_flat + self.h_x_attn_flat) 139 | 140 | attn_scores_flat = self.attn_U(attn_U_inp) 141 | 142 | return attn_scores_flat 143 | 144 | def forward(self, x, w, g, t): 145 | """Respond to communication query. 146 | 147 | Communication Response: 148 | z_hat = U_z(U_x x + U_w w) 149 | z = bernoulli(sig(z_hat)) or round(sig(z_hat)) 150 | 151 | Image Attention (https://arxiv.org/pdf/1502.03044.pdf): 152 | \beta_i = U tanh(W_r z_r + W_x x_i [+ W_g g]) 153 | \alpha = 1 / |x| if t == 0 154 | \alpha = softmax(\beta) otherwise 155 | x = \sum_i \alpha x_i 156 | 157 | Args: 158 | x: Image features. 159 | w: Communication query from Receiver. 160 | g: (attention) Image features used as query in attention. 161 | t: (attention) Timestep. Used to change attention equation in first iteration. 162 | Output: 163 | features: A binary (or continuous) message in response to Receiver's query. 164 | feature_probs: If the message is binary, then these are probability of ``1`` for each bit 165 | in the message. 166 | """ 167 | 168 | if self.use_attn: 169 | batch_size, channels, height, width = x.size() 170 | n_feats = height * width 171 | x = x.view(batch_size, channels, n_feats) 172 | x = x.transpose(1, 2) 173 | 174 | attn_scores_flat = self.attention_func(w, x, g) 175 | 176 | # attention scores 177 | if t == 0: 178 | attn_scores = Variable(torch.FloatTensor( 179 | batch_size, n_feats).fill_(1), volatile=not self.training) 180 | attn_scores = attn_scores / n_feats 181 | else: 182 | attn_scores = F.softmax( 183 | attn_scores_flat.view(batch_size, n_feats)) 184 | 185 | # x = \sum_i a_i x_i 186 | x_attn = torch.bmm(attn_scores.unsqueeze(1), x).squeeze() 187 | 188 | # Cache values for inspection 189 | self.attn_scores.append(attn_scores) 190 | 191 | _x = x_attn 192 | else: 193 | _x = x 194 | 195 | self.h_x = h_x = self.image_layer(_x) 196 | if t == 0: 197 | batch_size = x.size(0) 198 | # Same first code for all batch items. 199 | first_code = F.sigmoid(self.code_bias.view(1, -1)) 200 | h_w = self.code_layer(first_code).expand(batch_size, self.h_dim) 201 | elif t > 0 and FLAGS.ignore_code and FLAGS.sender_mix == "mou": 202 | batch_size = x.size(0) 203 | # Same code for all batch items. 204 | code_mou = F.sigmoid(self.code_bias_mou.view(1, -1)) 205 | h_w = self.code_layer(code_mou).expand(batch_size, self.h_dim) 206 | else: 207 | h_w = self.code_layer(w) 208 | if FLAGS.ignore_code: 209 | if FLAGS.sender_mix == "sum" or FLAGS.sender_mix == "prod": 210 | features = self.binary_layer(F.tanh(h_x)) 211 | elif FLAGS.sender_mix == "mou": 212 | features = self.binary_layer( 213 | F.tanh(torch.cat([h_x, h_w, h_x - h_w, h_x * h_w], 1))) 214 | else: 215 | if FLAGS.sender_mix == "sum": 216 | features = self.binary_layer(F.tanh(h_x + h_w)) 217 | elif FLAGS.sender_mix == "prod": 218 | features = self.binary_layer(F.tanh(h_x * h_w)) 219 | elif FLAGS.sender_mix == "mou": 220 | features = self.binary_layer( 221 | F.tanh(torch.cat([h_x, h_w, h_x - h_w, h_x * h_w], 1))) 222 | if self.use_binary: 223 | probs = F.sigmoid(features) 224 | if self.training: 225 | probs_ = probs.data.cpu().numpy() 226 | binary_features = Variable(torch.from_numpy( 227 | (np.random.rand(*probs_.shape) < probs_).astype('float32'))) 228 | else: 229 | binary_features = torch.round(probs).detach() 230 | if probs.is_cuda: 231 | binary_features = binary_features.cuda() 232 | 233 | if FLAGS.flipout_sen is not None and (self.training or FLAGS.flipout_dev): 234 | binary_features = flipout(binary_features, FLAGS.flipout_sen) 235 | 236 | return binary_features, probs 237 | else: 238 | return features, None 239 | 240 | 241 | class Receiver(nn.Module): 242 | """Agent 2 Network: Receiver 243 | """ 244 | 245 | def __init__(self, z_dim, desc_dim, hid_dim, out_dim, w_dim, s_dim, use_binary): 246 | super(Receiver, self).__init__() 247 | self.z_dim = z_dim 248 | self.desc_dim = desc_dim 249 | self.hid_dim = hid_dim 250 | self.out_dim = out_dim 251 | self.w_dim = w_dim 252 | self.s_dim = s_dim 253 | self.use_binary = use_binary 254 | 255 | # RNN network 256 | self.rnn = nn.GRUCell(self.z_dim, self.hid_dim) 257 | # Network for Receiver communications 258 | self.w_h = nn.Linear(self.hid_dim, self.hid_dim, bias=True) 259 | self.w_d = nn.Linear(self.desc_dim, self.hid_dim, bias=False) 260 | self.w = nn.Linear(self.hid_dim, self.w_dim) 261 | # Network for Receiver predicitons 262 | self.y1 = nn.Linear(self.hid_dim + self.desc_dim, self.hid_dim) 263 | self.y2 = nn.Linear(self.hid_dim, self.out_dim) 264 | # Network for Receiver decisions 265 | self.s = nn.Linear(self.hid_dim, self.s_dim) 266 | 267 | if FLAGS.desc_attn: 268 | self.attn_dim = FLAGS.desc_attn_dim 269 | self.d_d = nn.Linear(self.desc_dim, self.attn_dim) 270 | self.d_h = nn.Linear(self.hid_dim, self.attn_dim) 271 | self.d_attn = nn.Linear(self.attn_dim, 1) 272 | 273 | self.reset_parameters() 274 | 275 | def reset_parameters(self): 276 | for m in self.modules(): 277 | if isinstance(m, nn.Linear): 278 | m.weight.data.set_(xavier_normal(m.weight.data)) 279 | if m.bias is not None: 280 | m.bias.data.zero_() 281 | elif isinstance(m, nn.GRUCell): 282 | for mm in m.parameters(): 283 | if mm.data.ndimension() == 2: 284 | mm.data.set_(xavier_normal(mm.data)) 285 | elif mm.data.ndimension() == 1: # Bias 286 | mm.data.zero_() 287 | if hasattr(self, 'code_bias'): 288 | self.code_bias.data.uniform_(-1, 1) 289 | 290 | def reset_state(self): 291 | """Initialize state for Receiver. 292 | 293 | The Receiver is stateful, keeping tracking of previous messages it 294 | has sent and received. 295 | 296 | """ 297 | self.h_z = None 298 | self.s_prob_prod = None 299 | 300 | def initial_state(self, batch_size): 301 | return Variable(torch.zeros(batch_size, self.hid_dim)) 302 | 303 | def forward(self, z, desc, desc_set=None, desc_set_lens=None): 304 | """Send communication query. 305 | 306 | Update State: 307 | h_z = rnn(z, h_z) 308 | 309 | Predictions: 310 | y_i = f_y(h_z, desc_i) 311 | 312 | Communication Query: 313 | desc = \sum_i y_i desc_i 314 | w_hat = tanh(W_h h_z + W_d desc) 315 | w = bernoulli(sig(w_hat)) or round(sig(w_hat)) 316 | 317 | STOP Bit: 318 | s_hat = W_s h_z 319 | s = bernoulli(sig(s_hat)) or round(sig(s_hat)) 320 | 321 | Args: 322 | z: Communication response from Receiver. 323 | desc: List of description vectors used in communication and predictions. 324 | Output: 325 | s, s_probs: A STOP bit and its associated probability, indicating whether the Receiver has decided to stop 326 | or continue its conversation with the Sender. 327 | w, w_probs: A binary (or continuous) message, which is a query incorporating the descriptions. If the 328 | message is binary, then the probability of each bit in the message being ``1`` is included. 329 | y: A prediction for each class described in the descriptions. 330 | """ 331 | 332 | # BatchSize x BinaryDim 333 | batch_size, binary_dim = z.size() 334 | 335 | # Initialize hidden state if necessary 336 | if self.h_z is None: 337 | self.h_z = self.initial_state(batch_size) 338 | 339 | # Run z through RNN 340 | self.h_z = self.rnn(z, self.h_z) 341 | 342 | # Build input for prediction using descriptions 343 | # size of inp_with_desc: B*D x (WV+h) 344 | if FLAGS.desc_attn: 345 | nwords, desc_dim = desc_set.size() 346 | desc_set = Variable(desc_set) # NW x WV 347 | # Broadcast and Flatten 348 | desc_set_broadcast = desc_set.unsqueeze(0).expand( 349 | batch_size, nwords, self.desc_dim) # B x NW x WV 350 | 351 | # Broadcast and Flatten 352 | dd = self.d_d(desc_set) # NW x A 353 | dd_broadcast = dd.unsqueeze(0).expand( 354 | batch_size, nwords, self.attn_dim) # B x NW x A 355 | dd_flat = dd_broadcast.contiguous().view( 356 | batch_size * nwords, self.attn_dim) # B*NW x A 357 | 358 | # Broadcast and Flatten 359 | dh = self.d_h(self.h_z) # B x A 360 | dh_broadcast = dh.unsqueeze(1).expand( 361 | batch_size, nwords, self.attn_dim) # B x NW x A 362 | dh_flat = dh_broadcast.contiguous().view( 363 | batch_size * nwords, self.attn_dim) # B*NW x A 364 | 365 | # Get and Apply Attention Scores 366 | d_attn = self.d_attn(F.tanh(dd_flat + dh_flat) 367 | ).view(batch_size, nwords) # B x NW 368 | 369 | # Partitioned Scores 370 | cumlen = 0 371 | d_attn_scores = [] 372 | for idesc, ndesc in enumerate(desc_set_lens): 373 | start = cumlen 374 | end = cumlen + ndesc 375 | cumlen = end 376 | 377 | scores = F.softmax(d_attn[:, start:end]) # B x NW_i 378 | d_attn_scores.append(scores) 379 | self.d_attn_scores = d_attn_scores = torch.cat( 380 | d_attn_scores, 1) # B x NW 381 | 382 | # Attend 383 | d_attn_broadcast = d_attn_scores.unsqueeze( 384 | 2).expand_as(desc_set_broadcast) # B x NW x WV 385 | desc_set_weighted = desc_set_broadcast * d_attn_broadcast # B x NW x WV 386 | 387 | # Partitioned Weighted Sum 388 | cumlen = 0 389 | weighted_desc = [] 390 | for idesc, ndesc in enumerate(desc_set_lens): 391 | start = cumlen 392 | end = cumlen + ndesc 393 | cumlen = end 394 | 395 | cbow = desc_set_weighted[:, start:end, :].sum(1) # B x 1 x WV 396 | weighted_desc.append(cbow) 397 | weighted_desc = torch.cat(weighted_desc, 1) # B x D x WV 398 | 399 | # Build Input 400 | nclasses = weighted_desc.size(1) 401 | weighted_desc = weighted_desc.view( 402 | batch_size * nclasses, self.desc_dim) # B*D x WV 403 | 404 | h_z_broadcast = self.h_z.unsqueeze(1).expand( 405 | batch_size, nclasses, self.hid_dim) # B x D x h 406 | h_z_flat = h_z_broadcast.contiguous().view( 407 | batch_size * nclasses, self.hid_dim) # B*D x h 408 | 409 | inp_with_desc = torch.cat( 410 | [weighted_desc, h_z_flat], 1) # B*D x (WV+h) 411 | else: 412 | inp_with_desc = build_inp(self.h_z, desc) # B*D x (WV+h) 413 | 414 | s_score = self.s(self.h_z) 415 | s_prob = F.sigmoid(s_score) 416 | if self.training: 417 | # Sample decisions 418 | prob_ = s_prob.data.cpu().numpy() 419 | s_binary = Variable(torch.from_numpy( 420 | (np.random.rand(*prob_.shape) < prob_).astype('float32'))) 421 | else: 422 | # Infer decisions 423 | if not self.s_prob_prod or not FLAGS.s_prob_prod: 424 | self.s_prob_prod = s_prob 425 | else: 426 | self.s_prob_prod = self.s_prob_prod * s_prob 427 | s_binary = torch.round(self.s_prob_prod).detach() 428 | if s_prob.is_cuda: 429 | s_binary = s_binary.cuda() 430 | 431 | # Obtain predictions 432 | y = self.y1(inp_with_desc).clamp(min=0) 433 | y = self.y2(y).view(batch_size, -1) 434 | 435 | # Obtain communications 436 | # size of y = batch_size x # descriptions 437 | # size of desc = # descriptions x self.desc_dim 438 | # size of wd_inp = batch_size x self.desc_dim 439 | n_desc = y.size(1) 440 | # Reweight descriptions based on current model confidence 441 | y_scores = F.softmax(y).detach() 442 | y_broadcast = y_scores.unsqueeze(2).expand( 443 | batch_size, n_desc, self.desc_dim) 444 | if FLAGS.desc_attn: 445 | wd_inp = weighted_desc.view(batch_size, nclasses, self.desc_dim) 446 | else: 447 | wd_inp = desc.unsqueeze(0).expand( 448 | batch_size, n_desc, self.desc_dim) 449 | wd_inp = (y_broadcast * wd_inp).sum(1).squeeze(1) 450 | 451 | # Hidden state for Receiver message 452 | self.h_w = F.tanh(self.w_h(self.h_z) + self.w_d(wd_inp)) 453 | 454 | w_scores = self.w(self.h_w) 455 | if self.use_binary: 456 | w_probs = F.sigmoid(w_scores) 457 | if self.training: 458 | probs_ = w_probs.data.cpu().numpy() 459 | w_binary = Variable(torch.from_numpy( 460 | (np.random.rand(*probs_.shape) < probs_).astype('float32'))) 461 | else: 462 | w_binary = torch.round(w_probs).detach() 463 | if w_probs.is_cuda: 464 | w_binary = w_binary.cuda() 465 | w_feats = w_binary 466 | 467 | if FLAGS.flipout_rec is not None and (self.training or FLAGS.flipout_dev): 468 | w_feats = flipout(w_feats, FLAGS.flipout_rec) 469 | 470 | if FLAGS.ignore_receiver: 471 | w_feats = Variable(torch.zeros(w_feats.size()), 472 | volatile=not self.training) 473 | else: 474 | w_feats = w_scores 475 | w_probs = None 476 | 477 | return (s_binary, s_prob), (w_feats, w_probs), y 478 | 479 | 480 | class Baseline(nn.Module): 481 | """Baseline 482 | """ 483 | 484 | def __init__(self, hid_dim, x_dim, binary_dim, inp_dim): 485 | super(Baseline, self).__init__() 486 | self.x_dim = x_dim 487 | self.binary_dim = binary_dim 488 | self.inp_dim = inp_dim 489 | self.hid_dim = hid_dim 490 | 491 | # Additional layers on top of feature extractor 492 | self.linear1 = nn.Linear( 493 | x_dim + self.binary_dim + self.inp_dim, self.hid_dim) 494 | self.linear2 = nn.Linear(self.hid_dim, 1) 495 | 496 | def forward(self, x, binary, inp): 497 | """Estimate agent's loss based on the agent's input. 498 | 499 | Args: 500 | x: Image features. 501 | binary: Communication message. 502 | inp: Hidden state (used when agent is the Receiver). 503 | Output: 504 | score: An estimate of the agent's loss. 505 | """ 506 | features = [] 507 | if x is not None: 508 | features.append(x) 509 | if binary is not None: 510 | features.append(binary) 511 | if inp is not None: 512 | features.append(inp) 513 | features = torch.cat(features, 1) 514 | hidden = self.linear1(features).clamp(min=0) 515 | pred_score = self.linear2(hidden) 516 | return pred_score 517 | 518 | 519 | def build_inp(binary_features, descs): 520 | """Function preparing input for Receiver network 521 | 522 | Args: 523 | binary_features: List of communication vectors, length ``B``. 524 | descs: List of description vectors, length ``D``. 525 | Output: 526 | b_cat_d: The cartesian product of binary features and descriptions, length ``B`` x ``D``. 527 | """ 528 | if descs is not None: 529 | batch_size = binary_features.size(0) 530 | num_desc, desc_dim = descs.size() 531 | 532 | # Expand binary features. 533 | binary_index = torch.from_numpy( 534 | np.arange(batch_size).repeat(num_desc).astype(np.int32)).long() 535 | if binary_features.is_cuda: 536 | binary_index = binary_index.cuda() 537 | binary_copied = torch.index_select( 538 | binary_features, 0, Variable(binary_index)) 539 | 540 | # Expand descriptions. 541 | desc_index = torch.from_numpy(np.concatenate( 542 | [np.arange(num_desc)] * batch_size).astype(np.int32)).long() 543 | if descs.is_cuda: 544 | desc_index = desc_index.cuda() 545 | desc_copied = torch.index_select(descs, 0, Variable(desc_index)) 546 | 547 | # Concat binary vector with description vectors 548 | inp = torch.cat([binary_copied, desc_copied], 1) 549 | return inp 550 | else: 551 | return binary_features 552 | 553 | 554 | def flipout(binary, p): 555 | """ 556 | Args: 557 | binary: Tensor of binary values. 558 | p: Probability of flipping a binary value. 559 | Output: 560 | outp: Tensor with same size as `binary` where bits have been 561 | flipped with probability `p`. 562 | """ 563 | mask = torch.FloatTensor(binary.size()).fill_(p).numpy() 564 | mask = Variable(torch.from_numpy( 565 | (np.random.rand(*mask.shape) < mask).astype('float32'))) 566 | outp = (binary - mask).abs() 567 | 568 | return outp 569 | 570 | 571 | def loglikelihood(log_prob, target): 572 | """ 573 | Args: log softmax scores (N, C) where N is the batch size 574 | and C is the number of classes 575 | Output: log likelihood (N) 576 | """ 577 | return log_prob.gather(1, target) 578 | 579 | 580 | def eval_dev(dev_file, batch_size, epoch, shuffle, cuda, top_k, 581 | sender, receiver, desc_dict, map_labels, file_name, 582 | callback=None): 583 | """ 584 | Function computing development accuracy 585 | """ 586 | 587 | desc = desc_dict["desc"] 588 | desc_set = desc_dict.get("desc_set", None) 589 | desc_set_lens = desc_dict.get("desc_set_lens", None) 590 | 591 | extra = dict() 592 | 593 | # Keep track of conversation lengths 594 | conversation_lengths = [] 595 | 596 | # Keep track of message diversity 597 | hamming_sen = [] 598 | hamming_rec = [] 599 | 600 | # Keep track of labels 601 | true_labels = [] 602 | pred_labels = [] 603 | 604 | # Keep track of number of correct observations 605 | total = 0 606 | correct = 0 607 | 608 | # Load development images 609 | dev_loader = load_hdf5(dev_file, batch_size, epoch, shuffle, 610 | truncate_final_batch=True, map_labels=map_labels) 611 | 612 | for batch in dev_loader: 613 | # Extract images and targets 614 | 615 | target = batch["target"] 616 | data = batch[FLAGS.img_feat] 617 | _batch_size = target.size(0) 618 | 619 | true_labels.append(target.cpu().numpy().reshape(-1)) 620 | 621 | # GPU support 622 | if cuda: 623 | data = data.cuda() 624 | target = target.cuda() 625 | desc = desc.cuda() 626 | 627 | exchange_args = dict() 628 | exchange_args["data"] = data 629 | if FLAGS.attn_extra_context: 630 | exchange_args["data_context"] = batch[FLAGS.data_context] 631 | exchange_args["target"] = target 632 | exchange_args["desc"] = desc 633 | exchange_args["desc_set"] = desc_set 634 | exchange_args["desc_set_lens"] = desc_set_lens 635 | exchange_args["train"] = False 636 | exchange_args["break_early"] = not FLAGS.fixed_exchange 637 | exchange_args["corrupt"] = FLAGS.bit_flip 638 | exchange_args["corrupt_region"] = FLAGS.corrupt_region 639 | 640 | s, sen_w, rec_w, y, bs, br = exchange( 641 | sender, receiver, None, None, exchange_args) 642 | 643 | s_masks, s_feats, s_probs = s 644 | sen_feats, sen_probs = sen_w 645 | rec_feats, rec_probs = rec_w 646 | 647 | # Mask if dynamic exchange length 648 | if FLAGS.fixed_exchange: 649 | y_masks = None 650 | else: 651 | y_masks = [torch.min(1 - m1, m2) 652 | for m1, m2 in zip(s_masks[1:], s_masks[:-1])] 653 | 654 | outp, _ = get_rec_outp(y, y_masks) 655 | 656 | # Obtain top k predictions 657 | dist = F.log_softmax(outp) 658 | top_k_ind = torch.from_numpy( 659 | dist.data.cpu().numpy().argsort()[:, -top_k:]).long() 660 | target = target.view(-1, 1).expand(_batch_size, top_k) 661 | 662 | # Store top 1 prediction for confusion matrix 663 | _, argmax = dist.data.max(1) 664 | pred_labels.append(argmax.cpu().numpy()) 665 | 666 | # Update accuracy counts 667 | total += float(batch_size) 668 | correct += (top_k_ind == target.cpu()).sum() 669 | 670 | # Keep track of conversation lengths 671 | conversation_lengths += torch.cat(s_feats, 672 | 1).data.float().sum(1).view(-1).tolist() 673 | 674 | # Keep track of message diversity 675 | mean_hamming_rec = 0 676 | mean_hamming_sen = 0 677 | prev_rec = torch.FloatTensor(_batch_size, FLAGS.rec_w_dim).fill_(0) 678 | prev_sen = torch.FloatTensor(_batch_size, FLAGS.rec_w_dim).fill_(0) 679 | 680 | for msg in sen_feats: 681 | mean_hamming_sen += (msg.data.cpu() - prev_sen).abs().sum(1).mean() 682 | prev_sen = msg.data.cpu() 683 | mean_hamming_sen = mean_hamming_sen / float(len(sen_feats)) 684 | 685 | for msg in rec_feats: 686 | mean_hamming_rec += (msg.data.cpu() - prev_rec).abs().sum(1).mean() 687 | prev_rec = msg.data.cpu() 688 | mean_hamming_rec = mean_hamming_rec / float(len(rec_feats)) 689 | 690 | hamming_sen.append(mean_hamming_sen) 691 | hamming_rec.append(mean_hamming_rec) 692 | 693 | if callback is not None: 694 | callback_dict = dict( 695 | s_masks=s_masks, 696 | s_feats=s_feats, 697 | s_probs=s_probs, 698 | sen_feats=sen_feats, 699 | sen_probs=sen_probs, 700 | rec_feats=rec_feats, 701 | rec_probs=rec_probs, 702 | y=y) 703 | callback(sender, receiver, batch, callback_dict) 704 | 705 | # Print confusion matrix 706 | true_labels = np.concatenate(true_labels).reshape(-1) 707 | pred_labels = np.concatenate(pred_labels).reshape(-1) 708 | 709 | np.savetxt(FLAGS.conf_mat, confusion_matrix( 710 | true_labels, pred_labels), delimiter=',', fmt='%d') 711 | 712 | # Compute statistics 713 | conversation_lengths = np.array(conversation_lengths) 714 | hamming_sen = np.array(hamming_sen) 715 | hamming_rec = np.array(hamming_rec) 716 | extra['conversation_lengths_mean'] = conversation_lengths.mean() 717 | extra['conversation_lengths_std'] = conversation_lengths.std() 718 | extra['hamming_sen_mean'] = hamming_sen.mean() 719 | extra['hamming_rec_mean'] = hamming_rec.mean() 720 | 721 | # Return accuracy 722 | return correct / total, extra 723 | 724 | 725 | def exchange(sender, receiver, baseline_sen, baseline_rec, exchange_args): 726 | """Run a batched conversation between Sender and Receiver. 727 | 728 | The Sender has only the image, and the Receiver has descriptions of each of the image's 729 | possible classes and a history of each message it has sent and received. 730 | 731 | The Receiver begins the conversation by sending a query of Os. The Sender inspects this query 732 | and the image, then formulates a response. The Receiver inspects the response and its set of 733 | descriptions, then formulates a new query. The conversation continues this way until it has 734 | reached some predetermined length, or the Receiver has decided it has processed a sufficient 735 | amount of information at which point it ignores all future conversation. When each Receiver 736 | in the batch has received sufficient information, then the batched conversation may terminate 737 | early. 738 | 739 | Exchange Args: 740 | data: Image features. 741 | data_context: Optional additional image features that can be used as query in visual attention. 742 | target: Class labels. 743 | desc: List of description vectors. 744 | train: Boolean value indicating training mode (True) or evaluation mode (False). 745 | break_early: Boolean value. If True, then terminate batched conversation if all Receivers are satisfied. 746 | Args: 747 | sender: Agent 1. The Sender. 748 | receiver: Agent 2. The Receiver. 749 | baseline_sen: Baseline network for Sender. 750 | baseline_rec: Baseline network for Receiver. 751 | exchange_args: Other useful arguments. 752 | Output: 753 | s: All STOP bits. (Masks, Values, Probabilities) 754 | sen_w: All sender messages. (Values, Probabilities) 755 | rec_w: All receiver messages. (Values, Probabilities) 756 | y: All predictions that were made. 757 | bs: Estimated loss of sender. 758 | br: Estimated loss of receiver. 759 | """ 760 | 761 | data = exchange_args["data"] 762 | data_context = exchange_args.get("data_context", None) 763 | target = exchange_args["target"] 764 | desc = exchange_args["desc"] 765 | desc_set = exchange_args.get("desc_set", None) 766 | desc_set_lens = exchange_args.get("desc_set_lens", None) 767 | train = exchange_args["train"] 768 | break_early = exchange_args.get("break_early", False) 769 | corrupt = exchange_args.get("corrupt", False) 770 | corrupt_region = exchange_args.get("corrupt_region", None) 771 | 772 | batch_size = data.size(0) 773 | 774 | # Pad with one column of ones. 775 | stop_mask = [Variable(torch.ones(batch_size, 1).byte())] 776 | stop_feat = [] 777 | stop_prob = [] 778 | sen_feats = [] 779 | sen_probs = [] 780 | rec_feats = [] 781 | rec_probs = [] 782 | y = [] 783 | bs = [] 784 | br = [] 785 | 786 | w_binary = Variable(torch.FloatTensor(batch_size, sender.w_dim).fill_( 787 | FLAGS.first_rec), volatile=not train) 788 | 789 | if train: 790 | sender.train() 791 | receiver.train() 792 | baseline_sen.train() 793 | baseline_rec.train() 794 | else: 795 | sender.eval() 796 | receiver.eval() 797 | 798 | sender.reset_state() # only for debugging/performance 799 | receiver.reset_state() 800 | 801 | for i_exchange in range(FLAGS.max_exchange): 802 | 803 | z_r = w_binary # rename variable to z_r which makes more sense 804 | 805 | # Run data through Sender 806 | if data_context is not None: 807 | z_binary, z_probs = sender(Variable(data, volatile=not train), Variable(z_r.data, volatile=not train), 808 | Variable(data_context, volatile=not train), i_exchange) 809 | else: 810 | z_binary, z_probs = sender(Variable(data, volatile=not train), Variable(z_r.data, volatile=not train), 811 | None, i_exchange) 812 | 813 | # Optionally corrupt Sender's message 814 | if corrupt: 815 | # Obtain mask 816 | mask = Variable(build_mask(corrupt_region, sender.w_dim)) 817 | mask_broadcast = mask.view(1, sender.w_dim).expand_as(z_binary) 818 | # Subtract the mask to change values, but need to get absolute value 819 | # to set -1 values to 1 to essentially "flip" all the bits. 820 | z_binary = (z_binary - mask_broadcast).abs() 821 | 822 | # Generate input for Receiver 823 | z_s = z_binary # rename variable to z_s which makes more sense 824 | 825 | # Run batch through Receiver 826 | (s_binary, s_prob), (w_binary, w_probs), outp = receiver( 827 | Variable(z_s.data, volatile=not train), Variable( 828 | desc.data, volatile=not train), 829 | desc_set, desc_set_lens) 830 | 831 | if train: 832 | sen_h_x = sender.h_x 833 | 834 | # Score from Baseline (Sender) 835 | baseline_sen_scores = baseline_sen( 836 | Variable(sen_h_x.data), Variable(z_r.data), None) 837 | 838 | rec_h_z = receiver.h_z if receiver.h_z else receiver.initial_state( 839 | batch_size) 840 | 841 | # Score from Baseline (Receiver) 842 | baseline_rec_scores = baseline_rec( 843 | None, Variable(z_s.data), Variable(rec_h_z.data)) 844 | 845 | outp = outp.view(batch_size, -1) 846 | 847 | # Obtain predictions 848 | dist = F.log_softmax(outp) 849 | maxdist, argmax = dist.data.max(1) 850 | 851 | # Save for later 852 | stop_mask.append(torch.min(stop_mask[-1], s_binary.byte())) 853 | stop_feat.append(s_binary) 854 | stop_prob.append(s_prob) 855 | sen_feats.append(z_binary) 856 | sen_probs.append(z_probs) 857 | rec_feats.append(w_binary) 858 | rec_probs.append(w_probs) 859 | y.append(outp) 860 | 861 | if train: 862 | br.append(baseline_rec_scores) 863 | bs.append(baseline_sen_scores) 864 | 865 | # Terminate exchange if everyone is done conversing 866 | if break_early and stop_mask[-1].float().sum().data[0] == 0: 867 | break 868 | 869 | # The final mask must always be zero. 870 | stop_mask[-1].data.fill_(0) 871 | 872 | s = (stop_mask, stop_feat, stop_prob) 873 | sen_w = (sen_feats, sen_probs) 874 | rec_w = (rec_feats, rec_probs) 875 | 876 | return s, sen_w, rec_w, y, bs, br 877 | 878 | 879 | def get_rec_outp(y, masks): 880 | def negent(yy): 881 | probs = F.softmax(yy) 882 | return (torch.log(probs + 1e-8) * probs).sum(1).mean() 883 | 884 | # TODO: This is wrong for the dynamic exchange, and we might want a "per example" 885 | # entropy for either exchange (this version is mean across batch). 886 | negentropy = map(negent, y) 887 | 888 | if masks is not None: 889 | 890 | batch_size = y[0].size(0) 891 | exchange_steps = len(masks) 892 | 893 | inp = torch.cat([yy.view(batch_size, 1, -1) for yy in y], 1) 894 | mask = torch.cat(masks, 1).view( 895 | batch_size, exchange_steps, 1).expand_as(inp) 896 | outp = torch.masked_select(inp, mask.detach()).view(batch_size, -1) 897 | 898 | if FLAGS.debug: 899 | # Each mask index should have exactly 1 true value. 900 | assert all([mm.data[0] == 1 for mm in torch.cat(masks, 1).sum(1)]) 901 | 902 | return outp, negentropy 903 | else: 904 | return y[-1], negentropy 905 | 906 | 907 | def calculate_loss_binary(binary_features, binary_probs, logs, baseline_scores, entropy_penalty): 908 | log_p_z = Variable(binary_features.data) * torch.log(binary_probs + 1e-8) + \ 909 | (1 - Variable(binary_features.data)) * \ 910 | torch.log(1 - binary_probs + 1e-8) 911 | log_p_z = log_p_z.sum(1) 912 | weight = Variable(logs.data) - \ 913 | Variable(baseline_scores.clone().detach().data) 914 | if logs.size(0) > 1: 915 | weight = weight / np.maximum(1., torch.std(weight.data)) 916 | loss = torch.mean(-1 * weight * log_p_z) 917 | 918 | # Must do both sides of negent, otherwise is skewed towards 0. 919 | initial_negent = (torch.log(binary_probs + 1e-8) 920 | * binary_probs).sum(1).mean() 921 | inverse_negent = (torch.log((1. - binary_probs) + 1e-8) 922 | * (1. - binary_probs)).sum(1).mean() 923 | negentropy = initial_negent + inverse_negent 924 | 925 | if entropy_penalty is not None: 926 | loss = (loss + entropy_penalty * negentropy) 927 | return loss, negentropy 928 | 929 | 930 | def multistep_loss_binary(binary_features, binary_probs, logs, baseline_scores, masks, entropy_penalty): 931 | if masks is not None: 932 | def mapped_fn(feat, prob, scores, mask, mask_sums): 933 | if mask_sums == 0: 934 | return Variable(torch.zeros(1)) 935 | 936 | feat_size = feat.size() 937 | prob_size = prob.size() 938 | logs_size = logs.size() 939 | scores_size = scores.size() 940 | 941 | feat = feat[mask.expand_as(feat)].view(-1, feat_size[1]) 942 | prob = prob[mask.expand_as(prob)].view(-1, prob_size[1]) 943 | _logs = logs[mask.expand_as(logs)].view(-1, logs_size[1]) 944 | scores = scores[mask.expand_as(scores)].view(-1, scores_size[1]) 945 | return calculate_loss_binary(feat, prob, _logs, scores, entropy_penalty) 946 | 947 | _mask_sums = [m.float().sum().data[0] for m in masks] 948 | 949 | if FLAGS.debug: 950 | assert len(masks) > 0 951 | assert len(masks) == len(binary_features) 952 | assert len(masks) == len(binary_probs) 953 | assert len(masks) == len(baseline_scores) 954 | assert sum(_mask_sums) > 0 955 | 956 | outp = map(mapped_fn, binary_features, binary_probs, 957 | baseline_scores, masks, _mask_sums) 958 | losses = [o[0] for o in outp] 959 | entropies = [o[1] for o in outp] 960 | _losses = [l * ms for l, ms in zip(losses, _mask_sums)] 961 | loss = sum(_losses) / sum(_mask_sums) 962 | else: 963 | outp = map(lambda feat, prob, scores: calculate_loss_binary(feat, prob, logs, scores, entropy_penalty), 964 | binary_features, binary_probs, baseline_scores) 965 | losses = [o[0] for o in outp] 966 | entropies = [o[1] for o in outp] 967 | loss = sum(losses) / len(binary_features) 968 | return loss, entropies 969 | 970 | 971 | def calculate_loss_bas(baseline_scores, logs): 972 | loss_bas = nn.MSELoss()(baseline_scores, Variable(logs.data)) 973 | return loss_bas 974 | 975 | 976 | def multistep_loss_bas(baseline_scores, logs, masks): 977 | if masks is not None: 978 | losses = map(lambda scores, mask: calculate_loss_bas( 979 | scores[mask].view(-1, 1), logs[mask].view(-1, 1)), 980 | baseline_scores, masks) 981 | _mask_sums = [m.sum().float() for m in masks] 982 | _losses = [l * ms for l, ms in zip(losses, _mask_sums)] 983 | loss = sum(_losses) / sum(_mask_sums) 984 | else: 985 | losses = map(lambda scores: calculate_loss_bas(scores, logs), 986 | baseline_scores) 987 | loss = sum(losses) / len(baseline_scores) 988 | return loss 989 | 990 | 991 | def bin_to_alpha(binary): 992 | ret = [] 993 | interval = 5 994 | offset = 65 995 | for i in range(0, len(binary), interval): 996 | val = int(binary[i:i + interval], 2) 997 | ret.append(unichr(offset + val)) 998 | return " ".join(ret) 999 | 1000 | 1001 | def run(): 1002 | flogger = FileLogger(FLAGS.log_file) 1003 | logger = Logger( 1004 | env=FLAGS.env, experiment_name=FLAGS.experiment_name, enabled=FLAGS.visdom) 1005 | 1006 | flogger.Log("Flag Values:\n" + 1007 | json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) 1008 | 1009 | if not os.path.exists(FLAGS.json_file): 1010 | with open(FLAGS.json_file, "w") as f: 1011 | f.write(json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) 1012 | 1013 | # Sender model 1014 | sender = Sender(feature_type=FLAGS.img_feat, 1015 | feat_dim=FLAGS.img_feat_dim, 1016 | h_dim=FLAGS.img_h_dim, 1017 | w_dim=FLAGS.rec_w_dim, 1018 | bin_dim_out=FLAGS.sender_out_dim, 1019 | use_binary=FLAGS.use_binary, 1020 | use_attn=FLAGS.visual_attn, 1021 | attn_dim=FLAGS.attn_dim, 1022 | attn_extra_context=FLAGS.attn_extra_context, 1023 | attn_context_dim=FLAGS.attn_context_dim) 1024 | 1025 | flogger.Log("Architecture: {}".format(sender)) 1026 | total_params = sum([reduce(lambda x, y: x * y, p.size(), 1.0) 1027 | for p in sender.parameters()]) 1028 | flogger.Log("Total Parameters: {}".format(total_params)) 1029 | 1030 | # Baseline model 1031 | baseline_sen = Baseline(hid_dim=FLAGS.baseline_hid_dim, 1032 | x_dim=FLAGS.img_h_dim, 1033 | binary_dim=FLAGS.rec_w_dim, 1034 | inp_dim=0) 1035 | 1036 | flogger.Log("Architecture: {}".format(baseline_sen)) 1037 | total_params = sum([reduce(lambda x, y: x * y, p.size(), 1.0) 1038 | for p in baseline_sen.parameters()]) 1039 | flogger.Log("Total Parameters: {}".format(total_params)) 1040 | 1041 | # Receiver network 1042 | receiver = Receiver(hid_dim=FLAGS.rec_hidden, 1043 | out_dim=FLAGS.rec_out_dim, 1044 | z_dim=FLAGS.sender_out_dim, 1045 | desc_dim=FLAGS.wv_dim, 1046 | w_dim=FLAGS.rec_w_dim, 1047 | s_dim=FLAGS.rec_s_dim, 1048 | use_binary=FLAGS.use_binary) 1049 | 1050 | flogger.Log("Architecture: {}".format(receiver)) 1051 | total_params = sum([reduce(lambda x, y: x * y, p.size(), 1.0) 1052 | for p in receiver.parameters()]) 1053 | flogger.Log("Total Parameters: {}".format(total_params)) 1054 | 1055 | # Baseline model 1056 | baseline_rec = Baseline(hid_dim=FLAGS.baseline_hid_dim, 1057 | x_dim=0, 1058 | binary_dim=FLAGS.rec_w_dim, 1059 | inp_dim=FLAGS.rec_hidden) 1060 | 1061 | flogger.Log("Architecture: {}".format(baseline_rec)) 1062 | total_params = sum([reduce(lambda x, y: x * y, p.size(), 1.0) 1063 | for p in baseline_rec.parameters()]) 1064 | flogger.Log("Total Parameters: {}".format(total_params)) 1065 | 1066 | # Get description vectors 1067 | if FLAGS.wv_type == "fake": 1068 | num_desc = 10 1069 | desc = Variable(torch.randn(num_desc, FLAGS.wv_dim).float()) 1070 | elif FLAGS.wv_type == "glove.6B": 1071 | # Train 1072 | descr_train, word_dict_train, dict_size_train, label_id_to_idx_train, idx_to_label_train = read_data( 1073 | FLAGS.descr_train) 1074 | 1075 | def map_labels_train(x): return label_id_to_idx_train.get(x) 1076 | word_dict_train = embed(word_dict_train, FLAGS.glove_path) 1077 | descr_train = cbow(descr_train, word_dict_train) 1078 | desc_train = torch.cat([descr_train[i]["cbow"].view(1, -1) 1079 | for i in descr_train.keys()], 0) 1080 | desc_train = Variable(desc_train) 1081 | desc_train_set = torch.cat( 1082 | [descr_train[i]["set"].view(-1, FLAGS.wv_dim) for i in descr_train.keys()], 0) 1083 | desc_train_set_lens = [len(descr_train[i]["desc"]) 1084 | for i in descr_train.keys()] 1085 | 1086 | # Dev 1087 | descr_dev, word_dict_dev, dict_size_dev, label_id_to_idx_dev, idx_to_label_dev = read_data( 1088 | FLAGS.descr_dev) 1089 | 1090 | def map_labels_dev(x): return label_id_to_idx_dev.get(x) 1091 | word_dict_dev = embed(word_dict_dev, FLAGS.glove_path) 1092 | descr_dev = cbow(descr_dev, word_dict_dev) 1093 | desc_dev = torch.cat([descr_dev[i]["cbow"].view(1, -1) 1094 | for i in descr_dev.keys()], 0) 1095 | desc_dev = Variable(desc_dev) 1096 | desc_dev_set = torch.cat( 1097 | [descr_dev[i]["set"].view(-1, FLAGS.wv_dim) for i in descr_dev.keys()], 0) 1098 | desc_dev_set_lens = [len(descr_dev[i]["desc"]) 1099 | for i in descr_dev.keys()] 1100 | 1101 | desc_dev_dict = dict( 1102 | desc=desc_dev, 1103 | desc_set=desc_dev_set, 1104 | desc_set_lens=desc_dev_set_lens) 1105 | elif FLAGS.wv_type == "none": 1106 | desc = None 1107 | else: 1108 | raise NotImplementedError 1109 | 1110 | # Optimizer 1111 | if FLAGS.optim_type == "SGD": 1112 | optimizer_rec = optim.SGD( 1113 | receiver.parameters(), lr=FLAGS.learning_rate) 1114 | optimizer_sen = optim.SGD(sender.parameters(), lr=FLAGS.learning_rate) 1115 | optimizer_bas_rec = optim.SGD( 1116 | baseline_rec.parameters(), lr=FLAGS.learning_rate) 1117 | optimizer_bas_sen = optim.SGD( 1118 | baseline_sen.parameters(), lr=FLAGS.learning_rate) 1119 | elif FLAGS.optim_type == "Adam": 1120 | optimizer_rec = optim.Adam( 1121 | receiver.parameters(), lr=FLAGS.learning_rate) 1122 | optimizer_sen = optim.Adam(sender.parameters(), lr=FLAGS.learning_rate) 1123 | optimizer_bas_rec = optim.Adam( 1124 | baseline_rec.parameters(), lr=FLAGS.learning_rate) 1125 | optimizer_bas_sen = optim.Adam( 1126 | baseline_sen.parameters(), lr=FLAGS.learning_rate) 1127 | elif FLAGS.optim_type == "RMSprop": 1128 | optimizer_rec = optim.RMSprop( 1129 | receiver.parameters(), lr=FLAGS.learning_rate) 1130 | optimizer_sen = optim.RMSprop( 1131 | sender.parameters(), lr=FLAGS.learning_rate) 1132 | optimizer_bas_rec = optim.RMSprop( 1133 | baseline_rec.parameters(), lr=FLAGS.learning_rate) 1134 | optimizer_bas_sen = optim.RMSprop( 1135 | baseline_sen.parameters(), lr=FLAGS.learning_rate) 1136 | else: 1137 | raise NotImplementedError 1138 | 1139 | optimizers_dict = dict(optimizer_rec=optimizer_rec, optimizer_sen=optimizer_sen, 1140 | optimizer_bas_rec=optimizer_bas_rec, optimizer_bas_sen=optimizer_bas_sen) 1141 | models_dict = dict(receiver=receiver, sender=sender, 1142 | baseline_rec=baseline_rec, baseline_sen=baseline_sen) 1143 | 1144 | # Training metrics 1145 | epoch = 0 1146 | step = 0 1147 | best_dev_acc = 0 1148 | 1149 | # Optionally load previously saved model 1150 | if os.path.exists(FLAGS.checkpoint): 1151 | flogger.Log("Loading from: " + FLAGS.checkpoint) 1152 | data = torch_load(FLAGS.checkpoint, models_dict, optimizers_dict) 1153 | flogger.Log("Loaded at step: {} and best dev acc: {}".format( 1154 | data['step'], data['best_dev_acc'])) 1155 | step = data['step'] 1156 | best_dev_acc = data['best_dev_acc'] 1157 | 1158 | # GPU support 1159 | if FLAGS.cuda: 1160 | for m in models_dict.values(): 1161 | m.cuda() 1162 | for o in optimizers_dict.values(): 1163 | recursively_set_device(o.state_dict(), gpu=0) 1164 | 1165 | # Alternatives to training. 1166 | if FLAGS.eval_only: 1167 | if not os.path.exists(FLAGS.checkpoint): 1168 | raise Exception("Must provide valid checkpoint.") 1169 | dev_acc, extra = eval_dev(FLAGS.dev_file, FLAGS.batch_size_dev, epoch, 1170 | FLAGS.shuffle_dev, FLAGS.cuda, FLAGS.top_k_dev, 1171 | sender, receiver, desc_dev_dict, map_labels_dev, FLAGS.experiment_name) 1172 | flogger.Log("Dev Accuracy: " + str(dev_acc)) 1173 | with open(FLAGS.eval_csv_file, 'w') as f: 1174 | f.write( 1175 | "checkpoint,eval_file,topk,step,best_dev_acc,eval_acc,convlen_mean,convlen_std\n") 1176 | f.write("{},{},{},{},{},{},{},{}\n".format( 1177 | FLAGS.checkpoint, FLAGS.dev_file, FLAGS.top_k_dev, 1178 | step, best_dev_acc, dev_acc, 1179 | extra['conversation_lengths_mean'], extra['conversation_lengths_std'])) 1180 | sys.exit() 1181 | elif FLAGS.binary_only: 1182 | if not os.path.exists(FLAGS.checkpoint): 1183 | raise Exception("Must provide valid checkpoint.") 1184 | extract_binary(FLAGS, load_hdf5, exchange, FLAGS.dev_file, FLAGS.batch_size_dev, epoch, 1185 | FLAGS.shuffle_dev, FLAGS.cuda, FLAGS.top_k_dev, 1186 | sender, receiver, desc_dev_dict, map_labels_dev, FLAGS.experiment_name) 1187 | sys.exit() 1188 | 1189 | # Training loop 1190 | while epoch < FLAGS.max_epoch: 1191 | 1192 | flogger.Log("Starting epoch: {}".format(epoch)) 1193 | 1194 | # Read images randomly into batches - image_dim = [3, 227, 227] 1195 | if FLAGS.images == "cifar": 1196 | dataset = dset.CIFAR10(root="./", download=True, train=False, 1197 | transform=transforms.Compose([ 1198 | transforms.Scale(227), 1199 | transforms.ToTensor(), 1200 | transforms.Normalize( 1201 | (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 1202 | ]) 1203 | ) 1204 | dataloader = torch.utils.data.DataLoader(dataset, 1205 | batch_size=FLAGS.batch_size, 1206 | shuffle=True) 1207 | elif FLAGS.images == "mammal": 1208 | dataloader = load_hdf5(FLAGS.train_file, FLAGS.batch_size, 1209 | epoch, FLAGS.shuffle_train, map_labels=map_labels_train) 1210 | else: 1211 | raise NotImplementedError 1212 | 1213 | # Keep track of metrics 1214 | batch_accuracy = [] 1215 | dev_accuracy = [] 1216 | 1217 | # Iterate through batches 1218 | for i_batch, batch in enumerate(dataloader): 1219 | target = batch["target"] 1220 | data = batch[FLAGS.img_feat] 1221 | 1222 | # GPU support 1223 | if FLAGS.cuda: 1224 | data = data.cuda() 1225 | target = target.cuda() 1226 | desc_train = desc_train.cuda() 1227 | desc_train_set = desc_train_set.cuda() 1228 | 1229 | exchange_args = dict() 1230 | exchange_args["data"] = data 1231 | if FLAGS.attn_extra_context: 1232 | exchange_args["data_context"] = batch[FLAGS.data_context] 1233 | exchange_args["target"] = target 1234 | exchange_args["desc"] = desc_train 1235 | exchange_args["desc_set"] = desc_train_set 1236 | exchange_args["desc_set_lens"] = desc_train_set_lens 1237 | exchange_args["train"] = True 1238 | exchange_args["break_early"] = not FLAGS.fixed_exchange 1239 | 1240 | s, sen_w, rec_w, y, bs, br = exchange( 1241 | sender, receiver, baseline_sen, baseline_rec, exchange_args) 1242 | 1243 | s_masks, s_feats, s_probs = s 1244 | sen_feats, sen_probs = sen_w 1245 | rec_feats, rec_probs = rec_w 1246 | 1247 | # Mask loss if dynamic exchange length 1248 | if FLAGS.fixed_exchange: 1249 | binary_s_masks = None 1250 | binary_rec_masks = None 1251 | binary_sen_masks = None 1252 | bas_rec_masks = None 1253 | bas_sen_masks = None 1254 | y_masks = None 1255 | else: 1256 | binary_s_masks = s_masks[:-1] 1257 | binary_rec_masks = s_masks[1:-1] 1258 | binary_sen_masks = s_masks[:-1] 1259 | bas_rec_masks = s_masks[:-1] 1260 | bas_sen_masks = s_masks[:-1] 1261 | y_masks = [torch.min(1 - m1, m2) 1262 | for m1, m2 in zip(s_masks[1:], s_masks[:-1])] 1263 | 1264 | outp, ent_y_rec = get_rec_outp(y, y_masks) 1265 | 1266 | # Obtain predictions 1267 | dist = F.log_softmax(outp) 1268 | maxdist, argmax = dist.data.max(1) 1269 | 1270 | # Receiver classification loss 1271 | nll_loss = nn.NLLLoss()(dist, Variable(target)) 1272 | 1273 | # Individual log-likelihoods across the batch 1274 | logs = loglikelihood(Variable(dist.data), 1275 | Variable(target.view(-1, 1))) 1276 | 1277 | if FLAGS.use_binary: 1278 | if not FLAGS.fixed_exchange: 1279 | loss_binary_s, ent_binary_s = multistep_loss_binary( 1280 | s_feats, s_probs, logs, br, binary_s_masks, FLAGS.entropy_s) 1281 | 1282 | # The receiver might have no z-loss if we stop after first 1283 | # message from sender. 1284 | if len(rec_feats[:-1]) > 0: 1285 | loss_binary_rec, ent_binary_rec = multistep_loss_binary( 1286 | rec_feats[:-1], rec_probs[:-1], logs, br[:-1], binary_rec_masks, FLAGS.entropy_rec) 1287 | else: 1288 | loss_binary_rec, ent_binary_rec = Variable( 1289 | torch.zeros(1)), [] 1290 | 1291 | loss_binary_sen, ent_binary_sen = multistep_loss_binary( 1292 | sen_feats, sen_probs, logs, bs, binary_sen_masks, FLAGS.entropy_sen) 1293 | loss_bas_rec = multistep_loss_bas(br, logs, bas_rec_masks) 1294 | loss_bas_sen = multistep_loss_bas(bs, logs, bas_sen_masks) 1295 | 1296 | loss_rec = nll_loss 1297 | if FLAGS.use_binary: 1298 | loss_rec = loss_rec + loss_binary_rec 1299 | if not FLAGS.fixed_exchange: 1300 | loss_rec = loss_rec + loss_binary_s 1301 | loss_sen = loss_binary_sen 1302 | else: 1303 | loss_sen = Variable(torch.zeros(1)) 1304 | loss_bas_rec = Variable(torch.zeros(1)) 1305 | loss_bas_sen = Variable(torch.zeros(1)) 1306 | 1307 | # Update receiver 1308 | optimizer_rec.zero_grad() 1309 | loss_rec.backward() 1310 | nn.utils.clip_grad_norm(receiver.parameters(), max_norm=1.) 1311 | optimizer_rec.step() 1312 | 1313 | if FLAGS.use_binary: 1314 | # Update sender 1315 | optimizer_sen.zero_grad() 1316 | loss_sen.backward() 1317 | nn.utils.clip_grad_norm(sender.parameters(), max_norm=1.) 1318 | optimizer_sen.step() 1319 | 1320 | # Update baseline 1321 | optimizer_bas_rec.zero_grad() 1322 | loss_bas_rec.backward() 1323 | nn.utils.clip_grad_norm(baseline_rec.parameters(), max_norm=1.) 1324 | optimizer_bas_rec.step() 1325 | 1326 | # Update baseline 1327 | optimizer_bas_sen.zero_grad() 1328 | loss_bas_sen.backward() 1329 | nn.utils.clip_grad_norm(baseline_sen.parameters(), max_norm=1.) 1330 | optimizer_bas_sen.step() 1331 | 1332 | # Obtain top-k accuracy 1333 | top_k_ind = torch.from_numpy(dist.data.cpu().numpy().argsort()[ 1334 | :, -FLAGS.top_k_train:]).long() 1335 | target_exp = target.view(-1, 1336 | 1).expand(FLAGS.batch_size, FLAGS.top_k_train) 1337 | accuracy = (top_k_ind == target_exp.cpu()).sum() / \ 1338 | float(FLAGS.batch_size) 1339 | batch_accuracy.append(accuracy) 1340 | 1341 | # Print logs regularly 1342 | if step % FLAGS.log_interval == 0: 1343 | # Average batch accuracy 1344 | avg_batch_acc = np.array( 1345 | batch_accuracy[-FLAGS.log_interval:]).mean() 1346 | 1347 | # Log accuracy 1348 | log_acc = "Epoch: {} Step: {} Batch: {} Training Accuracy: {}"\ 1349 | .format(epoch, step, i_batch, avg_batch_acc) 1350 | flogger.Log(log_acc) 1351 | 1352 | # Sender 1353 | log_loss_sen = "Epoch: {} Step: {} Batch: {} Loss Sender: {}".format( 1354 | epoch, step, i_batch, loss_sen.data[0]) 1355 | flogger.Log(log_loss_sen) 1356 | 1357 | # Receiver 1358 | log_loss_rec_y = "Epoch: {} Step: {} Batch: {} Loss Receiver (Y): {}".format( 1359 | epoch, step, i_batch, nll_loss.data[0]) 1360 | flogger.Log(log_loss_rec_y) 1361 | if FLAGS.use_binary: 1362 | log_loss_rec_z = "Epoch: {} Step: {} Batch: {} Loss Receiver (Z): {}".format( 1363 | epoch, step, i_batch, loss_binary_rec.data[0]) 1364 | flogger.Log(log_loss_rec_z) 1365 | if not FLAGS.fixed_exchange: 1366 | log_loss_rec_s = "Epoch: {} Step: {} Batch: {} Loss Receiver (S): {}".format( 1367 | epoch, step, i_batch, loss_binary_s.data[0]) 1368 | flogger.Log(log_loss_rec_s) 1369 | 1370 | # Baslines 1371 | if FLAGS.use_binary: 1372 | log_loss_bas_s = "Epoch: {} Step: {} Batch: {} Loss Baseline (S): {}".format( 1373 | epoch, step, i_batch, loss_bas_sen.data[0]) 1374 | flogger.Log(log_loss_bas_s) 1375 | log_loss_bas_r = "Epoch: {} Step: {} Batch: {} Loss Baseline (R): {}".format( 1376 | epoch, step, i_batch, loss_bas_rec.data[0]) 1377 | flogger.Log(log_loss_bas_r) 1378 | 1379 | # Log predictions 1380 | log_pred = "Predictions: {}".format( 1381 | torch.cat([target, argmax], 0).view(-1, FLAGS.batch_size)) 1382 | flogger.Log(log_pred) 1383 | 1384 | # Log Entropy 1385 | if FLAGS.use_binary: 1386 | if len(ent_binary_sen) > 0: 1387 | log_ent_sen_bin = "Entropy Sender Binary" 1388 | for i, ent in enumerate(ent_binary_sen): 1389 | log_ent_sen_bin += "\n{}. {}".format( 1390 | i, -ent.data[0]) 1391 | log_ent_sen_bin += "\n" 1392 | flogger.Log(log_ent_sen_bin) 1393 | 1394 | if len(ent_binary_rec) > 0: 1395 | log_ent_rec_bin = "Entropy Receiver Binary" 1396 | for i, ent in enumerate(ent_binary_rec): 1397 | log_ent_rec_bin += "\n{}. {}".format( 1398 | i, -ent.data[0]) 1399 | log_ent_rec_bin += "\n" 1400 | flogger.Log(log_ent_rec_bin) 1401 | 1402 | if len(ent_y_rec) > 0: 1403 | log_ent_rec_y = "Entropy Receiver Predictions" 1404 | for i, ent in enumerate(ent_y_rec): 1405 | log_ent_rec_y += "\n{}. {}".format(i, -ent.data[0]) 1406 | log_ent_rec_y += "\n" 1407 | flogger.Log(log_ent_rec_y) 1408 | 1409 | # Optionally print sampled and inferred binary vectors from 1410 | # most recent exchange. 1411 | if FLAGS.exchange_samples > 0: 1412 | 1413 | current_exchange = len(sen_feats) 1414 | 1415 | log_train = "Train:" 1416 | for i_sample in range(FLAGS.exchange_samples): 1417 | prev_sen = torch.FloatTensor(FLAGS.rec_w_dim).fill_(0) 1418 | prev_rec = torch.FloatTensor(FLAGS.rec_w_dim).fill_(0) 1419 | for i_exchange in range(current_exchange): 1420 | sen_probs_i = sen_probs[i_exchange][i_sample].data.tolist( 1421 | ) 1422 | sen_spark = sparks( 1423 | [1] + sen_probs_i)[1:].encode('utf-8') 1424 | rec_probs_i = rec_probs[i_exchange][i_sample].data.tolist( 1425 | ) 1426 | rec_spark = sparks( 1427 | [1] + rec_probs_i)[1:].encode('utf-8') 1428 | s_probs_i = s_probs[i_exchange][i_sample].data.tolist( 1429 | ) 1430 | s_spark = sparks( 1431 | [1] + s_probs_i)[1:].encode('utf-8') 1432 | 1433 | sen_binary = sen_feats[i_exchange][i_sample].data.cpu( 1434 | ) 1435 | sen_hamming = (prev_sen - sen_binary).abs().sum() 1436 | prev_sen = sen_binary 1437 | rec_binary = rec_feats[i_exchange][i_sample].data.cpu( 1438 | ) 1439 | rec_hamming = (prev_rec - rec_binary).abs().sum() 1440 | prev_rec = rec_binary 1441 | 1442 | sen_msg = "".join( 1443 | map(str, map(int, sen_binary.tolist()))) 1444 | rec_msg = "".join( 1445 | map(str, map(int, rec_binary.tolist()))) 1446 | if FLAGS.use_alpha: 1447 | sen_msg = bin_to_alpha(sen_msg) 1448 | rec_msg = bin_to_alpha(rec_msg) 1449 | if i_exchange == 0: 1450 | log_train += "\n{:>3}".format(i_sample) 1451 | else: 1452 | log_train += "\n " 1453 | log_train += " {}".format(sen_spark) 1454 | log_train += " {} {}".format( 1455 | s_spark, rec_spark) 1456 | log_train += "\n {:>3} S: {} {:4}".format( 1457 | i_exchange, sen_msg, sen_hamming) 1458 | log_train += " s={} R: {} {:4}".format( 1459 | s_masks[1:][i_exchange][i_sample].data[0], rec_msg, rec_hamming) 1460 | log_train += "\n" 1461 | flogger.Log(log_train) 1462 | 1463 | exchange_args["train"] = False 1464 | s, sen_w, rec_w, y, bs, br = exchange( 1465 | sender, receiver, baseline_sen, baseline_rec, exchange_args) 1466 | s_masks, s_feats, s_probs = s 1467 | sen_feats, sen_probs = sen_w 1468 | rec_feats, rec_probs = rec_w 1469 | 1470 | current_exchange = len(sen_feats) 1471 | 1472 | log_eval = "Eval:" 1473 | for i_sample in range(FLAGS.exchange_samples): 1474 | prev_sen = torch.FloatTensor(FLAGS.rec_w_dim).fill_(0) 1475 | prev_rec = torch.FloatTensor(FLAGS.rec_w_dim).fill_(0) 1476 | for i_exchange in range(current_exchange): 1477 | sen_probs_i = sen_probs[i_exchange][i_sample].data.tolist( 1478 | ) 1479 | sen_spark = sparks( 1480 | [1] + sen_probs_i)[1:].encode('utf-8') 1481 | rec_probs_i = rec_probs[i_exchange][i_sample].data.tolist( 1482 | ) 1483 | rec_spark = sparks( 1484 | [1] + rec_probs_i)[1:].encode('utf-8') 1485 | s_probs_i = s_probs[i_exchange][i_sample].data.tolist( 1486 | ) 1487 | s_spark = sparks( 1488 | [1] + s_probs_i)[1:].encode('utf-8') 1489 | 1490 | sen_binary = sen_feats[i_exchange][i_sample].data.cpu( 1491 | ) 1492 | sen_hamming = (prev_sen - sen_binary).abs().sum() 1493 | prev_sen = sen_binary 1494 | rec_binary = rec_feats[i_exchange][i_sample].data.cpu( 1495 | ) 1496 | rec_hamming = (prev_rec - rec_binary).abs().sum() 1497 | prev_rec = rec_binary 1498 | 1499 | sen_msg = "".join( 1500 | map(str, map(int, sen_binary.tolist()))) 1501 | rec_msg = "".join( 1502 | map(str, map(int, rec_binary.tolist()))) 1503 | if FLAGS.use_alpha: 1504 | sen_msg = bin_to_alpha(sen_msg) 1505 | rec_msg = bin_to_alpha(rec_msg) 1506 | if i_exchange == 0: 1507 | log_eval += "\n{:>3}".format(i_sample) 1508 | else: 1509 | log_eval += "\n " 1510 | log_eval += " {}".format(sen_spark) 1511 | log_eval += " {} {}".format( 1512 | s_spark, rec_spark) 1513 | log_eval += "\n {:>3} S: {} {:4}".format( 1514 | i_exchange, sen_msg, sen_hamming) 1515 | log_eval += " s={} R: {} {:4}".format( 1516 | s_masks[1:][i_exchange][i_sample].data[0], rec_msg, rec_hamming) 1517 | log_eval += "\n" 1518 | flogger.Log(log_eval) 1519 | 1520 | # Sender 1521 | logger.log(key="Loss Sender", 1522 | val=loss_sen.data[0], step=step) 1523 | 1524 | # Receiver 1525 | logger.log(key="Loss Receiver (Y)", 1526 | val=nll_loss.data[0], step=step) 1527 | if FLAGS.use_binary: 1528 | logger.log(key="Loss Receiver (Z)", 1529 | val=loss_binary_rec.data[0], step=step) 1530 | if not FLAGS.fixed_exchange: 1531 | logger.log(key="Loss Receiver (S)", 1532 | val=loss_binary_s.data[0], step=step) 1533 | 1534 | # Baselines 1535 | if FLAGS.use_binary: 1536 | logger.log(key="Loss Baseline (S)", 1537 | val=loss_bas_sen.data[0], step=step) 1538 | logger.log(key="Loss Baseline (R)", 1539 | val=loss_bas_rec.data[0], step=step) 1540 | 1541 | logger.log(key="Training Accuracy", 1542 | val=avg_batch_acc, step=step) 1543 | 1544 | # Report development accuracy 1545 | if step % FLAGS.log_dev == 0: 1546 | dev_acc, extra = eval_dev(FLAGS.dev_file, FLAGS.batch_size_dev, epoch, 1547 | FLAGS.shuffle_dev, FLAGS.cuda, FLAGS.top_k_dev, 1548 | sender, receiver, desc_dev_dict, map_labels_dev, FLAGS.experiment_name) 1549 | dev_accuracy.append(dev_acc) 1550 | logger.log(key="Development Accuracy", 1551 | val=dev_accuracy[-1], step=step) 1552 | logger.log(key="Conversation Length (avg)", 1553 | val=extra['conversation_lengths_mean'], step=step) 1554 | logger.log(key="Conversation Length (std)", 1555 | val=extra['conversation_lengths_std'], step=step) 1556 | logger.log(key="Hamming Receiver (avg)", 1557 | val=extra['hamming_rec_mean'], step=step) 1558 | logger.log(key="Hamming Sender (avg)", 1559 | val=extra['hamming_sen_mean'], step=step) 1560 | 1561 | flogger.Log("Epoch: {} Step: {} Batch: {} Development Accuracy: {}" 1562 | .format(epoch, step, i_batch, dev_accuracy[-1])) 1563 | flogger.Log("Epoch: {} Step: {} Batch: {} Conversation Length (avg/std): {}/{}" 1564 | .format(epoch, step, i_batch, 1565 | extra['conversation_lengths_mean'], 1566 | extra['conversation_lengths_std'])) 1567 | flogger.Log("Epoch: {} Step: {} Batch: {} Mean Hamming Distance (R/S): {}/{}" 1568 | .format(epoch, step, i_batch, extra['hamming_rec_mean'], extra['hamming_sen_mean'])) 1569 | if step >= FLAGS.save_after and dev_acc > best_dev_acc: 1570 | best_dev_acc = dev_acc 1571 | flogger.Log( 1572 | "Checkpointing with best Development Accuracy: {}".format(best_dev_acc)) 1573 | # Optionally store additional information 1574 | data = dict(step=step, best_dev_acc=best_dev_acc) 1575 | torch_save(FLAGS.checkpoint + "_best", data, models_dict, 1576 | optimizers_dict, gpu=0 if FLAGS.cuda else -1) 1577 | 1578 | # Save model periodically 1579 | if step >= FLAGS.save_after and step % FLAGS.save_interval == 0: 1580 | flogger.Log("Checkpointing.") 1581 | # Optionally store additional information 1582 | data = dict(step=step, best_dev_acc=best_dev_acc) 1583 | torch_save(FLAGS.checkpoint, data, models_dict, 1584 | optimizers_dict, gpu=0 if FLAGS.cuda else -1) 1585 | 1586 | # Increment batch step 1587 | step += 1 1588 | 1589 | # Increment epoch 1590 | epoch += 1 1591 | 1592 | flogger.Log("Finished training.") 1593 | 1594 | 1595 | """ 1596 | Preset Model Configurations 1597 | 1598 | 1. Fixed - Fixed conversation length. 1599 | 2. Adaptive - Adaptive conversation length using STOP bit. 1600 | 3. FixedAttention - Fixed with Visual Attention. 1601 | 4. AdaptiveAttention - Adaptive with Visual Attention. 1602 | """ 1603 | 1604 | 1605 | def Fixed(): 1606 | FLAGS.img_feat = "avgpool_512" 1607 | FLAGS.img_feat_dim = 512 1608 | FLAGS.fixed_exchange = True 1609 | FLAGS.visual_attn = False 1610 | 1611 | 1612 | def Adaptive(): 1613 | FLAGS.img_feat = "avgpool_512" 1614 | FLAGS.img_feat_dim = 512 1615 | FLAGS.fixed_exchange = False 1616 | FLAGS.visual_attn = False 1617 | 1618 | 1619 | def FixedAttention(): 1620 | FLAGS.img_feat = "layer4_2" 1621 | FLAGS.img_feat_dim = 512 1622 | FLAGS.fixed_exchange = True 1623 | FLAGS.visual_attn = True 1624 | FLAGS.attn_dim = 256 1625 | FLAGS.attn_extra_context = True 1626 | FLAGS.attn_context_dim = 1000 1627 | 1628 | 1629 | def AdaptiveAttention(): 1630 | FLAGS.img_feat = "layer4_2" 1631 | FLAGS.img_feat_dim = 512 1632 | FLAGS.fixed_exchange = False 1633 | FLAGS.visual_attn = True 1634 | FLAGS.attn_dim = 256 1635 | FLAGS.attn_extra_context = True 1636 | FLAGS.attn_context_dim = 1000 1637 | 1638 | 1639 | def flags(): 1640 | # Debug settings 1641 | gflags.DEFINE_string("branch", None, "") 1642 | gflags.DEFINE_string("sha", None, "") 1643 | gflags.DEFINE_boolean("debug", False, "") 1644 | 1645 | # Convenience settings 1646 | gflags.DEFINE_integer("save_after", 1000, "") 1647 | gflags.DEFINE_integer("save_interval", 100, "") 1648 | gflags.DEFINE_string("checkpoint", None, "") 1649 | gflags.DEFINE_string("conf_mat", None, "") 1650 | gflags.DEFINE_string("log_path", "./logs", "") 1651 | gflags.DEFINE_string("log_file", None, "") 1652 | gflags.DEFINE_string("eval_csv_file", None, "") 1653 | gflags.DEFINE_string("json_file", None, "") 1654 | gflags.DEFINE_string("log_load", None, "") 1655 | gflags.DEFINE_boolean("eval_only", False, "") 1656 | 1657 | # Extract Settings 1658 | gflags.DEFINE_boolean("binary_only", False, "") 1659 | gflags.DEFINE_string("binary_output", None, "") 1660 | 1661 | # Performance settings 1662 | gflags.DEFINE_boolean("cuda", False, "") 1663 | 1664 | # Display settings 1665 | gflags.DEFINE_string("env", "main", "") 1666 | gflags.DEFINE_boolean("visdom", False, "") 1667 | gflags.DEFINE_boolean("use_alpha", False, "") 1668 | gflags.DEFINE_string("experiment_name", None, "") 1669 | gflags.DEFINE_integer("log_interval", 50, "") 1670 | gflags.DEFINE_integer("log_dev", 1000, "") 1671 | 1672 | # Data settings 1673 | gflags.DEFINE_enum("wv_type", "glove.6B", ["fake", "glove.6B", "none"], "") 1674 | gflags.DEFINE_integer("wv_dim", 100, "") 1675 | gflags.DEFINE_string("descr_train", "descriptions.csv", "") 1676 | gflags.DEFINE_string("descr_dev", "descriptions.csv", "") 1677 | gflags.DEFINE_string("train_file", "train.hdf5", "") 1678 | gflags.DEFINE_string("dev_file", "dev.hdf5", "") 1679 | gflags.DEFINE_enum("images", "mammal", ["cifar", "mammal"], "") 1680 | gflags.DEFINE_string( 1681 | "glove_path", "./glove.6B/glove.6B.100d.txt", "") 1682 | gflags.DEFINE_boolean("shuffle_train", True, "") 1683 | gflags.DEFINE_boolean("shuffle_dev", False, "") 1684 | 1685 | # Model settings 1686 | gflags.DEFINE_enum("model_type", None, [ 1687 | "Fixed", "Adaptive", "FixedAttention", "AdaptiveAttention"], "Preset model configurations.") 1688 | gflags.DEFINE_enum("img_feat", "avgpool_512", [ 1689 | "layer4_2", "avgpool_512", "fc"], "Specify which layer output to use as image") 1690 | gflags.DEFINE_enum("data_context", "fc", [ 1691 | "fc"], "Specify which layer output to use as context for attention") 1692 | gflags.DEFINE_enum("sender_mix", "sum", ["sum", "prod", "mou"], "") 1693 | gflags.DEFINE_integer("img_feat_dim", 4096, "") 1694 | gflags.DEFINE_integer("img_h_dim", 100, "") 1695 | gflags.DEFINE_integer("baseline_hid_dim", 500, "") 1696 | gflags.DEFINE_integer("sender_out_dim", 50, "") 1697 | gflags.DEFINE_integer("rec_hidden", 128, "") 1698 | gflags.DEFINE_integer("rec_out_dim", 1, "") 1699 | gflags.DEFINE_integer("rec_w_dim", 50, "") 1700 | gflags.DEFINE_integer("rec_s_dim", 1, "") 1701 | gflags.DEFINE_boolean("use_binary", True, 1702 | "Encoding whether Sender uses binary features") 1703 | gflags.DEFINE_boolean("ignore_receiver", False, 1704 | "Sender ignores messages from Receiver") 1705 | gflags.DEFINE_boolean("ignore_code", False, 1706 | "Sender ignores messages from Receiver") 1707 | gflags.DEFINE_boolean( 1708 | "block_y", True, "Halt gradient flow through description scores") 1709 | gflags.DEFINE_float("first_rec", 0, "") 1710 | gflags.DEFINE_float("flipout_rec", None, "Dropout for bit flipping") 1711 | gflags.DEFINE_float("flipout_sen", None, "Dropout for bit flipping") 1712 | gflags.DEFINE_boolean("flipout_dev", False, "Dropout for bit flipping") 1713 | gflags.DEFINE_boolean("s_prob_prod", True, 1714 | "Simulate sampling during test time") 1715 | gflags.DEFINE_boolean("visual_attn", False, "Sender attends over image") 1716 | gflags.DEFINE_integer("attn_dim", 256, "") 1717 | gflags.DEFINE_boolean("attn_extra_context", False, "") 1718 | gflags.DEFINE_integer("attn_context_dim", 4096, "") 1719 | gflags.DEFINE_boolean("desc_attn", False, "Receiver attends over text") 1720 | gflags.DEFINE_integer("desc_attn_dim", 64, "Receiver attends over text") 1721 | gflags.DEFINE_integer("top_k_dev", 6, "Top-k error in development") 1722 | gflags.DEFINE_integer("top_k_train", 6, "Top-k error in training") 1723 | 1724 | # Optimization settings 1725 | gflags.DEFINE_enum("optim_type", "RMSprop", ["Adam", "SGD", "RMSprop"], "") 1726 | gflags.DEFINE_integer("batch_size", 32, "Minibatch size for train set.") 1727 | gflags.DEFINE_integer("batch_size_dev", 50, "Minibatch size for dev set.") 1728 | gflags.DEFINE_float("learning_rate", 1e-4, "Used in optimizer.") 1729 | gflags.DEFINE_integer("max_epoch", 500, "") 1730 | gflags.DEFINE_float("entropy_s", None, "") 1731 | gflags.DEFINE_float("entropy_sen", None, "") 1732 | gflags.DEFINE_float("entropy_rec", None, "") 1733 | 1734 | # Conversation settings 1735 | gflags.DEFINE_integer("exchange_samples", 3, "") 1736 | gflags.DEFINE_integer("max_exchange", 3, "") 1737 | gflags.DEFINE_boolean("fixed_exchange", True, "") 1738 | gflags.DEFINE_boolean( 1739 | "bit_flip", False, "Whether sender's messages are corrupted.") 1740 | gflags.DEFINE_string("corrupt_region", None, 1741 | "Comma-separated ranges of bit indexes (e.g. ``0:3,5'').") 1742 | 1743 | 1744 | def default_flags(): 1745 | if FLAGS.log_load: 1746 | log_flags = json.loads(open(FLAGS.log_load).read()) 1747 | for k in log_flags.keys(): 1748 | if k in FLAGS.FlagValuesDict().keys(): 1749 | setattr(FLAGS, k, log_flags[k]) 1750 | FLAGS(sys.argv) # Optionally override predefined flags. 1751 | 1752 | if FLAGS.model_type: 1753 | eval(FLAGS.model_type)() 1754 | FLAGS(sys.argv) # Optionally override predefined flags. 1755 | 1756 | assert FLAGS.sender_out_dim == FLAGS.rec_w_dim, \ 1757 | "Both sender and receiver should communicate with same dim vectors for now." 1758 | 1759 | if not FLAGS.use_binary: 1760 | FLAGS.exchange_samples = 0 1761 | 1762 | if not FLAGS.experiment_name: 1763 | timestamp = str(int(time.time())) 1764 | FLAGS.experiment_name = "{}-so_{}-wv_{}-bs_{}-{}".format( 1765 | FLAGS.images, 1766 | FLAGS.sender_out_dim, 1767 | FLAGS.wv_dim, 1768 | FLAGS.batch_size, 1769 | timestamp, 1770 | ) 1771 | 1772 | if not FLAGS.conf_mat: 1773 | FLAGS.conf_mat = os.path.join( 1774 | FLAGS.log_path, FLAGS.experiment_name + ".conf_mat.txt") 1775 | 1776 | if not FLAGS.log_file: 1777 | FLAGS.log_file = os.path.join( 1778 | FLAGS.log_path, FLAGS.experiment_name + ".log") 1779 | 1780 | if not FLAGS.eval_csv_file: 1781 | FLAGS.eval_csv_file = os.path.join( 1782 | FLAGS.log_path, FLAGS.experiment_name + ".eval.csv") 1783 | 1784 | if not FLAGS.json_file: 1785 | FLAGS.json_file = os.path.join( 1786 | FLAGS.log_path, FLAGS.experiment_name + ".json") 1787 | 1788 | if not FLAGS.checkpoint: 1789 | FLAGS.checkpoint = os.path.join( 1790 | FLAGS.log_path, FLAGS.experiment_name + ".pt") 1791 | 1792 | if not FLAGS.binary_output: 1793 | FLAGS.binary_output = os.path.join( 1794 | FLAGS.log_path, FLAGS.experiment_name + ".bv.hdf5") 1795 | 1796 | if not FLAGS.branch: 1797 | FLAGS.branch = os.popen( 1798 | 'git rev-parse --abbrev-ref HEAD').read().strip() 1799 | 1800 | if not FLAGS.sha: 1801 | FLAGS.sha = os.popen('git rev-parse HEAD').read().strip() 1802 | 1803 | if not torch.cuda.is_available(): 1804 | FLAGS.cuda = False 1805 | 1806 | if FLAGS.debug: 1807 | np.seterr(all='raise') 1808 | 1809 | # silly expanduser 1810 | FLAGS.glove_path = os.path.expanduser(FLAGS.glove_path) 1811 | 1812 | 1813 | if __name__ == '__main__': 1814 | flags() 1815 | 1816 | FLAGS(sys.argv) 1817 | 1818 | default_flags() 1819 | 1820 | run() 1821 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | nltk 3 | numpy 4 | parse 5 | python-gflags 6 | sklearn 7 | tqdm -------------------------------------------------------------------------------- /sparks.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Sparklines in ascii 5 | source: https://github.com/rory/ascii_sparks/blob/master/ascii_sparks.py 6 | 7 | """ 8 | 9 | parts = u' ▁▂▃▄▅▆▇▉' 10 | 11 | 12 | def sparks(nums): 13 | fraction = max(nums) / float(len(parts) - 1) 14 | return ''.join(parts[int(round(x / fraction))] for x in nums) 15 | -------------------------------------------------------------------------------- /utils/build_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This is an example of how one might build train/dev/test splits for MultimodalGame. 4 | 5 | # Download imagenet urls. 6 | wget http://image-net.org/imagenet_data/urls/imagenet_fall11_urls.tgz 7 | tar -xvzf imagenet_fall11_urls.tgz 8 | 9 | # Download images and build descriptions. 10 | python download_data.py \ 11 | -cmd_urls \ 12 | -cmd_split \ 13 | -cmd_desc \ 14 | -cmd_download 15 | 16 | # Build hdf5 files. 17 | python package_data.py -cuda -load_desc descriptions.csv -load_imgs ./imgs/train -save_hdf5 train.hdf5 18 | python package_data.py -cuda -load_desc descriptions.csv -load_imgs ./imgs/dev -save_hdf5 dev.hdf5 19 | python package_data.py -cuda -load_desc descriptions.csv -load_imgs ./imgs/test -save_hdf5 test.hdf5 20 | 21 | # Finished! 22 | -------------------------------------------------------------------------------- /utils/descriptions.csv: -------------------------------------------------------------------------------- 1 | 0,agama,small terrestrial lizard of warm regions of the Old World 2 | 1,bullfrog,largest North American frog; highly aquatic with a deep-pitched voice 3 | 2,centipede,chiefly nocturnal predacious arthropod having a flattened body of 15 to 173 segments each with a pair of legs the foremost pair being modified as prehensors 4 | 3,chickadee,any of various small grey-and-black songbirds of North America 5 | 4,drake,adult male of a wild or domestic duck 6 | 5,goldfinch,American finch whose male has yellow body plumage in summer 7 | 6,goose,web-footed long-necked typically gregarious migratory aquatic birds usually larger and less aquatic than ducks 8 | 7,hen,adult female chicken 9 | 8,hummingbird,tiny American bird having brilliant iridescent plumage and long slender bills; wings are specialized for vibrating flight 10 | 9,jacamar,tropical American insectivorous bird having a long sharp bill and iridescent green or bronze plumage 11 | 10,jay,crested largely blue bird 12 | 11,jellyfish,large siphonophore having a bladderlike float and stinging tentacles 13 | 12,koala,sluggish tailless Australian arboreal marsupial with grey furry ears and coat; feeds on eucalyptus leaves and bark 14 | 13,lorikeet,any of various small lories 15 | 14,macaw,long-tailed brilliantly colored parrot of Central America and South America; among the largest and showiest of parrots 16 | 15,magpie,long-tailed black-and-white crow that utters a raucous chattering call 17 | 16,ostrich,fast-running African flightless bird with two-toed feet; largest living bird 18 | 17,partridge,heavy-bodied small-winged South American game bird resembling a gallinaceous bird but related to the ratite birds 19 | 18,peacock,European butterfly having reddish-brown wings each marked with a purple eyespot 20 | 19,quail,small gallinaceous game birds 21 | 20,robin,small Old World songbird with a reddish breast 22 | 21,scorpion,arachnid of warm dry regions having a long segmented tail ending in a venomous stinger 23 | 22,slug,any of various terrestrial gastropods having an elongated slimy body and no external shell 24 | 23,snail,freshwater or marine or terrestrial gastropod mollusk usually having an external enclosing spiral shell 25 | 24,stingray,large venomous ray with large barbed spines near the base of a thin whiplike tail capable of inflicting severe wounds 26 | 25,tarantula,large southern European spider once thought to be the cause of tarantism (uncontrollable bodily movement) 27 | 26,terrapin,any of various edible North American web-footed turtles living in fresh or brackish water 28 | 27,triceratops,huge ceratopsian dinosaur having three horns and the neck heavily armored with a very solid frill 29 | 28,trilobite,an extinct arthropod that was abundant in Paleozoic times; had an exoskeleton divided into three parts 30 | 29,vulture,any of various large diurnal birds of prey having naked heads and weak claws and feeding chiefly on carrion 31 | -------------------------------------------------------------------------------- /utils/descriptions_mammals.csv: -------------------------------------------------------------------------------- 1 | 0,baboon,large terrestrial monkeys having doglike muzzles 2 | 1,bat,nocturnal mouselike mammal with forelimbs modified to form membranous wings and anatomical adaptations for echolocation by which they navigate 3 | 2,bear,massive plantigrade carnivorous or omnivorous mammals with long shaggy coats and strong claws 4 | 3,beaver,large semiaquatic rodent with webbed hind feet and a broad flat tail; construct complex dams and underwater lodges 5 | 4,bobcat,small lynx of North America 6 | 5,camel,cud-chewing mammal used as a draft or saddle animal in desert regions 7 | 6,caribou,Arctic deer with large antlers in both sexes; called `reindeer' in Eurasia and `caribou' in North America 8 | 7,cat,feline mammal usually having thick soft fur and no ability to roar: domestic cats; wildcats 9 | 8,cheetah,long-legged spotted cat of Africa and southwestern Asia having nonretractile claws; the swiftest mammal; can be trained to run down game 10 | 9,chimpanzee,intelligent somewhat arboreal ape of equatorial African forests 11 | 10,chipmunk,a burrowing ground squirrel of western America and Asia; has cheek pouches and a light and dark stripe running down the body 12 | 11,coati,omnivorous mammal of Central America and South America 13 | 12,cougar,large American feline resembling a lion 14 | 13,coyote,small wolf native to western North America 15 | 14,dingo,wolflike yellowish-brown wild dog of Australia 16 | 15,dog,a member of the genus Canis (probably descended from the common wolf) that has been domesticated by man since prehistoric times; occurs in many breeds 17 | 16,dolphin,any of various small toothed whales with a beaklike snout; larger than porpoises 18 | 17,echidna,a burrowing monotreme mammal covered with spines and having a long snout and claws for hunting ants and termites; native to New Guinea 19 | 18,elephant,five-toed pachyderm 20 | 19,elk,large northern deer with enormous flattened antlers in the male; called `elk' in Europe and `moose' in North America 21 | 20,fox,alert carnivorous mammal with pointed muzzle and ears and a bushy tail; most are predators that do not hunt in packs 22 | 21,gazelle,small swift graceful antelope of Africa and Asia having lustrous eyes 23 | 22,gibbon,smallest and most perfectly anthropoid arboreal ape having long arms and no tail; of southern Asia and East Indies 24 | 23,giraffe,tallest living quadruped; having a spotted coat and small horns and very long neck and legs; of savannahs of tropical Africa 25 | 24,goat,any of numerous agile ruminants related to sheep but having a beard and straight horns 26 | 25,gorilla,largest anthropoid ape; terrestrial and vegetarian; of forests of central west Africa 27 | 26,grizzly,powerful brownish-yellow bear of the uplands of western North America 28 | 27,groundhog,reddish brown North American marmot 29 | 28,hare,swift timid long-eared mammal larger than a rabbit having a divided upper lip and long hind legs; young born furred and with open eyes 30 | 29,hippopotamus,massive thick-skinned herbivorous animal living in or around rivers of tropical Africa 31 | 30,horse,solid-hoofed herbivorous quadruped domesticated since prehistoric times 32 | 31,hyena,doglike nocturnal mammal of Africa and southern Asia that feeds chiefly on carrion 33 | 32,impala,African antelope with ridged curved horns; moves with enormous leaps 34 | 33,jaguar,a large spotted feline of tropical America similar to the leopard; in some classifications considered a member of the genus Felis 35 | 34,kangaroo,any of several herbivorous leaping marsupials of Australia and New Guinea having large powerful hind legs and a long thick tail 36 | 35,koala,sluggish tailless Australian arboreal marsupial with grey furry ears and coat; feeds on eucalyptus leaves and bark 37 | 36,kob,an orange-brown antelope of southeast Africa 38 | 37,lemur,large-eyed arboreal prosimian having foxy faces and long furry tails 39 | 38,leopard,large feline of African and Asian forests usually having a tawny coat with black spots 40 | 39,lion,large gregarious predatory feline of Africa and India having a tawny coat with a shaggy mane in the male 41 | 40,llama,wild or domesticated South American cud-chewing animal related to camels but smaller and lacking a hump 42 | 41,lynx,short-tailed wildcats with usually tufted ears; valued for their fur 43 | 42,manatee,sirenian mammal of tropical coastal waters of America; the flat tail is rounded 44 | 43,mandrill,baboon of west Africa with a bright red and blue muzzle and blue hindquarters 45 | 44,meerkat,a mongoose-like viverrine of South Africa having a face like a lemur and only four toes 46 | 45,mongoose,agile grizzled Old World viverrine; preys on snakes and rodents 47 | 46,monkey,any of various long-tailed primates (excluding the prosimians) 48 | 47,mouse,any of numerous small rodents typically resembling diminutive rats having pointed snouts and small ears on elongated bodies with slender usually hairless tails 49 | 48,ocelot,nocturnal wildcat of Central America and South America having a dark-spotted buff-brown coat 50 | 49,orangutan,large long-armed ape of Borneo and Sumatra having arboreal habits 51 | 50,otter,freshwater carnivorous mammal having webbed and clawed feet and dark brown fur 52 | 51,porcupine,relatively large rodents with sharp erectile bristles mingled with the fur 53 | 52,pronghorn,fleet antelope-like ruminant of western North American plains with small branched horns 54 | 53,rabbit,any of various burrowing animals of the family Leporidae having long ears and short tails; some domesticated and raised for pets or food 55 | 54,raccoon,an omnivorous nocturnal mammal native to North America and Central America 56 | 55,rat,any of various long-tailed rodents similar to but larger than a mouse 57 | 56,rhinoceros,massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout 58 | 57,sheep,woolly usually horned ruminant mammal related to the goat 59 | 58,sloth,any of several slow-moving arboreal mammals of South America and Central America; they hang from branches back downward and feed on leaves and fruits 60 | 59,squirrel,a kind of arboreal rodent having a long bushy tail 61 | 60,tamarin,small South American marmoset with silky fur and long nonprehensile tail 62 | 61,tapir,large inoffensive chiefly nocturnal ungulate of tropical America and southeast Asia having a heavy body and fleshy snout 63 | 62,tiger,large feline of forests in most of Asia having a tawny coat with black stripes; endangered 64 | 63,wallaby,any of various small or medium-sized kangaroos; often brightly colored 65 | 64,walrus,either of two large northern marine mammals having ivory tusks and tough hide over thick blubber 66 | 65,warthog,African wild swine with warty protuberances on the face and large protruding tusks 67 | 66,whale,any of the larger cetacean mammals having a streamlined body and breathing through a blowhole on the head 68 | 67,wolf,any of various predatory carnivorous canine mammals of North America and Eurasia that usually hunt in packs 69 | 68,wombat,burrowing herbivorous Australian marsupials about the size of a badger 70 | 69,zebra,any of several fleet black-and-white striped African equines 71 | -------------------------------------------------------------------------------- /utils/download_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to download images and create descriptions file. 3 | 4 | Usage: 5 | 6 | wget http://image-net.org/imagenet_data/urls/imagenet_fall11_urls.tgz # download imagenet urls. please decompress. 7 | python --cmd_urls # get relevant urls from imagenet 8 | python --cmd_split # create train/dev/test splits of urls 9 | python --cmd_desc # create descriptions file 10 | python --cmd_download # download files for each split/class 11 | 12 | Some sample synsets from Imagenet: 13 | 14 | n01498041 stingray 15 | n01514859 hen 16 | n01518878 ostrich 17 | n01531178 goldfinch 18 | n01558993 robin 19 | n01580077 jay 20 | n01582220 magpie 21 | n01592084 chickadee 22 | n01616318 vulture 23 | n01641577 bullfrog 24 | n01667778 terrapin 25 | n01687978 agama 26 | n01704323 triceratops 27 | n01768244 trilobite 28 | n01770393 scorpion 29 | n01774750 tarantula 30 | n01784675 centipede 31 | n01806143 peacock 32 | n01806567 quail 33 | n01807496 partridge 34 | n01818515 macaw 35 | n01820546 lorikeet 36 | n01833805 hummingbird 37 | n01843065 jacamar 38 | n01847000 drake 39 | n01855672 goose 40 | n01910747 jellyfish 41 | n01944390 snail 42 | n01945685 slug 43 | n01882714 koala 44 | 45 | """ 46 | 47 | from __future__ import print_function 48 | 49 | from collections import OrderedDict 50 | import os 51 | import sys 52 | import json 53 | import time 54 | import random 55 | import urllib 56 | import threading 57 | from tqdm import tqdm 58 | from parse import * 59 | 60 | from nltk.corpus import wordnet as wn 61 | 62 | import gflags 63 | 64 | FLAGS = gflags.FLAGS 65 | 66 | 67 | def try_mkdir(path): 68 | try: 69 | os.mkdir(path) 70 | return 1 71 | except BaseException as e: 72 | # directory already exists 73 | return 0 74 | 75 | 76 | def flickr_name(url): 77 | tpl = "http://{subdomain}.flickr.com/{part1}/{part2}.{suffix}" 78 | data = parse(tpl, url) 79 | return "{subdomain}_{part1}_{part2}.{suffix}".format(**data.named) 80 | 81 | 82 | class MultiThreadedDownloader(object): 83 | 84 | def __init__(self, download_path, num_threads, urls, time_wait): 85 | self.lock = threading.Lock() 86 | self.download_path = download_path 87 | self.num_threads = num_threads 88 | self.urls = urls 89 | self.index = 0 90 | self.time_wait = time_wait 91 | self.pbar = tqdm(total=len(self.urls)) 92 | 93 | def worker(self): 94 | finished = False 95 | while True: 96 | self.lock.acquire() 97 | try: 98 | if self.index < len(self.urls): 99 | # atomically acquire index 100 | url = self.urls[self.index] 101 | _filename = flickr_name(url) 102 | _save_path = os.path.join(self.download_path, _filename) 103 | 104 | # increment index 105 | self.index = self.index + 1 106 | self.pbar.update(1) 107 | else: 108 | finished = True 109 | finally: 110 | self.lock.release() 111 | 112 | # if no urls left, break loop 113 | if finished: 114 | break 115 | 116 | # download url 117 | if not os.path.exists(_save_path): 118 | urllib.urlretrieve(url, _save_path) 119 | saved = True 120 | time.sleep(self.time_wait) 121 | 122 | def run(self): 123 | # start threads 124 | threads = [] 125 | for i in range(self.num_threads): 126 | t = threading.Thread(target=self.worker, args=()) 127 | t.start() 128 | threads.append(t) 129 | time.sleep(self.time_wait) 130 | 131 | # wait until all threads complete 132 | for t in threads: 133 | t.join() 134 | 135 | self.pbar.close() 136 | 137 | 138 | def cmd_urls(): 139 | 140 | random.seed(FLAGS.seed) 141 | 142 | assert os.path.exists(FLAGS.save_urls_path), "Make sure to create urls directory: {}".format(FLAGS.save_urls_path) 143 | 144 | synsets = FLAGS.synsets.split(',') 145 | classes = FLAGS.classes.split(',') 146 | synsets_to_class = {ss: cc for ss, cc in zip(synsets, classes)} 147 | urls = OrderedDict() 148 | for k in classes: 149 | urls[k] = [] 150 | 151 | # read urls 152 | with open(FLAGS.load_imagenet_path) as f: 153 | for ii, line in enumerate(f): 154 | try: 155 | line = line.strip() 156 | _synset, _url = line.split('\t') 157 | _synset = _synset.split('_')[0] 158 | if _synset in synsets and FLAGS.filter_url in _url: 159 | _class = synsets_to_class[_synset] 160 | urls[_class].append(_url) 161 | except: 162 | print("skipping line {}: {}".format(ii, line)) 163 | 164 | # randomize and restrict to limit 165 | for k in urls.keys(): 166 | random.shuffle(urls[k]) 167 | urls[k] = urls[k][:FLAGS.class_size] 168 | assert len(urls[k]) == FLAGS.class_size, "Not enough urls for: {} ({})".format(k, len(urls[k])) 169 | 170 | # write to file 171 | for k in urls.keys(): 172 | with open("{}/{}.txt".format(FLAGS.save_urls_path, k), "w") as f: 173 | for _url in urls[k]: 174 | f.write(_url + '\n') 175 | 176 | 177 | def cmd_split(): 178 | 179 | random.seed(FLAGS.seed) 180 | 181 | datasets = dict(train=dict(), dev=dict(), test=dict()) 182 | 183 | for cls in FLAGS.classes.split(','): 184 | 185 | with open("{}/{}.txt".format(FLAGS.load_urls_path, cls)) as f: 186 | urls = [line.strip() for line in f] 187 | 188 | assert len(urls) >= FLAGS.train_size + FLAGS.dev_size + FLAGS.test_size, \ 189 | "There are not sufficient urls for class: {}".format(cls) 190 | 191 | random.shuffle(urls) 192 | 193 | # Train 194 | offset = 0 195 | size = FLAGS.train_size 196 | datasets['train'][cls] = urls[offset:offset + size] 197 | 198 | # Dev 199 | offset += FLAGS.train_size 200 | size = FLAGS.dev_size 201 | datasets['dev'][cls] = urls[offset:offset + size] 202 | 203 | # Test 204 | offset += FLAGS.dev_size 205 | size = FLAGS.test_size 206 | datasets['test'][cls] = urls[offset:offset + size] 207 | 208 | with open(FLAGS.save_datasets_path, "w") as f: 209 | f.write(json.dumps(datasets, indent=4, sort_keys=True)) 210 | 211 | 212 | def cmd_desc(): 213 | animal = wn.synset('animal.n.01') 214 | 215 | descriptions = OrderedDict() 216 | 217 | # get animal synset for each class, and the class's wordnet description 218 | for cls in FLAGS.classes.split(','): 219 | for i in range(1, 10): 220 | _synset = wn.synset('{}.n.0{}'.format(cls, i)) 221 | if _synset.lowest_common_hypernyms(animal)[0] == animal: 222 | break 223 | 224 | if _synset.lowest_common_hypernyms(animal)[0] != animal: 225 | raise BaseException("No animal synset found for: {}".format(cls)) 226 | 227 | descriptions[cls] = _synset.definition() 228 | 229 | # write to descriptions file 230 | with open(FLAGS.save_descriptions_path, "w") as f: 231 | for ii, cls in enumerate(sorted(descriptions.keys())): 232 | desc = descriptions[cls].replace(',', '') 233 | f.write("{},{},{}\n".format(ii, cls, desc)) 234 | 235 | 236 | def cmd_download(): 237 | 238 | with open(FLAGS.load_datasets_path) as f: 239 | datasets = json.loads(f.read()) 240 | 241 | for _d in ['train', 'dev', 'test']: 242 | _dataset_path = os.path.join(FLAGS.save_images, _d) 243 | try_mkdir(_dataset_path) 244 | 245 | for cls in FLAGS.classes.split(','): 246 | _dataset_cls_path = os.path.join(_dataset_path, cls) 247 | try_mkdir(_dataset_cls_path) 248 | 249 | print("Downloading images for {}/{}".format(_d, cls)) 250 | 251 | urls = datasets[_d][cls] 252 | downloader = MultiThreadedDownloader(_dataset_cls_path, FLAGS.num_threads, urls, FLAGS.throttle) 253 | downloader.run() 254 | 255 | 256 | if __name__ == '__main__': 257 | gflags.DEFINE_string("synsets", "n01498041,n01514859,n01518878,n01531178,n01558993,n01580077" \ 258 | ",n01582220,n01592084,n01616318,n01641577,n01667778,n01687978,n01704323,n01768244,n01770393" \ 259 | ",n01774750,n01784675,n01806143,n01806567,n01807496,n01818515,n01820546,n01833805,n01843065" \ 260 | ",n01847000,n01855672,n01910747,n01944390,n01945685,n01882714", "Comma-delimited list of sysnet ids to use.") 261 | gflags.DEFINE_string("classes", "stingray,hen,ostrich,goldfinch,robin,jay,magpie" \ 262 | ",chickadee,vulture,bullfrog,terrapin,agama,triceratops,trilobite,scorpion,tarantula" \ 263 | ",centipede,peacock,quail,partridge,macaw,lorikeet,hummingbird,jacamar,drake,goose" \ 264 | ",jellyfish,snail,slug,koala", "Comma-delimited list of classes to use. Should match sysnet ids.") 265 | gflags.DEFINE_integer("seed", 11, "Seed for shuffling urls.") 266 | 267 | # urls args 268 | gflags.DEFINE_string("load_imagenet_path", "./fall11_urls.txt", "Path to imagenet urls.") 269 | gflags.DEFINE_string("save_urls_path", "./urls", "Path to directory with url files.") 270 | gflags.DEFINE_integer("class_size", 500, "Size of urls to keep (in images per class).") 271 | gflags.DEFINE_string("filter_url", "static.flickr", "String to filter urls.") 272 | 273 | # split args 274 | gflags.DEFINE_string("load_urls_path", "./urls", "Path to directory with url files.") 275 | gflags.DEFINE_string("save_datasets_path", "datasets.json", "Single JSON file defining train/dev/test splits.") 276 | gflags.DEFINE_integer("train_size", 100, "Size of dataset (in images per class).") 277 | gflags.DEFINE_integer("dev_size", 100, "Size of dataset (in images per class).") 278 | gflags.DEFINE_integer("test_size", 100, "Size of dataset (in images per class).") 279 | 280 | # description args 281 | gflags.DEFINE_string("load_datasets_path", "datasets.json", "Single JSON file defining train/dev/test splits.") 282 | gflags.DEFINE_string("save_images", "./imgs", "Path to save images.") 283 | 284 | # download args 285 | gflags.DEFINE_string("save_descriptions_path", "./descriptions.csv", "Path to descriptions file.") 286 | gflags.DEFINE_integer("num_threads", 8, "Use a multi-threaded image downloader.") 287 | gflags.DEFINE_integer("throttle", 0.01, "Throttle the downloader.") 288 | 289 | # commands 290 | gflags.DEFINE_boolean("cmd_urls", False, "Extract relevant urls from imagenet.") 291 | gflags.DEFINE_boolean("cmd_split", False, "Split urls into datasets.") 292 | gflags.DEFINE_boolean("cmd_desc", False, "Create descriptions file.") 293 | gflags.DEFINE_boolean("cmd_download", False, "Download images from flickr.") 294 | 295 | FLAGS(sys.argv) 296 | 297 | print("Flag Values:\n" + json.dumps(FLAGS.flag_values_dict(), indent=4, sort_keys=True)) 298 | 299 | if FLAGS.cmd_urls: 300 | cmd_urls() 301 | if FLAGS.cmd_split: 302 | cmd_split() 303 | if FLAGS.cmd_desc: 304 | cmd_desc() 305 | if FLAGS.cmd_download: 306 | cmd_download() 307 | -------------------------------------------------------------------------------- /utils/imagenet.synset: -------------------------------------------------------------------------------- 1 | n01440764 tench, Tinca tinca 2 | n01443537 goldfish, Carassius auratus 3 | n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | n01491361 tiger shark, Galeocerdo cuvieri 5 | n01494475 hammerhead, hammerhead shark 6 | n01496331 electric ray, crampfish, numbfish, torpedo 7 | n01498041 stingray 8 | n01514668 cock 9 | n01514859 hen 10 | n01518878 ostrich, Struthio camelus 11 | n01530575 brambling, Fringilla montifringilla 12 | n01531178 goldfinch, Carduelis carduelis 13 | n01532829 house finch, linnet, Carpodacus mexicanus 14 | n01534433 junco, snowbird 15 | n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | n01558993 robin, American robin, Turdus migratorius 17 | n01560419 bulbul 18 | n01580077 jay 19 | n01582220 magpie 20 | n01592084 chickadee 21 | n01601694 water ouzel, dipper 22 | n01608432 kite 23 | n01614925 bald eagle, American eagle, Haliaeetus leucocephalus 24 | n01616318 vulture 25 | n01622779 great grey owl, great gray owl, Strix nebulosa 26 | n01629819 European fire salamander, Salamandra salamandra 27 | n01630670 common newt, Triturus vulgaris 28 | n01631663 eft 29 | n01632458 spotted salamander, Ambystoma maculatum 30 | n01632777 axolotl, mud puppy, Ambystoma mexicanum 31 | n01641577 bullfrog, Rana catesbeiana 32 | n01644373 tree frog, tree-frog 33 | n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | n01664065 loggerhead, loggerhead turtle, Caretta caretta 35 | n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | n01667114 mud turtle 37 | n01667778 terrapin 38 | n01669191 box turtle, box tortoise 39 | n01675722 banded gecko 40 | n01677366 common iguana, iguana, Iguana iguana 41 | n01682714 American chameleon, anole, Anolis carolinensis 42 | n01685808 whiptail, whiptail lizard 43 | n01687978 agama 44 | n01688243 frilled lizard, Chlamydosaurus kingi 45 | n01689811 alligator lizard 46 | n01692333 Gila monster, Heloderma suspectum 47 | n01693334 green lizard, Lacerta viridis 48 | n01694178 African chameleon, Chamaeleo chamaeleon 49 | n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | n01697457 African crocodile, Nile crocodile, Crocodylus niloticus 51 | n01698640 American alligator, Alligator mississipiensis 52 | n01704323 triceratops 53 | n01728572 thunder snake, worm snake, Carphophis amoenus 54 | n01728920 ringneck snake, ring-necked snake, ring snake 55 | n01729322 hognose snake, puff adder, sand viper 56 | n01729977 green snake, grass snake 57 | n01734418 king snake, kingsnake 58 | n01735189 garter snake, grass snake 59 | n01737021 water snake 60 | n01739381 vine snake 61 | n01740131 night snake, Hypsiglena torquata 62 | n01742172 boa constrictor, Constrictor constrictor 63 | n01744401 rock python, rock snake, Python sebae 64 | n01748264 Indian cobra, Naja naja 65 | n01749939 green mamba 66 | n01751748 sea snake 67 | n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 70 | n01768244 trilobite 71 | n01770081 harvestman, daddy longlegs, Phalangium opilio 72 | n01770393 scorpion 73 | n01773157 black and gold garden spider, Argiope aurantia 74 | n01773549 barn spider, Araneus cavaticus 75 | n01773797 garden spider, Aranea diademata 76 | n01774384 black widow, Latrodectus mactans 77 | n01774750 tarantula 78 | n01775062 wolf spider, hunting spider 79 | n01776313 tick 80 | n01784675 centipede 81 | n01795545 black grouse 82 | n01796340 ptarmigan 83 | n01797886 ruffed grouse, partridge, Bonasa umbellus 84 | n01798484 prairie chicken, prairie grouse, prairie fowl 85 | n01806143 peacock 86 | n01806567 quail 87 | n01807496 partridge 88 | n01817953 African grey, African gray, Psittacus erithacus 89 | n01818515 macaw 90 | n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | n01820546 lorikeet 92 | n01824575 coucal 93 | n01828970 bee eater 94 | n01829413 hornbill 95 | n01833805 hummingbird 96 | n01843065 jacamar 97 | n01843383 toucan 98 | n01847000 drake 99 | n01855032 red-breasted merganser, Mergus serrator 100 | n01855672 goose 101 | n01860187 black swan, Cygnus atratus 102 | n01871265 tusker 103 | n01872401 echidna, spiny anteater, anteater 104 | n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | n01877812 wallaby, brush kangaroo 106 | n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | n01883070 wombat 108 | n01910747 jellyfish 109 | n01914609 sea anemone, anemone 110 | n01917289 brain coral 111 | n01924916 flatworm, platyhelminth 112 | n01930112 nematode, nematode worm, roundworm 113 | n01943899 conch 114 | n01944390 snail 115 | n01945685 slug 116 | n01950731 sea slug, nudibranch 117 | n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | n01968897 chambered nautilus, pearly nautilus, nautilus 119 | n01978287 Dungeness crab, Cancer magister 120 | n01978455 rock crab, Cancer irroratus 121 | n01980166 fiddler crab 122 | n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | n01985128 crayfish, crawfish, crawdad, crawdaddy 126 | n01986214 hermit crab 127 | n01990800 isopod 128 | n02002556 white stork, Ciconia ciconia 129 | n02002724 black stork, Ciconia nigra 130 | n02006656 spoonbill 131 | n02007558 flamingo 132 | n02009229 little blue heron, Egretta caerulea 133 | n02009912 American egret, great white heron, Egretta albus 134 | n02011460 bittern 135 | n02012849 crane 136 | n02013706 limpkin, Aramus pictus 137 | n02017213 European gallinule, Porphyrio porphyrio 138 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 139 | n02018795 bustard 140 | n02025239 ruddy turnstone, Arenaria interpres 141 | n02027492 red-backed sandpiper, dunlin, Erolia alpina 142 | n02028035 redshank, Tringa totanus 143 | n02033041 dowitcher 144 | n02037110 oystercatcher, oyster catcher 145 | n02051845 pelican 146 | n02056570 king penguin, Aptenodytes patagonica 147 | n02058221 albatross, mollymawk 148 | n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | n02074367 dugong, Dugong dugon 151 | n02077923 sea lion 152 | n02085620 Chihuahua 153 | n02085782 Japanese spaniel 154 | n02085936 Maltese dog, Maltese terrier, Maltese 155 | n02086079 Pekinese, Pekingese, Peke 156 | n02086240 Shih-Tzu 157 | n02086646 Blenheim spaniel 158 | n02086910 papillon 159 | n02087046 toy terrier 160 | n02087394 Rhodesian ridgeback 161 | n02088094 Afghan hound, Afghan 162 | n02088238 basset, basset hound 163 | n02088364 beagle 164 | n02088466 bloodhound, sleuthhound 165 | n02088632 bluetick 166 | n02089078 black-and-tan coonhound 167 | n02089867 Walker hound, Walker foxhound 168 | n02089973 English foxhound 169 | n02090379 redbone 170 | n02090622 borzoi, Russian wolfhound 171 | n02090721 Irish wolfhound 172 | n02091032 Italian greyhound 173 | n02091134 whippet 174 | n02091244 Ibizan hound, Ibizan Podenco 175 | n02091467 Norwegian elkhound, elkhound 176 | n02091635 otterhound, otter hound 177 | n02091831 Saluki, gazelle hound 178 | n02092002 Scottish deerhound, deerhound 179 | n02092339 Weimaraner 180 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 181 | n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | n02093647 Bedlington terrier 183 | n02093754 Border terrier 184 | n02093859 Kerry blue terrier 185 | n02093991 Irish terrier 186 | n02094114 Norfolk terrier 187 | n02094258 Norwich terrier 188 | n02094433 Yorkshire terrier 189 | n02095314 wire-haired fox terrier 190 | n02095570 Lakeland terrier 191 | n02095889 Sealyham terrier, Sealyham 192 | n02096051 Airedale, Airedale terrier 193 | n02096177 cairn, cairn terrier 194 | n02096294 Australian terrier 195 | n02096437 Dandie Dinmont, Dandie Dinmont terrier 196 | n02096585 Boston bull, Boston terrier 197 | n02097047 miniature schnauzer 198 | n02097130 giant schnauzer 199 | n02097209 standard schnauzer 200 | n02097298 Scotch terrier, Scottish terrier, Scottie 201 | n02097474 Tibetan terrier, chrysanthemum dog 202 | n02097658 silky terrier, Sydney silky 203 | n02098105 soft-coated wheaten terrier 204 | n02098286 West Highland white terrier 205 | n02098413 Lhasa, Lhasa apso 206 | n02099267 flat-coated retriever 207 | n02099429 curly-coated retriever 208 | n02099601 golden retriever 209 | n02099712 Labrador retriever 210 | n02099849 Chesapeake Bay retriever 211 | n02100236 German short-haired pointer 212 | n02100583 vizsla, Hungarian pointer 213 | n02100735 English setter 214 | n02100877 Irish setter, red setter 215 | n02101006 Gordon setter 216 | n02101388 Brittany spaniel 217 | n02101556 clumber, clumber spaniel 218 | n02102040 English springer, English springer spaniel 219 | n02102177 Welsh springer spaniel 220 | n02102318 cocker spaniel, English cocker spaniel, cocker 221 | n02102480 Sussex spaniel 222 | n02102973 Irish water spaniel 223 | n02104029 kuvasz 224 | n02104365 schipperke 225 | n02105056 groenendael 226 | n02105162 malinois 227 | n02105251 briard 228 | n02105412 kelpie 229 | n02105505 komondor 230 | n02105641 Old English sheepdog, bobtail 231 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 232 | n02106030 collie 233 | n02106166 Border collie 234 | n02106382 Bouvier des Flandres, Bouviers des Flandres 235 | n02106550 Rottweiler 236 | n02106662 German shepherd, German shepherd dog, German police dog, alsatian 237 | n02107142 Doberman, Doberman pinscher 238 | n02107312 miniature pinscher 239 | n02107574 Greater Swiss Mountain dog 240 | n02107683 Bernese mountain dog 241 | n02107908 Appenzeller 242 | n02108000 EntleBucher 243 | n02108089 boxer 244 | n02108422 bull mastiff 245 | n02108551 Tibetan mastiff 246 | n02108915 French bulldog 247 | n02109047 Great Dane 248 | n02109525 Saint Bernard, St Bernard 249 | n02109961 Eskimo dog, husky 250 | n02110063 malamute, malemute, Alaskan malamute 251 | n02110185 Siberian husky 252 | n02110341 dalmatian, coach dog, carriage dog 253 | n02110627 affenpinscher, monkey pinscher, monkey dog 254 | n02110806 basenji 255 | n02110958 pug, pug-dog 256 | n02111129 Leonberg 257 | n02111277 Newfoundland, Newfoundland dog 258 | n02111500 Great Pyrenees 259 | n02111889 Samoyed, Samoyede 260 | n02112018 Pomeranian 261 | n02112137 chow, chow chow 262 | n02112350 keeshond 263 | n02112706 Brabancon griffon 264 | n02113023 Pembroke, Pembroke Welsh corgi 265 | n02113186 Cardigan, Cardigan Welsh corgi 266 | n02113624 toy poodle 267 | n02113712 miniature poodle 268 | n02113799 standard poodle 269 | n02113978 Mexican hairless 270 | n02114367 timber wolf, grey wolf, gray wolf, Canis lupus 271 | n02114548 white wolf, Arctic wolf, Canis lupus tundrarum 272 | n02114712 red wolf, maned wolf, Canis rufus, Canis niger 273 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 274 | n02115641 dingo, warrigal, warragal, Canis dingo 275 | n02115913 dhole, Cuon alpinus 276 | n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | n02117135 hyena, hyaena 278 | n02119022 red fox, Vulpes vulpes 279 | n02119789 kit fox, Vulpes macrotis 280 | n02120079 Arctic fox, white fox, Alopex lagopus 281 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 282 | n02123045 tabby, tabby cat 283 | n02123159 tiger cat 284 | n02123394 Persian cat 285 | n02123597 Siamese cat, Siamese 286 | n02124075 Egyptian cat 287 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | n02127052 lynx, catamount 289 | n02128385 leopard, Panthera pardus 290 | n02128757 snow leopard, ounce, Panthera uncia 291 | n02128925 jaguar, panther, Panthera onca, Felis onca 292 | n02129165 lion, king of beasts, Panthera leo 293 | n02129604 tiger, Panthera tigris 294 | n02130308 cheetah, chetah, Acinonyx jubatus 295 | n02132136 brown bear, bruin, Ursus arctos 296 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 297 | n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | n02134418 sloth bear, Melursus ursinus, Ursus ursinus 299 | n02137549 mongoose 300 | n02138441 meerkat, mierkat 301 | n02165105 tiger beetle 302 | n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | n02167151 ground beetle, carabid beetle 304 | n02168699 long-horned beetle, longicorn, longicorn beetle 305 | n02169497 leaf beetle, chrysomelid 306 | n02172182 dung beetle 307 | n02174001 rhinoceros beetle 308 | n02177972 weevil 309 | n02190166 fly 310 | n02206856 bee 311 | n02219486 ant, emmet, pismire 312 | n02226429 grasshopper, hopper 313 | n02229544 cricket 314 | n02231487 walking stick, walkingstick, stick insect 315 | n02233338 cockroach, roach 316 | n02236044 mantis, mantid 317 | n02256656 cicada, cicala 318 | n02259212 leafhopper 319 | n02264363 lacewing, lacewing fly 320 | n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | n02268853 damselfly 322 | n02276258 admiral 323 | n02277742 ringlet, ringlet butterfly 324 | n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | n02280649 cabbage butterfly 326 | n02281406 sulphur butterfly, sulfur butterfly 327 | n02281787 lycaenid, lycaenid butterfly 328 | n02317335 starfish, sea star 329 | n02319095 sea urchin 330 | n02321529 sea cucumber, holothurian 331 | n02325366 wood rabbit, cottontail, cottontail rabbit 332 | n02326432 hare 333 | n02328150 Angora, Angora rabbit 334 | n02342885 hamster 335 | n02346627 porcupine, hedgehog 336 | n02356798 fox squirrel, eastern fox squirrel, Sciurus niger 337 | n02361337 marmot 338 | n02363005 beaver 339 | n02364673 guinea pig, Cavia cobaya 340 | n02389026 sorrel 341 | n02391049 zebra 342 | n02395406 hog, pig, grunter, squealer, Sus scrofa 343 | n02396427 wild boar, boar, Sus scrofa 344 | n02397096 warthog 345 | n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | n02403003 ox 347 | n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | n02410509 bison 349 | n02412080 ram, tup 350 | n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | n02417914 ibex, Capra ibex 352 | n02422106 hartebeest 353 | n02422699 impala, Aepyceros melampus 354 | n02423022 gazelle 355 | n02437312 Arabian camel, dromedary, Camelus dromedarius 356 | n02437616 llama 357 | n02441942 weasel 358 | n02442845 mink 359 | n02443114 polecat, fitch, foulmart, foumart, Mustela putorius 360 | n02443484 black-footed ferret, ferret, Mustela nigripes 361 | n02444819 otter 362 | n02445715 skunk, polecat, wood pussy 363 | n02447366 badger 364 | n02454379 armadillo 365 | n02457408 three-toed sloth, ai, Bradypus tridactylus 366 | n02480495 orangutan, orang, orangutang, Pongo pygmaeus 367 | n02480855 gorilla, Gorilla gorilla 368 | n02481823 chimpanzee, chimp, Pan troglodytes 369 | n02483362 gibbon, Hylobates lar 370 | n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | n02484975 guenon, guenon monkey 372 | n02486261 patas, hussar monkey, Erythrocebus patas 373 | n02486410 baboon 374 | n02487347 macaque 375 | n02488291 langur 376 | n02488702 colobus, colobus monkey 377 | n02489166 proboscis monkey, Nasalis larvatus 378 | n02490219 marmoset 379 | n02492035 capuchin, ringtail, Cebus capucinus 380 | n02492660 howler monkey, howler 381 | n02493509 titi, titi monkey 382 | n02493793 spider monkey, Ateles geoffroyi 383 | n02494079 squirrel monkey, Saimiri sciureus 384 | n02497673 Madagascar cat, ring-tailed lemur, Lemur catta 385 | n02500267 indri, indris, Indri indri, Indri brevicaudatus 386 | n02504013 Indian elephant, Elephas maximus 387 | n02504458 African elephant, Loxodonta africana 388 | n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | n02514041 barracouta, snoek 391 | n02526121 eel 392 | n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | n02606052 rock beauty, Holocanthus tricolor 394 | n02607072 anemone fish 395 | n02640242 sturgeon 396 | n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus 397 | n02643566 lionfish 398 | n02655020 puffer, pufferfish, blowfish, globefish 399 | n02666196 abacus 400 | n02667093 abaya 401 | n02669723 academic gown, academic robe, judge's robe 402 | n02672831 accordion, piano accordion, squeeze box 403 | n02676566 acoustic guitar 404 | n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier 405 | n02690373 airliner 406 | n02692877 airship, dirigible 407 | n02699494 altar 408 | n02701002 ambulance 409 | n02704792 amphibian, amphibious vehicle 410 | n02708093 analog clock 411 | n02727426 apiary, bee house 412 | n02730930 apron 413 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | n02749479 assault rifle, assault gun 415 | n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack 416 | n02776631 bakery, bakeshop, bakehouse 417 | n02777292 balance beam, beam 418 | n02782093 balloon 419 | n02783161 ballpoint, ballpoint pen, ballpen, Biro 420 | n02786058 Band Aid 421 | n02787622 banjo 422 | n02788148 bannister, banister, balustrade, balusters, handrail 423 | n02790996 barbell 424 | n02791124 barber chair 425 | n02791270 barbershop 426 | n02793495 barn 427 | n02794156 barometer 428 | n02795169 barrel, cask 429 | n02797295 barrow, garden cart, lawn cart, wheelbarrow 430 | n02799071 baseball 431 | n02802426 basketball 432 | n02804414 bassinet 433 | n02804610 bassoon 434 | n02807133 bathing cap, swimming cap 435 | n02808304 bath towel 436 | n02808440 bathtub, bathing tub, bath, tub 437 | n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | n02814860 beacon, lighthouse, beacon light, pharos 439 | n02815834 beaker 440 | n02817516 bearskin, busby, shako 441 | n02823428 beer bottle 442 | n02823750 beer glass 443 | n02825657 bell cote, bell cot 444 | n02834397 bib 445 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 446 | n02837789 bikini, two-piece 447 | n02840245 binder, ring-binder 448 | n02841315 binoculars, field glasses, opera glasses 449 | n02843684 birdhouse 450 | n02859443 boathouse 451 | n02860847 bobsled, bobsleigh, bob 452 | n02865351 bolo tie, bolo, bola tie, bola 453 | n02869837 bonnet, poke bonnet 454 | n02870880 bookcase 455 | n02871525 bookshop, bookstore, bookstall 456 | n02877765 bottlecap 457 | n02879718 bow 458 | n02883205 bow tie, bow-tie, bowtie 459 | n02892201 brass, memorial tablet, plaque 460 | n02892767 brassiere, bra, bandeau 461 | n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | n02895154 breastplate, aegis, egis 463 | n02906734 broom 464 | n02909870 bucket, pail 465 | n02910353 buckle 466 | n02916936 bulletproof vest 467 | n02917067 bullet train, bullet 468 | n02927161 butcher shop, meat market 469 | n02930766 cab, hack, taxi, taxicab 470 | n02939185 caldron, cauldron 471 | n02948072 candle, taper, wax light 472 | n02950826 cannon 473 | n02951358 canoe 474 | n02951585 can opener, tin opener 475 | n02963159 cardigan 476 | n02965783 car mirror 477 | n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig 478 | n02966687 carpenter's kit, tool kit 479 | n02971356 carton 480 | n02974003 car wheel 481 | n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | n02978881 cassette 483 | n02979186 cassette player 484 | n02980441 castle 485 | n02981792 catamaran 486 | n02988304 CD player 487 | n02992211 cello, violoncello 488 | n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | n02999410 chain 490 | n03000134 chainlink fence 491 | n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | n03000684 chain saw, chainsaw 493 | n03014705 chest 494 | n03016953 chiffonier, commode 495 | n03017168 chime, bell, gong 496 | n03018349 china cabinet, china closet 497 | n03026506 Christmas stocking 498 | n03028079 church, church building 499 | n03032252 cinema, movie theater, movie theatre, movie house, picture palace 500 | n03041632 cleaver, meat cleaver, chopper 501 | n03042490 cliff dwelling 502 | n03045698 cloak 503 | n03047690 clog, geta, patten, sabot 504 | n03062245 cocktail shaker 505 | n03063599 coffee mug 506 | n03063689 coffeepot 507 | n03065424 coil, spiral, volute, whorl, helix 508 | n03075370 combination lock 509 | n03085013 computer keyboard, keypad 510 | n03089624 confectionery, confectionary, candy store 511 | n03095699 container ship, containership, container vessel 512 | n03100240 convertible 513 | n03109150 corkscrew, bottle screw 514 | n03110669 cornet, horn, trumpet, trump 515 | n03124043 cowboy boot 516 | n03124170 cowboy hat, ten-gallon hat 517 | n03125729 cradle 518 | n03126707 crane 519 | n03127747 crash helmet 520 | n03127925 crate 521 | n03131574 crib, cot 522 | n03133878 Crock Pot 523 | n03134739 croquet ball 524 | n03141823 crutch 525 | n03146219 cuirass 526 | n03160309 dam, dike, dyke 527 | n03179701 desk 528 | n03180011 desktop computer 529 | n03187595 dial telephone, dial phone 530 | n03188531 diaper, nappy, napkin 531 | n03196217 digital clock 532 | n03197337 digital watch 533 | n03201208 dining table, board 534 | n03207743 dishrag, dishcloth 535 | n03207941 dishwasher, dish washer, dishwashing machine 536 | n03208938 disk brake, disc brake 537 | n03216828 dock, dockage, docking facility 538 | n03218198 dogsled, dog sled, dog sleigh 539 | n03220513 dome 540 | n03223299 doormat, welcome mat 541 | n03240683 drilling platform, offshore rig 542 | n03249569 drum, membranophone, tympan 543 | n03250847 drumstick 544 | n03255030 dumbbell 545 | n03259280 Dutch oven 546 | n03271574 electric fan, blower 547 | n03272010 electric guitar 548 | n03272562 electric locomotive 549 | n03290653 entertainment center 550 | n03291819 envelope 551 | n03297495 espresso maker 552 | n03314780 face powder 553 | n03325584 feather boa, boa 554 | n03337140 file, file cabinet, filing cabinet 555 | n03344393 fireboat 556 | n03345487 fire engine, fire truck 557 | n03347037 fire screen, fireguard 558 | n03355925 flagpole, flagstaff 559 | n03372029 flute, transverse flute 560 | n03376595 folding chair 561 | n03379051 football helmet 562 | n03384352 forklift 563 | n03388043 fountain 564 | n03388183 fountain pen 565 | n03388549 four-poster 566 | n03393912 freight car 567 | n03394916 French horn, horn 568 | n03400231 frying pan, frypan, skillet 569 | n03404251 fur coat 570 | n03417042 garbage truck, dustcart 571 | n03424325 gasmask, respirator, gas helmet 572 | n03425413 gas pump, gasoline pump, petrol pump, island dispenser 573 | n03443371 goblet 574 | n03444034 go-kart 575 | n03445777 golf ball 576 | n03445924 golfcart, golf cart 577 | n03447447 gondola 578 | n03447721 gong, tam-tam 579 | n03450230 gown 580 | n03452741 grand piano, grand 581 | n03457902 greenhouse, nursery, glasshouse 582 | n03459775 grille, radiator grille 583 | n03461385 grocery store, grocery, food market, market 584 | n03467068 guillotine 585 | n03476684 hair slide 586 | n03476991 hair spray 587 | n03478589 half track 588 | n03481172 hammer 589 | n03482405 hamper 590 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | n03485407 hand-held computer, hand-held microcomputer 592 | n03485794 handkerchief, hankie, hanky, hankey 593 | n03492542 hard disc, hard disk, fixed disk 594 | n03494278 harmonica, mouth organ, harp, mouth harp 595 | n03495258 harp 596 | n03496892 harvester, reaper 597 | n03498962 hatchet 598 | n03527444 holster 599 | n03529860 home theater, home theatre 600 | n03530642 honeycomb 601 | n03532672 hook, claw 602 | n03534580 hoopskirt, crinoline 603 | n03535780 horizontal bar, high bar 604 | n03538406 horse cart, horse-cart 605 | n03544143 hourglass 606 | n03584254 iPod 607 | n03584829 iron, smoothing iron 608 | n03590841 jack-o'-lantern 609 | n03594734 jean, blue jean, denim 610 | n03594945 jeep, landrover 611 | n03595614 jersey, T-shirt, tee shirt 612 | n03598930 jigsaw puzzle 613 | n03599486 jinrikisha, ricksha, rickshaw 614 | n03602883 joystick 615 | n03617480 kimono 616 | n03623198 knee pad 617 | n03627232 knot 618 | n03630383 lab coat, laboratory coat 619 | n03633091 ladle 620 | n03637318 lampshade, lamp shade 621 | n03642806 laptop, laptop computer 622 | n03649909 lawn mower, mower 623 | n03657121 lens cap, lens cover 624 | n03658185 letter opener, paper knife, paperknife 625 | n03661043 library 626 | n03662601 lifeboat 627 | n03666591 lighter, light, igniter, ignitor 628 | n03670208 limousine, limo 629 | n03673027 liner, ocean liner 630 | n03676483 lipstick, lip rouge 631 | n03680355 Loafer 632 | n03690938 lotion 633 | n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | n03692522 loupe, jeweler's loupe 635 | n03697007 lumbermill, sawmill 636 | n03706229 magnetic compass 637 | n03709823 mailbag, postbag 638 | n03710193 mailbox, letter box 639 | n03710637 maillot 640 | n03710721 maillot, tank suit 641 | n03717622 manhole cover 642 | n03720891 maraca 643 | n03721384 marimba, xylophone 644 | n03724870 mask 645 | n03729826 matchstick 646 | n03733131 maypole 647 | n03733281 maze, labyrinth 648 | n03733805 measuring cup 649 | n03742115 medicine chest, medicine cabinet 650 | n03743016 megalith, megalithic structure 651 | n03759954 microphone, mike 652 | n03761084 microwave, microwave oven 653 | n03763968 military uniform 654 | n03764736 milk can 655 | n03769881 minibus 656 | n03770439 miniskirt, mini 657 | n03770679 minivan 658 | n03773504 missile 659 | n03775071 mitten 660 | n03775546 mixing bowl 661 | n03776460 mobile home, manufactured home 662 | n03777568 Model T 663 | n03777754 modem 664 | n03781244 monastery 665 | n03782006 monitor 666 | n03785016 moped 667 | n03786901 mortar 668 | n03787032 mortarboard 669 | n03788195 mosque 670 | n03788365 mosquito net 671 | n03791053 motor scooter, scooter 672 | n03792782 mountain bike, all-terrain bike, off-roader 673 | n03792972 mountain tent 674 | n03793489 mouse, computer mouse 675 | n03794056 mousetrap 676 | n03796401 moving van 677 | n03803284 muzzle 678 | n03804744 nail 679 | n03814639 neck brace 680 | n03814906 necklace 681 | n03825788 nipple 682 | n03832673 notebook, notebook computer 683 | n03837869 obelisk 684 | n03838899 oboe, hautboy, hautbois 685 | n03840681 ocarina, sweet potato 686 | n03841143 odometer, hodometer, mileometer, milometer 687 | n03843555 oil filter 688 | n03854065 organ, pipe organ 689 | n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | n03866082 overskirt 691 | n03868242 oxcart 692 | n03868863 oxygen mask 693 | n03871628 packet 694 | n03873416 paddle, boat paddle 695 | n03874293 paddlewheel, paddle wheel 696 | n03874599 padlock 697 | n03876231 paintbrush 698 | n03877472 pajama, pyjama, pj's, jammies 699 | n03877845 palace 700 | n03884397 panpipe, pandean pipe, syrinx 701 | n03887697 paper towel 702 | n03888257 parachute, chute 703 | n03888605 parallel bars, bars 704 | n03891251 park bench 705 | n03891332 parking meter 706 | n03895866 passenger car, coach, carriage 707 | n03899768 patio, terrace 708 | n03902125 pay-phone, pay-station 709 | n03903868 pedestal, plinth, footstall 710 | n03908618 pencil box, pencil case 711 | n03908714 pencil sharpener 712 | n03916031 perfume, essence 713 | n03920288 Petri dish 714 | n03924679 photocopier 715 | n03929660 pick, plectrum, plectron 716 | n03929855 pickelhaube 717 | n03930313 picket fence, paling 718 | n03930630 pickup, pickup truck 719 | n03933933 pier 720 | n03935335 piggy bank, penny bank 721 | n03937543 pill bottle 722 | n03938244 pillow 723 | n03942813 ping-pong ball 724 | n03944341 pinwheel 725 | n03947888 pirate, pirate ship 726 | n03950228 pitcher, ewer 727 | n03954731 plane, carpenter's plane, woodworking plane 728 | n03956157 planetarium 729 | n03958227 plastic bag 730 | n03961711 plate rack 731 | n03967562 plow, plough 732 | n03970156 plunger, plumber's helper 733 | n03976467 Polaroid camera, Polaroid Land camera 734 | n03976657 pole 735 | n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | n03980874 poncho 737 | n03982430 pool table, billiard table, snooker table 738 | n03983396 pop bottle, soda bottle 739 | n03991062 pot, flowerpot 740 | n03992509 potter's wheel 741 | n03995372 power drill 742 | n03998194 prayer rug, prayer mat 743 | n04004767 printer 744 | n04005630 prison, prison house 745 | n04008634 projectile, missile 746 | n04009552 projector 747 | n04019541 puck, hockey puck 748 | n04023962 punching bag, punch bag, punching ball, punchball 749 | n04026417 purse 750 | n04033901 quill, quill pen 751 | n04033995 quilt, comforter, comfort, puff 752 | n04037443 racer, race car, racing car 753 | n04039381 racket, racquet 754 | n04040759 radiator 755 | n04041544 radio, wireless 756 | n04044716 radio telescope, radio reflector 757 | n04049303 rain barrel 758 | n04065272 recreational vehicle, RV, R.V. 759 | n04067472 reel 760 | n04069434 reflex camera 761 | n04070727 refrigerator, icebox 762 | n04074963 remote control, remote 763 | n04081281 restaurant, eating house, eating place, eatery 764 | n04086273 revolver, six-gun, six-shooter 765 | n04090263 rifle 766 | n04099969 rocking chair, rocker 767 | n04111531 rotisserie 768 | n04116512 rubber eraser, rubber, pencil eraser 769 | n04118538 rugby ball 770 | n04118776 rule, ruler 771 | n04120489 running shoe 772 | n04125021 safe 773 | n04127249 safety pin 774 | n04131690 saltshaker, salt shaker 775 | n04133789 sandal 776 | n04136333 sarong 777 | n04141076 sax, saxophone 778 | n04141327 scabbard 779 | n04141975 scale, weighing machine 780 | n04146614 school bus 781 | n04147183 schooner 782 | n04149813 scoreboard 783 | n04152593 screen, CRT screen 784 | n04153751 screw 785 | n04154565 screwdriver 786 | n04162706 seat belt, seatbelt 787 | n04179913 sewing machine 788 | n04192698 shield, buckler 789 | n04200800 shoe shop, shoe-shop, shoe store 790 | n04201297 shoji 791 | n04204238 shopping basket 792 | n04204347 shopping cart 793 | n04208210 shovel 794 | n04209133 shower cap 795 | n04209239 shower curtain 796 | n04228054 ski 797 | n04229816 ski mask 798 | n04235860 sleeping bag 799 | n04238763 slide rule, slipstick 800 | n04239074 sliding door 801 | n04243546 slot, one-armed bandit 802 | n04251144 snorkel 803 | n04252077 snowmobile 804 | n04252225 snowplow, snowplough 805 | n04254120 soap dispenser 806 | n04254680 soccer ball 807 | n04254777 sock 808 | n04258138 solar dish, solar collector, solar furnace 809 | n04259630 sombrero 810 | n04263257 soup bowl 811 | n04264628 space bar 812 | n04265275 space heater 813 | n04266014 space shuttle 814 | n04270147 spatula 815 | n04273569 speedboat 816 | n04275548 spider web, spider's web 817 | n04277352 spindle 818 | n04285008 sports car, sport car 819 | n04286575 spotlight, spot 820 | n04296562 stage 821 | n04310018 steam locomotive 822 | n04311004 steel arch bridge 823 | n04311174 steel drum 824 | n04317175 stethoscope 825 | n04325704 stole 826 | n04326547 stone wall 827 | n04328186 stopwatch, stop watch 828 | n04330267 stove 829 | n04332243 strainer 830 | n04335435 streetcar, tram, tramcar, trolley, trolley car 831 | n04336792 stretcher 832 | n04344873 studio couch, day bed 833 | n04346328 stupa, tope 834 | n04347754 submarine, pigboat, sub, U-boat 835 | n04350905 suit, suit of clothes 836 | n04355338 sundial 837 | n04355933 sunglass 838 | n04356056 sunglasses, dark glasses, shades 839 | n04357314 sunscreen, sunblock, sun blocker 840 | n04366367 suspension bridge 841 | n04367480 swab, swob, mop 842 | n04370456 sweatshirt 843 | n04371430 swimming trunks, bathing trunks 844 | n04371774 swing 845 | n04372370 switch, electric switch, electrical switch 846 | n04376876 syringe 847 | n04380533 table lamp 848 | n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle 849 | n04392985 tape player 850 | n04398044 teapot 851 | n04399382 teddy, teddy bear 852 | n04404412 television, television system 853 | n04409515 tennis ball 854 | n04417672 thatch, thatched roof 855 | n04418357 theater curtain, theatre curtain 856 | n04423845 thimble 857 | n04428191 thresher, thrasher, threshing machine 858 | n04429376 throne 859 | n04435653 tile roof 860 | n04442312 toaster 861 | n04443257 tobacco shop, tobacconist shop, tobacconist 862 | n04447861 toilet seat 863 | n04456115 torch 864 | n04458633 totem pole 865 | n04461696 tow truck, tow car, wrecker 866 | n04462240 toyshop 867 | n04465501 tractor 868 | n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | n04476259 tray 870 | n04479046 trench coat 871 | n04482393 tricycle, trike, velocipede 872 | n04483307 trimaran 873 | n04485082 tripod 874 | n04486054 triumphal arch 875 | n04487081 trolleybus, trolley coach, trackless trolley 876 | n04487394 trombone 877 | n04493381 tub, vat 878 | n04501370 turnstile 879 | n04505470 typewriter keyboard 880 | n04507155 umbrella 881 | n04509417 unicycle, monocycle 882 | n04515003 upright, upright piano 883 | n04517823 vacuum, vacuum cleaner 884 | n04522168 vase 885 | n04523525 vault 886 | n04525038 velvet 887 | n04525305 vending machine 888 | n04532106 vestment 889 | n04532670 viaduct 890 | n04536866 violin, fiddle 891 | n04540053 volleyball 892 | n04542943 waffle iron 893 | n04548280 wall clock 894 | n04548362 wallet, billfold, notecase, pocketbook 895 | n04550184 wardrobe, closet, press 896 | n04552348 warplane, military plane 897 | n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | n04554684 washer, automatic washer, washing machine 899 | n04557648 water bottle 900 | n04560804 water jug 901 | n04562935 water tower 902 | n04579145 whiskey jug 903 | n04579432 whistle 904 | n04584207 wig 905 | n04589890 window screen 906 | n04590129 window shade 907 | n04591157 Windsor tie 908 | n04591713 wine bottle 909 | n04592741 wing 910 | n04596742 wok 911 | n04597913 wooden spoon 912 | n04599235 wool, woolen, woollen 913 | n04604644 worm fence, snake fence, snake-rail fence, Virginia fence 914 | n04606251 wreck 915 | n04612504 yawl 916 | n04613696 yurt 917 | n06359193 web site, website, internet site, site 918 | n06596364 comic book 919 | n06785654 crossword puzzle, crossword 920 | n06794110 street sign 921 | n06874185 traffic light, traffic signal, stoplight 922 | n07248320 book jacket, dust cover, dust jacket, dust wrapper 923 | n07565083 menu 924 | n07579787 plate 925 | n07583066 guacamole 926 | n07584110 consomme 927 | n07590611 hot pot, hotpot 928 | n07613480 trifle 929 | n07614500 ice cream, icecream 930 | n07615774 ice lolly, lolly, lollipop, popsicle 931 | n07684084 French loaf 932 | n07693725 bagel, beigel 933 | n07695742 pretzel 934 | n07697313 cheeseburger 935 | n07697537 hotdog, hot dog, red hot 936 | n07711569 mashed potato 937 | n07714571 head cabbage 938 | n07714990 broccoli 939 | n07715103 cauliflower 940 | n07716358 zucchini, courgette 941 | n07716906 spaghetti squash 942 | n07717410 acorn squash 943 | n07717556 butternut squash 944 | n07718472 cucumber, cuke 945 | n07718747 artichoke, globe artichoke 946 | n07720875 bell pepper 947 | n07730033 cardoon 948 | n07734744 mushroom 949 | n07742313 Granny Smith 950 | n07745940 strawberry 951 | n07747607 orange 952 | n07749582 lemon 953 | n07753113 fig 954 | n07753275 pineapple, ananas 955 | n07753592 banana 956 | n07754684 jackfruit, jak, jack 957 | n07760859 custard apple 958 | n07768694 pomegranate 959 | n07802026 hay 960 | n07831146 carbonara 961 | n07836838 chocolate sauce, chocolate syrup 962 | n07860988 dough 963 | n07871810 meat loaf, meatloaf 964 | n07873807 pizza, pizza pie 965 | n07875152 potpie 966 | n07880968 burrito 967 | n07892512 red wine 968 | n07920052 espresso 969 | n07930864 cup 970 | n07932039 eggnog 971 | n09193705 alp 972 | n09229709 bubble 973 | n09246464 cliff, drop, drop-off 974 | n09256479 coral reef 975 | n09288635 geyser 976 | n09332890 lakeside, lakeshore 977 | n09399592 promontory, headland, head, foreland 978 | n09421951 sandbar, sand bar 979 | n09428293 seashore, coast, seacoast, sea-coast 980 | n09468604 valley, vale 981 | n09472597 volcano 982 | n09835506 ballplayer, baseball player 983 | n10148035 groom, bridegroom 984 | n10565667 scuba diver 985 | n11879895 rapeseed 986 | n11939491 daisy 987 | n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | n12144580 corn 989 | n12267677 acorn 990 | n12620546 hip, rose hip, rosehip 991 | n12768682 buckeye, horse chestnut, conker 992 | n12985857 coral fungus 993 | n12998815 agaric 994 | n13037406 gyromitra 995 | n13040303 stinkhorn, carrion fungus 996 | n13044778 earthstar 997 | n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | n13054560 bolete 999 | n13133613 ear, spike, capitulum 1000 | n15075141 toilet tissue, toilet paper, bathroom tissue -------------------------------------------------------------------------------- /utils/imgs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-dl/MultimodalGame/0782a7bf3cf5125cd7c35a243e97f0e9e016fca3/utils/imgs/.gitkeep -------------------------------------------------------------------------------- /utils/package_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a directory of images with following hierarchy: 3 | ./imgs/ 4 | class-0/ 5 | img-0-01 6 | img-0-02 7 | ... 8 | class-1/ 9 | img-1-01 10 | ... 11 | ... 12 | 13 | Create hdf5 representing each of these images preprocessed. 14 | 15 | Layer Names 16 | =========== 17 | 18 | ResNet-34 19 | 20 | bn1 - (4L, 64L, 114L, 114L) 21 | relu - (4L, 64L, 114L, 114L) 22 | maxpool - (4L, 64L, 57L, 57L) 23 | layer1 - (4L, 64L, 57L, 57L) 24 | layer2 - (4L, 128L, 29L, 29L) 25 | layer3 - (4L, 256L, 15L, 15L) 26 | layer4_0_relu - (4L, 512L, 8L, 8L) 27 | layer4_1_relu - (4L, 512L, 8L, 8L) 28 | layer4_2 - (4L, 512L, 8L, 8L) #32768 29 | layer4_2_relu - (4L, 512L, 8L, 8L) 30 | avgpool - (4L, 512L, 1L, 1L) 31 | avgpool_512 - (4L, 512L) 32 | fc - (4L, 1000L) 33 | 34 | """ 35 | 36 | 37 | import sys 38 | import os 39 | 40 | import torch 41 | import torch.nn as nn 42 | from torch.autograd import Variable 43 | import torchvision.models as models 44 | import torchvision.datasets as dset 45 | import torchvision.transforms as transforms 46 | 47 | import numpy as np 48 | import h5py 49 | from tqdm import tqdm 50 | import gflags 51 | 52 | FLAGS = gflags.FLAGS 53 | 54 | 55 | def get_model(): 56 | return eval("models.resnet{}".format(FLAGS.resnet)) 57 | 58 | 59 | def basic_block(layer, relu=False): 60 | def forward(x): 61 | residual = x 62 | 63 | out = layer.conv1(x) 64 | out = layer.bn1(out) 65 | out = layer.relu(out) 66 | 67 | out = layer.conv2(out) 68 | out = layer.bn2(out) 69 | 70 | if layer.downsample is not None: 71 | residual = layer.downsample(x) 72 | 73 | out += residual 74 | if relu: 75 | out = layer.relu(out) 76 | 77 | return out 78 | return forward 79 | 80 | 81 | class FeatureModel(nn.Module): 82 | def __init__(self): 83 | super(FeatureModel, self).__init__() 84 | self.fn = get_model()(pretrained=True) 85 | 86 | # Turn off inplace 87 | for p in self.fn.modules(): 88 | if "ReLU" in p.__repr__(): 89 | p.inplace = False 90 | 91 | def forward(self, x, request=["layer4_2", "fc"]): 92 | model = self.fn 93 | 94 | ret = [] 95 | 96 | layers = [ 97 | (model.conv1, 'conv1'), 98 | (model.bn1, 'bn1'), 99 | (model.relu, 'relu'), 100 | (model.maxpool, 'maxpool'), 101 | (model.layer1, 'layer1'), 102 | (model.layer2, 'layer2'), 103 | (model.layer3, 'layer3'), 104 | ] 105 | 106 | if FLAGS.resnet == "34": 107 | layers += [ 108 | (model.layer4[0], 'layer4_0_relu'), 109 | (model.layer4[1], 'layer4_1_relu'), 110 | (basic_block(model.layer4[2]), 'layer4_2'), 111 | (lambda x: model.layer4[2].relu(x), 'layer4_2_relu'), 112 | ] 113 | else: 114 | raise NotImplementedError() 115 | 116 | layers += [ 117 | (model.avgpool, 'avgpool'), 118 | (lambda x: x.view(x.size(0), -1), 'avgpool_512'), 119 | (model.fc, 'fc'), 120 | ] 121 | 122 | for module, name in layers: 123 | x = module(x) 124 | # print(" N", x.data.numel()) 125 | # print("<0", (x.data < 0.).sum()) 126 | # print("=0", (x.data == 0.).sum()) 127 | # print("{} - {}".format(name, tuple(x.size()))) 128 | if name in request: 129 | ret.append(x) 130 | 131 | return ret 132 | 133 | 134 | def label_mapping(desc_path): 135 | label_to_id = dict() 136 | with open(desc_path) as f: 137 | for line in f: 138 | line = line.strip() 139 | label_id, label, desc = line.split(',') 140 | label_to_id[label] = int(label_id) 141 | return label_to_id 142 | 143 | 144 | def custom_dtype(outp, request): 145 | schema = [('Location', np.str_, 50), 146 | ('Target', 'i')] 147 | for o, r in zip(outp, request): 148 | size = tuple([1] + list(o.shape)[1:]) 149 | schema.append((r, np.float32, size)) 150 | dtype = np.dtype(schema) 151 | return dtype 152 | 153 | 154 | def multi_split(outp, batch_size): 155 | return [np.split(o, batch_size) for o in outp] 156 | 157 | 158 | def run(): 159 | dtype = None 160 | 161 | # Model Initialization 162 | model = FeatureModel() 163 | model.fn.eval() 164 | model.eval() 165 | 166 | if FLAGS.cuda: 167 | model.fn.cuda() 168 | model.cuda() 169 | 170 | # Load dataset and transform 171 | dataset = dset.ImageFolder(root=FLAGS.load_imgs, 172 | transform=transforms.Compose([ 173 | transforms.Scale(227), 174 | transforms.CenterCrop(227), 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 177 | ]) 178 | ) 179 | 180 | # Read images 181 | dataloader = torch.utils.data.DataLoader(dataset, 182 | batch_size=FLAGS.batch_size, 183 | shuffle=False) 184 | 185 | # Used to get label ids 186 | label_to_id = label_mapping(FLAGS.load_desc) 187 | 188 | # Only keep relevant output 189 | request = FLAGS.request.split(',') 190 | 191 | # Preprocess images to new vector representation 192 | data = [] 193 | 194 | targets = [] 195 | locations = [] 196 | other = dict() 197 | 198 | def data_it(dataloader): 199 | _it = iter(dataloader) 200 | 201 | while True: 202 | try: 203 | ret = next(_it) 204 | except StopIteration: 205 | break 206 | except BaseException as e: 207 | continue 208 | yield ret 209 | 210 | for i, img in tqdm(enumerate(data_it(dataloader))): 211 | tensor, target = img 212 | 213 | if FLAGS.cuda: 214 | tensor = tensor.cuda() 215 | outp = model(Variable(tensor), request) 216 | 217 | np_outp = [o.data.cpu().numpy() for o in outp] 218 | 219 | batch_size = np_outp[0].shape[0] 220 | offset = i * batch_size 221 | 222 | for j, o in enumerate(zip(*multi_split(np_outp, batch_size))): 223 | filename = dataset.imgs[offset + j][0] 224 | parts = filename.split(os.sep) 225 | label = parts[-2] # the label as a string 226 | loc = parts[-1] # something like '1-100-251756690_e68ac649e3_z.jpg' 227 | label_id = label_to_id[label] # use the label id specified by the desc file 228 | 229 | # Save Image 230 | row = tuple([loc, label_id] + list(o)) 231 | data.append(row) 232 | targets.append(label_id) 233 | locations.append(loc) 234 | for r, oo in zip(request, list(o)): 235 | other.setdefault(r, []).append(oo) 236 | 237 | # Save hdf5 file 238 | hdf5_f = h5py.File(FLAGS.save_hdf5, 'w') 239 | hdf5_f.create_dataset("Target", data=np.array(targets)) 240 | hdf5_f.create_dataset("Location", data=np.array(locations)) 241 | for r in request: 242 | hdf5_f.create_dataset(r, data=np.array(other[r])) 243 | hdf5_f.close() 244 | 245 | 246 | if __name__ == "__main__": 247 | # Settings 248 | gflags.DEFINE_string("load_desc", "descriptions.csv", "Path to description file.") 249 | gflags.DEFINE_string("load_imgs", "./imgs/train", "Path to input data.") 250 | gflags.DEFINE_string("save_hdf5", "train.hdf5", "Path to store new dataset.") 251 | gflags.DEFINE_integer("batch_size", 4, "Minibatch size.") 252 | gflags.DEFINE_enum("resnet", "34", ["18", "34", "50", "101", "152"], "Specify Resnet variant.") 253 | gflags.DEFINE_string("request", "layer4_2,avgpool_512,fc", "Run feature model. Save specified layer output.") 254 | gflags.DEFINE_boolean("cuda", False, "") 255 | 256 | FLAGS(sys.argv) 257 | 258 | run() 259 | --------------------------------------------------------------------------------