├── .gitignore ├── LICENSE ├── README.md └── ibr_game ├── beam_search.py ├── few_shot_learning_system.py ├── inner_loop_optimizers.py ├── lang_id.py ├── maml_speaker.py ├── models.py ├── referential_game.py ├── translate_caption.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 CLAW Lab @ CMU LTI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ToM 2 | Code accompanying ICML 2021 paper "Few-shot Language Coordination by Modeling Theory of Mind" 3 | 4 | The core code used in preparing data and training MAML speaker is `ibr\_game/maml_speaker.py`. 5 | 6 | Stay tuned for the data and pretrained listeners. 7 | 8 | # Misc 9 | `ibr_game/beam_search.py` is a lightweight batchified standalone beam search toolkit we designed during doing experiments. 10 | 11 | # Acknowledgement 12 | We used and adapted code from [MAML++](https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch) and [S2P](https://github.com/backpropper/s2p). 13 | -------------------------------------------------------------------------------- /ibr_game/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import tqdm 5 | 6 | from typing import Callable, Generator 7 | 8 | HUGE = 1e15 9 | 10 | 11 | def beam_search(model_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 12 | beam_size: int, 13 | max_len: int, 14 | eos_id: int, 15 | bos_id: int, 16 | dataloader: Generator[torch.Tensor, None, None]): 17 | 18 | return_outputs = [] 19 | return_logprobs = [] 20 | 21 | for batch in dataloader: 22 | device = batch.get_device() 23 | batch_size = batch.size()[0] 24 | beam_outputs = torch.full( 25 | (batch_size, 1, 1), bos_id, dtype=torch.long).to(device) 26 | beam_inputs = torch.full( 27 | (batch_size, ), bos_id, dtype=torch.long).to(device) 28 | beam_hiddens = batch 29 | beam_logprobs = torch.zeros(batch_size, 1).to(device) 30 | finish_mask = torch.zeros(batch_size, 1).to(device) 31 | 32 | for i in range(max_len): 33 | outputs, beam_hiddens_ = model_func(beam_hiddens, beam_inputs) 34 | vocabulary = outputs.size()[-1] 35 | 36 | # (B, b) -> (B, b, V) 37 | beam_logprobs = beam_logprobs.unsqueeze( 38 | 2).repeat(1, 1, vocabulary) 39 | # (B x b, V) -> (B, b, V) 40 | outputs = outputs.view(beam_logprobs.size()) 41 | 42 | finish_mask = finish_mask.unsqueeze(2).repeat(1, 1, vocabulary) 43 | outputs = outputs * (1 - finish_mask) - HUGE * finish_mask 44 | outputs[:, :, eos_id] = outputs[:, :, eos_id] * \ 45 | (1 - finish_mask[:, :, 0]) 46 | 47 | beam_logprobs = (beam_logprobs + outputs).view(batch_size, -1) 48 | beam_logprobs, indices = torch.topk(beam_logprobs, beam_size) 49 | 50 | beam_indices = indices // vocabulary 51 | word_indices = indices % vocabulary 52 | beam_inputs = word_indices.view(-1) 53 | finish_mask = (word_indices == eos_id).float() 54 | 55 | # (B, b, i+1) -> (B, b, i+1) 56 | beam_outputs = torch.gather( 57 | beam_outputs, 1, beam_indices.unsqueeze(2).repeat(1, 1, i+1)) 58 | # cat((B, b, i+1), (B, b, 1)) -> (B, b, i+2) 59 | beam_outputs = torch.cat( 60 | [beam_outputs, word_indices.unsqueeze(2)], dim=2) 61 | # (B, b, H) -> (B, b, H) -> (B x b, H) 62 | hid_size = beam_hiddens_.size()[-1] 63 | beam_hiddens = torch.gather( 64 | beam_hiddens_.view(batch_size, -1, hid_size), 65 | 1, 66 | beam_indices.unsqueeze(2).repeat(1, 1, hid_size))\ 67 | .view(-1, hid_size) 68 | 69 | return_outputs.append(beam_outputs) 70 | return_logprobs.append(beam_logprobs) 71 | 72 | return_outputs = torch.cat(return_outputs, dim=0) 73 | return_logprobs = torch.cat(return_logprobs, dim=0) 74 | 75 | return (return_outputs, return_logprobs) 76 | 77 | 78 | class LanguageModel(nn.Module): 79 | def __init__(self, vocabulary, hidden_size): 80 | super(LanguageModel, self).__init__() 81 | self.word_emb = nn.Embedding(vocabulary, hidden_size) 82 | self.gru = nn.GRUCell(hidden_size, hidden_size) 83 | self.output = nn.Sequential(nn.Linear(hidden_size, vocabulary), 84 | nn.LogSoftmax(dim=-1)) 85 | 86 | def forward(self, hidden, inputs): 87 | hid = self.gru(self.word_emb(inputs), hidden) 88 | return self.output(hid), hid 89 | 90 | 91 | def generating_random_data(batch_size: int, hidden_size: int, size: int): 92 | for i in tqdm.tqdm(range(size)): 93 | yield torch.randn(batch_size, hidden_size).cuda() 94 | 95 | 96 | if __name__ == "__main__": 97 | language_model = LanguageModel(200, 128).cuda() 98 | with torch.no_grad(): 99 | beam_search(language_model, 10, 15, 99, 0, 100 | generating_random_data(128, 128, 1000)) 101 | -------------------------------------------------------------------------------- /ibr_game/few_shot_learning_system.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | from fastcore.utils import mapped 10 | 11 | from inner_loop_optimizers import LSLRGradientDescentLearningRule 12 | 13 | 14 | def set_torch_seed(seed): 15 | """ 16 | Sets the pytorch seeds for current experiment run 17 | :param seed: The seed (int) 18 | :return: A random number generator to use 19 | """ 20 | rng = np.random.RandomState(seed=seed) 21 | torch_seed = rng.randint(0, 999999) 22 | torch.manual_seed(seed=torch_seed) 23 | 24 | return rng 25 | 26 | 27 | 28 | class MAMLFewShotClassifier(nn.Module): 29 | def __init__(self, classifier_class, args, listener_args): 30 | """ 31 | Initializes a MAML few shot learning system 32 | :param classifier_class: The classifier's class 33 | :param args: A namedtuple of arguments specifying various hyperparameters. 34 | :param seed 35 | :param number_of_training_steps_per_iter: 36 | :param learnable_per_layer_per_step_inner_loop_learning_rate 37 | :param total_epochs 38 | :param min_learning_rate 39 | :param multi_step_loss_num_epochs 40 | :param enable_inner_loop_optimizable_bn_params 41 | :param second_order 42 | :param first_order_to_second_order_epoch 43 | :param dataset_name 44 | :param use_multi_step_loss_optimization 45 | :param listener_args: Listener arguments 46 | """ 47 | super(MAMLFewShotClassifier, self).__init__() 48 | self.args = args 49 | self.classifier_class = classifier_class 50 | self.batch_size = args.batch_size 51 | self.device = listener_args.device 52 | self.current_epoch = 0 53 | 54 | self.rng = set_torch_seed(seed=listener_args.seed) 55 | self.classifier = classifier_class( 56 | args=listener_args).to(device=self.device) 57 | self.task_learning_rate = args.init_inner_loop_learning_rate 58 | 59 | self.inner_loop_optimizer = LSLRGradientDescentLearningRule(device=self.device, 60 | init_learning_rate=self.task_learning_rate, 61 | total_num_inner_loop_steps=self.args.number_of_training_steps_per_iter, 62 | use_learnable_learning_rates=self.args.learnable_per_layer_per_step_inner_loop_learning_rate) 63 | self.inner_loop_optimizer.initialise( 64 | names_weights_dict=self.get_inner_loop_parameter_dict(params=self.classifier.named_parameters(), excluded_params=self.args.excluded_params)) 65 | 66 | print("Inner Loop parameters") 67 | for key, value in self.inner_loop_optimizer.named_parameters(): 68 | print(key, value.shape) 69 | 70 | self.to(self.device) 71 | print("Outer Loop parameters") 72 | for name, param in self.named_parameters(): 73 | if param.requires_grad: 74 | print(name, param.shape, param.device, param.requires_grad) 75 | 76 | self.optimizer = optim.Adam( 77 | self.trainable_parameters(), lr=args.meta_learning_rate, amsgrad=False) 78 | self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=self.args.total_epochs, 79 | eta_min=self.args.min_learning_rate) 80 | 81 | if torch.cuda.is_available(): 82 | if torch.cuda.device_count() > 1: 83 | self.to(torch.cuda.current_device()) 84 | self.classifier = nn.DataParallel(module=self.classifier) 85 | else: 86 | self.to(torch.cuda.current_device()) 87 | 88 | self.device = torch.cuda.current_device() 89 | 90 | def move_to_cuda(self, x): 91 | return x.to(device=self.device) 92 | 93 | def get_per_step_loss_importance_vector(self): 94 | """ 95 | Generates a tensor of dimensionality (num_inner_loop_steps) indicating the importance of each step's target 96 | loss towards the optimization loss. 97 | :return: A tensor to be used to compute the weighted average of the loss, useful for 98 | the MSL (Multi Step Loss) mechanism. 99 | """ 100 | loss_weights = np.ones(shape=(self.args.number_of_training_steps_per_iter)) * ( 101 | 1.0 / self.args.number_of_training_steps_per_iter) 102 | decay_rate = 1.0 / self.args.number_of_training_steps_per_iter / \ 103 | self.args.multi_step_loss_num_epochs 104 | min_value_for_non_final_losses = 0.03 / \ 105 | self.args.number_of_training_steps_per_iter 106 | for i in range(len(loss_weights) - 1): 107 | curr_value = np.maximum( 108 | loss_weights[i] - (self.current_epoch * decay_rate), min_value_for_non_final_losses) 109 | loss_weights[i] = curr_value 110 | 111 | curr_value = np.minimum( 112 | loss_weights[-1] + (self.current_epoch * 113 | (self.args.number_of_training_steps_per_iter - 1) * decay_rate), 114 | 1.0 - ((self.args.number_of_training_steps_per_iter - 1) * min_value_for_non_final_losses)) 115 | loss_weights[-1] = curr_value 116 | loss_weights = torch.Tensor(loss_weights).to(device=self.device) 117 | return loss_weights 118 | 119 | def get_inner_loop_parameter_dict(self, params, excluded_params=[]): 120 | """ 121 | Returns a dictionary with the parameters to use for inner loop updates. 122 | :param params: A dictionary of the network's parameters. 123 | :return: A dictionary of the parameters to use for the inner loop optimization process. 124 | """ 125 | param_dict = dict() 126 | for name, param in params: 127 | if name in excluded_params: 128 | continue 129 | if param.requires_grad: 130 | if self.args.enable_inner_loop_optimizable_bn_params: 131 | param_dict[name] = param.to(device=self.device) 132 | else: 133 | if "norm_layer" not in name: 134 | param_dict[name] = param.to(device=self.device) 135 | 136 | return param_dict 137 | 138 | def apply_inner_loop_update(self, loss, names_weights_copy, use_second_order, current_step_idx): 139 | """ 140 | Applies an inner loop update given current step's loss, the weights to update, a flag indicating whether to use 141 | second order derivatives and the current step's index. 142 | :param loss: Current step's loss with respect to the support set. 143 | :param names_weights_copy: A dictionary with names to parameters to update. 144 | :param use_second_order: A boolean flag of whether to use second order derivatives. 145 | :param current_step_idx: Current step's index. 146 | :return: A dictionary with the updated weights (name, param) 147 | """ 148 | num_gpus = torch.cuda.device_count() 149 | if num_gpus > 1: 150 | self.classifier.module.zero_grad(params=names_weights_copy) 151 | else: 152 | self.classifier.zero_grad(params=names_weights_copy) 153 | 154 | grads = torch.autograd.grad(loss, names_weights_copy.values(), 155 | create_graph=use_second_order, allow_unused=True) 156 | names_grads_copy = dict(zip(names_weights_copy.keys(), grads)) 157 | 158 | names_weights_copy = {key: value[0] 159 | for key, value in names_weights_copy.items()} 160 | 161 | for key, grad in names_grads_copy.items(): 162 | if grad is None: 163 | print('Grads not found for inner loop parameter', key) 164 | names_grads_copy[key] = names_grads_copy[key].sum(dim=0) 165 | 166 | names_weights_copy = self.inner_loop_optimizer.update_params(names_weights_dict=names_weights_copy, 167 | names_grads_wrt_params_dict=names_grads_copy, 168 | num_step=current_step_idx) 169 | 170 | num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 171 | names_weights_copy = { 172 | name.replace('module.', ''): value.unsqueeze(0).repeat( 173 | [num_devices] + [1 for i in range(len(value.shape))]) for 174 | name, value in names_weights_copy.items()} 175 | 176 | return names_weights_copy 177 | 178 | def get_across_task_loss_metrics(self, total_losses, total_accuracies): 179 | losses = dict() 180 | 181 | losses['loss'] = torch.mean(torch.stack(total_losses)) 182 | losses['accuracy'] = np.mean(total_accuracies) 183 | 184 | return losses 185 | 186 | def forward(self, data_batch, epoch, use_second_order, use_multi_step_loss_optimization, num_steps, training_phase): 187 | """ 188 | Runs a forward outer loop pass on the batch of tasks using the MAML/++ framework. 189 | :param data_batch: A data batch containing the support and target sets. 190 | :param epoch: Current epoch's index 191 | :param use_second_order: A boolean saying whether to use second order derivatives. 192 | :param use_multi_step_loss_optimization: Whether to optimize on the outer loop using just the last step's 193 | target loss (True) or whether to use multi step loss which improves the stability of the system (False) 194 | :param num_steps: Number of inner loop steps. 195 | :param training_phase: Whether this is a training phase (True) or an evaluation phase (False) 196 | :return: A dictionary with the collected losses of the current outer forward propagation. 197 | """ 198 | x_support_set, x_target_set, y_support_set, y_target_set = data_batch 199 | 200 | # _, ncs, _ = y_support_set.shape 201 | 202 | # self.num_classes_per_set = ncs 203 | 204 | total_losses = [] 205 | total_accuracies = [] 206 | per_task_target_preds = [[] for i in range(len(y_target_set))] 207 | self.classifier.zero_grad() 208 | for task_id, (x_support_set_task, y_support_set_task, x_target_set_task, y_target_set_task) in \ 209 | enumerate(zip(zip(*x_support_set), 210 | y_support_set, 211 | zip(*x_target_set), 212 | y_target_set)): 213 | task_losses = [] 214 | # task_accuracies = [] 215 | per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector() 216 | names_weights_copy = self.get_inner_loop_parameter_dict( 217 | self.classifier.named_parameters(), self.args.excluded_params) 218 | 219 | 220 | num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 221 | 222 | names_weights_copy = { 223 | name.replace('module.', ''): value.unsqueeze(0).repeat( 224 | [num_devices] + [1 for i in range(len(value.shape))]) for 225 | name, value in names_weights_copy.items()} 226 | 227 | # _, _, c, h, w = x_target_set_task.shape 228 | 229 | # x_support_set_task = x_support_set_task.view(-1, c, h, w) 230 | # y_support_set_task = y_support_set_task.view(-1) 231 | # x_target_set_task = x_target_set_task.view(-1, c, h, w) 232 | # y_target_set_task = y_target_set_task.view(-1) 233 | 234 | for num_step in range(num_steps): 235 | 236 | support_loss, _ = self.net_forward(x=x_support_set_task, 237 | y=y_support_set_task, 238 | weights=names_weights_copy, 239 | backup_running_statistics=True if ( 240 | num_step == 0) else False, 241 | training=True, num_step=num_step) 242 | 243 | names_weights_copy = self.apply_inner_loop_update(loss=support_loss, 244 | names_weights_copy=names_weights_copy, 245 | use_second_order=use_second_order, 246 | current_step_idx=num_step) 247 | 248 | if use_multi_step_loss_optimization and training_phase and epoch < self.args.multi_step_loss_num_epochs: 249 | target_loss, target_preds = self.net_forward(x=x_target_set_task, 250 | y=y_target_set_task, weights=names_weights_copy, 251 | backup_running_statistics=False, training=True, 252 | num_step=num_step) 253 | 254 | task_losses.append( 255 | per_step_loss_importance_vectors[num_step] * target_loss) 256 | else: 257 | if num_step == (self.args.number_of_training_steps_per_iter - 1): 258 | target_loss, target_preds = self.net_forward(x=x_target_set_task, 259 | y=y_target_set_task, weights=names_weights_copy, 260 | backup_running_statistics=False, training=True, 261 | num_step=num_step) 262 | task_losses.append(target_loss) 263 | 264 | # per_task_target_preds[task_id] = target_preds.detach( 265 | # ).cpu().numpy() 266 | per_task_target_preds[task_id] = target_preds.detach() 267 | _, predicted = torch.max(target_preds.data, 1) 268 | 269 | if y_target_set_task.size() != predicted.size(): 270 | y_target_set_task_hard = torch.max(y_target_set_task, dim=-1)[1] 271 | accuracy = predicted.float().eq(y_target_set_task_hard.data.float()).cpu().float() 272 | else: 273 | accuracy = predicted.float().eq(y_target_set_task.data.float()).cpu().float() 274 | task_losses = torch.sum(torch.stack(task_losses)) 275 | total_losses.append(task_losses) 276 | total_accuracies.extend(accuracy) 277 | 278 | if not training_phase: 279 | self.classifier.restore_backup_stats() 280 | 281 | losses = self.get_across_task_loss_metrics(total_losses=total_losses, 282 | total_accuracies=total_accuracies) 283 | 284 | for idx, item in enumerate(per_step_loss_importance_vectors): 285 | losses['loss_importance_vector_{}'.format( 286 | idx)] = item.detach().cpu().numpy() 287 | 288 | return losses, per_task_target_preds 289 | 290 | def net_forward(self, x, y, weights, backup_running_statistics, training, num_step): 291 | """ 292 | A base model forward pass on some data points x. Using the parameters in the weights dictionary. Also requires 293 | boolean flags indicating whether to reset the running statistics at the end of the run (if at evaluation phase). 294 | A flag indicating whether this is the training session and an int indicating the current step's number in the 295 | inner loop. 296 | :param x: A data batch of shape b, c, h, w 297 | :param y: A data targets batch of shape b, n_classes 298 | :param weights: A dictionary containing the weights to pass to the network. 299 | :param backup_running_statistics: A flag indicating whether to reset the batch norm running statistics to their 300 | previous values after the run (only for evaluation) 301 | :param training: A flag indicating whether the current process phase is a training or evaluation. 302 | :param num_step: An integer indicating the number of the step in the inner loop. 303 | :return: the crossentropy losses with respect to the given y, the predictions of the base model. 304 | """ 305 | preds = self.classifier.maml_forward(x=x, params=weights, 306 | training=training, 307 | backup_running_statistics=backup_running_statistics, num_step=num_step) 308 | 309 | # self.classifier.sim(x, weights, y) 310 | 311 | if y.size() == preds.size(): # using soft label 312 | loss = -torch.mean(torch.sum( 313 | torch.log_softmax(preds, dim=-1) 314 | * torch.softmax(y, dim=-1), 315 | dim=-1 316 | ) 317 | ) 318 | else: 319 | loss = F.cross_entropy(input=preds, target=y) 320 | 321 | return loss, preds 322 | 323 | def trainable_parameters(self): 324 | """ 325 | Returns an iterator over the trainable parameters of the model. 326 | """ 327 | for param in self.parameters(): 328 | if param.requires_grad: 329 | yield param 330 | 331 | def train_forward_prop(self, data_batch, epoch): 332 | """ 333 | Runs an outer loop forward prop using the meta-model and base-model. 334 | :param data_batch: A data batch containing the support set and the target set input, output pairs. 335 | :param epoch: The index of the currrent epoch. 336 | :return: A dictionary of losses for the current step. 337 | """ 338 | losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch, 339 | use_second_order=self.args.second_order and 340 | epoch > self.args.first_order_to_second_order_epoch, 341 | use_multi_step_loss_optimization=self.args.use_multi_step_loss_optimization, 342 | num_steps=self.args.number_of_training_steps_per_iter, 343 | training_phase=True) 344 | return losses, per_task_target_preds 345 | 346 | def evaluation_forward_prop(self, data_batch, epoch): 347 | """ 348 | Runs an outer loop evaluation forward prop using the meta-model and base-model. 349 | :param data_batch: A data batch containing the support set and the target set input, output pairs. 350 | :param epoch: The index of the currrent epoch. 351 | :return: A dictionary of losses for the current step. 352 | """ 353 | losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch, use_second_order=False, 354 | use_multi_step_loss_optimization=True, 355 | num_steps=self.args.number_of_evaluation_steps_per_iter, 356 | training_phase=False) 357 | 358 | return losses, per_task_target_preds 359 | 360 | def meta_update(self, loss): 361 | """ 362 | Applies an outer loop update on the meta-parameters of the model. 363 | :param loss: The current crossentropy loss. 364 | """ 365 | self.optimizer.zero_grad() 366 | loss.backward() 367 | if 'imagenet' in self.args.dataset_name: 368 | for _, param in self.classifier.named_parameters(): 369 | if param.requires_grad: 370 | # not sure if this is necessary, more experiments are needed 371 | param.grad.data.clamp_(-10, 10) 372 | self.optimizer.step() 373 | 374 | def run_train_iter(self, data_batch, epoch): 375 | """ 376 | Runs an outer loop update step on the meta-model's parameters. 377 | :param data_batch: input data batch containing the support set and target set input, output pairs 378 | :param epoch: the index of the current epoch 379 | :return: The losses of the ran iteration. 380 | """ 381 | epoch = int(epoch) 382 | self.scheduler.step(epoch=epoch) 383 | if self.current_epoch != epoch: 384 | self.current_epoch = epoch 385 | 386 | if not self.training: 387 | self.train() 388 | 389 | x_support_set, x_target_set, y_support_set, y_target_set = data_batch 390 | 391 | x_support_set = mapped(self.move_to_cuda, x_support_set) 392 | x_target_set = mapped(self.move_to_cuda, x_target_set) 393 | y_support_set = mapped(self.move_to_cuda, y_support_set) 394 | y_target_set = mapped(self.move_to_cuda, y_target_set) 395 | 396 | data_batch = (x_support_set, x_target_set, y_support_set, y_target_set) 397 | 398 | losses, per_task_target_preds = self.train_forward_prop( 399 | data_batch=data_batch, epoch=epoch) 400 | 401 | self.meta_update(loss=losses['loss']) 402 | losses['learning_rate'] = self.scheduler.get_lr()[0] 403 | self.optimizer.zero_grad() 404 | self.zero_grad() 405 | 406 | return losses, per_task_target_preds 407 | 408 | def run_validation_iter(self, data_batch): 409 | """ 410 | Runs an outer loop evaluation step on the meta-model's parameters. 411 | :param data_batch: input data batch containing the support set and target set input, output pairs 412 | :param epoch: the index of the current epoch 413 | :return: The losses of the ran iteration. 414 | """ 415 | 416 | if self.training: 417 | # self.eval() 418 | pass 419 | 420 | x_support_set, x_target_set, y_support_set, y_target_set = data_batch 421 | 422 | x_support_set = mapped(self.move_to_cuda, x_support_set) 423 | x_target_set = mapped(self.move_to_cuda, x_target_set) 424 | y_support_set = mapped(self.move_to_cuda, y_support_set) 425 | y_target_set = mapped(self.move_to_cuda, y_target_set) 426 | 427 | data_batch = (x_support_set, x_target_set, y_support_set, y_target_set) 428 | 429 | losses, per_task_target_preds = self.evaluation_forward_prop( 430 | data_batch=data_batch, epoch=self.current_epoch) 431 | 432 | # losses['loss'].backward() # uncomment if you get the weird memory error 433 | # self.zero_grad() 434 | # self.optimizer.zero_grad() 435 | 436 | return losses, per_task_target_preds 437 | 438 | def save_model(self, model_save_dir, state): 439 | """ 440 | Save the network parameter state and experiment state dictionary. 441 | :param model_save_dir: The directory to store the state at. 442 | :param state: The state containing the experiment state and the network. It's in the form of a dictionary 443 | object. 444 | """ 445 | state['network'] = self.state_dict() 446 | torch.save(state, f=model_save_dir) 447 | 448 | def load_model(self, model_save_dir, model_name, model_idx): 449 | """ 450 | Load checkpoint and return the state dictionary containing the network state params and experiment state. 451 | :param model_save_dir: The directory from which to load the files. 452 | :param model_name: The model_name to be loaded from the direcotry. 453 | :param model_idx: The index of the model (i.e. epoch number or 'latest' for the latest saved model of the current 454 | experiment) 455 | :return: A dictionary containing the experiment state and the saved model parameters. 456 | """ 457 | filepath = os.path.join( 458 | model_save_dir, "{}_{}".format(model_name, model_idx)) 459 | state = torch.load(filepath) 460 | state_dict_loaded = state['network'] 461 | self.load_state_dict(state_dict=state_dict_loaded) 462 | return state 463 | -------------------------------------------------------------------------------- /ibr_game/inner_loop_optimizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | 12 | class GradientDescentLearningRule(nn.Module): 13 | """Simple (stochastic) gradient descent learning rule. 14 | For a scalar error function `E(p[0], p_[1] ... )` of some set of 15 | potentially multidimensional parameters this attempts to find a local 16 | minimum of the loss function by applying updates to each parameter of the 17 | form 18 | p[i] := p[i] - learning_rate * dE/dp[i] 19 | With `learning_rate` a positive scaling parameter. 20 | The error function used in successive applications of these updates may be 21 | a stochastic estimator of the true error function (e.g. when the error with 22 | respect to only a subset of data-points is calculated) in which case this 23 | will correspond to a stochastic gradient descent learning rule. 24 | """ 25 | 26 | def __init__(self, device, learning_rate=1e-3): 27 | """Creates a new learning rule object. 28 | Args: 29 | learning_rate: A postive scalar to scale gradient updates to the 30 | parameters by. This needs to be carefully set - if too large 31 | the learning dynamic will be unstable and may diverge, while 32 | if set too small learning will proceed very slowly. 33 | """ 34 | super(GradientDescentLearningRule, self).__init__() 35 | assert learning_rate >= 0., 'learning_rate should be positive.' 36 | self.learning_rate = torch.ones(1) * learning_rate 37 | self.learning_rate.to(device) 38 | 39 | def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9): 40 | """Applies a single gradient descent update to all parameters. 41 | All parameter updates are performed using in-place operations and so 42 | nothing is returned. 43 | Args: 44 | grads_wrt_params: A list of gradients of the scalar loss function 45 | with respect to each of the parameters passed to `initialise` 46 | previously, with this list expected to be in the same order. 47 | """ 48 | updated_names_weights_dict = dict() 49 | for key in names_weights_dict.keys(): 50 | updated_names_weights_dict[key] = names_weights_dict[key] - self.learning_rate * \ 51 | names_grads_wrt_params_dict[ 52 | key] 53 | 54 | return updated_names_weights_dict 55 | 56 | 57 | class LSLRGradientDescentLearningRule(nn.Module): 58 | """Simple (stochastic) gradient descent learning rule. 59 | For a scalar error function `E(p[0], p_[1] ... )` of some set of 60 | potentially multidimensional parameters this attempts to find a local 61 | minimum of the loss function by applying updates to each parameter of the 62 | form 63 | p[i] := p[i] - learning_rate * dE/dp[i] 64 | With `learning_rate` a positive scaling parameter. 65 | The error function used in successive applications of these updates may be 66 | a stochastic estimator of the true error function (e.g. when the error with 67 | respect to only a subset of data-points is calculated) in which case this 68 | will correspond to a stochastic gradient descent learning rule. 69 | """ 70 | 71 | def __init__(self, device, total_num_inner_loop_steps, use_learnable_learning_rates, init_learning_rate=1e-3): 72 | """Creates a new learning rule object. 73 | Args: 74 | init_learning_rate: A postive scalar to scale gradient updates to the 75 | parameters by. This needs to be carefully set - if too large 76 | the learning dynamic will be unstable and may diverge, while 77 | if set too small learning will proceed very slowly. 78 | """ 79 | super(LSLRGradientDescentLearningRule, self).__init__() 80 | print(init_learning_rate) 81 | assert init_learning_rate >= 0., 'learning_rate should be positive.' 82 | 83 | self.init_learning_rate = torch.ones(1) * init_learning_rate 84 | self.init_learning_rate.to(device) 85 | self.total_num_inner_loop_steps = total_num_inner_loop_steps 86 | self.use_learnable_learning_rates = use_learnable_learning_rates 87 | 88 | def initialise(self, names_weights_dict): 89 | self.names_learning_rates_dict = nn.ParameterDict() 90 | for idx, (key, param) in enumerate(names_weights_dict.items()): 91 | self.names_learning_rates_dict[key.replace(".", "-")] = nn.Parameter( 92 | data=torch.ones(self.total_num_inner_loop_steps + 1) * self.init_learning_rate, 93 | requires_grad=self.use_learnable_learning_rates) 94 | 95 | def reset(self): 96 | 97 | # for key, param in self.names_learning_rates_dict.items(): 98 | # param.fill_(self.init_learning_rate) 99 | pass 100 | 101 | def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.1): 102 | """Applies a single gradient descent update to all parameters. 103 | All parameter updates are performed using in-place operations and so 104 | nothing is returned. 105 | Args: 106 | grads_wrt_params: A list of gradients of the scalar loss function 107 | with respect to each of the parameters passed to `initialise` 108 | previously, with this list expected to be in the same order. 109 | """ 110 | updated_names_weights_dict = dict() 111 | for key in names_grads_wrt_params_dict.keys(): 112 | updated_names_weights_dict[key] = names_weights_dict[key] - \ 113 | self.names_learning_rates_dict[key.replace(".", "-")][num_step] \ 114 | * names_grads_wrt_params_dict[ 115 | key] 116 | 117 | return updated_names_weights_dict 118 | 119 | -------------------------------------------------------------------------------- /ibr_game/lang_id.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from typing import List 4 | 5 | 6 | lang_set = torch.load("lang_id.data") 7 | 8 | 9 | def language_identification(sent: List[int]) -> int: 10 | result = set(range(10)) 11 | for word in sent: 12 | if len(lang_set[word]): 13 | result = result.intersection(lang_set[word]) 14 | try: 15 | return list(result)[0] 16 | except: 17 | return None 18 | 19 | 20 | if __name__ == "__main__": 21 | data = torch.load("coco_ml/labs/train_org_lang") 22 | lang_set = defaultdict(set) 23 | for image in data: 24 | for lang_id, lang in enumerate(image): 25 | for sent in lang: 26 | for word in sent: 27 | lang_set[word].add(lang_id) 28 | torch.save(lang_set, "lang_id.data") -------------------------------------------------------------------------------- /ibr_game/maml_speaker.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import json 3 | import argparse 4 | import random 5 | import copy 6 | import os 7 | import tqdm 8 | import time 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.autograd.profiler as profiler 16 | 17 | from models import Listener, AggregatedListener 18 | from few_shot_learning_system import MAMLFewShotClassifier 19 | from referential_game import get_caption_candidates, load_model 20 | from utils import truncate_dicts, nearest_images 21 | # from lang_id import language_identification 22 | 23 | # torch.backends.cudnn.enabled = False 24 | 25 | # torch.autograd.set_detect_anomaly(True) 26 | 27 | 28 | def get_maml_args(maml_args_file: str): 29 | class objectview(object): 30 | def __init__(self, d): 31 | self.__dict__ = d 32 | return objectview(json.load(open(maml_args_file, "r"))) 33 | 34 | 35 | def sample_listener(listener_template, listener_range, args, idx=None) \ 36 | -> Listener: 37 | if idx is None: 38 | listener_choice = random.choice(listener_range) 39 | else: 40 | listener_choice = listener_range[idx] 41 | new_args = copy.deepcopy(args) 42 | new_args.save_dir = listener_template.format(*listener_choice) 43 | new_args.vocab_size = new_args.listener_vocab_size 44 | new_args.seed = listener_choice[0] # hardcode warning!!! 45 | return load_model(new_args).listener 46 | 47 | 48 | def gen_game(n_img: int, N: int, M: int, n_distrs: int, 49 | sample_candidates: torch.Tensor) \ 50 | -> Tuple[torch.Tensor, torch.Tensor]: 51 | device = sample_candidates.device 52 | n = sample_candidates.size()[1] 53 | target_images = torch.randint(n_img, size=(N, M)).to(device) 54 | distr_images = torch.randint(n, size=(N, M, n_distrs + 1)).to(device) 55 | target_candidates = torch.index_select( 56 | sample_candidates, 0, target_images.view(-1)).view(N, M, n) 57 | distr_images = torch.gather( 58 | target_candidates, 2, distr_images).view(N*M, n_distrs+1) 59 | target_indices = torch.randint(n_distrs + 1, size=(N*M,)).to(device) 60 | distr_images[range(N*M), target_indices] = target_images.view(N*M) 61 | return distr_images, target_indices 62 | 63 | 64 | def print_profile_of_tensor(x): 65 | print(x.size(), x.device) 66 | 67 | 68 | def rollout(images: torch.Tensor, caption_beams: torch.Tensor, 69 | caption_beams_len: torch.Tensor, nearest_images: torch.Tensor, 70 | N: int, T: int, M: int, args, listener_range, 71 | maml_listener: MAMLFewShotClassifier, epoch_id: int) \ 72 | -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 73 | """ 74 | :param images: input images 75 | :param caption_beams: captions for the input images 76 | :param N: number of tasks 77 | :param T: number of interactions between speaker and listener 78 | :param M: number of target tasks (M-1 of them not used for training) 79 | """ 80 | b = caption_beams.size()[1] 81 | L = caption_beams.size()[2] 82 | n_distrs = args.num_distrs 83 | D_img = images.size()[-1] 84 | n_img = images.size()[0] 85 | listeners = list() 86 | for i in range(N): 87 | while True: 88 | try: 89 | listeners.append(sample_listener( 90 | args.listener_template, listener_range, args)) 91 | break 92 | except FileNotFoundError: 93 | pass 94 | imgs_till_now = torch.zeros(N, 1, M, n_distrs + 1, D_img).to(args.device) 95 | caps_till_now = torch.zeros(N, 1, M, L).long().to(args.device) 96 | cap_lens_till_now = torch.ones(N, 1, M).long().to(args.device) 97 | ys_till_now = torch.zeros(N, 1, M).long().to(args.device) 98 | ys_logp_till_now = torch.zeros(N, 1, M, n_distrs + 1).to(args.device) 99 | 100 | rollout_stat = {} 101 | rollout_accuracy = [] 102 | 103 | for i in range(T): 104 | imgs = imgs_till_now[:, :, 0] 105 | caps = caps_till_now[:, :, 0] 106 | cap_lens = cap_lens_till_now[:, :, 0] 107 | ys = ys_till_now[:, :, 0] 108 | ys_logp = ys_logp_till_now[:, :, 0] 109 | img_ids, gold_ys = gen_game(n_img, N, M, n_distrs, nearest_images) 110 | img_ids, gold_ys = img_ids.to(args.device), gold_ys.to(args.device) 111 | img_tgt_ids = img_ids[range(N * M), gold_ys] 112 | cap_candidates = torch.index_select(caption_beams, 0, img_tgt_ids) 113 | cap_len_candidates = torch.index_select( 114 | caption_beams_len, 0, img_tgt_ids) 115 | img_ids = img_ids.unsqueeze(1).expand(-1, b, -1) 116 | gold_ys = gold_ys.unsqueeze(1).expand(-1, b) 117 | img_ids = img_ids.reshape(-1) 118 | imgs_new = torch.index_select( 119 | images, 0, img_ids).view(N*M, b, n_distrs+1, D_img) 120 | img_ids = img_ids.view(N*M, b, n_distrs+1) 121 | 122 | imgs_new = imgs_new.view(N, M * b, n_distrs+1, D_img) 123 | cap_candidates = cap_candidates.view(N, M*b, L) 124 | cap_len_candidates = cap_len_candidates.view(N, M*b) 125 | 126 | x_support = (imgs, caps, cap_lens) 127 | y_support = ys if not args.maml_args.soft_y_support else ys_logp 128 | x_target = (imgs_new, cap_candidates, cap_len_candidates) 129 | y_target = torch.zeros(N, M*b).long().to(args.device) # dummy y_target 130 | _, preds = maml_listener.run_validation_iter( 131 | (x_support, x_target, y_support, y_target)) 132 | preds = torch.stack(preds, dim=0) 133 | preds = F.log_softmax(preds, dim=-1) 134 | preds_acc = torch.gather( 135 | preds, -1, gold_ys.reshape(N, M * b, 1)).view(N, M, b) 136 | caption_choice = torch.max(preds_acc, -1)[1].view(N * M) 137 | cap_candidates = cap_candidates.view(N*M, b, L) 138 | cap_new = cap_candidates[range(N * M), caption_choice] 139 | cap_len_candidates = cap_len_candidates.view(N*M, b) 140 | cap_len_new = cap_len_candidates[range(N * M), caption_choice] 141 | 142 | imgs_new = imgs_new.view(N, M, b, n_distrs+1, D_img) 143 | imgs_new = imgs_new[:, :, 0] 144 | cap_new = cap_new.view(N, M, L) 145 | cap_len_new = cap_len_new.view(N, M) 146 | 147 | y, y_logp = [], [] 148 | with torch.no_grad(): 149 | for j in range(N): 150 | preds_out, preds_logp = \ 151 | listeners[j].predict( 152 | imgs_new[j].view(M, n_distrs+1, D_img), 153 | cap_new[j].view(M, L), 154 | cap_len_new[j].view(M), 155 | output_logp=True) 156 | y.append(preds_out) 157 | y_logp.append(preds_logp) 158 | y = torch.stack(y, dim=0) 159 | y_logp = torch.stack(y_logp, dim=0) 160 | rollout_accuracy.append(gold_ys[:, 0] == y.view(-1)) 161 | 162 | imgs_till_now = torch.cat( 163 | [imgs_till_now, imgs_new.unsqueeze(1)], dim=1) 164 | caps_till_now = torch.cat([caps_till_now, cap_new.unsqueeze(1)], dim=1) 165 | cap_lens_till_now = torch.cat( 166 | [cap_lens_till_now, cap_len_new.unsqueeze(1)], dim=1) 167 | ys_till_now = torch.cat([ys_till_now, y.unsqueeze(1)], dim=1) 168 | ys_logp_till_now = torch.cat([ys_logp_till_now, y_logp.unsqueeze(1)], 169 | dim=1) 170 | 171 | rollout_accuracy = torch.cat(rollout_accuracy, dim=0) 172 | rollout_stat["acc"] = torch.mean(rollout_accuracy.float()) 173 | 174 | return (imgs_till_now, caps_till_now, cap_lens_till_now), ys_till_now, \ 175 | ys_logp_till_now, rollout_stat 176 | 177 | 178 | def rollout_case_study( 179 | images: torch.Tensor, caption_beams: torch.Tensor, 180 | caption_beams_len: torch.Tensor, nearest_images: torch.Tensor, 181 | N: int, T: int, M: int, args, listener_range, 182 | maml_listener: MAMLFewShotClassifier, epoch_id: int) \ 183 | -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: 184 | """ 185 | :param images: input images 186 | :param caption_beams: captions for the input images 187 | :param N: number of tasks 188 | :param T: number of interactions between speaker and listener 189 | :param M: number of target tasks (M-1 of them not used for training) 190 | """ 191 | b = caption_beams.size()[1] 192 | L = caption_beams.size()[2] 193 | n_distrs = args.num_distrs 194 | D_img = images.size()[-1] 195 | n_img = images.size()[0] 196 | listeners = list() 197 | for i in range(N): 198 | try: 199 | listeners.append(sample_listener( 200 | args.listener_template, listener_range, args, i)) 201 | except: 202 | pass 203 | N = len(listeners) 204 | imgs_till_now = torch.zeros(N, 1, M, n_distrs + 1, D_img).to(args.device) 205 | caps_till_now = torch.zeros(N, 1, M, L).long().to(args.device) 206 | cap_lens_till_now = torch.ones(N, 1, M).long().to(args.device) 207 | ys_till_now = torch.zeros(N, 1, M).long().to(args.device) 208 | ys_gold_till_now = torch.zeros(N, 1, M).long().to(args.device) 209 | ys_logp_till_now = torch.zeros(N, 1, M, n_distrs + 1).to(args.device) 210 | 211 | rollout_stat = {} 212 | rollout_accuracy = [] 213 | 214 | for i in range(T): 215 | imgs = imgs_till_now[:, :, 0] 216 | caps = caps_till_now[:, :, 0] 217 | cap_lens = cap_lens_till_now[:, :, 0] 218 | ys = ys_till_now[:, :, 0] 219 | ys_gold = ys_gold_till_now[:, :, 0] 220 | ys_logp = ys_logp_till_now[:, :, 0] 221 | img_ids, gold_ys = gen_game(n_img, 1, M, n_distrs, nearest_images) 222 | img_ids = img_ids.repeat(N, 1) 223 | gold_ys = gold_ys.repeat(N) 224 | 225 | img_ids, gold_ys = img_ids.to(args.device), gold_ys.to(args.device) 226 | img_tgt_ids = img_ids[range(N * M), gold_ys] 227 | cap_candidates = torch.index_select(caption_beams, 0, img_tgt_ids) 228 | cap_len_candidates = torch.index_select( 229 | caption_beams_len, 0, img_tgt_ids) 230 | img_ids = img_ids.unsqueeze(1).expand(-1, b, -1) 231 | gold_ys = gold_ys.unsqueeze(1).expand(-1, b) 232 | img_ids = img_ids.reshape(-1) 233 | imgs_new = torch.index_select( 234 | images, 0, img_ids).view(N*M, b, n_distrs+1, D_img) 235 | img_ids = img_ids.view(N*M, b, n_distrs+1) 236 | 237 | imgs_new = imgs_new.view(N, M * b, n_distrs+1, D_img) 238 | cap_candidates = cap_candidates.view(N, M*b, L) 239 | cap_len_candidates = cap_len_candidates.view(N, M*b) 240 | 241 | x_support = (imgs, caps, cap_lens) 242 | y_support = ys if not args.maml_args.soft_y_support else ys_logp 243 | x_target = (imgs_new, cap_candidates, cap_len_candidates) 244 | y_target = torch.zeros(N, M*b).long().to(args.device) # dummy y_target 245 | if args.binary_loss: 246 | y_support = (y_support, ys_gold) 247 | y_target = (y_target, torch.zeros(N, M*b).long().to(args.device)) 248 | _, preds = maml_listener.run_validation_iter( 249 | (x_support, x_target, y_support, y_target)) 250 | preds = torch.stack(preds, dim=0) 251 | preds = F.log_softmax(preds, dim=-1) 252 | preds_acc = torch.gather( 253 | preds, -1, gold_ys.reshape(N, M * b, 1)).view(N, M, b) 254 | caption_choice = torch.max(preds_acc, -1)[1].view(N * M) 255 | cap_candidates = cap_candidates.view(N*M, b, L) 256 | cap_new = cap_candidates[range(N * M), caption_choice] 257 | cap_len_candidates = cap_len_candidates.view(N*M, b) 258 | cap_len_new = cap_len_candidates[range(N * M), caption_choice] 259 | 260 | imgs_new = imgs_new.view(N, M, b, n_distrs+1, D_img) 261 | imgs_new = imgs_new[:, :, 0] 262 | cap_new = cap_new.view(N, M, L) 263 | cap_len_new = cap_len_new.view(N, M) 264 | 265 | y, y_logp = [], [] 266 | with torch.no_grad(): 267 | for j in range(N): 268 | preds_out, preds_logp = \ 269 | listeners[j].predict( 270 | imgs_new[j].view(M, n_distrs+1, D_img), 271 | cap_new[j].view(M, L), 272 | cap_len_new[j].view(M), 273 | output_logp=True) 274 | y.append(preds_out) 275 | y_logp.append(preds_logp) 276 | y = torch.stack(y, dim=0) 277 | y_logp = torch.stack(y_logp, dim=0) 278 | rollout_accuracy.append(gold_ys[:, 0] == y.view(-1)) 279 | 280 | imgs_till_now = torch.cat( 281 | [imgs_till_now, imgs_new.unsqueeze(1)], dim=1) 282 | caps_till_now = torch.cat([caps_till_now, cap_new.unsqueeze(1)], dim=1) 283 | cap_lens_till_now = torch.cat( 284 | [cap_lens_till_now, cap_len_new.unsqueeze(1)], dim=1) 285 | ys_till_now = torch.cat([ys_till_now, y.unsqueeze(1)], dim=1) 286 | # ys_gold_till_now = torch.cat([ys_gold_till_now, gold_ys.unsqueeze(1)], dim=1) 287 | ys_logp_till_now = torch.cat([ys_logp_till_now, y_logp.unsqueeze(1)], 288 | dim=1) 289 | 290 | rollout_stat["per_step_acc"] \ 291 | = [torch.mean(i.float()).item() for i in rollout_accuracy] 292 | rollout_accuracy = torch.cat(rollout_accuracy, dim=0) 293 | rollout_stat["acc"] = torch.mean(rollout_accuracy.float()).item() 294 | 295 | return (imgs_till_now, caps_till_now, cap_lens_till_now), ys_till_now, \ 296 | ys_logp_till_now, rollout_stat 297 | 298 | 299 | def main(args): 300 | random.seed(args.seed) 301 | np.random.seed(args.seed) 302 | torch.manual_seed(args.seed) 303 | 304 | args.cuda = torch.cuda.is_available() 305 | args.device = torch.device('cuda' if args.cuda else 'cpu') 306 | print("OPTS:\n", vars(args)) 307 | feat_path = args.coco_path 308 | data_path = args.coco_path 309 | 310 | train_images, val_images, test_images \ 311 | = [torch.load('{}/feats/{}'.format(feat_path, x)) for x in 312 | "train_feats valid_feats test_feats".split()] 313 | train_images = train_images.to(device=args.device) 314 | val_images = val_images.to(device=args.device) 315 | test_images = test_images.to(device=args.device) 316 | 317 | if args.image_sample_ratio == 1.0: 318 | train_nearest_images = torch.arange( 319 | train_images.size()[0]).expand(train_images.size()[0], -1) 320 | val_nearest_images = torch.arange( 321 | val_images.size()[0]).expand(val_images.size()[0], -1) 322 | test_nearest_images = torch.arange( 323 | test_images.size()[0]).expand(test_images.size()[0], -1) 324 | else: 325 | train_nearest_images = nearest_images(train_images, 326 | n=int(args.image_sample_ratio * train_images.size()[0])) 327 | torch.cuda.empty_cache() 328 | val_nearest_images = nearest_images(val_images, n=int( 329 | args.image_sample_ratio * val_images.size()[0])) 330 | test_nearest_images = nearest_images(test_images, n=int( 331 | args.image_sample_ratio * test_images.size()[0])) 332 | 333 | (w2i, i2w) = [torch.load(data_path + 'dics/{}'.format(x)) 334 | for x in "w2i i2w".split()] 335 | w2i, i2w = truncate_dicts(w2i, i2w, args.num_words) 336 | args.vocab_size = len(w2i) 337 | args.w2i = w2i 338 | args.i2w = i2w 339 | 340 | if os.path.exists(f"save/train_caption_beams_{args.beam_size}.pt") \ 341 | and os.path.exists(f"save/train_cap_len_{args.beam_size}.pt"): 342 | train_caption_beams = torch.load( 343 | f"save/train_caption_beams_{args.beam_size}.pt").to(args.device) 344 | train_cap_len = torch.load(f"save/train_cap_len_{args.beam_size}.pt")\ 345 | .to(args.device) 346 | else: 347 | train_caption_beams, train_cap_len = get_caption_candidates( 348 | train_images, args) 349 | torch.save(train_caption_beams, 350 | f"save/train_caption_beams_{args.beam_size}.pt") 351 | torch.save(train_cap_len, f"save/train_cap_len_{args.beam_size}.pt") 352 | 353 | if os.path.exists(f"save/val_caption_beams_{args.beam_size}.pt") \ 354 | and os.path.exists(f"save/val_cap_len_{args.beam_size}.pt"): 355 | val_caption_beams = torch.load( 356 | f"save/val_caption_beams_{args.beam_size}.pt").to(args.device) 357 | val_cap_len = torch.load(f"save/val_cap_len_{args.beam_size}.pt")\ 358 | .to(args.device) 359 | else: 360 | val_caption_beams, val_cap_len = get_caption_candidates( 361 | val_images, args) 362 | torch.save(val_caption_beams, 363 | f"save/val_caption_beams_{args.beam_size}.pt") 364 | torch.save(val_cap_len, f"save/val_cap_len_{args.beam_size}.pt") 365 | 366 | args.maml_args = get_maml_args(args.maml_args) 367 | if args.use_mono_listeners: 368 | new_args = copy.deepcopy(args) 369 | new_args.vocab_size = new_args.listener_vocab_size 370 | new_args.D_hid = new_args.D_hid_maml_listener 371 | maml_listener = MAMLFewShotClassifier( 372 | AggregatedListener, args.maml_args, new_args 373 | ) 374 | for idx, i in enumerate(maml_listener.classifier.listeners): 375 | i.load_state_dict( 376 | torch.load(os.path.join(f"save/mono_student_{idx}", 377 | 'list_params', f"pop{idx}.pt")) 378 | ) 379 | else: 380 | new_args = copy.deepcopy(args) 381 | new_args.D_hid = new_args.D_hid_maml_listener 382 | maml_listener = MAMLFewShotClassifier( 383 | Listener, args.maml_args, new_args 384 | ) 385 | maml_listener.classifier.load_state_dict( 386 | torch.load(os.path.join(args.save_dir, 387 | 'list_params', f"pop{args.seed}.pt")) 388 | ) 389 | 390 | if args.fully_offline: 391 | maml_listener_freeze = copy.deepcopy(maml_listener) 392 | for i in maml_listener_freeze.parameters(): 393 | i.requires_grad_ = False 394 | 395 | max_acc = 0.0 396 | 397 | if args.evaluate_model_filename != "": 398 | maml_listener.load_state_dict(torch.load(args.evaluate_model_filename)) 399 | if os.path.exists(f"save/test_caption_beams_{args.beam_size}.pt") \ 400 | and os.path.exists(f"save/test_cap_len_{args.beam_size}.pt"): 401 | test_caption_beams = torch.load( 402 | f"save/test_caption_beams_{args.beam_size}.pt").to(args.device) 403 | test_cap_len = torch.load(f"save/test_cap_len_{args.beam_size}.pt")\ 404 | .to(args.device) 405 | else: 406 | test_caption_beams, test_cap_len = get_caption_candidates( 407 | test_images, args) 408 | torch.save(test_caption_beams, 409 | f"save/test_caption_beams_{args.beam_size}.pt") 410 | torch.save(test_cap_len, f"save/test_cap_len_{args.beam_size}.pt") 411 | average_stat = {'acc': [], 'per_step_acc': [[] 412 | for _ in range(args.maml_time_step)]} 413 | for i in range(1): 414 | (imgs, caps, cap_lens), ys, ys_logp, rollout_stat = \ 415 | rollout_case_study(test_images, test_caption_beams, test_cap_len, 416 | test_nearest_images, 417 | 20, args.maml_time_step, 50, args, 418 | [(i,) for i in range(80, 100)], maml_listener, 0) 419 | average_stat['acc'].append(rollout_stat['acc']) 420 | for j, k in zip(average_stat['per_step_acc'], rollout_stat['per_step_acc']): 421 | j.append(k) 422 | average_stat['acc'] = sum(average_stat['acc']) / \ 423 | len(average_stat['acc']) 424 | for i in range(args.maml_time_step): 425 | average_stat['per_step_acc'][i] = sum(average_stat['per_step_acc'][i]) / \ 426 | len(average_stat['per_step_acc'][i]) 427 | print(average_stat) 428 | # for i in range(100): 429 | # (imgs, caps, cap_lens), ys, ys_logp, rollout_stat = \ 430 | # rollout_case_study(test_images, test_caption_beams, test_cap_len, 431 | # test_nearest_images, 432 | # 20, args.maml_time_step, 1, args, 433 | # [(i,) for i in range(80, 100)], maml_listener, 0) 434 | 435 | # data = [] 436 | # for i in range(caps.size()[0]): 437 | # lang_list = [] 438 | # for j in range(args.maml_time_step): 439 | # sent = caps[i][j][0][:cap_lens[i][j][0]].tolist() 440 | # lang = language_identification(sent) 441 | # lang_list.append(lang) 442 | # data.append(lang_list) 443 | 444 | # last_lang = set() 445 | # for lang_list in data: 446 | # last_lang.add(lang_list[-1]) 447 | # if len(last_lang) > 3: 448 | # for i in range(caps.size()[0]): 449 | # line = [str(i)] 450 | # for j in range(args.maml_time_step): 451 | # line.append( 452 | # ' '.join( 453 | # map(lambda x: args.i2w[x.item()], caps[i][j][0][:cap_lens[i][j][0]])) 454 | # ) 455 | # for k in range(args.num_distrs+1): 456 | # line.append(str(torch.exp(ys_logp[i][j][0][k]).item())) 457 | # print('\t'.join(line)) 458 | exit() 459 | 460 | for epoch in range(args.maml_args.total_epochs): 461 | mean_losses = {} 462 | pbar = tqdm.tqdm( 463 | range(args.maml_args.total_iter_per_epoch), dynamic_ncols=True) 464 | for _ in pbar: 465 | # Rollout Phase: collecting data from current maml model 466 | rollout_start_time = time.time() 467 | (imgs, caps, cap_lens), ys, ys_logp, rollout_stat \ 468 | = rollout( 469 | train_images, train_caption_beams, train_cap_len, 470 | train_nearest_images, 471 | args.maml_args.batch_size, args.maml_time_step, 472 | args.maml_args.parallel_target_tasks, args, 473 | [(i,) for i in range(80)], 474 | maml_listener_freeze if args.fully_offline else maml_listener, 475 | epoch) 476 | rollout_acc = rollout_stat["acc"] 477 | if "rollout_acc" not in mean_losses: 478 | mean_losses["rollout_acc"] = [rollout_acc] 479 | else: 480 | mean_losses["rollout_acc"].append(rollout_acc) 481 | rollout_end_time = time.time() 482 | # MAML Phase: train maml model on the collected data 483 | maml_start_time = time.time() 484 | for t in range(1, args.maml_time_step+1): 485 | x_support = (imgs[:, :t, 0], caps[:, :t, 0], 486 | cap_lens[:, :t, 0]) 487 | y_support = ys[:, :t, 0] 488 | x_target = (imgs[:, t], caps[:, t], cap_lens[:, t]) 489 | # y_target = ys[:, t:t+1] 490 | y_target = ys_logp[:, t] 491 | losses, preds = maml_listener.run_train_iter( 492 | (x_support, x_target, y_support, y_target), epoch) 493 | preds = torch.stack(preds, dim=0) 494 | preds = F.log_softmax(preds, dim=-1) 495 | for loss in losses: 496 | if loss not in mean_losses: 497 | mean_losses[loss] = [losses[loss]] 498 | else: 499 | mean_losses[loss].append(losses[loss]) 500 | maml_end_time = time.time() 501 | pbar.set_postfix({'rollout time': rollout_end_time - rollout_start_time, 502 | 'maml time': maml_end_time - maml_start_time}) 503 | for i in mean_losses: 504 | mean_losses[i] = sum(mean_losses[i]) / len(mean_losses[i]) 505 | print(mean_losses) 506 | _, _, _, rollout_stat = rollout( 507 | val_images, val_caption_beams, val_cap_len, 508 | val_nearest_images, 509 | 100, args.maml_time_step, 10, args, 510 | [(i,) for i in range(80, 100)], maml_listener, epoch) 511 | print(f"Accuracy on the val set: {rollout_stat['acc']}") 512 | if rollout_stat['acc'] > max_acc and args.maml_save_dir != "": 513 | max_acc = rollout_stat['acc'] 514 | Path(args.maml_save_dir).mkdir(parents=True, exist_ok=True) 515 | save_path = os.path.join( 516 | args.maml_save_dir, f"maml_listener_{max_acc}.pt") 517 | print(f"Saving model to {save_path}") 518 | torch.save(maml_listener.state_dict(), save_path) 519 | 520 | 521 | if __name__ == "__main__": 522 | parser = argparse.ArgumentParser() 523 | parser.add_argument('--seed', type=int, default=0, 524 | help='seed') 525 | parser.add_argument("--num_seed_examples", type=int, default=1000, 526 | help="Number of seed examples") 527 | parser.add_argument("--num_distrs", type=int, default=9, 528 | help="Number of distractors") 529 | parser.add_argument("--s2p_schedule", type=str, default="sched", 530 | help="s2p schedule") 531 | parser.add_argument("--s2p_selfplay_updates", type=int, default=50, 532 | help="s2p self-play updates") 533 | parser.add_argument("--s2p_list_updates", type=int, default=50, 534 | help="s2p listener supervised updates") 535 | parser.add_argument("--s2p_spk_updates", type=int, default=50, 536 | help="s2p speaker supervised updates") 537 | parser.add_argument("--s2p_batch_size", type=int, default=1000, 538 | help="s2p batch size") 539 | parser.add_argument("--pop_batch_size", type=int, default=1000, 540 | help="Pop Batch size") 541 | parser.add_argument("--rand_perc", type=float, default=0.75, 542 | help="rand perc") 543 | parser.add_argument("--sched_rand_frz", type=int, default=0.5, 544 | help="sched_rand_frz perc") 545 | parser.add_argument("--num_words", type=int, default=100, 546 | help="Number of words in the vocabulary") 547 | parser.add_argument("--seq_len", type=int, default=15, 548 | help="Max Sequence length of speaker utterance") 549 | parser.add_argument("--unk_perc", type=float, default=0.3, 550 | help="Max percentage of ") 551 | parser.add_argument("--max_iters", type=int, default=300, 552 | help="max training iters") 553 | parser.add_argument("--D_img", type=int, default=2048, 554 | help="ResNet feature dimensionality. Can't change this") 555 | parser.add_argument("--D_hid", type=int, default=512, 556 | help="RNN hidden state dimensionality") 557 | parser.add_argument("--D_hid_maml_listener", type=int, default=512, 558 | help="RNN hidden state dimensionality") 559 | parser.add_argument("--D_emb", type=int, default=256, 560 | help="Token embedding (word) dimensionality") 561 | parser.add_argument("--lr", type=float, default=2e-4, 562 | help="Learning rate") 563 | parser.add_argument("--dropout", type=float, default=0.0, 564 | help="Dropout probability") 565 | parser.add_argument("--temp", type=float, default=1.0, 566 | help="Gumbel temperature") 567 | parser.add_argument("--hard", type=bool, default=True, 568 | help="Hard Gumbel-Softmax Sampling.") 569 | parser.add_argument("--min_list_steps", type=int, default=2000, 570 | help="Min num of listener supervised steps") 571 | parser.add_argument("--min_spk_steps", type=int, default=1000, 572 | help="Min num of speaker supervised steps") 573 | parser.add_argument("--test_every", type=int, default=10, 574 | help="test interval") 575 | parser.add_argument("--seed_val_pct", type=float, default=0.1, 576 | help="% of seed samples used as validation for early stopping") 577 | parser.add_argument('--coco_path', type=str, default="./coco/", 578 | help="MSCOCO dir path") 579 | parser.add_argument("--save_dir", type=str, default="", 580 | help="Save directory.") 581 | parser.add_argument("--sample_temp", type=float, default=30, 582 | help="Temperature used for sampling difficult distractors") 583 | parser.add_argument("--sample_lang", action="store_true", default=False, 584 | help="Sample the languges of the data") 585 | parser.add_argument("--alpha", type=float, default=0.5, 586 | help="Parameter of Dirichlet distribution") 587 | parser.add_argument("--beam_size", type=int, default=10, 588 | help="Beam size") 589 | parser.add_argument("--listener_save_dir", type=str, default="", 590 | help="Listener save dir") 591 | parser.add_argument("--listener_vocab_size", type=int, default=100, 592 | help="Listener vocab size") 593 | parser.add_argument("--maml_args", type=str, default="", 594 | help="MAML arugments") 595 | parser.add_argument("--listener_template", type=str, default="", 596 | help="listener's save dir template") 597 | parser.add_argument("--maml_time_step", type=int, default=20) 598 | parser.add_argument("--maml_save_dir", type=str, default="") 599 | parser.add_argument("--evaluate_model_filename", type=str, default="") 600 | parser.add_argument("--use_mono_listeners", action="store_true") 601 | parser.add_argument("--image_sample_ratio", type=float, default=1.0) 602 | parser.add_argument("--fully_offline", action="store_true", default=False, 603 | help="freeze the listener in rollout") 604 | parser.add_argument("--binary_loss", action="store_true", default=False, 605 | help="Using binary loss") 606 | args = parser.parse_args() 607 | main(args) 608 | -------------------------------------------------------------------------------- /ibr_game/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import optim 7 | import torch.nn.functional as F 8 | 9 | import utils as U 10 | 11 | from beam_search import beam_search 12 | 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def extract_top_level_dict(current_dict): 18 | """ 19 | Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params 20 | :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree. 21 | :param value: Param value 22 | :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it. 23 | :return: A dictionary graph of the params already added to the graph. 24 | """ 25 | output_dict = dict() 26 | for key in current_dict.keys(): 27 | name = key.replace("layer_dict.", "") 28 | name = name.replace("layer_dict.", "") 29 | name = name.replace("block_dict.", "") 30 | name = name.replace("module-", "") 31 | top_level = name.split(".")[0] 32 | sub_level = ".".join(name.split(".")[1:]) 33 | 34 | if top_level not in output_dict: 35 | if sub_level == "": 36 | output_dict[top_level] = current_dict[key] 37 | else: 38 | output_dict[top_level] = {sub_level: current_dict[key]} 39 | else: 40 | new_item = {key: value for key, 41 | value in output_dict[top_level].items()} 42 | new_item[sub_level] = current_dict[key] 43 | output_dict[top_level] = new_item 44 | 45 | # print(current_dict.keys(), output_dict.keys()) 46 | return output_dict 47 | 48 | 49 | class Beholder(nn.Module): 50 | def __init__(self, args): 51 | super(Beholder, self).__init__() 52 | self.img_to_hid = nn.Linear(args.D_img, args.D_hid) 53 | self.drop = nn.Dropout(p=args.dropout) 54 | 55 | def forward(self, img): 56 | h_img = img 57 | h_img = self.img_to_hid(h_img) 58 | h_img = self.drop(h_img) 59 | return h_img 60 | 61 | 62 | class AggregatedListener(nn.Module): 63 | def __init__(self, args): 64 | super(AggregatedListener, self).__init__() 65 | self.listeners = nn.ModuleList( 66 | [Listener(args) for _ in range(10)]) 67 | for params in self.parameters(): 68 | params.requires_grad = False 69 | self.log_psi = nn.Parameter(torch.zeros(10)) 70 | 71 | def maml_forward(self, x, params, training, backup_running_statistics, num_step): 72 | for name, param in params.items(): 73 | submodule_names, parameter_name = name.split( 74 | '.')[:-1], name.split('.')[-1] 75 | current_module = self 76 | for i in submodule_names: 77 | current_module = current_module.__getattr__(i) 78 | object.__setattr__(current_module, parameter_name, param[0]) 79 | if len(submodule_names) and submodule_names[0] == "rnn": 80 | self.rnn._flat_weights = [ 81 | (lambda wn: getattr(self.rnn, wn) if hasattr( 82 | self.rnn, wn) else None)(wn) 83 | for wn in self.rnn._flat_weights_names] 84 | self.rnn.flatten_parameters() 85 | 86 | with torch.no_grad(): 87 | logits_list = [self.listeners[i].maml_forward( 88 | x) for i in range(10)] 89 | logits = torch.stack(logits_list) 90 | logits = torch.log_softmax(logits, dim=-1) 91 | log_psi = torch.log_softmax(self.log_psi, dim=0).unsqueeze( 92 | 1).unsqueeze(2).expand_as(logits) 93 | return torch.logsumexp(log_psi + logits, dim=0) 94 | 95 | def sim(self, x, params, y): 96 | for name, param in params.items(): 97 | submodule_names, parameter_name = name.split( 98 | '.')[:-1], name.split('.')[-1] 99 | current_module = self 100 | for i in submodule_names: 101 | current_module = current_module.__getattr__(i) 102 | object.__setattr__(current_module, parameter_name, param[0]) 103 | if len(submodule_names) and submodule_names[0] == "rnn": 104 | self.rnn._flat_weights = [ 105 | (lambda wn: getattr(self.rnn, wn) if hasattr( 106 | self.rnn, wn) else None)(wn) 107 | for wn in self.rnn._flat_weights_names] 108 | self.rnn.flatten_parameters() 109 | 110 | with torch.no_grad(): 111 | logits_list = [self.listeners[i].maml_forward( 112 | x) for i in range(10)] 113 | logits = torch.stack(logits_list) 114 | logits = torch.softmax(logits, dim=-1) 115 | try: 116 | y_ = torch.softmax(y, dim=-1).unsqueeze(0).expand_as(logits) 117 | nearest = torch.argmin( 118 | torch.sum(torch.abs(logits - y_), dim=-1), dim=0) 119 | import ipdb 120 | ipdb.set_trace() 121 | except: 122 | pass 123 | 124 | def zero_grad(self, params=None): 125 | if params is None: 126 | for param in self.parameters(): 127 | if param.requires_grad == True: 128 | if param.grad is not None: 129 | if torch.sum(param.grad) > 0: 130 | print(param.grad) 131 | param.grad.zero_() 132 | else: 133 | for name, param in params.items(): 134 | if param.requires_grad == True: 135 | if param.grad is not None: 136 | if torch.sum(param.grad) > 0: 137 | print(param.grad) 138 | param.grad.zero_() 139 | params[name].grad = None 140 | 141 | def restore_backup_stats(self): 142 | """ 143 | Reset stored batch statistics from the stored backup. 144 | """ 145 | pass 146 | 147 | 148 | class Listener(nn.Module): 149 | 150 | def __init__(self, args, beholder=None): 151 | super(Listener, self).__init__() 152 | self.rnn = nn.GRU(args.D_emb, args.D_hid, 1, batch_first=True) 153 | self.emb = nn.Linear(args.vocab_size, args.D_emb) 154 | self.hid_to_hid = nn.Linear(args.D_hid, args.D_hid) 155 | self.drop = nn.Dropout(p=args.dropout) 156 | self.D_hid = args.D_hid 157 | self.D_emb = args.D_emb 158 | self.vocab_size = args.vocab_size 159 | self.i2w = args.i2w 160 | self.w2i = args.w2i 161 | if beholder is None: 162 | self.beholder = Beholder(args) 163 | else: 164 | self.beholder = beholder 165 | # self.loss_fn = nn.CrossEntropyLoss(reduction='none').to(device=args.device) 166 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 167 | 168 | def forward(self, spk_msg, spk_msg_lens): 169 | batch_size = spk_msg.shape[0] 170 | 171 | h_0 = torch.zeros(1, batch_size, self.D_hid, device=device) 172 | 173 | if spk_msg.type() in ['torch.FloatTensor', 'torch.cuda.FloatTensor']: 174 | spk_msg_emb = self.emb(spk_msg.float()) 175 | elif spk_msg.type() in ['torch.LongTensor', 'torch.cuda.LongTensor']: 176 | spk_msg[spk_msg > self.vocab_size] = self.w2i[""] 177 | spk_msg_emb = F.embedding( 178 | spk_msg.clone(), self.emb.weight.transpose(0, 1)) 179 | spk_msg_emb += self.emb.bias 180 | else: 181 | print(spk_msg.type()) 182 | raise NotImplementedError 183 | spk_msg_emb = self.drop(spk_msg_emb) 184 | 185 | try: 186 | pack = nn.utils.rnn.pack_padded_sequence( 187 | spk_msg_emb, spk_msg_lens, batch_first=True, enforce_sorted=False) 188 | except: 189 | import pdb 190 | pdb.set_trace() 191 | 192 | self.rnn.flatten_parameters() 193 | _, h_n = self.rnn(pack, h_0) 194 | h_n = h_n[-1:, :, :] 195 | out = h_n.transpose(0, 1).view(batch_size, self.D_hid) 196 | out = self.hid_to_hid(out) 197 | return out 198 | 199 | def maml_forward(self, x, params={}, training=None, 200 | backup_running_statistics=None, num_step=None): 201 | imgs, caps, cap_lens = x 202 | 203 | for name, param in params.items(): 204 | submodule_names, parameter_name = name.split( 205 | '.')[:-1], name.split('.')[-1] 206 | current_module = self 207 | for i in submodule_names: 208 | current_module = current_module.__getattr__(i) 209 | object.__setattr__(current_module, parameter_name, param[0]) 210 | if submodule_names[0] == "rnn": 211 | self.rnn._flat_weights = [ 212 | (lambda wn: getattr(self.rnn, wn) if hasattr( 213 | self.rnn, wn) else None)(wn) 214 | for wn in self.rnn._flat_weights_names] 215 | self.rnn.flatten_parameters() 216 | 217 | h_pred = self.forward(caps, cap_lens.cpu()) 218 | h_pred = h_pred.unsqueeze(1).repeat(1, imgs.size()[1], 1) 219 | 220 | h_img = self.beholder(imgs) 221 | 222 | logits = 1 / torch.mean(torch.pow(h_pred - h_img, 2), 223 | 2).view(-1, imgs.size()[1]) 224 | 225 | return logits 226 | 227 | def get_loss_acc(self, image, distractor_images, spk_msg, spk_msg_lens, 228 | reduction='mean', shuffle=True, output_pred=False, 229 | output_logits=False): 230 | batch_size = spk_msg.shape[0] 231 | 232 | if reduction != 'none': 233 | spk_msg_lens, sorted_indices = torch.sort( 234 | spk_msg_lens, descending=True) 235 | spk_msg = spk_msg.index_select(0, sorted_indices) 236 | image = image.index_select(0, sorted_indices) 237 | 238 | h_pred = self.forward(spk_msg, spk_msg_lens.cpu()) 239 | h_pred = h_pred.unsqueeze(1).repeat(1, 1 + len(distractor_images), 1) 240 | 241 | all_images = len(distractor_images) + 1 242 | img_idx = [list(range(all_images)) for _ in range(batch_size)] 243 | for c in img_idx: 244 | if shuffle: 245 | random.shuffle(c) 246 | 247 | target_idx = torch.tensor( 248 | np.argmax(np.array(img_idx) == 0, -1), dtype=torch.long, device=device) 249 | 250 | h_img = [self.beholder(image)] + [self.beholder(img) 251 | for img in distractor_images] 252 | h_img = torch.stack(h_img, dim=0).permute(1, 0, 2) 253 | for i in range(batch_size): 254 | h_img[i] = h_img[i, img_idx[i], :] 255 | 256 | logits = 1 / torch.mean(torch.pow(h_pred - h_img, 2), 257 | 2).view(-1, 1 + len(distractor_images)) 258 | 259 | pred_outs = torch.argmax(logits, dim=-1).cpu().numpy() 260 | batch_inds = target_idx.cpu().numpy() 261 | 262 | acc = np.mean(np.equal(batch_inds, pred_outs)) 263 | loss = F.cross_entropy(logits, target_idx, reduction=reduction) 264 | if not output_pred: 265 | if not output_logits: 266 | return loss, acc 267 | else: 268 | return loss, acc, logits 269 | else: 270 | if not output_logits: 271 | return loss, acc, pred_outs 272 | else: 273 | return loss, acc, pred_outs, logits 274 | 275 | def predict(self, images, spk_msg, spk_msg_lens, output_logp=False): 276 | h_pred = self.forward(spk_msg, spk_msg_lens.cpu()) 277 | h_pred = h_pred.unsqueeze(1).repeat(1, images.size()[1], 1) 278 | 279 | h_img = self.beholder(images) 280 | 281 | logits = 1 / torch.mean(torch.pow(h_pred - h_img, 2), 282 | 2).view(-1, images.size()[1]) 283 | 284 | pred_outs = torch.argmax(logits, dim=-1) 285 | if output_logp: 286 | return pred_outs, torch.log_softmax(logits, dim=-1) 287 | else: 288 | return pred_outs 289 | 290 | def test(self, image, distractor_images, spk_msg, spk_msg_lens): 291 | self.eval() 292 | loss, acc = self.get_loss_acc( 293 | image, distractor_images, spk_msg, spk_msg_lens) 294 | return loss.detach().cpu().numpy(), acc 295 | 296 | def update(self, image, distractor_images, spk_msg, spk_msg_lens): 297 | self.train() 298 | loss, acc = self.get_loss_acc( 299 | image, distractor_images, spk_msg, spk_msg_lens) 300 | self.optimizer.zero_grad() 301 | loss.backward() 302 | return loss, acc 303 | 304 | def zero_grad(self, params=None): 305 | if params is None: 306 | for param in self.parameters(): 307 | if param.requires_grad == True: 308 | if param.grad is not None: 309 | if torch.sum(param.grad) > 0: 310 | print(param.grad) 311 | param.grad.zero_() 312 | else: 313 | for name, param in params.items(): 314 | if param.requires_grad == True: 315 | if param.grad is not None: 316 | if torch.sum(param.grad) > 0: 317 | print(param.grad) 318 | param.grad.zero_() 319 | params[name].grad = None 320 | 321 | def restore_backup_stats(self): 322 | """ 323 | Reset stored batch statistics from the stored backup. 324 | """ 325 | pass 326 | 327 | 328 | class Speaker(nn.Module): 329 | 330 | def __init__(self, args, beholder): 331 | super(Speaker, self).__init__() 332 | self.rnn = nn.GRU(args.D_emb, args.D_hid, 1, batch_first=True) 333 | self.emb = nn.Embedding(args.vocab_size, args.D_emb, padding_idx=0) 334 | self.hid_to_voc = nn.Linear(args.D_hid, args.vocab_size) 335 | self.D_emb = args.D_emb 336 | self.D_hid = args.D_hid 337 | self.drop = nn.Dropout(p=args.dropout) 338 | self.vocab_size = args.vocab_size 339 | self.i2w = args.i2w 340 | self.w2i = args.w2i 341 | self.temp = args.temp 342 | self.hard = args.hard 343 | self.seq_len = args.seq_len 344 | self.beholder = beholder 345 | self.loss_fn = nn.CrossEntropyLoss(reduce=False) 346 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 347 | 348 | def forward(self, image): 349 | batch_size, gen_idx, done, hid, input, msg_lens \ 350 | = self.prepare_input(image) 351 | 352 | for idx in range(self.seq_len): 353 | # input = F.relu(input) 354 | output, hid = self.rnn_step(hid, input) 355 | 356 | top1, topi = U.gumbel_softmax(output, self.temp, self.hard) 357 | gen_idx.append(top1) 358 | 359 | for ii in range(batch_size): 360 | if topi[ii] == self.w2i[""]: 361 | done[ii] = True 362 | msg_lens[ii] = idx + 1 363 | if np.array_equal(done, np.array([True for _ in range(batch_size)])): 364 | break 365 | 366 | input = self.emb(topi) 367 | 368 | gen_idx = torch.stack(gen_idx).permute(1, 0, 2) 369 | msg_lens = torch.tensor(msg_lens, dtype=torch.long, device=device) 370 | return gen_idx, msg_lens 371 | 372 | def rnn_step(self, hid, input): 373 | self.rnn.flatten_parameters() 374 | output, hid = self.rnn(input, hid) 375 | output = output.view(-1, self.D_hid) 376 | output = self.hid_to_voc(output) 377 | output = output.view(-1, self.vocab_size) 378 | return output, hid 379 | 380 | def prepare_input(self, image): 381 | batch_size = image.shape[0] 382 | h_img = self.beholder(image).detach() 383 | start = [self.w2i[""] for _ in range(batch_size)] 384 | gen_idx = [] 385 | done = np.array([False for _ in range(batch_size)]) 386 | h_img = h_img.unsqueeze(0).view(1, -1, self.D_hid).repeat(1, 1, 1) 387 | hid = h_img 388 | ft = torch.tensor(start, dtype=torch.long, 389 | device=device).view(-1).unsqueeze(1) 390 | input = self.emb(ft) 391 | msg_lens = [self.seq_len for _ in range(batch_size)] 392 | return batch_size, gen_idx, done, hid, input, msg_lens 393 | 394 | def step(self, hid, inputs): 395 | self.rnn.flatten_parameters() 396 | # output, hid = self.rnn( 397 | # F.relu(self.emb(inputs.unsqueeze(1))), hid.unsqueeze(0)) 398 | output, hid = self.rnn( 399 | self.emb(inputs.unsqueeze(1)), hid.unsqueeze(0)) 400 | 401 | output = output.view(-1, self.D_hid) 402 | output = self.hid_to_voc(output) 403 | output = F.log_softmax(output.view(-1, self.vocab_size)) 404 | 405 | return output, hid 406 | 407 | def batchify(self, image, batch_size): 408 | h_img = self.beholder(image).detach().view(-1, self.D_hid) 409 | for i in range(0, image.shape[0], batch_size): 410 | yield h_img[i: i + batch_size] 411 | 412 | def bs(self, image, beam_size, max_len, choose_max=True): 413 | batch_size = image.shape[0] 414 | 415 | with torch.no_grad(): 416 | generated_sents, _ = \ 417 | beam_search(self.step, beam_size, max_len, 418 | self.w2i[""], self.w2i[""], 419 | self.batchify(image, batch_size//beam_size)) 420 | 421 | if choose_max: 422 | generated_sents = generated_sents[:, 0, :] 423 | msg_lens = [max_len for i in range(batch_size)] 424 | else: 425 | generated_sents = generated_sents.view(batch_size * beam_size, -1) 426 | msg_lens = [max_len for i in range(batch_size * beam_size)] 427 | 428 | print("Beams generated, now calculating lengths") 429 | generated_sents_cpu = generated_sents.cpu() 430 | for idx, i in enumerate(generated_sents_cpu): 431 | for jdx, j in enumerate(i): 432 | if self.w2i[""] == j: 433 | msg_lens[idx] = jdx 434 | break 435 | 436 | msg_lens = torch.tensor(msg_lens, dtype=torch.long, 437 | device=generated_sents.get_device()) 438 | 439 | if not choose_max: 440 | generated_sents = generated_sents.view(batch_size, beam_size, -1) 441 | msg_lens = msg_lens.view(batch_size, beam_size) 442 | 443 | return generated_sents, msg_lens 444 | 445 | def get_loss(self, image, caps, caps_lens, word_loss=False): 446 | batch_size = caps.shape[0] 447 | mask = (torch.arange(self.seq_len, device=device).expand(batch_size, self.seq_len) < caps_lens.unsqueeze( 448 | 1)).float() 449 | 450 | caps_in = caps[:, :-1] 451 | caps_out = caps[:, 1:] 452 | 453 | h_img = self.beholder(image).detach() 454 | h_img = h_img.view(1, batch_size, self.D_hid).repeat(1, 1, 1) 455 | 456 | caps_in_emb = self.emb(caps_in) 457 | caps_in_emb = self.drop(caps_in_emb) 458 | 459 | self.rnn.flatten_parameters() 460 | output, _ = self.rnn(caps_in_emb, h_img) 461 | logits = self.hid_to_voc(output) 462 | 463 | loss = 0 464 | for j in range(logits.size(1)): 465 | flat_score = logits[:, j, :] 466 | flat_mask = mask[:, j] 467 | flat_tgt = caps_out[:, j] 468 | nll = self.loss_fn(flat_score, flat_tgt) 469 | loss += (flat_mask * nll).sum() 470 | 471 | if word_loss: 472 | loss /= mask.sum() 473 | return loss 474 | 475 | def test(self, image, caps, caps_lens): 476 | self.eval() 477 | loss = self.get_loss(image, caps, caps_lens) 478 | return loss.detach().cpu().numpy() 479 | 480 | def update(self, image, caps, caps_lens): 481 | self.train() 482 | loss = self.get_loss(image, caps, caps_lens) 483 | self.optimizer.zero_grad() 484 | loss.backward() 485 | return loss 486 | 487 | 488 | class ReinforceSpeaker(Speaker): 489 | def forward(self, image): 490 | batch_size, gen_idx, done, hid, input, msg_lens \ 491 | = self.prepare_input(image) 492 | 493 | running_logprob = torch.zeros(batch_size) 494 | for idx in range(self.seq_len): 495 | # input = F.relu(input) 496 | output, hid = self.rnn_step(hid, input) 497 | 498 | next_token, log_prob = U.sample(output, self.temp) 499 | gen_idx.append(next_token) 500 | 501 | for ii in range(batch_size): 502 | if next_token[ii] == self.w2i[""]: 503 | done[ii] = True 504 | msg_lens[ii] = idx + 1 505 | else: 506 | running_logprob[ii] += log_prob[ii] 507 | if np.array_equal(done, 508 | np.array([True for _ in range(batch_size)])): 509 | break 510 | 511 | input = self.emb(next_token) 512 | 513 | gen_idx = torch.stack(gen_idx).permute(1, 0, 2) 514 | msg_lens = torch.tensor(msg_lens, dtype=torch.long, device=device) 515 | return gen_idx, msg_lens, log_prob 516 | 517 | 518 | class ReinforceTrainer(nn.Module): 519 | def __init__(self, args, speaker: nn.Module, listener: nn.Module): 520 | self.speaker = speaker 521 | self.listener = speaker 522 | self.i2w = args.i2w 523 | self.w2i = args.w2i 524 | self.D_hid = args.D_hid 525 | self.loss_fn = nn.CrossEntropyLoss() 526 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 527 | 528 | def forward(self, image: torch.Tensor, distractor_images: torch.Tensor): 529 | batch_size = image.shape[0] 530 | 531 | msg, msg_lens, log_prob = self.speaker(image) 532 | msg = msg.detach() 533 | 534 | h_pred = self.listener.forward(msg, msg_lens.cpu()) 535 | h_pred = h_pred.unsqueeze(1).repeat(1, 1 + len(distractor_images), 1) 536 | 537 | all_images = len(distractor_images) + 1 538 | img_idx = [list(range(all_images)) for _ in range(batch_size)] 539 | for c in img_idx: 540 | random.shuffle(c) 541 | 542 | target_idx = torch.tensor( 543 | np.argmax(np.array(img_idx) == 0, -1), dtype=torch.long, 544 | device=device) 545 | 546 | h_img = [self.beholder(image)] + [self.beholder(img) 547 | for img in distractor_images] 548 | h_img = torch.stack(h_img, dim=0).permute(1, 0, 2) 549 | for i in range(batch_size): 550 | h_img[i] = h_img[i, img_idx[i], :] 551 | 552 | logits = 1 / torch.mean(torch.pow(h_pred - h_img, 2), 553 | 2).view(-1, 1 + len(distractor_images)) 554 | pred_outs = torch.argmax(logits, dim=-1).cpu().numpy() 555 | batch_inds = target_idx.cpu().numpy() 556 | 557 | acc = np.mean(np.equal(batch_inds, pred_outs)) 558 | reward = np.equal(batch_inds, pred_outs) - acc 559 | loss = -torch.mean(log_prob * reward) 560 | return acc, loss 561 | 562 | def update(self, image, distractor_images): 563 | self.train() 564 | acc, loss = self.get_loss_acc(image, distractor_images) 565 | self.optimizer.zero_grad() 566 | loss.backward() 567 | self.optimizer.step() 568 | return acc, loss 569 | 570 | 571 | class SpeakerListener(nn.Module): 572 | 573 | def __init__(self, args): 574 | super(SpeakerListener, self).__init__() 575 | self.beholder = Beholder(args) 576 | self.speaker = Speaker(args, self.beholder) 577 | self.listener = Listener(args, self.beholder) 578 | self.i2w = args.i2w 579 | self.w2i = args.w2i 580 | self.D_hid = args.D_hid 581 | self.loss_fn = nn.CrossEntropyLoss() 582 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 583 | 584 | def forward(self, image, spk_list=-1): 585 | msg, msg_lens = self.speaker.forward(image) 586 | 587 | if spk_list == 0: 588 | msg = msg.detach() 589 | 590 | msg_lens, sorted_indices = torch.sort(msg_lens, descending=True) 591 | msg = msg.index_select(0, sorted_indices) 592 | image = image.index_select(0, sorted_indices) 593 | 594 | h_pred = self.listener.forward(msg, msg_lens.cpu()) 595 | return h_pred, image 596 | 597 | def get_loss_acc(self, image, distractor_images, spk_list=-1): 598 | batch_size = image.shape[0] 599 | 600 | h_pred, image = self.forward(image, spk_list) 601 | h_pred = h_pred.unsqueeze(1).repeat(1, 1 + len(distractor_images), 1) 602 | 603 | all_images = len(distractor_images) + 1 604 | img_idx = [list(range(all_images)) for _ in range(batch_size)] 605 | for c in img_idx: 606 | random.shuffle(c) 607 | 608 | target_idx = torch.tensor( 609 | np.argmax(np.array(img_idx) == 0, -1), dtype=torch.long, device=device) 610 | 611 | h_img = [self.beholder(image)] + [self.beholder(img) 612 | for img in distractor_images] 613 | h_img = torch.stack(h_img, dim=0).permute(1, 0, 2) 614 | for i in range(batch_size): 615 | h_img[i] = h_img[i, img_idx[i], :] 616 | 617 | logits = 1 / torch.mean(torch.pow(h_pred - h_img, 2), 618 | 2).view(-1, 1 + len(distractor_images)) 619 | pred_outs = torch.argmax(logits, dim=-1).cpu().numpy() 620 | batch_inds = target_idx.cpu().numpy() 621 | 622 | acc = np.mean(np.equal(batch_inds, pred_outs)) 623 | loss = self.loss_fn(logits, target_idx) 624 | return acc, loss 625 | 626 | def update(self, image, distractor_images): 627 | self.eval() 628 | acc, loss = self.get_loss_acc(image, distractor_images) 629 | return acc, loss 630 | 631 | def update(self, image, distractor_images, spk_list=-1): 632 | self.train() 633 | acc, loss = self.get_loss_acc(image, distractor_images, spk_list) 634 | self.optimizer.zero_grad() 635 | loss.backward() 636 | return acc, loss 637 | -------------------------------------------------------------------------------- /ibr_game/referential_game.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from typing import List 5 | import copy 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | from models import Listener, SpeakerListener 14 | from utils import truncate_dicts, get_distractor_images 15 | 16 | 17 | def load_model(args) -> nn.Module: 18 | model = SpeakerListener(args).to(device=args.device) 19 | model.speaker.load_state_dict(torch.load(os.path.join(args.save_dir, 20 | 'spk_params', f"pop{args.seed}.pt"))) 21 | model.listener.load_state_dict(torch.load(os.path.join(args.save_dir, 22 | 'list_params', f"pop{args.seed}.pt"))) 23 | model.beholder.load_state_dict(torch.load(os.path.join(args.save_dir, 24 | 'bhd_params', f"pop{args.seed}.pt"))) 25 | return model 26 | 27 | 28 | def generate_beams(spk: nn.Module, images: torch.Tensor, args) -> torch.Tensor: 29 | beam_list, beam_len_list = [], [] 30 | # msg, msg_len = spk.bs(images[i * args.s2p_batch_size: 31 | # min((i+1) * args.s2p_batch_size, 32 | # images.size()[0])], 33 | # args.beam_size, args.seq_len, False) 34 | msg, msg_len = spk.bs(images, 35 | args.beam_size, args.seq_len, False) 36 | beam_list.append(msg) 37 | beam_len_list.append(msg_len) 38 | return torch.cat(beam_list, dim=0), torch.cat(beam_len_list, dim=0) 39 | 40 | 41 | def choose_captions_from_beams(listener: Listener, captions: torch.Tensor, 42 | cap_len: torch.Tensor, image: torch.Tensor, 43 | distractor_images: List[torch.Tensor]) \ 44 | -> torch.Tensor: 45 | """Choose captions from beams with a listener. 46 | 47 | Args: 48 | listener: The listener model. 49 | captions: The candidate captions (B, b, L). 50 | cap_len: Length of candidate captions (B, b). 51 | image: The target images (B, D_img). 52 | distractor_images: The distractor images (B, n_dis, D_img). 53 | 54 | B is the batch size, b being beam size, L sentence length, 55 | D_img image embedding size, and n_dis being the number of distractors. 56 | 57 | Returns: 58 | indices of chosen captions 59 | 60 | """ 61 | B, b, L = captions.size() 62 | _, D_img = image.size() 63 | with torch.no_grad(): 64 | log_prob_list = [] 65 | batch_size = B // b 66 | for i in range(0, B, batch_size): 67 | batch_slice = slice(i, min(i + batch_size, B)) 68 | expanded_image = image[batch_slice].unsqueeze(1).expand(-1, b, -1)\ 69 | .reshape(-1, D_img) 70 | expanded_distractor_images = [i[batch_slice] 71 | .unsqueeze(1).expand(-1, b, -1) 72 | .reshape(-1, D_img) 73 | for i in distractor_images] 74 | log_prob = - listener.get_loss_acc(expanded_image, 75 | expanded_distractor_images, 76 | captions[batch_slice].view(-1, L), 77 | cap_len[batch_slice].view(-1), 78 | reduction='none')[0] 79 | log_prob_list.append(log_prob) 80 | log_prob = torch.cat(log_prob_list, dim=0).view(B, b) 81 | return torch.max(log_prob, dim=-1)[1] 82 | 83 | 84 | def get_caption_candidates(test_image, args): 85 | model = load_model(args) 86 | speaker_model = model.speaker 87 | test_caption_beams, test_cap_len = generate_beams( 88 | speaker_model, test_image, args) 89 | return test_caption_beams, test_cap_len 90 | 91 | 92 | def same_partner_baseline(test_image, test_distractor_images, 93 | test_caption_beams, test_cap_len, args): 94 | 95 | model = load_model(args) 96 | training_listener = model.listener 97 | chosen_indices = choose_captions_from_beams(training_listener, 98 | test_caption_beams, 99 | test_cap_len, test_image, 100 | test_distractor_images) 101 | test_caption = test_caption_beams[range(chosen_indices.size()[0]), 102 | chosen_indices] 103 | test_cap_len = test_cap_len[range(chosen_indices.size()[0]), 104 | chosen_indices] 105 | _, acc = training_listener.get_loss_acc(test_image, test_distractor_images, 106 | test_caption, test_cap_len) 107 | return acc 108 | 109 | 110 | def different_partner_baseline(test_image, test_distractor_images, 111 | test_caption_beams, test_cap_len, args, 112 | new_partner_args): 113 | model = load_model(args) 114 | training_listener = model.listener 115 | chosen_indices = choose_captions_from_beams(training_listener, 116 | test_caption_beams, 117 | test_cap_len, test_image, 118 | test_distractor_images) 119 | test_caption = test_caption_beams[range(chosen_indices.size()[0]), 120 | chosen_indices] 121 | test_cap_len = test_cap_len[range(chosen_indices.size()[0]), 122 | chosen_indices] 123 | 124 | model = load_model(new_partner_args) 125 | listener = model.listener 126 | _, acc = listener.get_loss_acc(test_image, test_distractor_images, 127 | test_caption, test_cap_len) 128 | return acc 129 | 130 | 131 | def overlap_with_training_listener(test_image, test_distractor_images, 132 | test_caption_beams, test_cap_len, args, 133 | new_partner_args): 134 | model = load_model(args) 135 | training_listener = model.listener 136 | chosen_indices = choose_captions_from_beams(training_listener, 137 | test_caption_beams, 138 | test_cap_len, test_image, 139 | test_distractor_images) 140 | test_caption = test_caption_beams[range(chosen_indices.size()[0]), 141 | chosen_indices] 142 | test_cap_len = test_cap_len[range(chosen_indices.size()[0]), 143 | chosen_indices] 144 | 145 | _, _, preds = training_listener.get_loss_acc(test_image, 146 | test_distractor_images, 147 | test_caption, test_cap_len, 148 | shuffle=False, output_pred=True) 149 | model = load_model(new_partner_args) 150 | listener = model.listener 151 | _, _, preds_new = listener.get_loss_acc(test_image, test_distractor_images, 152 | test_caption, test_cap_len, shuffle=False, 153 | output_pred=True) 154 | return np.mean(np.equal(preds, preds_new)) 155 | 156 | 157 | def different_partner_upperbound(test_image, test_distractor_images, 158 | test_caption_beams, test_cap_len, args, 159 | new_partner_args): 160 | model = load_model(new_partner_args) 161 | training_listener = model.listener 162 | chosen_indices = choose_captions_from_beams(training_listener, 163 | test_caption_beams, 164 | test_cap_len, test_image, 165 | test_distractor_images) 166 | test_caption = test_caption_beams[range(chosen_indices.size()[0]), 167 | chosen_indices] 168 | test_cap_len = test_cap_len[range(chosen_indices.size()[0]), 169 | chosen_indices] 170 | 171 | model = load_model(new_partner_args) 172 | listener = model.listener 173 | _, acc = listener.get_loss_acc(test_image, test_distractor_images, 174 | test_caption, test_cap_len) 175 | return acc 176 | 177 | 178 | def different_partner_upperbound_confidence( 179 | test_image, test_distractor_images, test_caption_beams, test_cap_len, args, 180 | new_partner_args): 181 | model = load_model(new_partner_args) 182 | training_listener = model.listener 183 | chosen_indices = choose_captions_from_beams(training_listener, 184 | test_caption_beams, 185 | test_cap_len, test_image, 186 | test_distractor_images) 187 | test_caption = test_caption_beams[range(chosen_indices.size()[0]), 188 | chosen_indices] 189 | test_cap_len = test_cap_len[range(chosen_indices.size()[0]), 190 | chosen_indices] 191 | 192 | model = load_model(new_partner_args) 193 | listener = model.listener 194 | _, _, logits = listener.get_loss_acc( 195 | test_image, test_distractor_images, test_caption, test_cap_len, 196 | output_logits=True) 197 | confidence = torch.mean(torch.max(logits, dim=-1)) 198 | return confidence 199 | 200 | 201 | def different_partner_nonchosen_baseline(test_image, test_distractor_images, 202 | test_caption_beams, test_cap_len, args, 203 | new_partner_args): 204 | test_caption = test_caption_beams[:, 0] 205 | test_cap_len = test_cap_len[:, 0] 206 | 207 | model = load_model(new_partner_args) 208 | listener = model.listener 209 | _, acc = listener.get_loss_acc(test_image, test_distractor_images, 210 | test_caption, test_cap_len) 211 | return acc 212 | 213 | 214 | def main(args): 215 | random.seed(args.seed) 216 | np.random.seed(args.seed) 217 | torch.manual_seed(args.seed) 218 | 219 | args.cuda = torch.cuda.is_available() 220 | args.device = torch.device('cuda' if args.cuda else 'cpu') 221 | print("OPTS:\n", vars(args)) 222 | feat_path = args.coco_path 223 | data_path = args.coco_path 224 | 225 | (_, _, test_images) \ 226 | = [torch.load('{}/feats/{}'.format(feat_path, x)) for x in 227 | "train_feats valid_feats test_feats".split()] 228 | test_images = test_images.to(device=args.device) 229 | 230 | (w2i, i2w) = [torch.load(data_path + 'dics/{}'.format(x)) 231 | for x in "w2i i2w".split()] 232 | w2i, i2w = truncate_dicts(w2i, i2w, args.num_words) 233 | args.vocab_size = len(w2i) 234 | args.w2i = w2i 235 | args.i2w = i2w 236 | 237 | test_caption_beams, test_cap_len = get_caption_candidates( 238 | test_images, args) 239 | 240 | distractors_images = get_distractor_images(test_images, args) 241 | print( 242 | f"Same partner baseline score: {same_partner_baseline(test_images, distractors_images, test_caption_beams, test_cap_len, args)}") 243 | 244 | new_args = copy.deepcopy(args) 245 | new_args.save_dir = new_args.listener_save_dir 246 | new_args.vocab_size = new_args.listener_vocab_size 247 | new_args.seed = new_args.listener_seed 248 | print( 249 | f"Different partner baseline score: {different_partner_baseline(test_images, distractors_images, test_caption_beams, test_cap_len, args, new_args)}") 250 | 251 | print( 252 | f"Different partner non-chosen baseline score: {different_partner_nonchosen_baseline(test_images, distractors_images, test_caption_beams, test_cap_len, args, new_args)}") 253 | 254 | print( 255 | f"Different partner upperbound score: {different_partner_upperbound(test_images, distractors_images, test_caption_beams, test_cap_len, args, new_args)}") 256 | 257 | print( 258 | f"Different partner overlap: {overlap_with_training_listener(test_images, distractors_images, test_caption_beams, test_cap_len, args, new_args)}") 259 | 260 | print( 261 | f"Different partner confidence score: {different_partner_upperbound_confidence(test_images, distractors_images, test_caption_beams, test_cap_len, args, new_args)}") 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser() 265 | parser.add_argument('--seed', type=int, default=0, 266 | help='seed') 267 | parser.add_argument("--num_seed_examples", type=int, default=1000, 268 | help="Number of seed examples") 269 | parser.add_argument("--num_distrs", type=int, default=9, 270 | help="Number of distractors") 271 | parser.add_argument("--s2p_schedule", type=str, default="sched", 272 | help="s2p schedule") 273 | parser.add_argument("--s2p_selfplay_updates", type=int, default=50, 274 | help="s2p self-play updates") 275 | parser.add_argument("--s2p_list_updates", type=int, default=50, 276 | help="s2p listener supervised updates") 277 | parser.add_argument("--s2p_spk_updates", type=int, default=50, 278 | help="s2p speaker supervised updates") 279 | parser.add_argument("--s2p_batch_size", type=int, default=1000, 280 | help="s2p batch size") 281 | parser.add_argument("--pop_batch_size", type=int, default=1000, 282 | help="Pop Batch size") 283 | parser.add_argument("--rand_perc", type=float, default=0.75, 284 | help="rand perc") 285 | parser.add_argument("--sched_rand_frz", type=int, default=0.5, 286 | help="sched_rand_frz perc") 287 | parser.add_argument("--num_words", type=int, default=100, 288 | help="Number of words in the vocabulary") 289 | parser.add_argument("--seq_len", type=int, default=15, 290 | help="Max Sequence length of speaker utterance") 291 | parser.add_argument("--unk_perc", type=float, default=0.3, 292 | help="Max percentage of ") 293 | parser.add_argument("--max_iters", type=int, default=300, 294 | help="max training iters") 295 | parser.add_argument("--D_img", type=int, default=2048, 296 | help="ResNet feature dimensionality. Can't change this") 297 | parser.add_argument("--D_hid", type=int, default=512, 298 | help="RNN hidden state dimensionality") 299 | parser.add_argument("--D_emb", type=int, default=256, 300 | help="Token embedding (word) dimensionality") 301 | parser.add_argument("--lr", type=float, default=2e-4, 302 | help="Learning rate") 303 | parser.add_argument("--dropout", type=float, default=0.0, 304 | help="Dropout probability") 305 | parser.add_argument("--temp", type=float, default=1.0, 306 | help="Gumbel temperature") 307 | parser.add_argument("--hard", type=bool, default=True, 308 | help="Hard Gumbel-Softmax Sampling.") 309 | parser.add_argument("--min_list_steps", type=int, default=2000, 310 | help="Min num of listener supervised steps") 311 | parser.add_argument("--min_spk_steps", type=int, default=1000, 312 | help="Min num of speaker supervised steps") 313 | parser.add_argument("--test_every", type=int, default=10, 314 | help="test interval") 315 | parser.add_argument("--seed_val_pct", type=float, default=0.1, 316 | help="% of seed samples used as validation for early stopping") 317 | parser.add_argument('--coco_path', type=str, default="./coco/", 318 | help="MSCOCO dir path") 319 | parser.add_argument("--save_dir", type=str, default="", 320 | help="Save directory.") 321 | parser.add_argument("--sample_temp", type=float, default=30, 322 | help="Temperature used for sampling difficult distractors") 323 | parser.add_argument("--sample_lang", action="store_true", default=False, 324 | help="Sample the languges of the data") 325 | parser.add_argument("--alpha", type=float, default=0.5, 326 | help="Parameter of Dirichlet distribution") 327 | parser.add_argument("--beam_size", type=int, default=10, 328 | help="Beam size") 329 | parser.add_argument("--listener_save_dir", type=str, default="", 330 | help="Listener save dir") 331 | parser.add_argument("--listener_vocab_size", type=int, default=100, 332 | help="Listener vocab size") 333 | parser.add_argument("--listener_seed", type=int, default=0, 334 | help="Listener seed") 335 | 336 | args = parser.parse_args() 337 | main(args) 338 | -------------------------------------------------------------------------------- /ibr_game/translate_caption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from googletrans import Translator 3 | from varname import nameof 4 | 5 | import pickle 6 | import tqdm 7 | 8 | import argparse 9 | 10 | from multiprocessing import Pool 11 | from functools import reduce 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--lang", type=str) 15 | parser.add_argument("--split", type=str) 16 | parser.add_argument("--threads", type=int) 17 | 18 | args = parser.parse_args() 19 | 20 | translator = Translator() 21 | 22 | train_cap = torch.load("coco/labs/train_org") 23 | valid_cap = torch.load("coco/labs/valid_org") 24 | test_cap = torch.load("coco/labs/test_org") 25 | 26 | i2w = torch.load("coco/dics/i2w") 27 | w2i = torch.load("coco/dics/w2i") 28 | 29 | cnt = 0 30 | 31 | languages = ["de", "lt", "zh-cn", "it", "fr", "pt", "es", "ja", "el"] 32 | splits = ["train_cap", "valid_cap", "test_cap"] 33 | 34 | assert args.lang in languages and args.split in splits 35 | language = args.lang 36 | split = args.split 37 | print(f"Now start translating {language}") 38 | collection = locals()[split] 39 | 40 | def get_translated(end_points): 41 | begin, end = end_points 42 | caps = [] 43 | for image in collection[begin:end]: 44 | image_res_caps = [] 45 | for cap in image: 46 | cap = cap[1:-1] 47 | try: 48 | cap.remove(w2i[""]) 49 | except: 50 | pass 51 | try: 52 | cap.remove(w2i["'s"]) 53 | except: 54 | pass 55 | cap = ' '.join(map(lambda x: i2w[x], cap)) 56 | try: 57 | translated_cap = translator.translate(cap, dest=language).text 58 | except: 59 | translated_cap = "" 60 | image_res_caps.append(translated_cap) 61 | caps.append(image_res_caps) 62 | return caps 63 | 64 | partition = [] 65 | for i in range(args.threads): 66 | if i != args.threads-1: 67 | partition.append((len(collection)//args.threads * i, len(collection)//args.threads * (i+1))) 68 | else: 69 | partition.append((len(collection)//args.threads * i, len(collection))) 70 | 71 | import time 72 | start = time.time() 73 | with Pool(args.threads) as p: 74 | caps_to_merge = p.map(get_translated, partition) 75 | 76 | caps = reduce(lambda x, y: x + y, caps_to_merge) 77 | end = time.time() 78 | print(f"Multiprocessing uses {end - start} seconds.") 79 | 80 | pickle.dump(caps, open(f"coco/labs/{split}_{language}.pkl", "wb")) 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /ibr_game/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | import itertools 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from typing import Tuple, List, Dict 12 | 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | # def rejection_sampling(target_sample, seed_train_images, seed_train_captions, left_samples, batch_size): 17 | # T = 30 # rejection sampling temperature 18 | 19 | # def overlap_sents(lst1, lst2): 20 | # set1 = set(lst1) 21 | # set2 = set(lst2) 22 | # return len(set1 & set2) / len(set1 | set2) 23 | 24 | # def overlap(captions_0, captions_1): 25 | # overlap_sum = 0 26 | # for caption_0 in captions_0: 27 | # for caption_1 in captions_1: 28 | # overlap_sum += overlap_sents(caption_0, caption_1) 29 | # return overlap_sum / len(captions_0) / len(captions_1) 30 | 31 | # def unnorm_prob(a_id, b_id): 32 | # # return torch.exp(-torch.dist(seed_train_images[a_id], seed_train_images[b_id])/T) 33 | # return torch.exp(overlap(seed_train_captions[a_id], seed_train_captions[b_id])/T) 34 | 35 | # distrs = [] 36 | # distrs_app = distrs.append 37 | # for i in range(batch_size): 38 | # distr = random.choice(left_samples) 39 | # while(distr == target_sample[i] or unnorm_prob(target_sample[i], distr) < random.random()): 40 | # distr = random.choice(left_samples) 41 | # distrs_app(distr) 42 | # return distrs 43 | 44 | 45 | # def get_s2p_batch(seed_train_images, seed_train_captions, args, batch_size=None, dist_matrix=None): 46 | # data_size = seed_train_images.shape[0] 47 | # if batch_size is None: 48 | # batch_size = args.s2p_batch_size 49 | # if batch_size > data_size: 50 | # batch_size = data_size 51 | 52 | # target_sample = random.sample(list(range(data_size)), batch_size) 53 | # left_samples = list(range(data_size)) 54 | # distractor_samples = [[] for i in range(args.num_distrs)] 55 | # for i in range(batch_size): 56 | # if dist_matrix is None: 57 | # dsample = np.random.choice(left_samples, args.num_distrs, replace=False) 58 | # else: 59 | # dsample = np.random.choice(left_samples, args.num_distrs, replace=False, 60 | # p=dist_matrix[target_sample[i]]) 61 | # for j in range(args.num_distrs): 62 | # distractor_samples[j].append(dsample[j]) 63 | 64 | # image_batch = seed_train_images[target_sample] 65 | # distractor_image_batch = [] 66 | # for ds in distractor_samples: 67 | # d_image_batch = seed_train_images[ds] 68 | # distractor_image_batch.append(d_image_batch) 69 | # caption_batch = [seed_train_captions[t] for t in target_sample] 70 | # caption_len_batch = torch.tensor([len(c) for c in caption_batch], dtype=torch.long, device=device) 71 | 72 | # caption_batch = torch.tensor([np.pad(c, (0, args.seq_len - len(c))) for c in caption_batch], dtype=torch.long, device=device) 73 | 74 | # caption_batch_onehot = torch.zeros(caption_batch.shape[0], caption_batch.shape[1], args.vocab_size, 75 | # device=device).scatter_(-1, caption_batch.unsqueeze(-1), 1) 76 | 77 | # return image_batch, distractor_image_batch, caption_batch_onehot, caption_len_batch 78 | 79 | 80 | def get_s2p_batch(seed_train_images, seed_train_captions, args, all_seed_train_captions=None, batch_size=None): 81 | data_size = seed_train_images.shape[0] 82 | if batch_size is None: 83 | batch_size = args.s2p_batch_size 84 | if batch_size > data_size: 85 | batch_size = data_size 86 | 87 | target_sample = random.sample(list(range(data_size)), batch_size) 88 | left_samples = list(range(data_size)) 89 | distractor_samples = [] 90 | for i in range(args.num_distrs): 91 | dsample = random.sample(left_samples, batch_size) 92 | distractor_samples.append(dsample) 93 | 94 | image_batch = seed_train_images[target_sample] 95 | distractor_image_batch = [] 96 | for ds in distractor_samples: 97 | d_image_batch = seed_train_images[ds] 98 | distractor_image_batch.append(d_image_batch) 99 | caption_batch = [seed_train_captions[t] for t in target_sample] 100 | if all_seed_train_captions is not None: 101 | all_caption_batch = [all_seed_train_captions[t] for t in target_sample] 102 | caption_len_batch = torch.tensor( 103 | [len(c) for c in caption_batch], dtype=torch.long, device=device) 104 | 105 | caption_batch = torch.tensor([np.pad(c, (0, args.seq_len - len(c))) 106 | for c in caption_batch], dtype=torch.long, device=device) 107 | 108 | caption_batch_onehot = torch.zeros(caption_batch.shape[0], caption_batch.shape[1], args.vocab_size, 109 | device=device).scatter_(-1, caption_batch.unsqueeze(-1), 1) 110 | 111 | if all_seed_train_captions is not None: 112 | return image_batch, distractor_image_batch, caption_batch_onehot, caption_len_batch, all_caption_batch 113 | else: 114 | return image_batch, distractor_image_batch, caption_batch_onehot, caption_len_batch 115 | 116 | 117 | def get_batch_with_speaker(train_images, speaker, args, batch_size=None): 118 | data_size = train_images.shape[0] 119 | if batch_size is None: 120 | batch_size = args.s2p_batch_size 121 | if batch_size > data_size: 122 | batch_size = data_size 123 | 124 | target_sample = random.sample(list(range(data_size)), batch_size) 125 | left_samples = list(range(data_size)) 126 | distractor_samples = [] 127 | for i in range(args.num_distrs): 128 | dsample = random.sample(left_samples, batch_size) 129 | distractor_samples.append(dsample) 130 | 131 | image_batch = train_images[target_sample] 132 | distractor_image_batch = [] 133 | for ds in distractor_samples: 134 | d_image_batch = train_images[ds] 135 | distractor_image_batch.append(d_image_batch) 136 | 137 | train_captions, train_captions_len = speaker.forward(image_batch) 138 | train_captions = train_captions.detach() 139 | train_captions_len = train_captions_len.detach() 140 | 141 | return image_batch, distractor_image_batch, train_captions, train_captions_len 142 | 143 | 144 | def get_distractor_images(image_pool, args, batch_size=None): 145 | data_size = image_pool.shape[0] 146 | if batch_size is None: 147 | batch_size = data_size 148 | left_samples = list(range(data_size)) 149 | distractor_samples = [] 150 | for i in range(args.num_distrs): 151 | dsample = random.sample(left_samples, batch_size) 152 | distractor_samples.append(dsample) 153 | 154 | distractor_image_batch = [] 155 | for ds in distractor_samples: 156 | d_image_batch = image_pool[ds] 157 | distractor_image_batch.append(d_image_batch) 158 | 159 | return distractor_image_batch 160 | 161 | 162 | # def get_pop_batch(train_images, args, batch_size=None, dist_matrix=None): 163 | # data_size = train_images.shape[0] 164 | # if batch_size is None: 165 | # batch_size = args.pop_batch_size 166 | 167 | # target_sample = random.sample(list(range(data_size)), batch_size) 168 | # left_samples = list(range(data_size)) 169 | # distractor_samples = [[] for i in range(args.num_distrs)] 170 | # for i in range(batch_size): 171 | # if dist_matrix is None: 172 | # dsample = np.random.choice(left_samples, args.num_distrs, replace=False) 173 | # else: 174 | # dsample = np.random.choice(left_samples, args.num_distrs, replace=False, 175 | # p=dist_matrix[target_sample[i]]) 176 | # for j in range(args.num_distrs): 177 | # distractor_samples[j].append(dsample[j]) 178 | 179 | # image_batch = train_images[target_sample] 180 | # distractor_image_batch = [] 181 | # for ds in distractor_samples: 182 | # d_image_batch = train_images[ds] 183 | # distractor_image_batch.append(d_image_batch) 184 | 185 | # return image_batch, distractor_image_batch 186 | 187 | def get_pop_batch(train_images, args, batch_size=None): 188 | data_size = train_images.shape[0] 189 | if batch_size is None: 190 | batch_size = args.pop_batch_size 191 | 192 | target_sample = random.sample(list(range(data_size)), batch_size) 193 | left_samples = list(range(data_size)) 194 | distractor_samples = [] 195 | for i in range(args.num_distrs): 196 | dsample = random.sample(left_samples, batch_size) 197 | distractor_samples.append(dsample) 198 | 199 | image_batch = train_images[target_sample] 200 | distractor_image_batch = [] 201 | for ds in distractor_samples: 202 | d_image_batch = train_images[ds] 203 | distractor_image_batch.append(d_image_batch) 204 | 205 | return image_batch, distractor_image_batch 206 | 207 | 208 | def trim_caps(caps, minlen, maxlen): 209 | new_cap = [[cap for cap in cap_i if maxlen >= 210 | len(cap) >= minlen] for cap_i in caps] 211 | return new_cap 212 | 213 | 214 | def truncate_dicts(w2i, i2w, trunc_size): 215 | symbols_to_keep = ["", "", "", ""] 216 | inds_to_keep = [w2i[s] for s in symbols_to_keep] 217 | 218 | w2i_trunc = OrderedDict(itertools.islice(w2i.items(), trunc_size)) 219 | i2w_trunc = OrderedDict(itertools.islice(i2w.items(), trunc_size)) 220 | 221 | for s, i in zip(symbols_to_keep, inds_to_keep): 222 | w2i_trunc[s] = i 223 | i2w_trunc[i] = s 224 | 225 | return w2i_trunc, i2w_trunc 226 | 227 | 228 | def truncate_captions(train_captions, valid_captions, test_captions, w2i, i2w): 229 | unk_ind = w2i[""] 230 | 231 | def truncate_data(data): 232 | for i in range(len(data)): 233 | for ii in range(len(data[i])): 234 | for iii in range(len(data[i][ii])): 235 | if data[i][ii][iii] not in i2w: 236 | data[i][ii][iii] = unk_ind 237 | return data 238 | 239 | train_captions = truncate_data(train_captions) 240 | valid_captions = truncate_data(valid_captions) 241 | test_captions = truncate_data(test_captions) 242 | 243 | return train_captions, valid_captions, test_captions 244 | 245 | 246 | def load_model(model_dir, model, device): 247 | model_dicts = torch.load(os.path.join( 248 | model_dir, 'model.pt'), map_location=device) 249 | model.load_state_dict(model_dicts) 250 | iters = model_dicts['iters'] 251 | best_test_acc = model_dicts['test_acc'] 252 | print("Best Test acc:", best_test_acc, " at", iters, "iters") 253 | 254 | 255 | def save_to_file(vals, folder='', file=''): 256 | if not os.path.exists(folder): 257 | os.makedirs(folder) 258 | save_str = os.path.join(folder, file+'.pkl') 259 | with open(save_str, 'wb') as f1: 260 | pickle.dump(vals, f1) 261 | 262 | 263 | def torch_save_to_file(to_save, folder='', file=''): 264 | if not os.path.exists(folder): 265 | os.makedirs(folder) 266 | torch.save(to_save, os.path.join(folder, file)) 267 | 268 | 269 | def to_sentence(inds_list, i2w, trim=False): 270 | sentences = [] 271 | for inds in inds_list: 272 | if type(inds) is not list: 273 | inds = list(inds) 274 | sentence = [] 275 | for i in inds: 276 | sentence.append(i2w[i]) 277 | if i2w[i] == "" and trim: 278 | break 279 | sentences.append(' '.join(sentence)) 280 | return sentences 281 | 282 | 283 | def filter_caps(captions, images, w2i, perc, keep_caps=False, 284 | original_captions=None): 285 | if keep_caps: 286 | original_captions = captions 287 | new_train_captions = [] 288 | new_train_images = [] 289 | all_new_train_captions = [] 290 | for ci, cap in enumerate(captions): 291 | random.shuffle(cap) 292 | for cap_ in cap: 293 | if len(cap) > 0 and cap_.count(w2i[""]) / len(cap_) < perc: 294 | new_train_captions.append(cap_) 295 | new_train_images.append(images[ci]) 296 | if keep_caps: 297 | all_new_train_captions.append(original_captions[ci]) 298 | break 299 | if keep_caps: 300 | return new_train_captions, torch.stack(new_train_images), all_new_train_captions 301 | else: 302 | return new_train_captions, torch.stack(new_train_images) 303 | 304 | 305 | def sample_gumbel(shape, eps=1e-20): 306 | U = torch.empty(shape, device=device).uniform_(0, 1) 307 | return -torch.log(-torch.log(U + eps) + eps) 308 | 309 | 310 | def gumbel_softmax_sample(logits, temp): 311 | y = (logits + sample_gumbel(logits.shape)) / temp 312 | return F.softmax(y, dim=-1) 313 | 314 | 315 | def gumbel_softmax(logits, temp, hard): 316 | y = gumbel_softmax_sample(logits, temp) 317 | y_max, y_max_idx = torch.max(y, 1, keepdim=True) 318 | if hard: 319 | y_hard = torch.zeros(y.shape, device=device).scatter_(1, y_max_idx, 1) 320 | y = (y_hard - y).detach() + y 321 | return y, y_max_idx 322 | 323 | 324 | def sample(logits: torch.Tensor, temp: float = 1.0) \ 325 | -> Tuple[torch.Tensor, torch.Tensor]: 326 | dists = torch.distributions.Categorical(logits=logits * temp) 327 | result = dists.sample() 328 | logprob = dists.log_prob(result) 329 | return result, logprob 330 | 331 | 332 | def truncate_msg(msg, msg_lens): 333 | result_msg = [] 334 | for i, j in zip(msg, msg_lens): 335 | result_msg.append(i[1:j+1]) 336 | return result_msg 337 | 338 | 339 | def build_vocab(word_lang_list: dict, word_lang_dist: dict) -> set: 340 | vocab = set() 341 | for lang in word_lang_list: 342 | vocab.update(word_lang_list[lang][:word_lang_dist[lang]]) 343 | return vocab 344 | 345 | 346 | def update_vocab(w2i: dict, i2w: dict, vocab: set) -> Tuple[dict, dict]: 347 | symbols_to_keep = ["", "", "", ""] 348 | new_i2w = dict() 349 | new_w2i = dict() 350 | i2i = dict() 351 | for idx, w in enumerate(itertools.chain(symbols_to_keep, vocab)): 352 | assert i2w[w2i[w]] == w 353 | i2i[w2i[w]] = idx 354 | new_w2i[w] = idx 355 | new_i2w[idx] = w 356 | return new_w2i, new_i2w, i2i 357 | 358 | 359 | def index_map(data, i2i, unk_ind): 360 | for i in range(len(data)): 361 | for ii in range(len(data[i])): 362 | for iii in range(len(data[i][ii])): 363 | if data[i][ii][iii] not in i2i: 364 | data[i][ii][iii] = unk_ind 365 | else: 366 | data[i][ii][iii] = i2i[data[i][ii][iii]] 367 | return data 368 | 369 | 370 | def sample_vocab_dirichlet(alpha: float, whole_ratio: float, 371 | lang_size: Dict[str, int]) -> Dict[str, int]: 372 | ratio_list = list(np.random.dirichlet( 373 | [alpha] * len(lang_size)) * whole_ratio) 374 | word_lang_dist = dict() 375 | for idx, i in enumerate(lang_size): 376 | word_lang_dist[i] = int(ratio_list[idx] * lang_size[i]) 377 | return word_lang_dist 378 | 379 | 380 | def calc_sim_matrix(images: torch.Tensor, dis_metric='cosine') -> torch.Tensor: 381 | n_img, D_img = images.size() 382 | if dis_metric == 'cosine': 383 | with torch.no_grad(): 384 | dot = torch.tensordot(images, images, dims=([1], [1])) 385 | norm = torch.norm(images, dim=1) 386 | return dot / norm / norm.unsqueeze(1).expand_as(dot) 387 | else: 388 | raise NotImplementedError 389 | 390 | 391 | def nearest_images(images: torch.Tensor, n: int): 392 | with torch.no_grad(): 393 | sim_matrix = calc_sim_matrix(images) 394 | sim_matrix.fill_diagonal_(-1) 395 | _, indices = torch.topk(sim_matrix, n, dim=1) 396 | return indices 397 | --------------------------------------------------------------------------------