├── .gitignore ├── LICENSE ├── README.md ├── TRAINING.md ├── environment.yml ├── requirements.txt ├── scripts ├── examine_structure.py ├── generate_sqoop.py ├── print_programs.py ├── run_model.py ├── train │ ├── convlstm_flatqa.sh │ ├── ee_flatqa.sh │ ├── film_flatqa.sh │ ├── mac_flatqa.sh │ ├── rel_flatqa.sh │ └── shnmn_flatqa.sh └── train_model.py ├── setup.py └── vr ├── __init__.py ├── data.py ├── embedding.py ├── models ├── __init__.py ├── baselines.py ├── convlstm.py ├── film_gen.py ├── filmed_net.py ├── hetero_net.py ├── layers.py ├── maced_net.py ├── module_net.py ├── relation_net.py ├── seq2seq.py ├── seq2seq_att.py ├── shnmn.py └── simple_module_net.py ├── plotting.py ├── preprocess.py ├── programs.py ├── treeGenerator.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | data 3 | 4 | # Experiment files 5 | exp 6 | scripts/dev 7 | notebooks/ 8 | output/ 9 | 10 | # Image files 11 | img/cst 12 | 13 | # Editor files 14 | *.DS_Store 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | env/ 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv 102 | venv/ 103 | ENV/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | 118 | # eclipse 119 | .project 120 | .pydevproject 121 | 122 | # vim 123 | *.swp 124 | .editorconfig 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Systematic Generalization: What Is Required and Can It Be Learned 2 | 3 | The code used for the experiments in [the paper](https://openreview.net/forum?id=HkezXnA9YX). 4 | 5 | ### Setup 6 | 7 | Clone the repo 8 | ``` 9 | git clone https://github.com/rizar/systematic-generalization-sqoop.git 10 | cd systematic-generalization-sqoop 11 | export NMN=$PWD 12 | ``` 13 | Setup the environment using `conda` (recommended) and install this as a package in development mode 14 | ``` 15 | conda env create -f environment.yml 16 | conda activate sysgen 17 | pip install -e . 18 | ``` 19 | if you don't use conda, you can do `pip install --user -r requirements.txt` 20 | 21 | 22 | Download all versions of SQOOP dataset from [here](https://drive.google.com/file/d/1yaXQL-MH0nQM9cqRbIrWkB3kBNM_ltY_/view?usp=sharing) 23 | and unpack it. Let `$DATA` be the location of the data on your system. 24 | 25 | ### Running Experiments 26 | 27 | In the examples below we are using SQOOP with `#rhs/lhs=1`, other versions can be used by changing `--data_dir`. 28 | 29 | #### FiLM 30 | 31 | scripts/train/film_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000 --checkpoint_path model.pt\ 32 | --num_iterations 200000 33 | 34 | #### MAC 35 | 36 | scripts/train/mac_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000 --checkpoint_path model.pt\ 37 | --num_iterations 100000 38 | 39 | #### Conv+LSTM 40 | 41 | scripts/train/convlstm_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000 --checkpoint_path model.pt\ 42 | --num_iterations 200000 43 | 44 | #### RelNet 45 | 46 | scripts/train/rel_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000 --checkpoint_path model.pt\ 47 | --num_iterations 500000 48 | 49 | #### NMN-Tree, NMN-Chain, NMN-Chain-Shortcut 50 | 51 | scripts/train/shnmn_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000\ 52 | --hard_code_tau --tau_init tree --hard_code_alpha --alpha_init correct\ 53 | --num_iterations 50000 --checkpoint_path model.pt 54 | 55 | For a different layout use `--tau_init=chain` or `--tau_init=chain_shortcut`. For a different module, use `--use_module=find`, the default is Residual. 56 | Make sure to train for 200000 iterations if you use Find. 57 | 58 | #### Stochastic-N2NMN 59 | 60 | scripts/train/shnmn_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000\ 61 | --shnmn_type hard --model_bernoulli 0.5 --hard_code_alpha --alpha_init=correct\ 62 | --num_iterations 200000 --checkpoint_path model.pt 63 | 64 | `--model_bernoulli` is the initial probability of the model being a tree. 65 | 66 | #### Attention-N2NMN 67 | 68 | scripts/train/shnmn_flatqa.sh --data_dir $DATA/sqoop-variety_1-repeats_30000\ 69 | --hard_code_tau --tau_init tree --use_module=find --num_iterations 200000\ 70 | --checkpoint_path model.pt 71 | 72 | ### Citation 73 | 74 | **Bahdanau, D., Murty, S.**, Noukhovitch, M., Nguyen, T. H., de Vries, H., & Courville, A. (2018). Systematic Generalization: What Is Required and Can It Be Learned?. ICLR 2019 75 | 76 | (the first two authors contributed equally) 77 | 78 | ``` 79 | @inproceedings{sysgen2019, 80 | title = {Systematic Generalization: What Is Required and Can It Be Learned?}, 81 | booktitle = {International Conference on Learning Representations}, 82 | author = {Bahdanau, Dzmitry and Murty, Shikhar and Noukhovitch, Michael and Nguyen, Thien Huu and Vries, Harm de and Courville, Aaron}, 83 | year = {2019}, 84 | url = {https://openreview.net/forum?id=HkezXnA9YX}, 85 | } 86 | ``` 87 | 88 | ### Acknowledgements. 89 | 90 | This code is based on the reference implementation for ["FiLM: Visual Reasoning with a General Conditioning Layer"](https://github.com/ethanjperez/film) by Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, Aaron Courville (AAAI 2018) which was based on the reference implementation for ["Inferring and Executing Programs for Visual Reasoning"](https://github.com/facebookresearch/clevr-iep) by Justin Johnson, Bharath Hariharan, Laurens van der Maaten, Judy Hoffman, Fei-Fei Li, Larry Zitnick, Ross Girshick (ICCV 2017) 91 | -------------------------------------------------------------------------------- /TRAINING.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | Here we will walk through the process of training your own model on the CLEVR dataset, 4 | and finetuning the model on the CLEVR-Humans dataset. 5 | All training code runs on GPU, and assumes that CUDA and cuDNN already been installed. 6 | 7 | - [Preprocessing CLEVR](#preprocessing-clevr) 8 | - [Training on CLEVR](#training-on-clevr) 9 | - [Training baselines on CLEVR](#training-baselines-on-clevr) 10 | - [Preprocessing CLEVR-Humans](#preprocessing-clevr-humans) 11 | - [Finetuning on CLEVR-Humans](#finetuning-on-clevr-humans) 12 | - [Finetuning baselines on CLEVR-Humans](#finetuning-baselines-on-clevr-humans) 13 | 14 | ## Preprocessing CLEVR 15 | 16 | Before you can train any models, you need to download the 17 | [CLEVR dataset](http://cs.stanford.edu/people/jcjohns/clevr/); 18 | you also need to extract features for the images, and preprocess the questions and programs. 19 | 20 | ### Step 1: Download the data 21 | 22 | First you need to download and unpack the [CLEVR dataset](http://cs.stanford.edu/people/jcjohns/clevr/). 23 | For the purpose of this tutorial we assume that all data will be stored in a new directory called `data/`: 24 | 25 | ```bash 26 | mkdir data 27 | wget https://s3-us-west-1.amazonaws.com/clevr/CLEVR_v1.0.zip -O data/CLEVR_v1.0.zip 28 | unzip data/CLEVR_v1.0.zip -d data 29 | ``` 30 | 31 | ### Step 2: Extract Image Features 32 | 33 | Extract ResNet-101 features for the CLEVR train, val, and test images with the following commands: 34 | 35 | ```bash 36 | python scripts/extract_features.py \ 37 | --input_image_dir data/CLEVR_v1.0/images/train \ 38 | --output_h5_file data/train_features.h5 39 | 40 | python scripts/extract_features.py \ 41 | --input_image_dir data/CLEVR_v1.0/images/val \ 42 | --output_h5_file data/val_features.h5 43 | 44 | python scripts/extract_features.py \ 45 | --input_image_dir data/CLEVR_v1.0/images/test \ 46 | --output_h5_file data/test_features.h5 47 | ``` 48 | 49 | ### Step 3: Preprocess Questions 50 | 51 | Preprocess the questions and programs for the CLEVR train, val, and test sets with the following commands: 52 | 53 | ```bash 54 | python scripts/preprocess_questions.py \ 55 | --input_questions_json data/CLEVR_v1.0/questions/CLEVR_train_questions.json \ 56 | --output_h5_file data/train_questions.h5 \ 57 | --output_vocab_json data/vocab.json 58 | 59 | python scripts/preprocess_questions.py \ 60 | --input_questions_json data/CLEVR_v1.0/questions/CLEVR_val_questions.json \ 61 | --output_h5_file data/val_questions.h5 \ 62 | --input_vocab_json data/vocab.json 63 | 64 | python scripts/preprocess_questions.py \ 65 | --input_questions_json data/CLEVR_v1.0/questions/CLEVR_test_questions.json \ 66 | --output_h5_file data/test_questions.h5 \ 67 | --input_vocab_json data/vocab.json 68 | ``` 69 | 70 | When preprocessing questions, we create a file `vocab.json` which stores the mapping between 71 | tokens and indices for questions and programs. We create this vocabulary when preprocessing 72 | the training questions, then reuse the same vocabulary file for the val and test questions. 73 | 74 | ## Training on CLEVR 75 | 76 | Models are trained through a three-step procedure: 77 | 78 | 1. Train the program generator using a small number of ground-truth programs 79 | 2. Train the execution engine using predicted outputs from the trained program generator 80 | 3. Jointly fine-tune both the program generator and the execution engine without any ground-truth programs 81 | 82 | ### Step 1: Train the Program Generator 83 | 84 | In this step we use a small number of ground-truth programs to train the program generator: 85 | 86 | ```bash 87 | python scripts/train_model.py \ 88 | --model_type PG \ 89 | --num_train_samples 18000 \ 90 | --num_iterations 20000 \ 91 | --checkpoint_every 1000 \ 92 | --checkpoint_path data/program_generator.pt 93 | ``` 94 | 95 | ### Step 2: Train the Execution Engine 96 | 97 | In this step we train the execution engine, using programs predicted from the program generator 98 | in the previous step: 99 | 100 | ```bash 101 | python scripts/train_model.py \ 102 | --model_type EE \ 103 | --program_generator_start_from data/program_generator.py \ 104 | --num_iterations 100000 \ 105 | --checkpoint_path data/execution_engine.pt 106 | ``` 107 | 108 | ### Step 3: Jointly train entire model 109 | 110 | In this step we jointly train the program generator and execution engine using REINFORCE: 111 | 112 | ```bash 113 | python scripts/train_model.py \ 114 | --model_type PG+EE \ 115 | --program_generator_start_from data/program_generator.pt \ 116 | --execution_engine_start_from data/execution_engine.pt \ 117 | --checkpoint_path data/joint_pg_ee.pt 118 | ``` 119 | 120 | ### Step 4: Test the model 121 | 122 | You can use the `run_model.py` script to test your model on the entire validation 123 | and test sets. To test the version of the model before finetuning on the val set: 124 | 125 | ```bash 126 | python scripts/run_model.py \ 127 | --program_generator data/program_generator.pt \ 128 | --execution_engine data/execution_engine.pt \ 129 | --input_question_h5 data/val_questions.h5 \ 130 | --input_features_h5 data/val_features.h5 131 | ``` 132 | 133 | You can test the jointly finetuned model like this: 134 | 135 | ```bash 136 | python scripts/run_model.py \ 137 | --program_generator data/joint_pg_ee.pt \ 138 | --execution_engine data/joint_pg_ee.pt \ 139 | --input_question_h5 data/val_questions.h5 \ 140 | --input_features_h5 data/val_features.h5 141 | ``` 142 | 143 | ## Training baselines on CLEVR 144 | 145 | ### Step 1: Train the model 146 | 147 | You can use the `train_model.py` script to train the LSTM, CNN+LSTM, and CNN+LSTM+SA baselines. 148 | 149 | For example you can train CNN+LSTM+SA+MLP like this: 150 | 151 | ```bash 152 | python scripts/train_model.py \ 153 | --model_type CNN+LSTM+SA \ 154 | --classifier_fc_dims 1024 \ 155 | --num_iterations 400000 \ 156 | --checkpoint_path data/cnn_lstm_sa_mlp.pt 157 | ``` 158 | 159 | ### Step 2: Test the model 160 | 161 | You can use the `run_model.py` script to test baseline models on the entire validation or test sets. 162 | You can run the model from the previous step on the entire val set like this: 163 | 164 | ```bash 165 | python scripts/run_model.py \ 166 | --baseline_model data/cnn_lstm_mlp.pt \ 167 | --input_question_h5 data/val_questions.h5 \ 168 | --input_features_h5 data/val_features.h5 169 | ``` 170 | 171 | ## Preprocessing CLEVR-Humans 172 | 173 | ### Step 1: Download the data 174 | 175 | You can download the CLEVR-Humans dataset like this: 176 | 177 | ```bash 178 | wget http://cs.stanford.edu/people/jcjohns/iep/CLEVR-Humans.zip -O data/CLEVR-Humans.zip 179 | unzip data/CLEVR-Humans.zip -d data 180 | ``` 181 | 182 | ### Step 2: Preprocess the data 183 | 184 | Preprocessing the CLEVR-Humans dataset is a bit tricky, since it contains words that do not appear 185 | in the CLEVR dataset. In addition, unlike CLEVR, we wish to replace infrequent words (which may be 186 | misspellings or typos) with a special `` token. Furthermore, in order to use models trained on 187 | CLEVR on CLEVR-Humans, we need to ensure that the vocabulary we compute on CLEVR-Humans is compatible 188 | with that from CLEVR which we preprocessed earlier. 189 | 190 | All of these issues are handled by the `preprocess_questions.py` script, but we need to pass a few 191 | extra flags to control this behavior. 192 | 193 | ```bash 194 | python scripts/preprocess_questions.py \ 195 | --input_questions_json data/CLEVR_humans/CLEVR_humans_train.json \ 196 | --input_vocab_json data/input_vocab.json \ 197 | --output_h5_file data/train_human_questions.h5 \ 198 | --output_vocab_json data/human_vocab.json \ 199 | --expand_vocab 1 \ 200 | --unk_threshold 10 \ 201 | --encode_unk 1 \ 202 | 203 | python scripts/preprocess_questions.py \ 204 | --input_questions_json data/CLEVR_humans/CLEVR_humans_val.json \ 205 | --input_vocab_json data/human_vocab.json \ 206 | --output_h5_file data/val_human_questions.h5 \ 207 | --encode_unk 1 208 | 209 | python scripts/preprocess_questions.py \ 210 | --input_questions_json data/CLEVR_humans/CLEVR_humans_test.json \ 211 | --input_vocab_json data/human_vocab.json \ 212 | --output_h5_file data/test_human_questions.h5 \ 213 | --encode_unk 1 214 | ``` 215 | 216 | ## Finetuning on CLEVR-Humans 217 | 218 | ### Step 1: Finetune the model 219 | 220 | The CLEVR-Humans dataset does not provide ground-truth programs, but we can use REINFORCE to 221 | jointly train our entire model on this dataset regardless. When finetuning on CLEVR-Humans, 222 | we only update the program generator to prevent overfitting. 223 | 224 | You can use the `train_model.py` script for finetuning like this: 225 | 226 | ```bash 227 | python scripts/train_model.py \ 228 | --model_type PG+EE \ 229 | --train_question_h5 data/train_human_questions.h5 \ 230 | --train_features_h5 data/train_features.h5 \ 231 | --val_question_h5 data/val_human_questions.h5 \ 232 | --val_features_h5 data/val_features.h5 \ 233 | --vocab_json data/human_vocab.json \ 234 | --program_generator_start_from data/joint_pg_ee.pt \ 235 | --execution_engine_start_from data/joint_pg_ee.pt \ 236 | --train_program_generator 1 \ 237 | --train_execution_engine 0 \ 238 | --learning_rate 1e-4 \ 239 | --num_iterations 100000 \ 240 | --checkpoint_every 500 \ 241 | --checkpoint_path data/human_program_generator.pt 242 | ``` 243 | 244 | ### Step 2: Test the model 245 | 246 | You can use the `run_model.py` script to run the model on the entire CLEVR-Humans 247 | validation or test set. In the previous step we only updated the program generator; 248 | when testing the model we use the execution engine that was trained on CLEVR. 249 | 250 | ```bash 251 | python scripts/run_model.py \ 252 | --program_generator data/human_program_generator.pt \ 253 | --execution_engine data/joint_pg_ee.pt \ 254 | --input_question_h5 data/val_human_questions.h5 \ 255 | --input_features_h5 data/val_features.h5 256 | ``` 257 | 258 | ## Finetuning baselines on CLEVR-Humans 259 | 260 | ### Step 1: Finetune the model 261 | 262 | You can use the `train_model.py` script to finetune the LSTM, CNN+LSTM, and CNN+LSTM+SA 263 | models on the CLEVR-Humans dataset. When finetuning baselines on CLEVR-Humans we only 264 | update the RNN to prevent overfitting. For example you can finetune the CNN+LSTM+SA+MLP 265 | model we trained earlier like this: 266 | 267 | ```bash 268 | python scripts/train.py \ 269 | --model_type CNN+LSTM+SA \ 270 | --train_question_h5 data/train_human_questions.h5 \ 271 | --train_features_h5 data/train_features.h5 \ 272 | --val_question_h5 data/val_human_questions.h5 \ 273 | --val_features_h5 data-ssd/val_features.h5 \ 274 | --vocab_json data/human_vocab.json \ 275 | --baseline_start_from data/cnn_lstm_sa_mlp.pt \ 276 | --baseline_train_only_rnn 1 \ 277 | --learning_rate 1e-4 \ 278 | --num_iterations 100000 \ 279 | --checkpoint_every 500 \ 280 | --checkpoint_path data/cnn_lstm_sa_mlp_human.pt 281 | ``` 282 | 283 | ### Step 2: Test the model 284 | 285 | You can use the `run_model.py` script to test the finetuned baseline models on 286 | the entire val or test sets of the CLEVR-Humans dataset. For example you can 287 | run the finetuned CNN+LSTM+SA+MLP model on the entire validation set like this: 288 | 289 | ```bash 290 | python scripts/run_model.py \ 291 | --baseline_model data/cnn_lstm_sa_mlp_human.pt \ 292 | --input_question_h5 data/val_human_questions.h5 \ 293 | --input_features_h5 data/val_features.h5 294 | ``` 295 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sysgen 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.6 7 | - pytorch 8 | - cuda90 9 | - torchvision 10 | - h5py 11 | - numpy 12 | - scipy 13 | - tqdm 14 | - termcolor 15 | - pillow 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | scipy 4 | torchvision 5 | h5py 6 | tqdm 7 | Pillow 8 | termcolor 9 | -------------------------------------------------------------------------------- /scripts/examine_structure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import h5py 8 | from scipy.misc import imread, imresize, imsave 9 | 10 | import glob 11 | import vr.utils as utils 12 | import vr.programs 13 | from vr.data import ClevrDataset, ClevrDataLoader 14 | from vr.preprocess import tokenize, encode 15 | 16 | 17 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--model_path', default=None) 22 | parser.add_argument('--data_dir', default=None, type=str) 23 | 24 | loss_fn = torch.nn.CrossEntropyLoss().to(device) 25 | 26 | def run_our_model_batch(args, model, loader, dtype): 27 | model.type(dtype) 28 | model.eval() 29 | 30 | num_correct, num_samples = 0, 0 31 | total_loss = 0.0 32 | 33 | 34 | start = time.time() 35 | for batch in tqdm(loader): 36 | questions, images, feats, answers, programs, program_lists = batch 37 | if isinstance(questions, list): 38 | questions_var = Variable(questions[0].type(dtype).long(), volatile=True) 39 | q_types += [questions[1].cpu().numpy()] 40 | else: 41 | questions_var = Variable(questions.type(dtype).long(), volatile=True) 42 | feats_var = Variable(feats.type(dtype), volatile=True) 43 | answers_var = Variable(answers.to(device)) 44 | scores = model(feats_var, questions_var) 45 | loss = loss_fn(scores, answers_var) 46 | probs = F.softmax(scores) 47 | total_loss += loss.data.cpu() 48 | _, preds = scores.data.cpu().max(1) 49 | num_correct += np.sum(preds == answers) 50 | num_samples += len(question) 51 | 52 | acc = float(num_correct) / num_samples 53 | print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) 54 | print('%.2fs to evaluate' % (start - time.time())) 55 | 56 | print(loss) 57 | 58 | 59 | 60 | 61 | def main(args): 62 | all_checkpoints = glob.glob('%s/*.pt' %args.model_path) 63 | print(all_checkpoints) 64 | 65 | for i, checkpoint in enumerate(all_checkpoints): 66 | 67 | model, _ = utils.load_execution_engine(checkpoint, False, 'SHNMN') 68 | for name, param in model.named_parameters(): 69 | if param.requires_grad: 70 | print(name) 71 | 72 | f = open('f_%d.txt' %i, 'w') 73 | f.write('%s\n' %checkpoint) 74 | f.write('HARD_TAU | HARD_ALPHA \n') 75 | f.write('%s-%s\n'%(model.hard_code_tau, model.hard_code_alpha)) 76 | f.write('TAUS\n') 77 | f.write('p(model) : %s\n' %str(F.sigmoid(model.model_bernoulli))) 78 | for i in range(3): 79 | tau0 = model.tau_0[i, :(i+2)] if model.hard_code_tau else F.softmax(model.tau_0[i, :(i+2)] ) 80 | f.write('tau0: %s\n' %str(tau0.data.cpu().numpy())) 81 | tau1 = model.tau_1[i, :(i+2)] if model.hard_code_tau else F.softmax(model.tau_1[i, :(i+2)] ) 82 | f.write('tau1: %s\n' %str(tau1.data.cpu().numpy())) 83 | 84 | f.write('ALPHAS\n') 85 | for i in range(3): 86 | alpha = model.alpha[i] if model.hard_code_alpha else F.softmax(model.alpha[i]) 87 | f.write('alpha: %s\n' % " ".join(['{:.3f}'.format(float(x)) for x in alpha.view(-1).data.numpy()])) 88 | 89 | f.close() 90 | 91 | if __name__ == '__main__': 92 | args = parser.parse_args() 93 | main(args) 94 | 95 | -------------------------------------------------------------------------------- /scripts/print_programs.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import sys 3 | import json 4 | import numpy 5 | import time 6 | import timeit 7 | from vr.data import ClevrDataset 8 | from vr.utils import load_vocab 9 | import cProfile, pstats, io 10 | 11 | 12 | def print_program_tree(program, prefix): 13 | token = program_vocab[program[0]] 14 | cur_arity = arity[token] 15 | print("{}{} {}".format(prefix, token, str(cur_arity))) 16 | if cur_arity == 0: 17 | return 1 18 | if cur_arity == 1: 19 | return 1 + print_program_tree(program[1:], prefix + " ") 20 | if cur_arity == 2: 21 | right_subtree = 1 + print_program_tree(program[1:], prefix + " ") 22 | return right_subtree + print_program_tree(program[right_subtree:], prefix + " ") 23 | raise ValueError() 24 | 25 | 26 | f = h5py.File(sys.argv[1]) 27 | num = int(sys.argv[2]) if len(sys.argv) > 1 else 10 28 | vocab = load_vocab('vocab.json') 29 | arity = vocab['program_token_arity'] 30 | program_vocab = vocab['program_idx_to_token'] 31 | question_vocab = vocab['question_idx_to_token'] 32 | programs = None 33 | if 'programs' in f: 34 | programs = f['programs'] 35 | questions = f['questions'] 36 | if 'answers' in f: 37 | answers = f['answers'] 38 | for i in range(num): 39 | if programs: 40 | prog = programs[i] 41 | print_program_tree(programs[i], "") 42 | quest = questions[i] 43 | print(" ".join(question_vocab[quest[j]] for j in range(len(quest)))) 44 | if 'answers' in f: 45 | print(answers[i]) 46 | -------------------------------------------------------------------------------- /scripts/run_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, Mila 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import json 10 | import random 11 | import shutil 12 | from termcolor import colored 13 | import time 14 | from tqdm import tqdm 15 | import sys 16 | import os 17 | sys.path.insert(0, os.path.abspath('.')) 18 | 19 | import torch 20 | from torch.autograd import Variable 21 | import torch.nn.functional as F 22 | import torchvision 23 | import numpy as np 24 | import h5py 25 | from scipy.misc import imread, imresize, imsave 26 | 27 | import vr.utils as utils 28 | import vr.programs 29 | from vr.data import ClevrDataset, ClevrDataLoader 30 | from vr.preprocess import tokenize, encode 31 | from vr.models import * 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--program_generator', default=None) 36 | parser.add_argument('--execution_engine', default=None) 37 | parser.add_argument('--baseline_model', default=None) 38 | parser.add_argument('--debug_every', default=float('inf'), type=float) 39 | parser.add_argument('--use_gpu', default=torch.cuda.is_available(), type=int) 40 | 41 | # For running on a preprocessed dataset 42 | parser.add_argument('--data_dir', default=None, type=str) 43 | parser.add_argument('--part', default='val', type=str) 44 | 45 | # This will override the vocab stored in the checkpoint; 46 | # we need this to run CLEVR models on human data 47 | parser.add_argument('--vocab_json', default=None) 48 | 49 | # For running on a single example 50 | parser.add_argument('--question', default=None) 51 | parser.add_argument('--image', default='img/CLEVR_val_000017.png') 52 | parser.add_argument('--cnn_model', default='resnet101') 53 | parser.add_argument('--cnn_model_stage', default=3, type=int) 54 | parser.add_argument('--image_width', default=224, type=int) 55 | parser.add_argument('--image_height', default=224, type=int) 56 | parser.add_argument('--enforce_clevr_vocab', default=1, type=int) 57 | 58 | parser.add_argument('--batch_size', default=64, type=int) 59 | parser.add_argument('--num_samples', default=None, type=int) 60 | parser.add_argument('--num_last_words_shuffled', default=0, type=int) # -1 for all shuffled 61 | parser.add_argument('--family_split_file', default=None) 62 | 63 | parser.add_argument('--sample_argmax', type=int, default=1) 64 | parser.add_argument('--temperature', default=1.0, type=float) 65 | 66 | # FiLM models only 67 | parser.add_argument('--gamma_option', default='linear', 68 | choices=['linear', 'sigmoid', 'tanh', 'exp', 'relu', 'softplus']) 69 | parser.add_argument('--gamma_scale', default=1, type=float) 70 | parser.add_argument('--gamma_shift', default=0, type=float) 71 | parser.add_argument('--gammas_from', default=None) # Load gammas from file 72 | parser.add_argument('--beta_option', default='linear', 73 | choices=['linear', 'sigmoid', 'tanh', 'exp', 'relu', 'softplus']) 74 | parser.add_argument('--beta_scale', default=1, type=float) 75 | parser.add_argument('--beta_shift', default=0, type=float) 76 | parser.add_argument('--betas_from', default=None) # Load betas from file 77 | 78 | # If this is passed, then save all predictions to this file 79 | parser.add_argument('--output_h5', default=None) 80 | parser.add_argument('--output_preds', default=None) 81 | parser.add_argument('--output_viz_dir', default='img/') 82 | parser.add_argument('--output_program_stats_dir', default=None) 83 | 84 | grads = {} 85 | programs = {} # NOTE: Useful for zero-shot program manipulation when in debug mode 86 | 87 | def main(args): 88 | if not args.program_generator: 89 | args.program_generator = args.execution_engine 90 | input_question_h5 = os.path.join(args.data_dir, '{}_questions.h5'.format(args.part)) 91 | input_features_h5 = os.path.join(args.data_dir, '{}_features.h5'.format(args.part)) 92 | 93 | model = None 94 | if args.baseline_model is not None: 95 | print('Loading baseline model from ', args.baseline_model) 96 | model, _ = utils.load_baseline(args.baseline_model) 97 | if args.vocab_json is not None: 98 | new_vocab = utils.load_vocab(args.vocab_json) 99 | model.rnn.expand_vocab(new_vocab['question_token_to_idx']) 100 | elif args.program_generator is not None and args.execution_engine is not None: 101 | pg, _ = utils.load_program_generator(args.program_generator) 102 | ee, _ = utils.load_execution_engine( 103 | args.execution_engine, verbose=False) 104 | if args.vocab_json is not None: 105 | new_vocab = utils.load_vocab(args.vocab_json) 106 | pg.expand_encoder_vocab(new_vocab['question_token_to_idx']) 107 | model = (pg, ee) 108 | else: 109 | print('Must give either --baseline_model or --program_generator and --execution_engine') 110 | return 111 | 112 | dtype = torch.FloatTensor 113 | if args.use_gpu == 1: 114 | dtype = torch.cuda.FloatTensor 115 | if args.question is not None and args.image is not None: 116 | run_single_example(args, model, dtype, args.question) 117 | else: 118 | vocab = load_vocab(args) 119 | loader_kwargs = { 120 | 'question_h5': input_question_h5, 121 | 'feature_h5': input_features_h5, 122 | 'vocab': vocab, 123 | 'batch_size': args.batch_size, 124 | } 125 | if args.num_samples is not None and args.num_samples > 0: 126 | loader_kwargs['max_samples'] = args.num_samples 127 | if args.family_split_file is not None: 128 | with open(args.family_split_file, 'r') as f: 129 | loader_kwargs['question_families'] = json.load(f) 130 | with ClevrDataLoader(**loader_kwargs) as loader: 131 | run_batch(args, model, dtype, loader) 132 | 133 | 134 | def extract_image_features(args, dtype): 135 | # Build the CNN to use for feature extraction 136 | print('Extracting image features...') 137 | cnn = build_cnn(args, dtype) 138 | 139 | # Load and preprocess the image 140 | img_size = (args.image_height, args.image_width) 141 | img = imread(args.image, mode='RGB') 142 | img = imresize(img, img_size, interp='bicubic') 143 | img = img.transpose(2, 0, 1)[None] 144 | mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) 145 | std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) 146 | img = (img.astype(np.float32) / 255.0 - mean) / std 147 | 148 | # Use CNN to extract features for the image 149 | img_var = Variable(torch.FloatTensor(img).type(dtype), volatile=False, requires_grad=True) 150 | feats_var = cnn(img_var) 151 | return feats_var 152 | 153 | 154 | def run_our_model_batch(args, pg, ee, loader, dtype): 155 | if pg: 156 | pg.type(dtype) 157 | pg.eval() 158 | ee.type(dtype) 159 | ee.eval() 160 | 161 | all_scores = [] 162 | all_programs = [] 163 | all_correct = [] 164 | all_probs = [] 165 | all_preds = [] 166 | all_film_scores = [] 167 | all_read_scores = [] 168 | all_control_scores = [] 169 | all_connections = [] 170 | all_vib_costs = [] 171 | num_correct, num_samples = 0, 0 172 | 173 | q_types = [] 174 | 175 | start = time.time() 176 | for batch in tqdm(loader): 177 | assert(not pg or not pg.training) 178 | assert(not ee.training) 179 | questions, images, feats, answers, programs, program_lists = batch 180 | 181 | if isinstance(questions, list): 182 | questions_var = questions[0].type(dtype).long() 183 | q_types += [questions[1].cpu().numpy()] 184 | else: 185 | questions_var = questions.type(dtype).long() 186 | feats_var = feats.type(dtype) 187 | if pg: 188 | programs_pred = pg(questions_var) 189 | else: 190 | programs_pred = programs 191 | 192 | kwargs = ({'save_activations': True} 193 | if isinstance(ee, (FiLMedNet, ModuleNet, MAC)) 194 | else {}) 195 | pos_args = [feats_var] 196 | if isinstance(ee, SHNMN): 197 | pos_args.append(questions_var) 198 | else: 199 | pos_args.append(programs_pred) 200 | scores = ee(*pos_args, **kwargs) 201 | probs = F.softmax(scores, dim=1) 202 | 203 | #loss = torch.nn.CrossEntropyLoss()(scores, answers.cuda()) 204 | #loss.backward() 205 | 206 | #for i, output in enumerate(ee.stem.outputs): 207 | # print('module_{}:'.format(i), output.mean().item(), 208 | # ((output ** 2).mean() ** 0.5).item(), 209 | # output.min().item(), 210 | # output.max().item()) 211 | 212 | _, preds = scores.data.cpu().max(1) 213 | # all_programs.append(programs_pred.data.cpu().clone()) 214 | all_scores.append(scores.data.cpu().clone()) 215 | all_probs.append(probs.data.cpu().clone()) 216 | all_preds.append(preds.cpu().clone()) 217 | all_correct.append(preds == answers) 218 | if isinstance(pg, FiLMGen) and pg.scores is not None: 219 | all_film_scores.append(pg.scores.data.cpu().clone()) 220 | if isinstance(ee, MAC): 221 | all_control_scores.append(ee.control_scores.data.cpu().clone()) 222 | all_read_scores.append(ee.read_scores.data.cpu().clone()) 223 | if hasattr(ee, 'vib_costs'): 224 | all_vib_costs.append(ee.vib_costs.data.cpu().clone()) 225 | if hasattr(ee, 'connections') and ee.connections: 226 | all_connections.append(torch.cat([conn.unsqueeze(1) for conn in ee.connections], 1).data.cpu().clone()) 227 | if answers[0] is not None: 228 | num_correct += (preds == answers).sum() 229 | num_samples += preds.size(0) 230 | 231 | acc = float(num_correct) / num_samples 232 | print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) 233 | print('%.2fs to evaluate' % (start - time.time())) 234 | if all_control_scores: 235 | max_len = max(cs.size(2) for cs in all_control_scores) 236 | for i in range(len(all_control_scores)): 237 | tmp = torch.zeros( 238 | (all_control_scores[i].size(0), all_control_scores[i].size(1), max_len)) 239 | tmp[:, :, :all_control_scores[i].size(2)] = all_control_scores[i] 240 | all_control_scores[i] = tmp 241 | 242 | output_path = ('output_' + args.execution_engine[:-3] + ".h5" 243 | if not args.output_h5 244 | else args.output_h5) 245 | 246 | print('Writing output to "%s"' % output_path) 247 | with h5py.File(output_path, 'w') as fout: 248 | fout.create_dataset('scores', data=torch.cat(all_scores, 0).numpy()) 249 | fout.create_dataset('probs', data=torch.cat(all_probs, 0).numpy()) 250 | fout.create_dataset('correct', data=torch.cat(all_correct, 0).numpy()) 251 | if all_film_scores: 252 | fout.create_dataset('film_scores', data=torch.cat(all_film_scores, 1).numpy()) 253 | if all_vib_costs: 254 | fout.create_dataset('vib_costs', data=torch.cat(all_vib_costs, 0).numpy()) 255 | if all_read_scores: 256 | fout.create_dataset('read_scores', data=torch.cat(all_read_scores, 0).numpy()) 257 | if all_control_scores: 258 | fout.create_dataset('control_scores', data=torch.cat(all_control_scores, 0).numpy()) 259 | if all_connections: 260 | fout.create_dataset('connections', data=torch.cat(all_connections, 0).numpy()) 261 | 262 | # Save FiLM param stats 263 | if args.output_program_stats_dir: 264 | if not os.path.isdir(args.output_program_stats_dir): 265 | os.mkdir(args.output_program_stats_dir) 266 | gammas = all_programs[:,:,:pg.module_dim] 267 | betas = all_programs[:,:,pg.module_dim:2*pg.module_dim] 268 | gamma_means = gammas.mean(0) 269 | torch.save(gamma_means, os.path.join(args.output_program_stats_dir, 'gamma_means')) 270 | beta_means = betas.mean(0) 271 | torch.save(beta_means, os.path.join(args.output_program_stats_dir, 'beta_means')) 272 | gamma_medians = gammas.median(0)[0] 273 | torch.save(gamma_medians, os.path.join(args.output_program_stats_dir, 'gamma_medians')) 274 | beta_medians = betas.median(0)[0] 275 | torch.save(beta_medians, os.path.join(args.output_program_stats_dir, 'beta_medians')) 276 | 277 | # Note: Takes O(10GB) space 278 | torch.save(gammas, os.path.join(args.output_program_stats_dir, 'gammas')) 279 | torch.save(betas, os.path.join(args.output_program_stats_dir, 'betas')) 280 | 281 | if args.output_preds is not None: 282 | vocab = load_vocab(args) 283 | all_preds_strings = [] 284 | for i in range(len(all_preds)): 285 | all_preds_strings.append(vocab['answer_idx_to_token'][all_preds[i]]) 286 | save_to_file(all_preds_strings, args.output_preds) 287 | 288 | if args.debug_every <= 1: 289 | pdb.set_trace() 290 | return 291 | 292 | 293 | def visualize(features, args, file_name=None): 294 | """ 295 | Converts a 4d map of features to alpha attention weights, 296 | According to their 2-Norm across dimensions 0 and 1. 297 | Then saves the input RGB image as an RGBA image using an upsampling of this attention map. 298 | """ 299 | save_file = os.path.join(args.viz_dir, file_name) 300 | img_path = args.image 301 | 302 | # Scale map to [0, 1] 303 | f_map = (features ** 2).mean(0, keepdim=True).mean(1, keepdim=True).squeeze().sqrt() 304 | f_map_shifted = f_map - f_map.min().expand_as(f_map) 305 | f_map_scaled = f_map_shifted / f_map_shifted.max().expand_as(f_map_shifted) 306 | 307 | if save_file is None: 308 | print(f_map_scaled) 309 | else: 310 | # Read original image 311 | img = imread(img_path, mode='RGB') 312 | orig_img_size = img.shape 313 | 314 | # Convert to image format 315 | alpha = (255 * f_map_scaled).round() 316 | alpha4d = alpha.unsqueeze(0).unsqueeze(0) 317 | alpha_upsampled = torch.nn.functional.upsample_bilinear( 318 | alpha4d, size=torch.Size(orig_img_size)).squeeze(0).transpose(1, 0).transpose(1, 2) 319 | alpha_upsampled_np = alpha_upsampled.cpu().data.numpy() 320 | 321 | # Create and save visualization 322 | imga = np.concatenate([img, alpha_upsampled_np], axis=2) 323 | if save_file[-4:] != '.png': save_file += '.png' 324 | imsave(save_file, imga) 325 | 326 | return f_map_scaled 327 | 328 | 329 | def build_cnn(args, dtype): 330 | if not hasattr(torchvision.models, args.cnn_model): 331 | raise ValueError('Invalid model "%s"' % args.cnn_model) 332 | if not 'resnet' in args.cnn_model: 333 | raise ValueError('Feature extraction only supports ResNets') 334 | whole_cnn = getattr(torchvision.models, args.cnn_model)(pretrained=True) 335 | layers = [ 336 | whole_cnn.conv1, 337 | whole_cnn.bn1, 338 | whole_cnn.relu, 339 | whole_cnn.maxpool, 340 | ] 341 | for i in range(args.cnn_model_stage): 342 | name = 'layer%d' % (i + 1) 343 | layers.append(getattr(whole_cnn, name)) 344 | cnn = torch.nn.Sequential(*layers) 345 | cnn.type(dtype) 346 | cnn.eval() 347 | return cnn 348 | 349 | 350 | def run_batch(args, model, dtype, loader): 351 | if type(model) is tuple: 352 | pg, ee = model 353 | run_our_model_batch(args, pg, ee, loader, dtype) 354 | else: 355 | run_baseline_batch(args, model, loader, dtype) 356 | 357 | 358 | def run_baseline_batch(args, model, loader, dtype): 359 | model.type(dtype) 360 | model.eval() 361 | 362 | all_scores, all_probs = [], [] 363 | num_correct, num_samples = 0, 0 364 | for batch in loader: 365 | questions, images, feats, answers, programs, program_lists = batch 366 | 367 | questions_var = Variable(questions.type(dtype).long(), volatile=True) 368 | feats_var = Variable(feats.type(dtype), volatile=True) 369 | scores = model(questions_var, feats_var) 370 | probs = F.softmax(scores) 371 | 372 | _, preds = scores.data.cpu().max(1) 373 | all_scores.append(scores.data.cpu().clone()) 374 | all_probs.append(probs.data.cpu().clone()) 375 | 376 | num_correct += (preds == answers).sum() 377 | num_samples += preds.size(0) 378 | print('Ran %d samples' % num_samples) 379 | 380 | acc = float(num_correct) / num_samples 381 | print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) 382 | 383 | all_scores = torch.cat(all_scores, 0) 384 | all_probs = torch.cat(all_probs, 0) 385 | if args.output_h5 is not None: 386 | print('Writing output to %s' % args.output_h5) 387 | with h5py.File(args.output_h5, 'w') as fout: 388 | fout.create_dataset('scores', data=all_scores.numpy()) 389 | fout.create_dataset('probs', data=all_probs.numpy()) 390 | 391 | 392 | def load_vocab(args): 393 | path = None 394 | if args.baseline_model is not None: 395 | path = args.baseline_model 396 | elif args.program_generator is not None: 397 | path = args.program_generator 398 | elif args.execution_engine is not None: 399 | path = args.execution_engine 400 | return utils.load_cpu(path)['vocab'] 401 | 402 | 403 | def save_grad(name): 404 | def hook(grad): 405 | grads[name] = grad 406 | return hook 407 | 408 | 409 | def save_to_file(text, filename): 410 | with open(filename, mode='wt', encoding='utf-8') as myfile: 411 | myfile.write('\n'.join(text)) 412 | myfile.write('\n') 413 | 414 | 415 | def get_index(l, index, default=-1): 416 | try: 417 | return l.index(index) 418 | except ValueError: 419 | return default 420 | 421 | 422 | if __name__ == '__main__': 423 | args = parser.parse_args() 424 | main(args) 425 | -------------------------------------------------------------------------------- /scripts/train/convlstm_flatqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $NMN/scripts/train_model.py \ 4 | --feature_dim 3,64,64 \ 5 | --num_val_samples 1000 \ 6 | --checkpoint_every 1000 \ 7 | --record_loss_every 10 \ 8 | \ 9 | --model_type ConvLSTM \ 10 | --num_iterations 100000 \ 11 | \ 12 | --optimizer Adam \ 13 | --batch_size 128 \ 14 | --learning_rate 1e-4 \ 15 | \ 16 | --rnn_num_layers 1 \ 17 | --rnn_wordvec_dim 64 \ 18 | --rnn_hidden_dim 128 \ 19 | \ 20 | --module_stem_num_layers 6 \ 21 | --module_stem_subsample_layers 1,3 \ 22 | --module_stem_batchnorm 1 \ 23 | --stem_dim 64 \ 24 | \ 25 | --module_dim 64 \ 26 | \ 27 | --classifier_fc_dims 1024 \ 28 | --classifier_downsample none \ 29 | $@ 30 | #--module_stem_num_layers 8 \ 31 | #--module_stem_subsample_layers 1,3,5 \ 32 | 33 | -------------------------------------------------------------------------------- /scripts/train/ee_flatqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $NMN/scripts/train_model.py \ 4 | --model_type EE \ 5 | --feature_dim=3,64,64 \ 6 | --num_iterations=50000 \ 7 | --checkpoint_every 1000 \ 8 | --record_loss_every 10 \ 9 | --num_val_samples 1000 \ 10 | --optimizer Adam \ 11 | --learning_rate 1e-4 \ 12 | --use_coords 1 \ 13 | --module_stem_batchnorm 1 \ 14 | --module_stem_num_layers 6 \ 15 | --module_stem_subsample_layers 1,3 \ 16 | --module_intermediate_batchnorm 0 \ 17 | --module_batchnorm 0 \ 18 | --module_dim 64 \ 19 | --classifier_batchnorm 1 \ 20 | --classifier_downsample maxpoolfull \ 21 | --classifier_proj_dim 512 \ 22 | --program_generator_parameter_efficient 1 $@ 23 | -------------------------------------------------------------------------------- /scripts/train/film_flatqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $NMN/scripts/train_model.py \ 4 | --model_type FiLM \ 5 | --num_iterations 50000 \ 6 | --feature_dim=3,64,64 \ 7 | --checkpoint_every 1000 \ 8 | --record_loss_every 10 \ 9 | --num_val_samples 1000 \ 10 | --optimizer Adam \ 11 | --learning_rate 3e-4 \ 12 | --batch_size 64 \ 13 | --use_coords 1 \ 14 | --module_stem_batchnorm 1 \ 15 | --module_stem_num_layers 6 \ 16 | --module_stem_subsample_layers 1,3\ 17 | --module_batchnorm 1 \ 18 | --classifier_batchnorm 1 \ 19 | --bidirectional 0 \ 20 | --decoder_type linear \ 21 | --encoder_type gru \ 22 | --rnn_num_layers 1 \ 23 | --rnn_wordvec_dim 200 \ 24 | --rnn_hidden_dim 1024 `#was 4096 in original FiLM` \ 25 | --rnn_output_batchnorm 0 \ 26 | --classifier_downsample maxpoolfull \ 27 | --classifier_proj_dim 512 \ 28 | --classifier_fc_dims 1024 \ 29 | --module_input_proj 1 \ 30 | --module_residual 1 \ 31 | --module_dim 64 `#was 128 in original FiLM`\ 32 | --module_dropout 0e-2 \ 33 | --module_stem_kernel_size 3 \ 34 | --module_kernel_size 3 \ 35 | --module_batchnorm_affine 0 \ 36 | --module_num_layers 1 \ 37 | --num_modules 4 \ 38 | --condition_pattern 1,1,1,1 \ 39 | --gamma_option linear \ 40 | --gamma_baseline 1 \ 41 | --use_gamma 1 \ 42 | --use_beta 1 \ 43 | --condition_method bn-film \ 44 | --program_generator_parameter_efficient 1 $@ 45 | -------------------------------------------------------------------------------- /scripts/train/mac_flatqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $NMN/scripts/train_model.py \ 4 | --feature_dim=3,64,64 \ 5 | --model_type MAC \ 6 | --num_iterations 50000 \ 7 | --checkpoint_every 1000 \ 8 | --record_loss_every 10 \ 9 | --num_val_samples 1000 \ 10 | --optimizer Adam \ 11 | --learning_rate 1e-4 \ 12 | --batch_size 128 \ 13 | --use_coords 1 \ 14 | --module_stem_batchnorm 1 \ 15 | --module_stem_num_layers 6 \ 16 | --module_stem_subsample_layers 1,3 \ 17 | --module_stem_kernel_size 3 \ 18 | --mac_question_embedding_dropout 0. \ 19 | --mac_stem_dropout 0. \ 20 | --mac_memory_dropout 0. \ 21 | --mac_read_dropout 0. \ 22 | --mac_use_prior_control_in_control_unit 0 \ 23 | --mac_embedding_uniform_boundary 1.0 \ 24 | --mac_nonlinearity ReLU \ 25 | --variational_embedding_dropout 0. \ 26 | --module_dim 128 \ 27 | --num_modules 12 \ 28 | --mac_use_self_attention 0 \ 29 | --mac_use_memory_gate 0 \ 30 | --bidirectional 1 \ 31 | --encoder_type lstm \ 32 | --rnn_num_layers 1 \ 33 | --rnn_wordvec_dim 300 \ 34 | --rnn_hidden_dim 128 \ 35 | --rnn_dropout 0 \ 36 | --rnn_output_batchnorm 0 \ 37 | --classifier_fc_dims 1024 \ 38 | --classifier_batchnorm 0 \ 39 | --classifier_dropout 0. \ 40 | --use_local_copies 0 \ 41 | --grad_clip 8. \ 42 | --program_generator_parameter_efficient 1 $@ 43 | -------------------------------------------------------------------------------- /scripts/train/rel_flatqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python $NMN/scripts/train_model.py \ 3 | --feature_dim 3,64,64 \ 4 | --checkpoint_every 1000 \ 5 | --record_loss_every 10 \ 6 | --num_val_samples 1000 \ 7 | \ 8 | --model_type RelNet \ 9 | --num_iterations 200000 \ 10 | \ 11 | --optimizer Adam \ 12 | --learning_rate 1e-4 \ 13 | --batch_size 64 \ 14 | \ 15 | --rnn_num_layers 1 \ 16 | --rnn_wordvec_dim 64 \ 17 | --rnn_hidden_dim 128 \ 18 | \ 19 | --module_stem_num_layers 8 \ 20 | --module_stem_subsample_layers 1,3,5 \ 21 | --module_stem_batchnorm 1 \ 22 | --stem_dim 64 \ 23 | \ 24 | --module_dim 256 \ 25 | --module_num_layers 4 \ 26 | --module_dropout 0 \ 27 | \ 28 | --classifier_fc_dims 1024 \ 29 | --classifier_dropout 0 \ 30 | --classifier_batchnorm 0 \ 31 | --classifier_downsample none \ 32 | --classifier_proj_dim 0 \ 33 | $@ 34 | 35 | #--module_stem_num_layers 4 \ 36 | #--module_stem_stride 2 \ 37 | 38 | #RelNet CLEVR 39 | #--learning_rate 2.5e-4 \ 40 | #--module_stem_num_layers 4 \ 41 | #--module_stem_batchnorm 1 \ 42 | #--module_stem_kernel_size 3 \ 43 | #--module_stem_stride 2 \ 44 | #--module_batchnorm 0 \ 45 | #--module_dim 256 \ 46 | #--module_num_layers 4 \ 47 | #--module_dropout 0,0.5,0 \ 48 | #--classifier_fc_dims 256,256,29 \ 49 | #--rnn_hidden_dim 128 `#was 4096 in original FiLM` \ 50 | #--rnn_wordvec_dim 32 \ 51 | 52 | #RelNet SortOf-CLEVR 53 | #--learning_rate 1e-4 \ 54 | #--module_stem_num_layers 4 \ 55 | #--module_stem_batchnorm 1 \ 56 | #--module_stem_kernel_size 3 \ 57 | #--module_stem_stride 2 \ 58 | #--module_batchnorm 0 \ 59 | #--module_dim 2000 \ 60 | #--module_num_layers 4 \ 61 | #--module_dropout 0 \ 62 | #--classifier_fc_dims 2000,1000,500,100 \ 63 | #--classifier_dropout 0 \ 64 | -------------------------------------------------------------------------------- /scripts/train/shnmn_flatqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python $NMN/scripts/train_model.py \ 4 | --model_type SHNMN \ 5 | --feature_dim=3,64,64 \ 6 | --num_iterations=50000 \ 7 | --checkpoint_every 1000 \ 8 | --record_loss_every 10 \ 9 | --num_val_samples 1000 \ 10 | --optimizer Adam \ 11 | --learning_rate 1e-4 \ 12 | --use_coords 1 \ 13 | --module_stem_batchnorm 1 \ 14 | --module_stem_num_layers 6 \ 15 | --module_stem_subsample_layers 1,3 \ 16 | --module_intermediate_batchnorm 0 \ 17 | --module_batchnorm 0 \ 18 | --module_dim 64 \ 19 | --classifier_batchnorm 1 \ 20 | --classifier_downsample maxpoolfull \ 21 | --classifier_proj_dim 512 \ 22 | --program_generator_parameter_efficient 1 $@ 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="nmn-iwp", 5 | version="0.1", 6 | keywords="", 7 | packages=["vr", "vr.models"] 8 | ) 9 | -------------------------------------------------------------------------------- /vr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, Mila 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /vr/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import numpy as np 11 | import PIL.Image 12 | import h5py 13 | import io 14 | import torch 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.utils.data.dataloader import default_collate 17 | import random, math 18 | import vr.programs 19 | from vr.programs import ProgramConverter 20 | 21 | 22 | def _dataset_to_tensor(dset, mask=None, dtype=None): 23 | arr = np.asarray(dset, dtype=np.int64 if dtype is None else dtype) 24 | if mask is not None: 25 | arr = arr[mask] 26 | tensor = torch.LongTensor(arr) 27 | return tensor 28 | 29 | def _gen_subsample_mask(num, percent=1.0): 30 | chosen_num = math.floor(num * percent) 31 | mask = np.full((num,), False) 32 | selected_ids = np.asarray(random.sample(range(num), chosen_num), dtype='int32') 33 | mask[selected_ids] = True 34 | return mask 35 | 36 | 37 | class ClevrDataset(Dataset): 38 | def __init__(self, question_h5, feature_h5_path, vocab, mode='prefix', 39 | image_h5=None, load_features=False, max_samples=None, question_families=None, 40 | image_idx_start_from=None, percent_of_data=1.0): 41 | mode_choices = ['prefix', 'postfix'] 42 | if mode not in mode_choices: 43 | raise ValueError('Invalid mode "%s"' % mode) 44 | self.image_h5 = image_h5 45 | self.vocab = vocab 46 | self.program_converter = ProgramConverter(vocab) 47 | self.feature_h5_path = feature_h5_path 48 | self.feature_h5 = None 49 | self.all_features = None 50 | self.load_features = load_features 51 | self.mode = mode 52 | self.max_samples = max_samples 53 | 54 | # Compute the mask 55 | mask = None 56 | if question_families is not None: 57 | # Use only the specified families 58 | all_families = np.asarray(question_h5['question_families']) 59 | N = all_families.shape[0] 60 | print(question_families) 61 | target_families = np.asarray(question_families)[:, None] 62 | mask = (all_families == target_families).any(axis=0) 63 | if image_idx_start_from is not None: 64 | all_image_idxs = np.asarray(question_h5['image_idxs']) 65 | mask = all_image_idxs >= image_idx_start_from 66 | if percent_of_data < 1.0: 67 | num_example = np.asarray(question_h5['image_idxs']).shape[0] 68 | mask = _gen_subsample_mask(num_example, percent_of_data) 69 | self.mask = mask 70 | 71 | # Data from the question file is small, so read it all into memory 72 | print('Reading question data into memory') 73 | self.all_types = None 74 | if 'types' in question_h5: 75 | self.all_types = _dataset_to_tensor(question_h5['types'], mask) 76 | self.all_question_families = None 77 | if 'question_families' in question_h5: 78 | self.all_question_families = _dataset_to_tensor(question_h5['question_families'], mask) 79 | self.all_questions = _dataset_to_tensor(question_h5['questions'], mask) 80 | self.all_image_idxs = _dataset_to_tensor(question_h5['image_idxs'], mask) 81 | self.all_programs = None 82 | if 'programs' in question_h5: 83 | self.all_programs = _dataset_to_tensor(question_h5['programs'], mask) 84 | self.all_answers = None 85 | if 'answers' in question_h5: 86 | self.all_answers = _dataset_to_tensor(question_h5['answers'], mask) 87 | 88 | def __getitem__(self, index): 89 | # Open the feature or load them if requested 90 | if not self.feature_h5: 91 | self.feature_h5 = h5py.File(self.feature_h5_path, 'r') 92 | if self.load_features: 93 | self.features = self.feature_h5['features'].value 94 | 95 | if self.all_question_families is not None: 96 | question_family = self.all_question_families[index] 97 | q_type = None if self.all_types is None else self.all_types[index] 98 | question = self.all_questions[index] 99 | image_idx = self.all_image_idxs[index] 100 | answer = None 101 | if self.all_answers is not None: 102 | answer = self.all_answers[index] 103 | program_seq = None 104 | if self.all_programs is not None: 105 | program_seq = self.all_programs[index] 106 | 107 | image = None 108 | if self.image_h5 is not None: 109 | image = self.image_h5['images'][image_idx] 110 | image = torch.FloatTensor(np.asarray(image, dtype=np.float32)) 111 | 112 | if self.load_features: 113 | feats = self.features[image_idx] 114 | else: 115 | feats = self.feature_h5['features'][image_idx] 116 | if feats.ndim == 1: 117 | feats = np.array(PIL.Image.open(io.BytesIO(feats))).transpose(2, 0, 1) / 255.0 118 | feats = torch.FloatTensor(np.asarray(feats, dtype=np.float32)) 119 | 120 | program_json = None 121 | if program_seq is not None: 122 | program_json_seq = [] 123 | for fn_idx in program_seq: 124 | fn_str = self.vocab['program_idx_to_token'][fn_idx.item()] 125 | if fn_str == '' or fn_str == '': 126 | continue 127 | fn = vr.programs.str_to_function(fn_str) 128 | program_json_seq.append(fn) 129 | if self.mode == 'prefix': 130 | program_json = self.program_converter.prefix_to_list(program_json_seq) 131 | elif self.mode == 'postfix': 132 | program_json = self.program_converter.to_list(program_json_seq) 133 | 134 | if q_type is None: 135 | return (question, image, feats, answer, program_seq, program_json) 136 | return ([question, q_type], image, feats, answer, program_seq, program_json) 137 | 138 | def __len__(self): 139 | if self.max_samples is None: 140 | return self.all_questions.size(0) 141 | else: 142 | return min(self.max_samples, self.all_questions.size(0)) 143 | 144 | 145 | class ClevrDataLoader(DataLoader): 146 | def __init__(self, **kwargs): 147 | if 'question_h5' not in kwargs: 148 | raise ValueError('Must give question_h5') 149 | if 'feature_h5' not in kwargs: 150 | raise ValueError('Must give feature_h5') 151 | if 'vocab' not in kwargs: 152 | raise ValueError('Must give vocab') 153 | 154 | feature_h5_path = kwargs.pop('feature_h5') 155 | print('Reading features from ', feature_h5_path) 156 | 157 | self.image_h5 = None 158 | if 'image_h5' in kwargs: 159 | image_h5_path = kwargs.pop('image_h5') 160 | print('Reading images from ', image_h5_path) 161 | self.image_h5 = h5py.File(image_h5_path, 'r') 162 | 163 | vocab = kwargs.pop('vocab') 164 | mode = kwargs.pop('mode', 'prefix') 165 | load_features = kwargs.pop('load_features', False) 166 | percent_of_data = kwargs.pop('percent_of_data', 1.) 167 | question_families = kwargs.pop('question_families', None) 168 | max_samples = kwargs.pop('max_samples', None) 169 | question_h5_path = kwargs.pop('question_h5') 170 | image_idx_start_from = kwargs.pop('image_idx_start_from', None) 171 | print('Reading questions from ', question_h5_path) 172 | with h5py.File(question_h5_path, 'r') as question_h5: 173 | self.dataset = ClevrDataset(question_h5, feature_h5_path, vocab, mode, 174 | image_h5=self.image_h5, 175 | load_features=load_features, 176 | max_samples=max_samples, 177 | question_families=question_families, 178 | image_idx_start_from=image_idx_start_from, 179 | percent_of_data=percent_of_data) 180 | kwargs['collate_fn'] = clevr_collate 181 | super(ClevrDataLoader, self).__init__(self.dataset, **kwargs) 182 | 183 | def close(self): 184 | if self.image_h5 is not None: 185 | self.image_h5.close() 186 | 187 | def __enter__(self): 188 | return self 189 | 190 | def __exit__(self, exc_type, exc_value, traceback): 191 | self.close() 192 | 193 | 194 | def clevr_collate(batch): 195 | transposed = list(zip(*batch)) 196 | question_batch = default_collate(transposed[0]) 197 | 198 | image_batch = transposed[1] 199 | if all(img is not None for img in image_batch): 200 | image_batch = default_collate(image_batch) 201 | 202 | feat_batch = transposed[2] 203 | if all(f is not None for f in feat_batch): 204 | feat_batch = default_collate(feat_batch) 205 | 206 | answer_batch = transposed[3] 207 | if transposed[3][0] is not None: 208 | answer_batch = default_collate(answer_batch) 209 | 210 | program_seq_batch = transposed[4] 211 | if transposed[4][0] is not None: 212 | program_seq_batch = default_collate(program_seq_batch) 213 | 214 | program_struct_batch = transposed[5] 215 | 216 | return [question_batch, image_batch, feat_batch, answer_batch, 217 | program_seq_batch, program_struct_batch] 218 | -------------------------------------------------------------------------------- /vr/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, Mila 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for dealing with embeddings. 10 | """ 11 | 12 | 13 | def convert_pretrained_wordvecs(vocab, word2vec): 14 | N = len(vocab['question_idx_to_token']) 15 | D = word2vec['vecs'].size(1) 16 | embed = torch.nn.Embedding(N, D) 17 | print(type(embed.weight)) 18 | word2vec_word_to_idx = {w: i for i, w in enumerate(word2vec['words'])} 19 | print(type(word2vec['vecs'])) 20 | for idx, word in vocab['question_idx_to_token'].items(): 21 | word2vec_idx = word2vec_word_to_idx.get(word, None) 22 | if word2vec_idx is not None: 23 | embed.weight.data[idx] = word2vec['vecs'][word2vec_idx] 24 | return embed 25 | 26 | 27 | def expand_embedding_vocab(embed, token_to_idx, word2vec=None, std=0.01): 28 | old_weight = embed.weight.data 29 | old_N, D = old_weight.size() 30 | new_N = 1 + max(idx for idx in token_to_idx.values()) 31 | new_weight = old_weight.new(new_N, D).normal_().mul_(std) 32 | new_weight[:old_N].copy_(old_weight) 33 | 34 | if word2vec is not None: 35 | num_found = 0 36 | assert D == word2vec['vecs'].size(1), 'Word vector dimension mismatch' 37 | word2vec_token_to_idx = {w: i for i, w in enumerate(word2vec['words'])} 38 | for token, idx in token_to_idx.items(): 39 | word2vec_idx = word2vec_token_to_idx.get(token, None) 40 | if idx >= old_N and word2vec_idx is not None: 41 | vec = word2vec['vecs'][word2vec_idx] 42 | new_weight[idx].copy_(vec) 43 | num_found += 1 44 | embed.num_embeddings = new_N 45 | embed.weight.data = new_weight 46 | return embed 47 | -------------------------------------------------------------------------------- /vr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | from vr.models.module_net import ModuleNet 11 | from vr.models.simple_module_net import SimpleModuleNet, forward_chain1, forward_chain2, forward_chain3 12 | from vr.models.shnmn import SHNMN 13 | from vr.models.hetero_net import HeteroModuleNet 14 | from vr.models.filmed_net import FiLMedNet 15 | from vr.models.seq2seq import Seq2Seq 16 | from vr.models.seq2seq_att import Seq2SeqAtt 17 | from vr.models.film_gen import FiLMGen 18 | from vr.models.maced_net import MAC 19 | from vr.models.baselines import LstmModel, CnnLstmModel, CnnLstmSaModel 20 | from vr.models.relation_net import RelationNet 21 | from vr.models.convlstm import ConvLSTM 22 | -------------------------------------------------------------------------------- /vr/models/baselines.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | 15 | from vr.models.layers import init_modules, ResidualBlock 16 | from vr.embedding import expand_embedding_vocab 17 | 18 | 19 | class StackedAttention(nn.Module): 20 | def __init__(self, input_dim, hidden_dim): 21 | super(StackedAttention, self).__init__() 22 | self.Wv = nn.Conv2d(input_dim, hidden_dim, kernel_size=1, padding=0) 23 | self.Wu = nn.Linear(input_dim, hidden_dim) 24 | self.Wp = nn.Conv2d(hidden_dim, 1, kernel_size=1, padding=0) 25 | self.hidden_dim = hidden_dim 26 | self.attention_maps = None 27 | init_modules(self.modules(), init='normal') 28 | 29 | def forward(self, v, u): 30 | """ 31 | Input: 32 | - v: N x D x H x W 33 | - u: N x D 34 | 35 | Returns: 36 | - next_u: N x D 37 | """ 38 | N, K = v.size(0), self.hidden_dim 39 | D, H, W = v.size(1), v.size(2), v.size(3) 40 | v_proj = self.Wv(v) # N x K x H x W 41 | u_proj = self.Wu(u) # N x K 42 | u_proj_expand = u_proj.view(N, K, 1, 1).expand(N, K, H, W) 43 | h = F.tanh(v_proj + u_proj_expand) 44 | p = F.softmax(self.Wp(h).view(N, H * W)).view(N, 1, H, W) 45 | self.attention_maps = p.data.clone() 46 | 47 | v_tilde = (p.expand_as(v) * v).sum(3).sum(2).view(N, D) 48 | next_u = u + v_tilde 49 | return next_u 50 | 51 | 52 | class LstmEncoder(nn.Module): 53 | def __init__(self, token_to_idx, wordvec_dim=300, 54 | rnn_dim=256, rnn_num_layers=2, rnn_dropout=0): 55 | super(LstmEncoder, self).__init__() 56 | self.token_to_idx = token_to_idx 57 | self.NULL = token_to_idx[''] 58 | self.START = token_to_idx[''] 59 | self.END = token_to_idx[''] 60 | 61 | self.embed = nn.Embedding(len(token_to_idx), wordvec_dim) 62 | self.rnn = nn.LSTM(wordvec_dim, rnn_dim, rnn_num_layers, 63 | dropout=rnn_dropout, batch_first=True) 64 | 65 | def expand_vocab(self, token_to_idx, word2vec=None, std=0.01): 66 | expand_embedding_vocab(self.embed, token_to_idx, 67 | word2vec=word2vec, std=std) 68 | 69 | def forward(self, x): 70 | N, T = x.size() 71 | idx = torch.LongTensor(N).fill_(T - 1) 72 | 73 | # Find the last non-null element in each sequence 74 | x_cpu = x.data.cpu() 75 | for i in range(N): 76 | for t in range(T - 1): 77 | if x_cpu[i, t] != self.NULL and x_cpu[i, t + 1] == self.NULL: 78 | idx[i] = t 79 | break 80 | idx = idx.type_as(x.data).long() 81 | idx = Variable(idx, requires_grad=False) 82 | 83 | hs, _ = self.rnn(self.embed(x)) 84 | idx = idx.view(N, 1, 1).expand(N, 1, hs.size(2)) 85 | H = hs.size(2) 86 | return hs.gather(1, idx).view(N, H) 87 | 88 | 89 | def build_cnn(feat_dim=(1024, 14, 14), 90 | res_block_dim=128, 91 | num_res_blocks=0, 92 | proj_dim=512, 93 | pooling='maxpool2'): 94 | C, H, W = feat_dim 95 | layers = [] 96 | if num_res_blocks > 0: 97 | layers.append(nn.Conv2d(C, res_block_dim, kernel_size=3, padding=1)) 98 | layers.append(nn.ReLU(inplace=True)) 99 | C = res_block_dim 100 | for _ in range(num_res_blocks): 101 | layers.append(ResidualBlock(C)) 102 | if proj_dim > 0: 103 | layers.append(nn.Conv2d(C, proj_dim, kernel_size=1, padding=0)) 104 | layers.append(nn.ReLU(inplace=True)) 105 | C = proj_dim 106 | if pooling == 'maxpool2': 107 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 108 | H, W = H // 2, W // 2 109 | return nn.Sequential(*layers), (C, H, W) 110 | 111 | 112 | def build_mlp(input_dim, hidden_dims, output_dim, 113 | use_batchnorm=False, dropout=0): 114 | layers = [] 115 | D = input_dim 116 | if dropout > 0: 117 | layers.append(nn.Dropout(p=dropout)) 118 | if use_batchnorm: 119 | layers.append(nn.BatchNorm1d(input_dim)) 120 | for dim in hidden_dims: 121 | layers.append(nn.Linear(D, dim)) 122 | if use_batchnorm: 123 | layers.append(nn.BatchNorm1d(dim)) 124 | if dropout > 0: 125 | layers.append(nn.Dropout(p=dropout)) 126 | layers.append(nn.ReLU(inplace=True)) 127 | D = dim 128 | layers.append(nn.Linear(D, output_dim)) 129 | return nn.Sequential(*layers) 130 | 131 | 132 | class LstmModel(nn.Module): 133 | def __init__(self, vocab, 134 | rnn_wordvec_dim=300, rnn_dim=256, rnn_num_layers=2, rnn_dropout=0, 135 | fc_use_batchnorm=False, fc_dropout=0, fc_dims=(1024,)): 136 | super(LstmModel, self).__init__() 137 | rnn_kwargs = { 138 | 'token_to_idx': vocab['question_token_to_idx'], 139 | 'wordvec_dim': rnn_wordvec_dim, 140 | 'rnn_dim': rnn_dim, 141 | 'rnn_num_layers': rnn_num_layers, 142 | 'rnn_dropout': rnn_dropout, 143 | } 144 | self.rnn = LstmEncoder(**rnn_kwargs) 145 | 146 | classifier_kwargs = { 147 | 'input_dim': rnn_dim, 148 | 'hidden_dims': fc_dims, 149 | 'output_dim': len(vocab['answer_token_to_idx']), 150 | 'use_batchnorm': fc_use_batchnorm, 151 | 'dropout': fc_dropout, 152 | } 153 | self.classifier = build_mlp(**classifier_kwargs) 154 | 155 | def forward(self, questions, feats): 156 | q_feats = self.rnn(questions) 157 | scores = self.classifier(q_feats) 158 | return scores 159 | 160 | 161 | class CnnLstmModel(nn.Module): 162 | def __init__(self, vocab, 163 | rnn_wordvec_dim=300, rnn_dim=256, rnn_num_layers=2, rnn_dropout=0, 164 | cnn_feat_dim=(1024,14,14), 165 | cnn_res_block_dim=128, cnn_num_res_blocks=0, 166 | cnn_proj_dim=512, cnn_pooling='maxpool2', 167 | fc_dims=(1024,), fc_use_batchnorm=False, fc_dropout=0): 168 | super(CnnLstmModel, self).__init__() 169 | rnn_kwargs = { 170 | 'token_to_idx': vocab['question_token_to_idx'], 171 | 'wordvec_dim': rnn_wordvec_dim, 172 | 'rnn_dim': rnn_dim, 173 | 'rnn_num_layers': rnn_num_layers, 174 | 'rnn_dropout': rnn_dropout, 175 | } 176 | self.rnn = LstmEncoder(**rnn_kwargs) 177 | 178 | cnn_kwargs = { 179 | 'feat_dim': cnn_feat_dim, 180 | 'res_block_dim': cnn_res_block_dim, 181 | 'num_res_blocks': cnn_num_res_blocks, 182 | 'proj_dim': cnn_proj_dim, 183 | 'pooling': cnn_pooling, 184 | } 185 | self.cnn, (C, H, W) = build_cnn(**cnn_kwargs) 186 | 187 | classifier_kwargs = { 188 | 'input_dim': C * H * W + rnn_dim, 189 | 'hidden_dims': fc_dims, 190 | 'output_dim': len(vocab['answer_token_to_idx']), 191 | 'use_batchnorm': fc_use_batchnorm, 192 | 'dropout': fc_dropout, 193 | } 194 | self.classifier = build_mlp(**classifier_kwargs) 195 | 196 | def forward(self, questions, feats): 197 | N = questions.size(0) 198 | assert N == feats.size(0) 199 | q_feats = self.rnn(questions) 200 | img_feats = self.cnn(feats) 201 | cat_feats = torch.cat([q_feats, img_feats.view(N, -1)], 1) 202 | scores = self.classifier(cat_feats) 203 | return scores 204 | 205 | 206 | class CnnLstmSaModel(nn.Module): 207 | def __init__(self, vocab, 208 | rnn_wordvec_dim=300, rnn_dim=256, rnn_num_layers=2, rnn_dropout=0, 209 | cnn_feat_dim=(1024,14,14), 210 | stacked_attn_dim=512, num_stacked_attn=2, 211 | fc_use_batchnorm=False, fc_dropout=0, fc_dims=(1024,)): 212 | super(CnnLstmSaModel, self).__init__() 213 | rnn_kwargs = { 214 | 'token_to_idx': vocab['question_token_to_idx'], 215 | 'wordvec_dim': rnn_wordvec_dim, 216 | 'rnn_dim': rnn_dim, 217 | 'rnn_num_layers': rnn_num_layers, 218 | 'rnn_dropout': rnn_dropout, 219 | } 220 | self.rnn = LstmEncoder(**rnn_kwargs) 221 | 222 | C, H, W = cnn_feat_dim 223 | self.image_proj = nn.Conv2d(C, rnn_dim, kernel_size=1, padding=0) 224 | self.stacked_attns = [] 225 | for i in range(num_stacked_attn): 226 | sa = StackedAttention(rnn_dim, stacked_attn_dim) 227 | self.stacked_attns.append(sa) 228 | self.add_module('stacked-attn-%d' % i, sa) 229 | 230 | classifier_args = { 231 | 'input_dim': rnn_dim, 232 | 'hidden_dims': fc_dims, 233 | 'output_dim': len(vocab['answer_token_to_idx']), 234 | 'use_batchnorm': fc_use_batchnorm, 235 | 'dropout': fc_dropout, 236 | } 237 | self.classifier = build_mlp(**classifier_args) 238 | init_modules(self.modules(), init='normal') 239 | 240 | def forward(self, questions, feats): 241 | u = self.rnn(questions) 242 | v = self.image_proj(feats) 243 | 244 | for sa in self.stacked_attns: 245 | u = sa(v, u) 246 | 247 | scores = self.classifier(u) 248 | return scores 249 | -------------------------------------------------------------------------------- /vr/models/convlstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | from vr.models.layers import (build_classifier, 8 | build_stem, 9 | init_modules) 10 | 11 | 12 | class ConvLSTM(nn.Module): 13 | def __init__(self, 14 | vocab, 15 | feature_dim=[3, 64, 64], 16 | stem_dim=128, 17 | module_dim=128, 18 | stem_num_layers=2, 19 | stem_batchnorm=True, 20 | stem_kernel_size=3, 21 | stem_stride=1, 22 | stem_padding=None, 23 | stem_feature_dim=24, 24 | stem_subsample_layers=None, 25 | classifier_fc_layers=(1024,), 26 | classifier_batchnorm=False, 27 | classifier_dropout=0, 28 | rnn_hidden_dim=128, 29 | **kwargs): 30 | super().__init__() 31 | 32 | # initialize stem 33 | self.stem = build_stem(feature_dim[0], 34 | stem_dim, 35 | module_dim, 36 | num_layers=stem_num_layers, 37 | with_batchnorm=stem_batchnorm, 38 | kernel_size=stem_kernel_size, 39 | stride=stem_stride, 40 | padding=stem_padding, 41 | subsample_layers=stem_subsample_layers) 42 | tmp = self.stem(Variable(torch.zeros([1] + feature_dim))) 43 | _, F, H, W = tmp.size() 44 | 45 | # initialize classifier 46 | # TODO(mnoukhov): fix this for >1 layer RNN 47 | question_dim = rnn_hidden_dim 48 | image_dim = F*H*W 49 | num_answers = len(vocab['answer_idx_to_token']) 50 | self.classifier = build_classifier(image_dim + question_dim, 51 | 1, 52 | 1, 53 | num_answers, 54 | classifier_fc_layers, 55 | None, 56 | None, 57 | classifier_batchnorm, 58 | classifier_dropout) 59 | 60 | init_modules(self.modules()) 61 | 62 | def forward(self, image, question): 63 | # convert image to features 64 | img_feats = self.stem(image) # N x F x H x W 65 | img_feats = img_feats.view(img_feats.size(0), -1) # N x F*H*W 66 | 67 | # get hidden state from question 68 | _, q_feats, _ = question # N x Q 69 | 70 | # concatenate feats 71 | feats = torch.cat([img_feats, q_feats], dim=1) # N x F*H*W+Q 72 | 73 | # pass through classifier 74 | out = self.classifier(feats) 75 | 76 | return out 77 | -------------------------------------------------------------------------------- /vr/models/film_gen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.cuda 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 9 | 10 | from vr.embedding import expand_embedding_vocab 11 | from vr.models.layers import init_modules 12 | from torch.nn.init import uniform_, xavier_uniform_, constant_ 13 | 14 | 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class FiLMGen(nn.Module): 19 | def __init__(self, 20 | null_token=0, 21 | start_token=1, 22 | end_token=2, 23 | encoder_embed=None, 24 | encoder_vocab_size=100, 25 | decoder_vocab_size=100, 26 | wordvec_dim=200, 27 | hidden_dim=512, 28 | rnn_num_layers=1, 29 | rnn_dropout=0, 30 | output_batchnorm=False, 31 | bidirectional=False, 32 | encoder_type='gru', 33 | decoder_type='linear', 34 | gamma_option='linear', 35 | gamma_baseline=1, 36 | num_modules=4, 37 | module_num_layers=1, 38 | module_dim=128, 39 | parameter_efficient=False, 40 | debug_every=float('inf'), 41 | 42 | taking_context=False, 43 | variational_embedding_dropout=0., 44 | embedding_uniform_boundary=0., 45 | 46 | use_attention=False, 47 | ): 48 | super(FiLMGen, self).__init__() 49 | 50 | self.use_attention = use_attention 51 | 52 | self.taking_context = taking_context 53 | if self.use_attention: 54 | #if we want to use attention, the full context should be computed 55 | self.taking_context = True 56 | if self.taking_context: 57 | #if we want to use the full context, it makes sense to use bidirectional modeling. 58 | bidirectional = True 59 | 60 | self.encoder_type = encoder_type 61 | self.decoder_type = decoder_type 62 | self.output_batchnorm = output_batchnorm 63 | self.bidirectional = bidirectional 64 | self.num_dir = 2 if self.bidirectional else 1 65 | self.gamma_option = gamma_option 66 | self.gamma_baseline = gamma_baseline 67 | self.num_modules = num_modules 68 | self.module_num_layers = module_num_layers 69 | self.module_dim = module_dim 70 | self.debug_every = debug_every 71 | self.NULL = null_token 72 | self.START = start_token 73 | self.END = end_token 74 | 75 | self.variational_embedding_dropout = variational_embedding_dropout 76 | 77 | if self.bidirectional: # and not self.taking_context: 78 | if decoder_type != 'linear': 79 | raise(NotImplementedError) 80 | hidden_dim = (int) (hidden_dim / self.num_dir) 81 | 82 | self.func_list = { 83 | 'linear': None, 84 | 'sigmoid': F.sigmoid, 85 | 'tanh': F.tanh, 86 | 'exp': torch.exp, 87 | } 88 | 89 | self.cond_feat_size = 2 * self.module_dim * self.module_num_layers # FiLM params per ResBlock 90 | if not parameter_efficient: # parameter_efficient=False only used to load older trained models 91 | self.cond_feat_size = 4 * self.module_dim + 2 * self.num_modules 92 | 93 | self.encoder_embed = nn.Embedding(encoder_vocab_size, wordvec_dim) 94 | self.encoder_rnn = init_rnn(self.encoder_type, wordvec_dim, hidden_dim, rnn_num_layers, 95 | dropout=rnn_dropout, bidirectional=self.bidirectional) 96 | self.decoder_rnn = init_rnn(self.decoder_type, hidden_dim, hidden_dim, rnn_num_layers, 97 | dropout=rnn_dropout, bidirectional=self.bidirectional) 98 | 99 | if self.taking_context: 100 | self.decoder_linear = None #nn.Linear(2 * hidden_dim, hidden_dim) 101 | for n, p in self.encoder_rnn.named_parameters(): 102 | if n.startswith('weight'): xavier_uniform_(p) 103 | elif n.startswith('bias'): constant_(p, 0.) 104 | else: 105 | self.decoder_linear = nn.Linear(hidden_dim * self.num_dir, self.num_modules * self.cond_feat_size) 106 | 107 | if self.use_attention: 108 | # Florian Strub used Tanh here, but let's use identity to make this model 109 | # closer to the baseline film version 110 | #Need to change this if we want a different mechanism to compute attention weights 111 | attention_dim = self.module_dim 112 | self.context2key = nn.Linear(hidden_dim * self.num_dir, self.module_dim) 113 | # to transform control vector to film coefficients 114 | self.last_vector2key = [] 115 | self.decoders_att = [] 116 | for i in range(num_modules): 117 | mod = nn.Linear(hidden_dim * self.num_dir, attention_dim) 118 | self.add_module("last_vector2key{}".format(i), mod) 119 | self.last_vector2key.append(mod) 120 | mod = nn.Linear(hidden_dim * self.num_dir, 2*self.module_dim) 121 | self.add_module("decoders_att{}".format(i), mod) 122 | self.decoders_att.append(mod) 123 | 124 | if self.output_batchnorm: 125 | self.output_bn = nn.BatchNorm1d(self.cond_feat_size, affine=True) 126 | 127 | init_modules(self.modules()) 128 | if embedding_uniform_boundary > 0.: 129 | uniform_(self.encoder_embed.weight, -1.*embedding_uniform_boundary, embedding_uniform_boundary) 130 | 131 | # The attention scores will be saved here if the attention is used. 132 | self.scores = None 133 | 134 | def expand_encoder_vocab(self, token_to_idx, word2vec=None, std=0.01): 135 | expand_embedding_vocab(self.encoder_embed, token_to_idx, 136 | word2vec=word2vec, std=std) 137 | 138 | def get_dims(self, x=None): 139 | V_in = self.encoder_embed.num_embeddings 140 | V_out = self.cond_feat_size 141 | D = self.encoder_embed.embedding_dim 142 | H = self.encoder_rnn.hidden_size 143 | H_full = self.encoder_rnn.hidden_size * self.num_dir 144 | L = self.encoder_rnn.num_layers * self.num_dir 145 | 146 | N = x.size(0) if x is not None else None 147 | T_in = x.size(1) if x is not None else None 148 | T_out = self.num_modules 149 | return V_in, V_out, D, H, H_full, L, N, T_in, T_out 150 | 151 | def before_rnn(self, x, replace=0): 152 | N, T = x.size() 153 | idx = torch.LongTensor(N).fill_(T - 1) 154 | 155 | #mask to specify non-null tokens 156 | mask = torch.FloatTensor(N, T).zero_() 157 | 158 | # Find the last non-null element in each sequence. 159 | x_cpu = x.cpu() 160 | for i in range(N): 161 | for t in range(T - 1): 162 | if x_cpu.data[i, t] != self.NULL and x_cpu.data[i, t + 1] == self.NULL: 163 | idx[i] = t 164 | break 165 | 166 | for i in range(N): 167 | for t in range(T): 168 | if x_cpu.data[i, t] not in [self.NULL]: 169 | mask[i, t] = 1. 170 | 171 | idx = idx.type_as(x.data) 172 | x[x.data == self.NULL] = replace 173 | return x, idx, mask.to(device) 174 | 175 | def encoder(self, x, isTest=False): 176 | V_in, V_out, D, H, H_full, L, N, T_in, T_out = self.get_dims(x=x) 177 | x, idx, mask = self.before_rnn(x) # Tokenized word sequences (questions), end index 178 | 179 | if self.taking_context: 180 | lengths = torch.LongTensor(idx.shape).fill_(1) + idx.data.cpu() 181 | lengths = lengths.to(device) 182 | seq_lengths, perm_idx = lengths.sort(0, descending=True) 183 | iperm_idx = torch.LongTensor(perm_idx.shape).fill_(0).to(device) 184 | for i, v in enumerate(perm_idx): 185 | iperm_idx[v.data] = i 186 | x = x[perm_idx] 187 | 188 | embed = self.encoder_embed(x) 189 | 190 | h0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 191 | if self.encoder_type == 'lstm': 192 | c0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 193 | 194 | if self.variational_embedding_dropout > 0. and not isTest: 195 | varDrop = torch.Tensor(N, D).fill_(self.variational_embedding_dropout).bernoulli_().to(device) 196 | embed = (embed / (1. - self.variational_embedding_dropout)) * varDrop.unsqueeze(1) 197 | 198 | if self.taking_context: 199 | embed = pack_padded_sequence(embed, seq_lengths.data.cpu().numpy(), batch_first=True) 200 | 201 | if self.encoder_type == 'lstm': 202 | out, (hn, _) = self.encoder_rnn(embed, (h0, c0)) 203 | elif self.encoder_type == 'gru': 204 | out, hn = self.encoder_rnn(embed, h0) 205 | 206 | hn = hn.transpose(1,0).contiguous() 207 | hn = hn.view(hn.shape[0], -1) 208 | 209 | # Pull out the hidden state for the last non-null value in each input 210 | if self.taking_context: 211 | idx_out = None 212 | out, _ = pad_packed_sequence(out, batch_first=True) 213 | out = out[iperm_idx] 214 | 215 | if out.shape[1] < T_in: 216 | mask = mask[:, :(out.shape[1]-T_in)] #The packing truncate the original length so we need to change mask to fit it 217 | 218 | hn = hn[iperm_idx] 219 | else: 220 | idx = idx.view(N, 1, 1).expand(N, 1, H_full) 221 | idx_out = out.gather(1, idx).view(N, H_full) 222 | 223 | out = None 224 | hn = None 225 | 226 | 227 | return idx_out, out, hn, mask 228 | 229 | def decoder(self, encoded, dims, h0=None, c0=None): 230 | 231 | #if self.taking_context: 232 | # return self.decoder_linear(encoded) 233 | 234 | V_in, V_out, D, H, H_full, L, N, T_in, T_out = dims 235 | 236 | if self.decoder_type == 'linear': 237 | # (N x H) x (H x T_out*V_out) -> (N x T_out*V_out) -> N x T_out x V_out 238 | return self.decoder_linear(encoded).view(N, T_out, V_out), (None, None) 239 | 240 | encoded_repeat = encoded.view(N, 1, H).expand(N, T_out, H) 241 | if not h0: 242 | h0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 243 | 244 | if self.decoder_type == 'lstm': 245 | if not c0: 246 | c0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 247 | rnn_output, (ht, ct) = self.decoder_rnn(encoded_repeat, (h0, c0)) 248 | elif self.decoder_type == 'gru': 249 | ct = None 250 | rnn_output, ht = self.decoder_rnn(encoded_repeat, h0) 251 | 252 | rnn_output_2d = rnn_output.contiguous().view(N * T_out, H) 253 | linear_output = self.decoder_linear(rnn_output_2d) 254 | if self.output_batchnorm: 255 | linear_output = self.output_bn(linear_output) 256 | output_shaped = linear_output.view(N, T_out, V_out) 257 | return output_shaped, (ht, ct) 258 | 259 | def attention_decoder(self, context, last_vector, mask): 260 | context_keys = self.context2key(context) 261 | 262 | out = [] 263 | self.scores = [] 264 | for i in range(self.num_modules): 265 | # vanilla dot-product attention in the key space 266 | query = self.last_vector2key[i](last_vector) 267 | scores = (context_keys * query.unsqueeze(1)).sum(2) #NxLxd -> NxL 268 | 269 | # softmax 270 | scores = torch.exp(scores - scores.max(1, keepdim=True)[0]) * mask #mask help to eliminate padding words 271 | scores = scores / scores.sum(1, keepdim=True) #NxL 272 | self.scores.append(scores) 273 | 274 | control = (context * scores.unsqueeze(2)).sum(1) #Nxd 275 | 276 | coefficients = self.decoders_att[i](control).unsqueeze(1) #Nxd -> Nx2d -> Nx1x2d 277 | out.append(coefficients) 278 | self.scores = torch.cat([t.unsqueeze(0) for t in self.scores], 0) 279 | 280 | if len(out) == 0: return None 281 | if len(out) == 1: return out[0] 282 | return torch.cat(out, 1) #N x num_module x 2d 283 | 284 | def forward(self, x, isTest=False): 285 | if self.debug_every <= -2: 286 | pdb.set_trace() 287 | encoded, whole_context, last_vector, mask = self.encoder(x, isTest=isTest) 288 | 289 | if self.taking_context and not self.use_attention: 290 | #whole_context = self.decoder(whole_context, None) 291 | return (whole_context, last_vector, mask) 292 | 293 | if self.use_attention: #make sure taking_context is True as well if we want to use this. 294 | film_pre_mod = self.attention_decoder(whole_context, last_vector, mask) 295 | else: 296 | film_pre_mod, _ = self.decoder(encoded, self.get_dims(x=x)) 297 | film = self.modify_output(film_pre_mod, gamma_option=self.gamma_option, 298 | gamma_shift=self.gamma_baseline) 299 | return film 300 | 301 | def modify_output(self, out, gamma_option='linear', gamma_scale=1, gamma_shift=0, 302 | beta_option='linear', beta_scale=1, beta_shift=0): 303 | gamma_func = self.func_list[gamma_option] 304 | beta_func = self.func_list[beta_option] 305 | 306 | gs = [] 307 | bs = [] 308 | for i in range(self.module_num_layers): 309 | gs.append(slice(i * (2 * self.module_dim), i * (2 * self.module_dim) + self.module_dim)) 310 | bs.append(slice(i * (2 * self.module_dim) + self.module_dim, (i + 1) * (2 * self.module_dim))) 311 | 312 | if gamma_func is not None: 313 | for i in range(self.module_num_layers): 314 | out[:,:,gs[i]] = gamma_func(out[:,:,gs[i]]) 315 | if gamma_scale != 1: 316 | for i in range(self.module_num_layers): 317 | out[:,:,gs[i]] = out[:,:,gs[i]] * gamma_scale 318 | if gamma_shift != 0: 319 | for i in range(self.module_num_layers): 320 | out[:,:,gs[i]] = out[:,:,gs[i]] + gamma_shift 321 | if beta_func is not None: 322 | for i in range(self.module_num_layers): 323 | out[:,:,bs[i]] = beta_func(out[:,:,bs[i]]) 324 | out[:,:,b2] = beta_func(out[:,:,b2]) 325 | if beta_scale != 1: 326 | for i in range(self.module_num_layers): 327 | out[:,:,bs[i]] = out[:,:,bs[i]] * beta_scale 328 | if beta_shift != 0: 329 | for i in range(self.module_num_layers): 330 | out[:,:,bs[i]] = out[:,:,bs[i]] + beta_shift 331 | return out 332 | 333 | def init_rnn(rnn_type, hidden_dim1, hidden_dim2, rnn_num_layers, 334 | dropout=0, bidirectional=False): 335 | if rnn_type == 'gru': 336 | return nn.GRU(hidden_dim1, hidden_dim2, rnn_num_layers, dropout=dropout, 337 | batch_first=True, bidirectional=bidirectional) 338 | elif rnn_type == 'lstm': 339 | return nn.LSTM(hidden_dim1, hidden_dim2, rnn_num_layers, dropout=dropout, 340 | batch_first=True, bidirectional=bidirectional) 341 | elif rnn_type == 'linear': 342 | return None 343 | else: 344 | print('RNN type ' + str(rnn_type) + ' not yet implemented.') 345 | raise(NotImplementedError) 346 | -------------------------------------------------------------------------------- /vr/models/filmed_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | import pprint 5 | from termcolor import colored 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import torchvision.models 11 | 12 | from vr.models.layers import init_modules, GlobalAveragePool, Flatten 13 | from vr.models.layers import build_classifier, build_stem 14 | import vr.programs 15 | 16 | 17 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | class FiLM(nn.Module): 21 | """ 22 | A Feature-wise Linear Modulation Layer from 23 | 'FiLM: Visual Reasoning with a General Conditioning Layer' 24 | """ 25 | def forward(self, x, gammas, betas): 26 | gammas = gammas.unsqueeze(2).unsqueeze(3).expand_as(x) 27 | betas = betas.unsqueeze(2).unsqueeze(3).expand_as(x) 28 | return (gammas * x) + betas 29 | 30 | 31 | class FiLMedNet(nn.Module): 32 | def __init__(self, vocab, feature_dim=(1024, 14, 14), 33 | stem_num_layers=2, 34 | stem_batchnorm=False, 35 | stem_kernel_size=3, 36 | stem_subsample_layers=None, 37 | stem_stride=1, 38 | stem_padding=None, 39 | stem_dim=64, 40 | num_modules=4, 41 | module_num_layers=1, 42 | module_dim=128, 43 | module_residual=True, 44 | module_intermediate_batchnorm=False, 45 | module_batchnorm=False, 46 | module_batchnorm_affine=False, 47 | module_dropout=0, 48 | module_input_proj=1, 49 | module_kernel_size=3, 50 | classifier_proj_dim=512, 51 | classifier_downsample='maxpool2', 52 | classifier_fc_layers=(1024,), 53 | classifier_batchnorm=False, 54 | classifier_dropout=0, 55 | condition_method='bn-film', 56 | condition_pattern=[], 57 | use_gamma=True, 58 | use_beta=True, 59 | use_coords=1, 60 | debug_every=float('inf'), 61 | print_verbose_every=float('inf'), 62 | verbose=True, 63 | ): 64 | super(FiLMedNet, self).__init__() 65 | 66 | num_answers = len(vocab['answer_idx_to_token']) 67 | 68 | self.stem_times = [] 69 | self.module_times = [] 70 | self.classifier_times = [] 71 | self.timing = False 72 | 73 | self.num_modules = num_modules 74 | self.module_num_layers = module_num_layers 75 | self.module_batchnorm = module_batchnorm 76 | self.module_dim = module_dim 77 | self.condition_method = condition_method 78 | self.use_gamma = use_gamma 79 | self.use_beta = use_beta 80 | self.use_coords_freq = use_coords 81 | self.debug_every = debug_every 82 | self.print_verbose_every = print_verbose_every 83 | 84 | # Initialize helper variables 85 | self.stem_use_coords = (stem_stride == 1) and (self.use_coords_freq > 0) 86 | self.condition_pattern = condition_pattern 87 | if len(condition_pattern) == 0: 88 | self.condition_pattern = [] 89 | for i in range(self.module_num_layers * self.num_modules): 90 | self.condition_pattern.append(self.condition_method != 'concat') 91 | else: 92 | self.condition_pattern = [i > 0 for i in self.condition_pattern] 93 | self.extra_channel_freq = self.use_coords_freq 94 | self.block = FiLMedResBlock 95 | self.num_cond_maps = 2 * self.module_dim if self.condition_method == 'concat' else 0 96 | self.fwd_count = 0 97 | self.num_extra_channels = 2 if self.use_coords_freq > 0 else 0 98 | if self.debug_every <= -1: 99 | self.print_verbose_every = 1 100 | 101 | # Initialize stem 102 | stem_feature_dim = feature_dim[0] + self.stem_use_coords * self.num_extra_channels 103 | self.stem = build_stem( 104 | stem_feature_dim, stem_dim, module_dim, 105 | num_layers=stem_num_layers, with_batchnorm=stem_batchnorm, 106 | kernel_size=stem_kernel_size, stride=stem_stride, padding=stem_padding, 107 | subsample_layers=stem_subsample_layers) 108 | tmp = self.stem(Variable(torch.zeros([1, feature_dim[0], feature_dim[1], feature_dim[2]]))) 109 | module_H = tmp.size(2) 110 | module_W = tmp.size(3) 111 | 112 | self.stem_coords = coord_map((feature_dim[1], feature_dim[2])) 113 | self.coords = coord_map((module_H, module_W)) 114 | self.default_weight = torch.ones(1, 1, self.module_dim).to(device) 115 | self.default_bias = torch.zeros(1, 1, self.module_dim).to(device) 116 | 117 | # Initialize FiLMed network body 118 | self.function_modules = {} 119 | self.vocab = vocab 120 | for fn_num in range(self.num_modules): 121 | with_cond = self.condition_pattern[self.module_num_layers * fn_num: 122 | self.module_num_layers * (fn_num + 1)] 123 | mod = self.block(module_dim, with_residual=module_residual, 124 | with_intermediate_batchnorm=module_intermediate_batchnorm, with_batchnorm=module_batchnorm, 125 | with_cond=with_cond, 126 | dropout=module_dropout, 127 | num_extra_channels=self.num_extra_channels, 128 | extra_channel_freq=self.extra_channel_freq, 129 | with_input_proj=module_input_proj, 130 | num_cond_maps=self.num_cond_maps, 131 | kernel_size=module_kernel_size, 132 | batchnorm_affine=module_batchnorm_affine, 133 | num_layers=self.module_num_layers, 134 | condition_method=condition_method, 135 | debug_every=self.debug_every) 136 | self.add_module(str(fn_num), mod) 137 | self.function_modules[fn_num] = mod 138 | 139 | # Initialize output classifier 140 | self.classifier = build_classifier(module_dim + self.num_extra_channels, module_H, module_W, 141 | num_answers, classifier_fc_layers, classifier_proj_dim, 142 | classifier_downsample, with_batchnorm=classifier_batchnorm, 143 | dropout=classifier_dropout) 144 | 145 | init_modules(self.modules()) 146 | 147 | def forward(self, x, film, save_activations=False): 148 | # Initialize forward pass and externally viewable activations 149 | self.fwd_count += 1 150 | if save_activations: 151 | self.feats = None 152 | self.module_outputs = [] 153 | self.cf_input = None 154 | 155 | if self.debug_every <= -2: 156 | pdb.set_trace() 157 | 158 | # Prepare FiLM layers 159 | gammas = None 160 | betas = None 161 | if self.condition_method == 'concat': 162 | # Use parameters usually used to condition via FiLM instead to condition via concatenation 163 | cond_params = film[:,:,:2*self.module_dim] 164 | cond_maps = cond_params.unsqueeze(3).unsqueeze(4).expand(cond_params.size() + x.size()[-2:]) 165 | else: 166 | gammas, betas = torch.split(film[:,:,:2*self.module_dim], self.module_dim, dim=-1) 167 | if not self.use_gamma: 168 | gammas = self.default_weight.expand_as(gammas) 169 | if not self.use_beta: 170 | betas = self.default_bias.expand_as(betas) 171 | 172 | # Propagate up image features CNN 173 | stem_batch_coords = None 174 | batch_coods = None 175 | if self.use_coords_freq > 0: 176 | stem_batch_coords = self.stem_coords.unsqueeze(0).expand( 177 | torch.Size((x.size(0), *self.stem_coords.size()))) 178 | batch_coords = self.coords.unsqueeze(0).expand( 179 | torch.Size((x.size(0), *self.coords.size()))) 180 | if self.stem_use_coords: 181 | x = torch.cat([x, stem_batch_coords], 1) 182 | feats = self.stem(x) 183 | if save_activations: 184 | self.feats = feats 185 | N, _, H, W = feats.size() 186 | 187 | # Propagate up the network from low-to-high numbered blocks 188 | module_inputs = torch.zeros(feats.size()).unsqueeze(1).expand( 189 | N, self.num_modules, self.module_dim, H, W).to(device) 190 | module_inputs[:,0] = feats 191 | for fn_num in range(self.num_modules): 192 | if self.condition_method == 'concat': 193 | layer_output = self.function_modules[fn_num](module_inputs[:,fn_num], 194 | extra_channels=batch_coords, cond_maps=cond_maps[:,fn_num]) 195 | else: 196 | layer_output = self.function_modules[fn_num](module_inputs[:,fn_num], 197 | gammas[:,fn_num,:], betas[:,fn_num,:], batch_coords) 198 | 199 | # Store for future computation 200 | if save_activations: 201 | self.module_outputs.append(layer_output) 202 | if fn_num == (self.num_modules - 1): 203 | final_module_output = layer_output 204 | else: 205 | module_inputs_updated = module_inputs.clone() 206 | module_inputs_updated[:,fn_num+1] = module_inputs_updated[:,fn_num+1] + layer_output 207 | module_inputs = module_inputs_updated 208 | 209 | if self.debug_every <= -2: 210 | pdb.set_trace() 211 | 212 | # Run the final classifier over the resultant, post-modulated features. 213 | if self.use_coords_freq > 0: 214 | final_module_output = torch.cat([final_module_output, batch_coords], 1) 215 | if save_activations: 216 | self.cf_input = final_module_output 217 | out = self.classifier(final_module_output) 218 | 219 | if ((self.fwd_count % self.debug_every) == 0) or (self.debug_every <= -1): 220 | pdb.set_trace() 221 | return out 222 | 223 | 224 | class FiLMedResBlock(nn.Module): 225 | def __init__(self, in_dim, out_dim=None, with_residual=True, with_intermediate_batchnorm=False, with_batchnorm=True, 226 | with_cond=[False], dropout=0, num_extra_channels=0, extra_channel_freq=1, 227 | with_input_proj=0, num_cond_maps=0, kernel_size=3, batchnorm_affine=False, 228 | num_layers=1, condition_method='bn-film', debug_every=float('inf')): 229 | if out_dim is None: 230 | out_dim = in_dim 231 | super(FiLMedResBlock, self).__init__() 232 | self.with_residual = with_residual 233 | self.with_intermediate_batchnorm = with_intermediate_batchnorm 234 | self.with_batchnorm = with_batchnorm 235 | self.with_cond = with_cond 236 | self.dropout = dropout 237 | self.extra_channel_freq = 0 if num_extra_channels == 0 else extra_channel_freq 238 | self.with_input_proj = with_input_proj # Kernel size of input projection 239 | self.num_cond_maps = num_cond_maps 240 | self.kernel_size = kernel_size 241 | self.batchnorm_affine = batchnorm_affine 242 | self.num_layers = num_layers 243 | self.condition_method = condition_method 244 | self.debug_every = debug_every 245 | 246 | if self.kernel_size % 2 == 0: 247 | raise(NotImplementedError) 248 | if self.num_layers >= 2: 249 | raise(NotImplementedError) 250 | 251 | if self.condition_method == 'block-input-film' and self.with_cond[0]: 252 | self.film = FiLM() 253 | if self.with_input_proj: 254 | self.input_proj = nn.Conv2d(in_dim + (num_extra_channels if self.extra_channel_freq >= 1 else 0), 255 | in_dim, kernel_size=self.with_input_proj, padding=self.with_input_proj // 2) 256 | 257 | self.conv1 = nn.Conv2d(in_dim + self.num_cond_maps + 258 | (num_extra_channels if self.extra_channel_freq >= 2 else 0), 259 | out_dim, kernel_size=self.kernel_size, 260 | padding=self.kernel_size // 2) 261 | if self.condition_method == 'conv-film' and self.with_cond[0]: 262 | self.film = FiLM() 263 | if self.with_intermediate_batchnorm: 264 | self.bn0 = nn.BatchNorm2d(in_dim, affine=((not self.with_cond[0]) or self.batchnorm_affine)) 265 | if self.with_batchnorm: 266 | self.bn1 = nn.BatchNorm2d(out_dim, affine=((not self.with_cond[0]) or self.batchnorm_affine)) 267 | if self.condition_method == 'bn-film' and self.with_cond[0]: 268 | self.film = FiLM() 269 | if dropout > 0: 270 | self.drop = nn.Dropout2d(p=self.dropout) 271 | if ((self.condition_method == 'relu-film' or self.condition_method == 'block-output-film') 272 | and self.with_cond[0]): 273 | self.film = FiLM() 274 | 275 | init_modules(self.modules()) 276 | 277 | def forward(self, x, gammas=None, betas=None, extra_channels=None, cond_maps=None): 278 | if self.debug_every <= -2: 279 | pdb.set_trace() 280 | 281 | if self.condition_method == 'block-input-film' and self.with_cond[0]: 282 | x = self.film(x, gammas, betas) 283 | 284 | # ResBlock input projection 285 | if self.with_input_proj: 286 | if extra_channels is not None and self.extra_channel_freq >= 1: 287 | x = torch.cat([x, extra_channels], 1) 288 | x = self.input_proj(x) 289 | if self.with_intermediate_batchnorm: 290 | x = self.bn0(x) 291 | x = F.relu(x) 292 | out = x 293 | 294 | # ResBlock body 295 | if cond_maps is not None: 296 | out = torch.cat([out, cond_maps], 1) 297 | if extra_channels is not None and self.extra_channel_freq >= 2: 298 | out = torch.cat([out, extra_channels], 1) 299 | out = self.conv1(out) 300 | if self.condition_method == 'conv-film' and self.with_cond[0]: 301 | out = self.film(out, gammas, betas) 302 | if self.with_batchnorm: 303 | out = self.bn1(out) 304 | if self.condition_method == 'bn-film' and self.with_cond[0]: 305 | out = self.film(out, gammas, betas) 306 | if self.dropout > 0: 307 | out = self.drop(out) 308 | out = F.relu(out) 309 | if self.condition_method == 'relu-film' and self.with_cond[0]: 310 | out = self.film(out, gammas, betas) 311 | 312 | # ResBlock remainder 313 | if self.with_residual: 314 | out = x + out 315 | if self.condition_method == 'block-output-film' and self.with_cond[0]: 316 | out = self.film(out, gammas, betas) 317 | return out 318 | 319 | 320 | class ConcatFiLMedResBlock(nn.Module): 321 | def __init__(self, num_input, in_dim, out_dim=None, with_residual=True, with_intermediate_batchnorm=False, with_batchnorm=True, 322 | with_cond=[False], dropout=0, num_extra_channels=0, extra_channel_freq=1, 323 | with_input_proj=0, num_cond_maps=0, kernel_size=3, batchnorm_affine=False, 324 | num_layers=1, condition_method='bn-film', debug_every=float('inf')): 325 | super(ConcatFiLMedResBlock, self).__init__() 326 | self.proj = nn.Conv2d(num_input * in_dim, in_dim, kernel_size=1, padding=0) 327 | self.tfilmedResBlock = FiLMedResBlock(in_dim=in_dim, out_dim=out_dim, with_residual=with_residual, 328 | with_intermediate_batchnorm=with_intermediate_batchnorm, with_batchnorm=with_batchnorm, 329 | with_cond=with_cond, dropout=dropout, num_extra_channels=num_extra_channels, extra_channel_freq=extra_channel_freq, 330 | with_input_proj=with_input_proj, num_cond_maps=num_cond_maps, kernel_size=kernel_size, batchnorm_affine=batchnorm_affine, 331 | num_layers=num_layers, condition_method=condition_method, debug_every=debug_every) 332 | 333 | def forward(self, x, gammas=None, betas=None, extra_channels=None, cond_maps=None): 334 | out = torch.cat(x, 1) # Concatentate along depth 335 | out = F.relu(self.proj(out)) 336 | out = self.tfilmedResBlock(out, gammas=gammas, betas=betas, extra_channels=extra_channels, cond_maps=cond_maps) 337 | return out 338 | 339 | 340 | def coord_map(shape, start=-1, end=1): 341 | """ 342 | Gives, a 2d shape tuple, returns two mxn coordinate maps, 343 | Ranging min-max in the x and y directions, respectively. 344 | """ 345 | m, n = shape 346 | x_coord_row = torch.linspace(start, end, steps=n).to(device) 347 | y_coord_row = torch.linspace(start, end, steps=m).to(device) 348 | x_coords = x_coord_row.unsqueeze(0).expand(torch.Size((m, n))).unsqueeze(0) 349 | y_coords = y_coord_row.unsqueeze(1).expand(torch.Size((m, n))).unsqueeze(0) 350 | return Variable(torch.cat([x_coords, y_coords], 0)) 351 | -------------------------------------------------------------------------------- /vr/models/hetero_net.py: -------------------------------------------------------------------------------- 1 | """Heterogenous ModuleNet as done originally in Hu et al 2 | 3 | TODO: 4 | - batchnorm? 5 | """ 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Variable 10 | 11 | from vr.models.layers import build_stem 12 | from vr.models.module_net import ModuleNet 13 | 14 | 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class Find(nn.Module): 19 | # Input: 20 | # image_feat_grid: [N, D_im, H, W] 21 | # text_param: [N, D_txt] 22 | # Output: 23 | # image_att: [N, 1, H, W] 24 | def __init__(self, image_dim, text_dim, map_dim=500): 25 | super().__init__() 26 | self.conv1 = nn.Conv2d(image_dim, map_dim, 1) 27 | self.embed = nn.Embedding(text_dim, map_dim) 28 | self.conv2 = nn.Conv2d(map_dim, 1, 1) 29 | self.map_dim = map_dim 30 | 31 | def forward(self, text, images): 32 | image_mapped = self.conv1(images) 33 | text_mapped = self.embed(text).view(-1, self.map_dim, 1, 1) 34 | 35 | mult_norm = F.normalize(image_mapped * text_mapped, p=2, dim=1) 36 | return self.conv2(mult_norm) 37 | 38 | 39 | class Transform(nn.Module): 40 | # Input: 41 | # image_att: [N, 1, H, W] 42 | # text: [N, D_txt] 43 | # Output: 44 | # image_att: [N, 1, H, W] 45 | def __init__(self, text_dim, map_dim=500, kernel_size=3): 46 | super().__init__() 47 | if kernel_size % 2 == 0: 48 | raise NotImplementedError() 49 | self.conv1 = nn.Conv2d(1, map_dim, kernel_size, padding=kernel_size // 2) 50 | self.embed = nn.Embedding(text_dim, map_dim) 51 | self.conv2 = nn.Conv2d(map_dim, 1, 1) 52 | self.map_dim = map_dim 53 | 54 | def forward(self, text, image_att): 55 | image_att_mapped = self.conv1(image_att) 56 | text_mapped = self.embed(text).view(-1, self.map_dim, 1, 1) 57 | 58 | mult_norm = F.normalize(image_att_mapped * text_mapped, p=2, dim=1) 59 | return self.conv2(mult_norm) 60 | 61 | 62 | class And(nn.Module): 63 | # Input: 64 | # att_grid_0: [N, 1, H, W] 65 | # att_grid_1: [N, 1, H, W] 66 | # Output: 67 | # att_grid_and: [N, 1, H, W] 68 | def forward(self, att1, att2): 69 | return torch.min(att1, att2) 70 | 71 | 72 | class Answer(nn.Module): 73 | # Input: 74 | # att_grid: [N, 1, H, W] 75 | # Output: 76 | # answer_scores: [N, self.num_answers] 77 | def __init__(self, num_answers): 78 | super().__init__() 79 | self.linear = nn.Linear(3, num_answers) 80 | 81 | def forward(self, att): 82 | att_min = att.min(dim=-1)[0].min(dim=-1)[0] 83 | att_max = att.max(dim=-1)[0].max(dim=-1)[0] 84 | att_mean = att.mean(dim=-1).mean(dim=-1) 85 | att_reduced = torch.cat((att_min, att_mean, att_max), dim=1) 86 | 87 | return self.linear(att_reduced) 88 | 89 | 90 | class HeteroModuleNet(ModuleNet): 91 | def __init__(self, 92 | vocab, 93 | feature_dim, 94 | stem_num_layers, 95 | stem_kernel_size, 96 | stem_stride, 97 | stem_padding, 98 | stem_batchnorm, 99 | module_dim, 100 | module_batchnorm, 101 | verbose=True): 102 | super(ModuleNet, self).__init__() 103 | 104 | self.program_idx_to_token = vocab['program_idx_to_token'] 105 | self.answer_to_idx = vocab['answer_idx_to_token'] 106 | self.text_token_to_idx = vocab['text_token_to_idx'] 107 | self.program_token_to_module_text = vocab['program_token_to_module_text'] 108 | self.name_to_module = { 109 | 'and': And(), 110 | 'answer': lambda x: x, 111 | 'find': Find(module_dim, len(self.text_token_to_idx)), 112 | 'transform': Transform(len(self.text_token_to_idx)), 113 | } 114 | self.name_to_num_inputs = { 115 | 'and': 2, 116 | 'answer': 1, 117 | 'find': 1, 118 | 'transform': 1, 119 | } 120 | 121 | input_C, input_H, input_W = feature_dim 122 | self.stem = build_stem(input_C, 123 | module_dim, 124 | num_layers=stem_num_layers, 125 | kernel_size=stem_kernel_size, 126 | stride=stem_stride, 127 | padding=stem_padding, 128 | with_batchnorm=stem_batchnorm) 129 | 130 | self.classifier = Answer(len(self.answer_to_idx)) 131 | 132 | if verbose: 133 | print('Here is my stem:') 134 | print(self.stem) 135 | print('Here is my classifier:') 136 | print(self.classifier) 137 | 138 | for name, module in self.name_to_module.items(): 139 | if name != 'answer': 140 | self.add_module(name, module) 141 | 142 | self.save_module_outputs = False 143 | 144 | def _forward_modules_ints_helper(self, feats, program, i, j): 145 | if j >= program.size(1): 146 | raise IndexError('malformed program, reached index', j) 147 | 148 | fn_idx = program.data[i, j] 149 | fn_str = self.program_idx_to_token[fn_idx] 150 | 151 | if fn_str == '': 152 | return self._forward_modules_ints_helper(feats, program, i, j + 1) 153 | elif fn_str in ['', '']: 154 | raise IndexError('reached area out of program ', fn_str) 155 | 156 | j += 1 157 | if fn_str == 'scene': 158 | output = feats[i].unsqueeze(0) 159 | else: 160 | module_name, text_token = self.program_token_to_module_text[fn_str] 161 | module = self.name_to_module[module_name] 162 | num_inputs = self.name_to_num_inputs[module_name] 163 | module_inputs = [] 164 | 165 | if text_token is not None: 166 | # very ugly 167 | input_text = torch.LongTensor([self.text_token_to_idx[text_token]]).unsqueeze(0) 168 | if program.is_cuda: 169 | input_text = input_text.to(device) 170 | module_inputs.append(Variable(input_text)) 171 | 172 | for _ in range(num_inputs): 173 | module_input, j = self._forward_modules_ints_helper(feats, program, i, j) 174 | module_inputs.append(module_input) 175 | 176 | output = module(*module_inputs) 177 | 178 | return output, j 179 | -------------------------------------------------------------------------------- /vr/models/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.nn.init import kaiming_normal_, kaiming_uniform_ 16 | 17 | 18 | class SequentialSaveActivations(nn.Sequential): 19 | 20 | def forward(self, input_): 21 | self.outputs = [input_] 22 | for module in self._modules.values(): 23 | input_ = module(input_) 24 | self.outputs.append(input_) 25 | return input_ 26 | 27 | 28 | class SimpleVisualBlock(nn.Module): 29 | def __init__(self, in_dim, out_dim=None, kernel_size=3): 30 | if out_dim is None: 31 | out_dim = in_dim 32 | super(SimpleVisualBlock, self).__init__() 33 | if kernel_size % 2 == 0: 34 | raise NotImplementedError() 35 | self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.conv(x)) 39 | return out 40 | 41 | 42 | 43 | 44 | class ResidualBlock(nn.Module): 45 | def __init__(self, in_dim, out_dim=None, kernel_size=3, with_residual=True, with_batchnorm=True): 46 | if out_dim is None: 47 | out_dim = in_dim 48 | super(ResidualBlock, self).__init__() 49 | if kernel_size % 2 == 0: 50 | raise NotImplementedError() 51 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2) 52 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2) 53 | self.with_batchnorm = with_batchnorm 54 | if with_batchnorm: 55 | self.bn1 = nn.BatchNorm2d(out_dim) 56 | self.bn2 = nn.BatchNorm2d(out_dim) 57 | self.with_residual = with_residual 58 | if in_dim == out_dim or not with_residual: 59 | self.proj = None 60 | else: 61 | self.proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 62 | 63 | def forward(self, x): 64 | if self.with_batchnorm: 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = self.bn2(self.conv2(out)) 67 | else: 68 | out = self.conv2(F.relu(self.conv1(x))) 69 | res = x if self.proj is None else self.proj(x) 70 | if self.with_residual: 71 | out = F.relu(res + out) 72 | else: 73 | out = F.relu(out) 74 | return out 75 | 76 | 77 | 78 | class ConcatBlock(nn.Module): 79 | def __init__(self, dim, kernel_size, with_residual=True, with_batchnorm=True, use_simple=False): 80 | super(ConcatBlock, self).__init__() 81 | self.proj = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0) 82 | if use_simple: 83 | self.vis_block = SimpleVisualBlock(dim, kernel_size=kernel_size) 84 | else: 85 | self.vis_block = ResidualBlock(dim, kernel_size=kernel_size, 86 | with_residual=with_residual,with_batchnorm=with_batchnorm) 87 | 88 | def forward(self, x, y): 89 | out = torch.cat([x, y], 1) # Concatentate along depth 90 | out = F.relu(self.proj(out)) 91 | out = self.vis_block(out) 92 | return out 93 | 94 | 95 | class GlobalAveragePool(nn.Module): 96 | def forward(self, x): 97 | N, C = x.size(0), x.size(1) 98 | return x.view(N, C, -1).mean(2).squeeze(2) 99 | 100 | 101 | class Flatten(nn.Module): 102 | def forward(self, x): 103 | return x.view(x.size(0), -1) 104 | 105 | 106 | def build_stem(feature_dim, 107 | stem_dim, 108 | module_dim, 109 | num_layers=2, 110 | with_batchnorm=True, 111 | kernel_size=[3], 112 | stride=[1], 113 | padding=None, 114 | subsample_layers=None, 115 | acceptEvenKernel=False): 116 | layers = [] 117 | prev_dim = feature_dim 118 | 119 | if len(kernel_size) == 1: 120 | kernel_size = num_layers * kernel_size 121 | if len(stride) == 1: 122 | stride = num_layers * stride 123 | if padding == None: 124 | padding = num_layers * [None] 125 | if len(padding) == 1: 126 | padding = num_layers * padding 127 | if subsample_layers is None: 128 | subsample_layers = [] 129 | 130 | for i, cur_kernel_size, cur_stride, cur_padding in zip(range(num_layers), kernel_size, stride, padding): 131 | curr_out = module_dim if (i == (num_layers-1) ) else stem_dim 132 | if cur_padding is None: # Calculate default padding when None provided 133 | if cur_kernel_size % 2 == 0 and not acceptEvenKernel: 134 | raise(NotImplementedError) 135 | cur_padding = cur_kernel_size // 2 136 | layers.append(nn.Conv2d(prev_dim, curr_out, 137 | kernel_size=cur_kernel_size, stride=cur_stride, padding=cur_padding, 138 | bias=not with_batchnorm)) 139 | if with_batchnorm: 140 | layers.append(nn.BatchNorm2d(curr_out)) 141 | layers.append(nn.ReLU(inplace=True)) 142 | if i in subsample_layers: 143 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 144 | prev_dim = curr_out 145 | return SequentialSaveActivations(*layers) 146 | 147 | 148 | def build_classifier(module_C, module_H, module_W, num_answers, 149 | fc_dims=[], proj_dim=None, downsample=None, 150 | with_batchnorm=True, dropout=[]): 151 | layers = [] 152 | prev_dim = module_C * module_H * module_W 153 | cur_dim = module_C 154 | if proj_dim is not None and proj_dim > 0: 155 | layers.append(nn.Conv2d(module_C, proj_dim, kernel_size=1, bias=not with_batchnorm)) 156 | if with_batchnorm: 157 | layers.append(nn.BatchNorm2d(proj_dim)) 158 | layers.append(nn.ReLU(inplace=True)) 159 | prev_dim = proj_dim * module_H * module_W 160 | cur_dim = proj_dim 161 | if downsample is not None: 162 | if 'maxpool' in downsample or 'avgpool' in downsample: 163 | pool = nn.MaxPool2d if 'maxpool' in downsample else nn.AvgPool2d 164 | if 'full' in downsample: 165 | if module_H != module_W: 166 | assert(NotImplementedError) 167 | pool_size = module_H 168 | else: 169 | pool_size = int(downsample[-1]) 170 | # Note: Potentially sub-optimal padding for non-perfectly aligned pooling 171 | padding = 0 if ((module_H % pool_size == 0) and (module_W % pool_size == 0)) else 1 172 | layers.append(pool(kernel_size=pool_size, stride=pool_size, padding=padding)) 173 | prev_dim = cur_dim * math.ceil(module_H / pool_size) * math.ceil(module_W / pool_size) 174 | if downsample == 'aggressive': 175 | raise ValueError() 176 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 177 | layers.append(nn.AvgPool2d(kernel_size=module_H // 2, stride=module_W // 2)) 178 | prev_dim = proj_dim 179 | fc_dims = [] # No FC layers here 180 | layers.append(Flatten()) 181 | 182 | if isinstance(dropout, float): 183 | dropout = [dropout] * len(fc_dims) 184 | elif not dropout: 185 | dropout = [0] * len(fc_dims) 186 | 187 | for next_dim, next_dropout in zip(fc_dims, dropout): 188 | layers.append(nn.Linear(prev_dim, next_dim, bias=not with_batchnorm)) 189 | if with_batchnorm: 190 | layers.append(nn.BatchNorm1d(next_dim)) 191 | layers.append(nn.ReLU(inplace=True)) 192 | if next_dropout > 0: 193 | layers.append(nn.Dropout(p=next_dropout)) 194 | prev_dim = next_dim 195 | layers.append(nn.Linear(prev_dim, num_answers)) 196 | return nn.Sequential(*layers) 197 | 198 | 199 | def init_modules(modules, init='uniform'): 200 | if init.lower() == 'normal': 201 | init_params = kaiming_normal_ 202 | elif init.lower() == 'uniform': 203 | init_params = kaiming_uniform_ 204 | else: 205 | return 206 | for m in modules: 207 | if isinstance(m, (nn.Conv2d, nn.Linear)): 208 | init_params(m.weight) 209 | -------------------------------------------------------------------------------- /vr/models/module_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | import torchvision.models 16 | 17 | from vr.models.layers import init_modules, ResidualBlock, SimpleVisualBlock, GlobalAveragePool, Flatten 18 | from vr.models.layers import build_classifier, build_stem, ConcatBlock 19 | import vr.programs 20 | 21 | from torch.nn.init import kaiming_normal, kaiming_uniform, xavier_uniform, xavier_normal, constant 22 | 23 | from vr.models.filmed_net import FiLM, FiLMedResBlock, ConcatFiLMedResBlock, coord_map 24 | 25 | class ModuleNet(nn.Module): 26 | def __init__(self, vocab, feature_dim, 27 | use_film, 28 | use_simple_block, 29 | sharing_patterns, 30 | stem_num_layers, 31 | stem_batchnorm, 32 | stem_subsample_layers, 33 | stem_kernel_size, 34 | stem_stride, 35 | stem_padding, 36 | stem_dim, 37 | module_dim, 38 | module_kernel_size, 39 | module_input_proj, 40 | module_residual=True, 41 | module_batchnorm=False, 42 | classifier_proj_dim=512, 43 | classifier_downsample='maxpool2', 44 | classifier_fc_layers=(1024,), 45 | classifier_batchnorm=False, 46 | classifier_dropout=0, 47 | verbose=True): 48 | super(ModuleNet, self).__init__() 49 | 50 | self.module_dim = module_dim 51 | 52 | # should be 0 or 1 to indicate the use of film block or not (0 would bring you back to the original EE model) 53 | self.use_film = use_film 54 | # should be 0 or 1 to indicate if we are using ResNets or a simple 3x3 conv followed by ReLU 55 | self.use_simple_block = use_simple_block 56 | 57 | # this should be a list of two elements (either 0 or 1). It's only active if self.use_film == 1 58 | # The first element of 1 indicates the sharing of CNN weights in the film blocks, 0 otheriwse 59 | # The second element of 1 indicate the sharing of film coefficient in the film blocks, 0 otherwise 60 | # so [1,0] would be sharing the CNN weights while having different film coefficients for different modules in the program 61 | self.sharing_patterns = sharing_patterns 62 | 63 | self.stem = build_stem(feature_dim[0], stem_dim, module_dim, 64 | num_layers=stem_num_layers, 65 | subsample_layers=stem_subsample_layers, 66 | kernel_size=stem_kernel_size, 67 | padding=stem_padding, 68 | with_batchnorm=stem_batchnorm) 69 | tmp = self.stem(Variable(torch.zeros([1, feature_dim[0], feature_dim[1], feature_dim[2]]))) 70 | module_H = tmp.size(2) 71 | module_W = tmp.size(3) 72 | 73 | self.coords = coord_map((module_H, module_W)) 74 | 75 | if verbose: 76 | print('Here is my stem:') 77 | print(self.stem) 78 | 79 | num_answers = len(vocab['answer_idx_to_token']) 80 | self.classifier = build_classifier(module_dim, module_H, module_W, num_answers, 81 | classifier_fc_layers, 82 | classifier_proj_dim, 83 | classifier_downsample, 84 | with_batchnorm=classifier_batchnorm, 85 | dropout=classifier_dropout) 86 | if verbose: 87 | print('Here is my classifier:') 88 | print(self.classifier) 89 | self.stem_times = [] 90 | self.module_times = [] 91 | self.classifier_times = [] 92 | self.timing = False 93 | 94 | self.function_modules = {} 95 | self.function_modules_num_inputs = {} 96 | self.fn_str_2_filmId = {} 97 | self.vocab = vocab 98 | for fn_str in vocab['program_token_to_idx']: 99 | num_inputs = vocab['program_token_arity'][fn_str] 100 | self.function_modules_num_inputs[fn_str] = num_inputs 101 | 102 | if self.use_film: 103 | if self.sharing_patterns[1] == 1: 104 | self.fn_str_2_filmId[fn_str] = 0 105 | else: 106 | self.fn_str_2_filmId[fn_str] = len(self.fn_str_2_filmId) 107 | 108 | if fn_str == 'scene' or num_inputs == 1: 109 | if self.use_film: 110 | if self.sharing_patterns[0] == 1: 111 | mod = None 112 | else: 113 | mod = FiLMedResBlock(module_dim, with_residual=module_residual, 114 | with_intermediate_batchnorm=False, with_batchnorm=False, 115 | with_cond=[True, True], 116 | num_extra_channels=2, # was 2 for original film, 117 | extra_channel_freq=1, 118 | with_input_proj=module_input_proj, 119 | num_cond_maps=0, 120 | kernel_size=module_kernel_size, 121 | batchnorm_affine=False, 122 | num_layers=1, 123 | condition_method='bn-film', 124 | debug_every=float('inf')) 125 | else: 126 | if self.use_simple_block: 127 | mod = SimpleVisualBlock(module_dim, kernel_size=module_kernel_size) 128 | else: 129 | mod = ResidualBlock( 130 | module_dim, 131 | kernel_size=module_kernel_size, 132 | with_residual=module_residual, 133 | with_batchnorm=module_batchnorm) 134 | elif num_inputs == 2: 135 | if self.use_film: 136 | if self.sharing_patterns[0] == 1: 137 | mod = None 138 | else: 139 | mod = ConcatFiLMedResBlock(2, module_dim, with_residual=module_residual, 140 | with_intermediate_batchnorm=False, with_batchnorm=False, 141 | with_cond=[True, True], 142 | num_extra_channels=2, #was 2 for original film, 143 | extra_channel_freq=1, 144 | with_input_proj=module_input_proj, 145 | num_cond_maps=0, 146 | kernel_size=module_kernel_size, 147 | batchnorm_affine=False, 148 | num_layers=1, 149 | condition_method='bn-film', 150 | debug_every=float('inf')) 151 | else: 152 | mod = ConcatBlock( 153 | module_dim, 154 | kernel_size=module_kernel_size, 155 | with_residual=module_residual, 156 | with_batchnorm=module_batchnorm) 157 | else: 158 | raise Exception('Not implemented!') 159 | 160 | if mod is not None: 161 | self.add_module(fn_str, mod) 162 | self.function_modules[fn_str] = mod 163 | 164 | if self.use_film and self.sharing_patterns[0] == 1: 165 | mod = ConcatFiLMedResBlock(2, module_dim, with_residual=module_residual, 166 | with_intermediate_batchnorm=False, with_batchnorm=False, 167 | with_cond=[True, True], 168 | num_extra_channels=2, #was 2 for original film, 169 | extra_channel_freq=1, 170 | with_input_proj=module_input_proj, 171 | num_cond_maps=0, 172 | kernel_size=module_kernel_size, 173 | batchnorm_affine=False, 174 | num_layers=1, 175 | condition_method='bn-film', 176 | debug_every=float('inf')) 177 | self.add_module('shared_film', mod) 178 | self.function_modules['shared_film'] = mod 179 | 180 | self.declare_film_coefficients() 181 | 182 | self.save_module_outputs = False 183 | 184 | def declare_film_coefficients(self): 185 | if self.use_film: 186 | self.gammas = nn.Parameter(torch.Tensor(1, len(self.fn_str_2_filmId), self.module_dim)) 187 | xavier_uniform(self.gammas) 188 | self.betas = nn.Parameter(torch.Tensor(1, len(self.fn_str_2_filmId), self.module_dim)) 189 | xavier_uniform(self.betas) 190 | 191 | else: 192 | self.gammas = None 193 | self.betas = None 194 | 195 | def expand_answer_vocab(self, answer_to_idx, std=0.01, init_b=-50): 196 | # TODO: This is really gross, dipping into private internals of Sequential 197 | final_linear_key = str(len(self.classifier._modules) - 1) 198 | final_linear = self.classifier._modules[final_linear_key] 199 | 200 | old_weight = final_linear.weight.data 201 | old_bias = final_linear.bias.data 202 | old_N, D = old_weight.size() 203 | new_N = 1 + max(answer_to_idx.values()) 204 | new_weight = old_weight.new(new_N, D).normal_().mul_(std) 205 | new_bias = old_bias.new(new_N).fill_(init_b) 206 | new_weight[:old_N].copy_(old_weight) 207 | new_bias[:old_N].copy_(old_bias) 208 | 209 | final_linear.weight.data = new_weight 210 | final_linear.bias.data = new_bias 211 | 212 | def _forward_modules_json(self, feats, program): 213 | def gen_hook(i, j): 214 | def hook(grad): 215 | self.all_module_grad_outputs[i][j] = grad.data.cpu().clone() 216 | return hook 217 | 218 | self.all_module_outputs = [] 219 | self.all_module_grad_outputs = [] 220 | # We can't easily handle minibatching of modules, so just do a loop 221 | N = feats.size(0) 222 | final_module_outputs = [] 223 | for i in range(N): 224 | if self.save_module_outputs: 225 | self.all_module_outputs.append([]) 226 | self.all_module_grad_outputs.append([None] * len(program[i])) 227 | module_outputs = [] 228 | for j, f in enumerate(program[i]): 229 | f_str = vr.programs.function_to_str(f) 230 | module = self.function_modules[f_str] 231 | if f_str == 'scene': 232 | module_inputs = [feats[i:i+1]] 233 | else: 234 | module_inputs = [module_outputs[j] for j in f['inputs']] 235 | module_outputs.append(module(*module_inputs)) 236 | if self.save_module_outputs: 237 | self.all_module_outputs[-1].append(module_outputs[-1].data.cpu().clone()) 238 | module_outputs[-1].register_hook(gen_hook(i, j)) 239 | final_module_outputs.append(module_outputs[-1]) 240 | final_module_outputs = torch.cat(final_module_outputs, 0) 241 | return final_module_outputs 242 | 243 | def _forward_modules_ints_helper(self, feats, program, i, j): 244 | used_fn_j = True 245 | if j < program.size(1): 246 | fn_idx = program.data[i, j] 247 | fn_str = self.vocab['program_idx_to_token'][fn_idx.item()] 248 | else: 249 | used_fn_j = False 250 | fn_str = 'scene' 251 | if fn_str == '': 252 | used_fn_j = False 253 | fn_str = 'scene' 254 | elif fn_str == '': 255 | used_fn_j = False 256 | return self._forward_modules_ints_helper(feats, program, i, j + 1) 257 | if used_fn_j: 258 | self.used_fns[i, j] = 1 259 | j += 1 260 | 261 | num_inputs = self.function_modules_num_inputs[fn_str] 262 | if fn_str == 'scene': num_inputs = 1 263 | 264 | if self.use_film: 265 | assert fn_str in self.fn_str_2_filmId 266 | midx = self.fn_str_2_filmId[fn_str] 267 | 268 | if self.sharing_patterns[0] == 1: 269 | query_id = 'shared_film' 270 | else: 271 | query_id = fn_str 272 | assert query_id in self.function_modules 273 | module = self.function_modules[query_id] 274 | else: 275 | midx = -1 276 | module = self.function_modules[fn_str] 277 | 278 | if fn_str == 'scene': 279 | module_inputs = [feats[i:i+1]] 280 | else: 281 | #num_inputs = self.function_modules_num_inputs[fn_str] 282 | module_inputs = [] 283 | while len(module_inputs) < num_inputs: 284 | cur_input, j = self._forward_modules_ints_helper(feats, program, i, j) 285 | module_inputs.append(cur_input) 286 | 287 | if self.use_film: 288 | igammas = self.gammas[:,midx,:] + 1 289 | ibetas = self.betas[:,midx,:] 290 | bcoords = self.coords.unsqueeze(0) 291 | if len(module_inputs) == 1: 292 | if self.sharing_patterns[0] == 1: 293 | module_inputs = [module_inputs[0], module_inputs[0]] 294 | else: 295 | module_inputs = module_inputs[0] 296 | module_output = module(module_inputs, igammas, ibetas, bcoords) 297 | else: 298 | module_output = module(*module_inputs) 299 | return module_output, j 300 | 301 | def _forward_modules_ints(self, feats, program): 302 | """ 303 | feats: FloatTensor of shape (N, C, H, W) giving features for each image 304 | program: LongTensor of shape (N, L) giving a prefix-encoded program for 305 | each image. 306 | """ 307 | N = feats.size(0) 308 | final_module_outputs = [] 309 | self.used_fns = torch.Tensor(program.size()).fill_(0) 310 | for i in range(N): 311 | cur_output, _ = self._forward_modules_ints_helper(feats, program, i, 0) 312 | final_module_outputs.append(cur_output) 313 | self.used_fns = self.used_fns.type_as(program.data).float() 314 | final_module_outputs = torch.cat(final_module_outputs, 0) 315 | return final_module_outputs 316 | 317 | def forward(self, x, program,save_activations = False ): 318 | N = x.size(0) 319 | assert N == len(program) 320 | 321 | feats = self.stem(x) 322 | 323 | if type(program) is list or type(program) is tuple: 324 | final_module_outputs = self._forward_modules_json(feats, program) 325 | elif type(program) is torch.Tensor and program.dim() == 2: 326 | final_module_outputs = self._forward_modules_ints(feats, program) 327 | elif torch.is_tensor(program) and program.dim() == 3: 328 | final_module_outputs = self._forward_modules_probs(feats, program) 329 | else: 330 | raise ValueError('Unrecognized program format') 331 | 332 | # After running modules for each input, concatenat the outputs from the 333 | # final module and run the classifier. 334 | out = self.classifier(final_module_outputs) 335 | return out 336 | -------------------------------------------------------------------------------- /vr/models/relation_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import itertools 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | from vr.models.layers import init_modules, GlobalAveragePool, Flatten 12 | from vr.models.layers import build_classifier, build_stem 13 | 14 | 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class RelationNet(nn.Module): 19 | def __init__(self, 20 | vocab, 21 | feature_dim=(3, 64, 64), 22 | stem_num_layers=2, 23 | stem_batchnorm=True, 24 | stem_kernel_size=3, 25 | stem_stride=1, 26 | stem_padding=None, 27 | stem_dim=24, 28 | module_num_layers=1, 29 | module_dim=128, 30 | classifier_fc_layers=(1024,), 31 | classifier_batchnorm=False, 32 | classifier_dropout=0, 33 | rnn_hidden_dim=128, 34 | # unused 35 | stem_subsample_layers=[], 36 | module_input_proj=None, 37 | module_residual=None, 38 | module_kernel_size=None, 39 | module_batchnorm=None, 40 | classifier_proj_dim=None, 41 | classifier_downsample=None, 42 | debug_every=float('inf'), 43 | print_verbose_every=float('inf'), 44 | verbose=True): 45 | super().__init__() 46 | 47 | # initialize stem 48 | self.stem = build_stem(feature_dim[0], 49 | stem_dim, 50 | stem_dim, 51 | num_layers=stem_num_layers, 52 | with_batchnorm=stem_batchnorm, 53 | kernel_size=stem_kernel_size, 54 | stride=stem_stride, 55 | padding=stem_padding, 56 | subsample_layers=stem_subsample_layers) 57 | tmp = self.stem(Variable(torch.zeros([1, feature_dim[0], feature_dim[1], feature_dim[2]]))) 58 | module_H = tmp.size(2) 59 | module_W = tmp.size(3) 60 | 61 | # initialize coordinates to be appended to "objects" 62 | # can be switched to using torch.meshgrid after 0.4.1 63 | x = torch.linspace(-1, 1, steps=module_W) 64 | y = torch.linspace(-1, 1, steps=module_H) 65 | xv = x.unsqueeze(1).repeat(1, module_H) 66 | yv = y.unsqueeze(0).repeat(module_W, 1) 67 | coords = torch.stack([xv,yv], dim=2).view(-1, 2) 68 | self.coords = Variable(coords.to(device)) 69 | 70 | # initialize relation model 71 | # (output of stem + 2 coordinates) * 2 objects + question vector 72 | relation_modules = [nn.Linear((stem_dim + 2)*2 + rnn_hidden_dim, module_dim)] 73 | for _ in range(module_num_layers - 1): 74 | relation_modules.append(nn.Linear(module_dim, module_dim)) 75 | self.relation = nn.Sequential(*relation_modules) 76 | 77 | # initialize classifier (f_theta) 78 | num_answers = len(vocab['answer_idx_to_token']) 79 | self.classifier = build_classifier(module_dim, 80 | 1, 81 | 1, 82 | num_answers, 83 | classifier_fc_layers, 84 | classifier_proj_dim, 85 | classifier_downsample, 86 | classifier_batchnorm, 87 | classifier_dropout) 88 | 89 | init_modules(self.modules()) 90 | 91 | def forward(self, image, question): 92 | # convert image to features (aka objects) 93 | features = self.stem(image) 94 | N, F, H, W = features.size() 95 | features_flat = features.view(N, F, H*W).permute(0,2,1) # N x H*W x F 96 | 97 | # conctenate coordinates to features 98 | batch_coords = self.coords.unsqueeze(0).repeat(N, 1, 1) # N x H*W x 2 99 | features_coords = torch.cat([features_flat, batch_coords], dim=2) # N x H*W x F+2 100 | 101 | # make matrix of all possible pairs of 2 features 102 | x_i = torch.unsqueeze(features_coords, 1) # N x 1 x H*W x F+2 103 | x_i = x_i.repeat(1, H*W, 1,1) # N x H*W x H*W x F+2 104 | x_j = torch.unsqueeze(features_coords, 2) # N x H*W x 1 x F+2 105 | x_j = x_j.repeat(1, 1, H*W, 1) # N x H*W x H*W x F+2 106 | feature_pairs = torch.cat([x_i, x_j], dim=3) # N x H*W x H*W x 2*(F+2) 107 | feature_pairs = feature_pairs.view(N, (H*W)**2, 2*(F+2)) # N x (H*W)^2 x 2*(F+2) 108 | 109 | # concatenate question to feature pair 110 | _, ques, _ = question # (N x Q) 111 | ques = ques.unsqueeze(1).repeat(1, feature_pairs.size(1), 1) # N x (H*W)^2 x Q 112 | feature_pairs_ques = torch.cat([feature_pairs, ques], dim=2) # N x (H*W)^2 x 2*(F+2)+Q 113 | 114 | # pass through model (g_theta) 115 | relations = self.relation(feature_pairs_ques) # N x (H*W)^2 x module_dim 116 | 117 | # sum across relations 118 | relations = torch.sum(relations, dim=1) # N x module_dim 119 | 120 | # pass through classifier (f_theta) 121 | out = self.classifier(relations) 122 | 123 | return out 124 | -------------------------------------------------------------------------------- /vr/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import torch 11 | import torch.cuda 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | 16 | from vr.embedding import expand_embedding_vocab 17 | 18 | 19 | class Seq2Seq(nn.Module): 20 | def __init__(self, 21 | encoder_vocab_size=100, 22 | decoder_vocab_size=100, 23 | wordvec_dim=300, 24 | hidden_dim=256, 25 | rnn_num_layers=2, 26 | rnn_dropout=0, 27 | null_token=0, 28 | start_token=1, 29 | end_token=2, 30 | encoder_embed=None 31 | ): 32 | super(Seq2Seq, self).__init__() 33 | self.encoder_embed = nn.Embedding(encoder_vocab_size, wordvec_dim) 34 | self.encoder_rnn = nn.LSTM(wordvec_dim, hidden_dim, rnn_num_layers, 35 | dropout=rnn_dropout, batch_first=True) 36 | self.decoder_embed = nn.Embedding(decoder_vocab_size, wordvec_dim) 37 | self.decoder_rnn = nn.LSTM(wordvec_dim + hidden_dim, hidden_dim, rnn_num_layers, 38 | dropout=rnn_dropout, batch_first=True) 39 | self.decoder_rnn_new = nn.LSTM(hidden_dim, hidden_dim, rnn_num_layers, 40 | dropout=rnn_dropout, batch_first=True) 41 | self.decoder_linear = nn.Linear(hidden_dim, decoder_vocab_size) 42 | self.NULL = null_token 43 | self.START = start_token 44 | self.END = end_token 45 | self.multinomial_outputs = None 46 | 47 | def expand_encoder_vocab(self, token_to_idx, word2vec=None, std=0.01): 48 | expand_embedding_vocab(self.encoder_embed, token_to_idx, 49 | word2vec=word2vec, std=std) 50 | 51 | def get_dims(self, x=None, y=None): 52 | V_in = self.encoder_embed.num_embeddings 53 | V_out = self.decoder_embed.num_embeddings 54 | D = self.encoder_embed.embedding_dim 55 | H = self.encoder_rnn.hidden_size 56 | L = self.encoder_rnn.num_layers 57 | 58 | N = x.size(0) if x is not None else None 59 | N = y.size(0) if N is None and y is not None else N 60 | T_in = x.size(1) if x is not None else None 61 | T_out = y.size(1) if y is not None else None 62 | return V_in, V_out, D, H, L, N, T_in, T_out 63 | 64 | def before_rnn(self, x, replace=0): 65 | # TODO: Use PackedSequence instead of manually plucking out the last 66 | # non-NULL entry of each sequence; it is cleaner and more efficient. 67 | N, T = x.size() 68 | idx = torch.LongTensor(N).fill_(T - 1) 69 | 70 | # Find the last non-null element in each sequence. Is there a clean 71 | # way to do this? 72 | x_cpu = x.cpu() 73 | for i in range(N): 74 | for t in range(T - 1): 75 | if x_cpu.data[i, t] != self.NULL and x_cpu.data[i, t + 1] == self.NULL: 76 | idx[i] = t 77 | break 78 | idx = idx.type_as(x.data) 79 | x[x.data == self.NULL] = replace 80 | return x, Variable(idx) 81 | 82 | def encoder(self, x): 83 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(x=x) 84 | x, idx = self.before_rnn(x) 85 | embed = self.encoder_embed(x) 86 | h0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 87 | c0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 88 | 89 | out, _ = self.encoder_rnn(embed, (h0, c0)) 90 | 91 | # Pull out the hidden state for the last non-null value in each input 92 | idx = idx.view(N, 1, 1).expand(N, 1, H) 93 | return out.gather(1, idx).view(N, H) 94 | 95 | def decoder(self, encoded, y, h0=None, c0=None): 96 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(y=y) 97 | 98 | if T_out > 1: 99 | y, _ = self.before_rnn(y) 100 | y_embed = self.decoder_embed(y) 101 | encoded_repeat = encoded.view(N, 1, H).expand(N, T_out, H) 102 | rnn_input = torch.cat([encoded_repeat, y_embed], 2) 103 | if h0 is None: 104 | h0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 105 | if c0 is None: 106 | c0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 107 | rnn_output, (ht, ct) = self.decoder_rnn(rnn_input, (h0, c0)) 108 | 109 | rnn_output_2d = rnn_output.contiguous().view(N * T_out, H) 110 | output_logprobs = self.decoder_linear(rnn_output_2d).view(N, T_out, V_out) 111 | 112 | return output_logprobs, ht, ct 113 | 114 | def compute_loss(self, output_logprobs, y): 115 | """ 116 | Compute loss. We assume that the first element of the output sequence y is 117 | a start token, and that each element of y is left-aligned and right-padded 118 | with self.NULL out to T_out. We want the output_logprobs to predict the 119 | sequence y, shifted by one timestep so that y[0] is fed to the network and 120 | then y[1] is predicted. We also don't want to compute loss for padded 121 | timesteps. 122 | 123 | Inputs: 124 | - output_logprobs: Variable of shape (N, T_out, V_out) 125 | - y: LongTensor Variable of shape (N, T_out) 126 | """ 127 | self.multinomial_outputs = None 128 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(y=y) 129 | mask = y.data != self.NULL 130 | y_mask = Variable(torch.Tensor(N, T_out).fill_(0).type_as(mask)) 131 | y_mask[:, 1:] = mask[:, 1:] 132 | y_masked = y[y_mask] 133 | out_mask = Variable(torch.Tensor(N, T_out).fill_(0).type_as(mask)) 134 | out_mask[:, :-1] = mask[:, 1:] 135 | out_mask = out_mask.view(N, T_out, 1).expand(N, T_out, V_out) 136 | out_masked = output_logprobs[out_mask].view(-1, V_out) 137 | loss = F.cross_entropy(out_masked, y_masked) 138 | return loss 139 | 140 | def forward(self, x, x_lengths, y, y_lengths): 141 | encoded = self.encoder(x) 142 | 143 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(x=x) 144 | T_out = 15 145 | encoded_repeat = encoded.view(N, 1, H).expand(N, T_out, H) 146 | h0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 147 | c0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 148 | rnn_output, (ht, ct) = self.decoder_rnn_new(encoded_repeat, (h0, c0)) 149 | 150 | output_logprobs, _, _ = self.decoder(encoded, y) 151 | loss = self.compute_loss(output_logprobs, y) 152 | return loss 153 | 154 | def sample(self, x, x_lengths, max_length=50): 155 | # TODO: Handle sampling for minibatch inputs 156 | # TODO: Beam search? 157 | self.multinomial_outputs = None 158 | assert x.size(0) == 1, "Sampling minibatches not implemented" 159 | encoded = self.encoder(x) 160 | y = [self.START] 161 | h0, c0 = None, None 162 | while True: 163 | cur_y = Variable(torch.LongTensor([y[-1]]).type_as(x.data).view(1, 1)) 164 | logprobs, h0, c0 = self.decoder(encoded, cur_y, h0=h0, c0=c0) 165 | _, next_y = logprobs.data.max(2, keepdim=True) 166 | y.append(next_y[0, 0, 0]) 167 | if len(y) >= max_length or y[-1] == self.END: 168 | break 169 | return y 170 | 171 | def reinforce_sample(self, x, x_lengths, max_length=30, temperature=1.0, argmax=False): 172 | N, T = x.size(0), max_length 173 | encoded = self.encoder(x) 174 | y = torch.LongTensor(N, T).fill_(self.NULL) 175 | done = torch.ByteTensor(N).fill_(0) 176 | cur_input = Variable(x.data.new(N, 1).fill_(self.START)) 177 | h, c = None, None 178 | self.multinomial_outputs = [] 179 | self.multinomial_probs = [] 180 | for t in range(T): 181 | # logprobs is N x 1 x V 182 | logprobs, h, c = self.decoder(encoded, cur_input, h0=h, c0=c) 183 | logprobs = logprobs / temperature 184 | probs = F.softmax(logprobs.view(N, -1), dim=1) # Now N x V 185 | if argmax: 186 | _, cur_output = probs.max(1, keepdim=True) 187 | else: 188 | cur_output = probs.multinomial() # Now N x 1 189 | self.multinomial_outputs.append(cur_output) 190 | self.multinomial_probs.append(probs) 191 | cur_output_data = cur_output.data.cpu() 192 | not_done = logical_not(done) 193 | y[:, t][not_done] = cur_output_data[not_done] 194 | done = logical_or(done, cur_output_data.cpu() == self.END) 195 | cur_input = cur_output 196 | if done.sum() == N: 197 | break 198 | return Variable(y.type_as(x.data)) 199 | 200 | def reinforce_backward(self, reward, output_mask=None): 201 | """ 202 | If output_mask is not None, then it should be a FloatTensor of shape (N, T) 203 | giving a multiplier to the output. 204 | """ 205 | assert self.multinomial_outputs is not None, 'Must call reinforce_sample first' 206 | grad_output = [] 207 | 208 | def gen_hook(mask): 209 | def hook(grad): 210 | return grad * mask.contiguous().view(-1, 1).expand_as(grad) 211 | return hook 212 | 213 | if output_mask is not None: 214 | for t, probs in enumerate(self.multinomial_probs): 215 | mask = Variable(output_mask[:, t]) 216 | probs.register_hook(gen_hook(mask)) 217 | 218 | for sampled_output in self.multinomial_outputs: 219 | sampled_output.reinforce(reward) 220 | grad_output.append(None) 221 | torch.autograd.backward(self.multinomial_outputs, grad_output, retain_variables=True) 222 | 223 | 224 | def logical_and(x, y): 225 | return x * y 226 | 227 | def logical_or(x, y): 228 | return (x + y).clamp_(0, 1) 229 | 230 | def logical_not(x): 231 | return x == 0 232 | -------------------------------------------------------------------------------- /vr/models/seq2seq_att.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | import math 9 | 10 | import torch 11 | import torch.cuda 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from torch.nn.utils.rnn import (pack_padded_sequence, 16 | pad_packed_sequence) 17 | 18 | from vr.embedding import expand_embedding_vocab 19 | 20 | 21 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | class Attn(nn.Module): 25 | def __init__(self, hidden_size): 26 | super(Attn, self).__init__() 27 | self.hidden_size = hidden_size 28 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 29 | self.v = nn.Parameter(torch.rand(hidden_size)) 30 | stdv = 1. / math.sqrt(self.v.size(0)) 31 | self.v.data.normal_(mean=0, std=stdv) 32 | 33 | def forward(self, encoder_outputs, hidden): 34 | ''' 35 | :param hidden: 36 | previous hidden state of the decoder, in shape (layers*directions, H) 37 | :param encoder_outputs: 38 | encoder outputs from Encoder, in shape (B,T,H) 39 | :return 40 | attention energies in shape (B,T) 41 | ''' 42 | seq_len = encoder_outputs.size(1) 43 | H = hidden.repeat(seq_len, 1, 1).transpose(0,1) 44 | attn_energies = self.score(H, encoder_outputs) # B*1*T 45 | return F.softmax(attn_energies, dim=2) 46 | 47 | def score(self, hidden, encoder_outputs): 48 | energy = F.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) # [B*T*2H]->[B*T*H] 49 | energy = energy.transpose(2,1) # [B*H*T] 50 | v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H] 51 | energy = torch.bmm(v, energy) # [B*1*T] 52 | return energy 53 | 54 | 55 | class Seq2SeqAtt(nn.Module): 56 | def __init__(self, 57 | null_token=0, 58 | start_token=1, 59 | end_token=2, 60 | encoder_vocab_size=100, 61 | decoder_vocab_size=100, 62 | wordvec_dim=300, 63 | hidden_dim=256, 64 | rnn_num_layers=2, 65 | rnn_dropout=0, 66 | ): 67 | super().__init__() 68 | self.encoder_embed = nn.Embedding(encoder_vocab_size, wordvec_dim) 69 | self.encoder_rnn = nn.LSTM(wordvec_dim, hidden_dim, rnn_num_layers, 70 | dropout=rnn_dropout, batch_first=True) 71 | self.decoder_embed = nn.Embedding(decoder_vocab_size, wordvec_dim) 72 | self.decoder_rnn = nn.LSTM(wordvec_dim + hidden_dim, hidden_dim, rnn_num_layers, 73 | dropout=rnn_dropout, batch_first=True) 74 | self.decoder_linear = nn.Linear(hidden_dim, decoder_vocab_size) 75 | self.decoder_attn = Attn(hidden_dim) 76 | self.rnn_num_layers = rnn_num_layers 77 | self.NULL = null_token 78 | self.START = start_token 79 | self.END = end_token 80 | self.multinomial_outputs = None 81 | 82 | def expand_encoder_vocab(self, token_to_idx, word2vec=None, std=0.01): 83 | expand_embedding_vocab(self.encoder_embed, token_to_idx, 84 | word2vec=word2vec, std=std) 85 | 86 | def get_dims(self, x=None, y=None): 87 | V_in = self.encoder_embed.num_embeddings 88 | V_out = self.decoder_embed.num_embeddings 89 | D = self.encoder_embed.embedding_dim 90 | H = self.encoder_rnn.hidden_size 91 | L = self.encoder_rnn.num_layers 92 | 93 | N = x.size(0) if x is not None else None 94 | N = y.size(0) if N is None and y is not None else N 95 | T_in = x.size(1) if x is not None else None 96 | T_out = y.size(1) if y is not None else None 97 | return V_in, V_out, D, H, L, N, T_in, T_out 98 | 99 | def encoder(self, x): 100 | x, x_lengths, inverse_index = sort_for_rnn(x, null=self.NULL) 101 | embed = self.encoder_embed(x) 102 | packed = pack_padded_sequence(embed, x_lengths, batch_first=True) 103 | out_packed, hidden = self.encoder_rnn(packed) 104 | out, _ = pad_packed_sequence(out_packed, batch_first=True) 105 | 106 | out = out[inverse_index] 107 | hidden = [h[:,inverse_index] for h in hidden] 108 | 109 | return out, hidden 110 | 111 | def decoder(self, word_inputs, encoder_outputs, prev_hidden): 112 | hn, cn = prev_hidden 113 | word_embedded = self.decoder_embed(word_inputs).unsqueeze(1) # batch x 1 x embed 114 | 115 | attn_weights = self.decoder_attn(encoder_outputs, hn[-1]) 116 | context = attn_weights.bmm(encoder_outputs) # batch x 1 x hidden 117 | 118 | rnn_input = torch.cat((word_embedded, context), 2) 119 | output, hidden = self.decoder_rnn(rnn_input, prev_hidden) 120 | 121 | output = output.squeeze(1) # batch x hidden 122 | output = self.decoder_linear(output) 123 | 124 | return output, hidden 125 | 126 | def compute_loss(self, output_logprobs, y): 127 | """ 128 | Compute loss. We assume that the first element of the output sequence y is 129 | a start token, and that each element of y is left-aligned and right-padded 130 | with self.NULL out to T_out. We want the output_logprobs to predict the 131 | sequence y, shifted by one timestep so that y[0] is fed to the network and 132 | then y[1] is predicted. We also don't want to compute loss for padded 133 | timesteps. 134 | 135 | Inputs: 136 | - output_logprobs: Variable of shape (N, T_out, V_out) 137 | - y: LongTensor Variable of shape (N, T_out) 138 | """ 139 | self.multinomial_outputs = None 140 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(y=y) 141 | mask = y.data != self.NULL 142 | y_mask = Variable(torch.Tensor(N, T_out).fill_(0).type_as(mask)) 143 | y_mask[:, 1:] = mask[:, 1:] 144 | y_masked = y[y_mask] 145 | out_mask = Variable(torch.Tensor(N, T_out).fill_(0).type_as(mask)) 146 | out_mask[:, :-1] = mask[:, 1:] 147 | out_mask = out_mask.view(N, T_out, 1).expand(N, T_out, V_out) 148 | out_masked = output_logprobs[out_mask].view(-1, V_out) 149 | loss = F.cross_entropy(out_masked, y_masked) 150 | return loss 151 | 152 | def forward(self, x, y): 153 | max_target_length = y.size(1) 154 | 155 | encoder_outputs, encoder_hidden = self.encoder(x) 156 | decoder_inputs = y 157 | decoder_hidden = encoder_hidden 158 | decoder_outputs = [] 159 | for t in range(max_target_length): 160 | decoder_out, decoder_hidden = self.decoder( 161 | decoder_inputs[:,t], encoder_outputs, decoder_hidden) 162 | decoder_outputs.append(decoder_out) 163 | 164 | decoder_outputs = torch.stack(decoder_outputs, dim=1) 165 | loss = self.compute_loss(decoder_outputs, y) 166 | return loss 167 | 168 | def sample(self, x, max_length=50): 169 | # TODO: Handle sampling for minibatch inputs 170 | # TODO: Beam search? 171 | self.multinomial_outputs = None 172 | assert x.size(0) == 1, "Sampling minibatches not implemented" 173 | 174 | encoder_outputs, encoder_hidden = self.encoder(x) 175 | decoder_hidden = encoder_hidden 176 | sampled_output = [self.START] 177 | for t in range(max_length): 178 | decoder_input = Variable(torch.cuda.LongTensor([sampled_output[-1]])) 179 | decoder_out, decoder_hidden = self.decoder( 180 | decoder_input, encoder_outputs, decoder_hidden) 181 | _, argmax = decoder_out.data.max(1) 182 | output = argmax[0] 183 | sampled_output.append(output) 184 | if output == self.END: 185 | break 186 | 187 | return sampled_output 188 | 189 | def reinforce_sample(self, x, max_length=30, temperature=1.0, argmax=False): 190 | N, T = x.size(0), max_length 191 | encoder_outputs, encoder_hidden = self.encoder(x) 192 | y = torch.LongTensor(N, T).fill_(self.NULL) 193 | done = torch.ByteTensor(N).fill_(0) 194 | cur_input = Variable(x.data.new(N, 1).fill_(self.START)) 195 | decoder_hidden = encoder_hidden 196 | self.multinomial_outputs = [] 197 | self.multinomial_probs = [] 198 | for t in range(T): 199 | # logprobs is N x 1 x V 200 | logprobs, decoder_hidden = self.decoder(cur_input, encoder_outputs, decoder_hidden) 201 | logprobs = logprobs / temperature 202 | probs = F.softmax(logprobs.view(N, -1), dim=1) # Now N x V 203 | if argmax: 204 | _, cur_output = probs.max(1, keepdim=True) 205 | else: 206 | cur_output = probs.multinomial() # Now N x 1 207 | self.multinomial_outputs.append(cur_output) 208 | self.multinomial_probs.append(probs) 209 | cur_output_data = cur_output.data.cpu() 210 | not_done = logical_not(done) 211 | y[:, t][not_done] = cur_output_data[not_done] 212 | done = logical_or(done, cur_output_data.cpu() == self.END) 213 | cur_input = cur_output 214 | if done.sum() == N: 215 | break 216 | return Variable(y.type_as(x.data)) 217 | 218 | def reinforce_backward(self, reward, output_mask=None): 219 | """ 220 | If output_mask is not None, then it should be a FloatTensor of shape (N, T) 221 | giving a multiplier to the output. 222 | """ 223 | assert self.multinomial_outputs is not None, 'Must call reinforce_sample first' 224 | grad_output = [] 225 | 226 | def gen_hook(mask): 227 | def hook(grad): 228 | return grad * mask.contiguous().view(-1, 1).expand_as(grad) 229 | return hook 230 | 231 | if output_mask is not None: 232 | for t, probs in enumerate(self.multinomial_probs): 233 | mask = Variable(output_mask[:, t]) 234 | probs.register_hook(gen_hook(mask)) 235 | 236 | for sampled_output in self.multinomial_outputs: 237 | sampled_output.reinforce(reward) 238 | grad_output.append(None) 239 | torch.autograd.backward(self.multinomial_outputs, grad_output, retain_variables=True) 240 | 241 | 242 | def logical_or(x, y): 243 | return (x + y).clamp_(0, 1) 244 | 245 | def logical_not(x): 246 | return x == 0 247 | 248 | def sort_for_rnn(x, null=0): 249 | lengths = torch.sum(x != null, dim=1).long() 250 | sorted_lengths, sorted_idx = torch.sort(lengths, dim=0, descending=True) 251 | sorted_lengths = sorted_lengths.data.tolist() # remove for pytorch 0.4+ 252 | # ugly 253 | inverse_sorted_idx = torch.LongTensor(sorted_idx.shape).fill_(0).to(device) 254 | for i, v in enumerate(sorted_idx): 255 | inverse_sorted_idx[v.data] = i 256 | 257 | return x[sorted_idx], sorted_lengths, inverse_sorted_idx 258 | 259 | -------------------------------------------------------------------------------- /vr/models/shnmn.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from vr.models.layers import init_modules, ResidualBlock, SimpleVisualBlock, GlobalAveragePool, Flatten 9 | from vr.models.layers import build_classifier, build_stem, ConcatBlock 10 | import vr.programs 11 | 12 | from torch.nn.init import kaiming_normal, kaiming_uniform, xavier_uniform, xavier_normal, constant, uniform 13 | 14 | from vr.models.filmed_net import FiLM, FiLMedResBlock, ConcatFiLMedResBlock, coord_map 15 | from functools import partial 16 | 17 | 18 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def _random_tau(num_modules): 22 | tau_0 = torch.zeros(num_modules, num_modules+1) 23 | tau_1 = torch.zeros(num_modules, num_modules+1) 24 | xavier_uniform(tau_0) 25 | xavier_uniform(tau_1) 26 | return tau_0, tau_1 27 | 28 | 29 | def _chain_tau(): 30 | tau_0 = torch.zeros(3, 4) 31 | tau_1 = torch.zeros(3, 4) 32 | tau_0[0][1] = tau_1[0][0] = 100 #1st block - lhs inp img, rhs inp sentinel 33 | tau_0[1][2] = tau_1[1][0] = 100 #2nd block - lhs inp 1st block, rhs inp sentinel 34 | tau_0[2][3] = tau_1[2][0] = 100 #3rd block - lhs inp 2nd block, rhs inp sentinel 35 | return tau_0, tau_1 36 | 37 | def _chain_with_shortcuts_tau(): 38 | tau_0 = torch.zeros(3, 4) 39 | tau_1 = torch.zeros(3, 4) 40 | tau_0[0][1] = tau_1[0][0] = 100 #1st block - lhs inp img, rhs inp sentinel 41 | tau_0[1][2] = tau_1[1][1] = 100 #2nd block - lhs inp 1st block, rhs img 42 | tau_0[2][3] = tau_1[2][1] = 100 #3rd block - lhs inp 2nd block, rhs img 43 | return tau_0, tau_1 44 | 45 | 46 | def _tree_tau(): 47 | tau_0 = torch.zeros(3, 4) 48 | tau_1 = torch.zeros(3, 4) 49 | tau_0[0][1] = tau_1[0][0] = 100 #1st block - lhs inp img, rhs inp sentinel 50 | tau_0[1][1] = tau_1[1][0] = 100 #2st block - lhs inp img, rhs inp sentinel 51 | tau_0[2][2] = tau_1[2][3] = 100 #3rd block - lhs inp 1st block, rhs inp 2nd block 52 | return tau_0, tau_1 53 | 54 | 55 | def correct_alpha_init_xyr(alpha): 56 | alpha.zero_() 57 | alpha[0][0] = 100 58 | alpha[1][2] = 100 59 | alpha[2][1] = 100 60 | 61 | return alpha 62 | 63 | def correct_alpha_init_rxy(alpha, use_stopwords=True): 64 | alpha.zero_() 65 | alpha[0][1] = 100 66 | alpha[1][0] = 100 67 | alpha[2][2] = 100 68 | 69 | return alpha 70 | 71 | def correct_alpha_init_xry(alpha, use_stopwords=True): 72 | alpha.zero_() 73 | alpha[0][0] = 100 74 | alpha[1][1] = 100 75 | alpha[2][2] = 100 76 | 77 | return alpha 78 | 79 | def _shnmn_func(question, img, num_modules, alpha, tau_0, tau_1, func): 80 | sentinel = torch.zeros_like(img) # B x 1 x C x H x W 81 | h_prev = torch.cat([sentinel, img], dim=1) # B x 2 x C x H x W 82 | 83 | for i in range(num_modules): 84 | alpha_curr = F.softmax(alpha[i], dim=0) 85 | tau_0_curr = F.softmax(tau_0[i, :(i+2)], dim=0) 86 | tau_1_curr = F.softmax(tau_1[i, :(i+2)], dim=0) 87 | 88 | question_rep = torch.sum(alpha_curr.view(1,-1,1)*question, dim=1) #(B,D) 89 | # B x C x H x W 90 | lhs_rep = torch.sum(tau_0_curr.view(1, (i+2), 1, 1, 1)*h_prev, dim=1) 91 | # B x C x H x W 92 | rhs_rep = torch.sum(tau_1_curr.view(1, (i+2), 1, 1, 1)*h_prev, dim=1) 93 | h_i = func(question_rep, lhs_rep, rhs_rep) # B x C x H x W 94 | 95 | h_prev = torch.cat([h_prev, h_i.unsqueeze(1)], dim=1) 96 | 97 | return h_prev 98 | 99 | 100 | class FindModule(nn.Module): 101 | def __init__(self, dim, kernel_size): 102 | super().__init__() 103 | self.dim = dim 104 | self.kernel_size = kernel_size 105 | self.conv_1 = nn.Conv2d(2*dim, dim, kernel_size=1, padding=0) 106 | self.conv_2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding = kernel_size // 2) 107 | 108 | def forward(self, question_rep, lhs_rep, rhs_rep): 109 | out = F.relu(self.conv_1(torch.cat([lhs_rep, rhs_rep], 1))) # concat along depth 110 | question_rep = question_rep.view(-1, self.dim, 1, 1) 111 | return F.relu(self.conv_2(out*question_rep)) 112 | 113 | class ResidualFunc: 114 | def __init__(self, dim, kernel_size): 115 | self.dim = dim 116 | self.kernel_size = kernel_size 117 | 118 | def __call__(self, question_rep, lhs_rep, rhs_rep): 119 | cnn_weight_dim = self.dim * self.dim * self.kernel_size * self.kernel_size 120 | cnn_bias_dim = self.dim 121 | proj_cnn_weight_dim = 2 * self.dim * self.dim 122 | proj_cnn_bias_dim = self.dim 123 | if (question_rep.size(1) != 124 | proj_cnn_weight_dim + proj_cnn_bias_dim 125 | + 2 * (cnn_weight_dim + cnn_bias_dim)): 126 | raise ValueError 127 | 128 | # pick out CNN and projection CNN weights/biases 129 | cnn1_weight = question_rep[:,:cnn_weight_dim] 130 | cnn2_weight = question_rep[:,cnn_weight_dim:2 * cnn_weight_dim] 131 | cnn1_bias = question_rep[:,2 * cnn_weight_dim:(2 * cnn_weight_dim) + cnn_bias_dim] 132 | cnn2_bias = question_rep[:,(2 * cnn_weight_dim) + cnn_bias_dim:2 * (cnn_weight_dim + cnn_bias_dim)] 133 | proj_weight = question_rep[:, 2 * (cnn_weight_dim + cnn_bias_dim) : 134 | 2 * (cnn_weight_dim + cnn_bias_dim) + proj_cnn_weight_dim] 135 | proj_bias = question_rep[:, 2*(cnn_weight_dim + cnn_bias_dim) + proj_cnn_weight_dim:] 136 | 137 | cnn_out_total = [] 138 | bs = question_rep.size(0) 139 | 140 | for i in range(bs): 141 | cnn1_weight_curr = cnn1_weight[i].view(self.dim, self.dim, self.kernel_size, self.kernel_size) 142 | cnn1_bias_curr = cnn1_bias[i] 143 | cnn2_weight_curr = cnn2_weight[i].view(self.dim, self.dim, self.kernel_size, self.kernel_size) 144 | cnn2_bias_curr = cnn2_bias[i] 145 | 146 | proj_weight_curr = proj_weight[i].view(self.dim, 2*self.dim, 1, 1) 147 | proj_bias_curr = proj_bias[i] 148 | 149 | cnn_inp = F.relu(F.conv2d(torch.cat([lhs_rep[[i]], rhs_rep[[i]]], 1), 150 | proj_weight_curr, 151 | bias=proj_bias_curr, padding=0)) 152 | 153 | cnn1_out = F.relu(F.conv2d(cnn_inp, cnn1_weight_curr, bias=cnn1_bias_curr, padding=self.kernel_size // 2)) 154 | cnn2_out = F.conv2d(cnn1_out, cnn2_weight_curr, bias=cnn2_bias_curr,padding=self.kernel_size // 2) 155 | 156 | cnn_out_total.append(F.relu(cnn_inp + cnn2_out) ) 157 | 158 | return torch.cat(cnn_out_total) 159 | 160 | 161 | 162 | class ConvFunc: 163 | def __init__(self, dim, kernel_size): 164 | self.dim = dim 165 | self.kernel_size = kernel_size 166 | 167 | def __call__(self, question_rep, lhs_rep, rhs_rep): 168 | cnn_weight_dim = self.dim*self.dim*self.kernel_size*self.kernel_size 169 | cnn_bias_dim = self.dim 170 | proj_cnn_weight_dim = 2*self.dim*self.dim 171 | proj_cnn_bias_dim = self.dim 172 | if (question_rep.size(1) != 173 | proj_cnn_weight_dim + proj_cnn_bias_dim 174 | + cnn_weight_dim + cnn_bias_dim): 175 | raise ValueError 176 | 177 | # pick out CNN and projection CNN weights/biases 178 | cnn_weight = question_rep[:, : cnn_weight_dim] 179 | cnn_bias = question_rep[:, cnn_weight_dim : cnn_weight_dim + cnn_bias_dim] 180 | proj_weight = question_rep[:, cnn_weight_dim+cnn_bias_dim : 181 | cnn_weight_dim+cnn_bias_dim+proj_cnn_weight_dim] 182 | proj_bias = question_rep[:, cnn_weight_dim+cnn_bias_dim+proj_cnn_weight_dim:] 183 | 184 | cnn_out_total = [] 185 | bs = question_rep.size(0) 186 | 187 | for i in range(bs): 188 | cnn_weight_curr = cnn_weight[i].view(self.dim, self.dim, self.kernel_size, self.kernel_size) 189 | cnn_bias_curr = cnn_bias[i] 190 | proj_weight_curr = proj_weight[i].view(self.dim, 2*self.dim, 1, 1) 191 | proj_bias_curr = proj_bias[i] 192 | 193 | cnn_inp = F.conv2d(torch.cat([lhs_rep[[i]], rhs_rep[[i]]], 1), 194 | proj_weight_curr, 195 | bias=proj_bias_curr, padding=0) 196 | cnn_out_total.append(F.relu(F.conv2d( 197 | cnn_inp, cnn_weight_curr, bias=cnn_bias_curr, padding=self.kernel_size // 2))) 198 | 199 | return torch.cat(cnn_out_total) 200 | 201 | INITS = {'xavier_uniform' : xavier_uniform, 202 | 'constant' : constant, 203 | 'uniform' : uniform, 204 | 'correct' : correct_alpha_init_xyr, 205 | 'correct_xry' : correct_alpha_init_xry, 206 | 'correct_rxy' : correct_alpha_init_rxy} 207 | 208 | class SHNMN(nn.Module): 209 | def __init__(self, 210 | vocab, 211 | feature_dim, 212 | module_dim, 213 | module_kernel_size, 214 | stem_dim, 215 | stem_num_layers, 216 | stem_subsample_layers, 217 | stem_kernel_size, 218 | stem_padding, 219 | stem_batchnorm, 220 | classifier_fc_layers, 221 | classifier_proj_dim, 222 | classifier_downsample,classifier_batchnorm, 223 | num_modules, 224 | hard_code_alpha=False, 225 | hard_code_tau=False, 226 | tau_init='random', 227 | alpha_init='xavier_uniform', 228 | model_type ='soft', 229 | model_bernoulli=0.5, 230 | use_module = 'conv', 231 | use_stopwords = True, 232 | **kwargs): 233 | 234 | super().__init__() 235 | self.num_modules = num_modules 236 | # alphas and taus from Overleaf Doc. 237 | self.hard_code_alpha = hard_code_alpha 238 | self.hard_code_tau = hard_code_tau 239 | 240 | num_question_tokens = 3 241 | 242 | if alpha_init.startswith('correct'): 243 | print('using correct initialization') 244 | alpha = INITS[alpha_init](torch.Tensor(num_modules, num_question_tokens)) 245 | elif alpha_init == 'constant': 246 | alpha = INITS[alpha_init](torch.Tensor(num_modules, num_question_tokens), 1) 247 | else: 248 | alpha = INITS[alpha_init](torch.Tensor(num_modules, num_question_tokens)) 249 | print('initial alpha ') 250 | print(alpha) 251 | 252 | 253 | if hard_code_alpha: 254 | assert(alpha_init.startswith('correct')) 255 | 256 | self.alpha = Variable(alpha) 257 | self.alpha = self.alpha.to(device) 258 | else: 259 | self.alpha = nn.Parameter(alpha) 260 | 261 | 262 | # create taus 263 | if tau_init == 'tree': 264 | tau_0, tau_1 = _tree_tau() 265 | print("initializing with tree.") 266 | elif tau_init == 'chain': 267 | tau_0, tau_1 = _chain_tau() 268 | print("initializing with chain") 269 | elif tau_init == 'chain_with_shortcuts': 270 | tau_0, tau_1 = _chain_with_shortcuts_tau() 271 | print("initializing with chain and shortcuts") 272 | 273 | else: 274 | tau_0, tau_1 = _random_tau(num_modules) 275 | 276 | if hard_code_tau: 277 | assert(tau_init in ['chain', 'tree', 'chain_with_shortcuts']) 278 | self.tau_0 = Variable(tau_0) 279 | self.tau_1 = Variable(tau_1) 280 | self.tau_0 = self.tau_0.to(device) 281 | self.tau_1 = self.tau_1.to(device) 282 | else: 283 | self.tau_0 = nn.Parameter(tau_0) 284 | self.tau_1 = nn.Parameter(tau_1) 285 | 286 | 287 | 288 | if use_module == 'conv': 289 | embedding_dim_1 = module_dim + (module_dim*module_dim*module_kernel_size*module_kernel_size) 290 | embedding_dim_2 = module_dim + (2*module_dim*module_dim) 291 | 292 | question_embeddings_1 = nn.Embedding(len(vocab['question_idx_to_token']),embedding_dim_1) 293 | question_embeddings_2 = nn.Embedding(len(vocab['question_idx_to_token']),embedding_dim_2) 294 | 295 | stdv_1 = 1. / math.sqrt(module_dim*module_kernel_size*module_kernel_size) 296 | stdv_2 = 1. / math.sqrt(2*module_dim) 297 | 298 | question_embeddings_1.weight.data.uniform_(-stdv_1, stdv_1) 299 | question_embeddings_2.weight.data.uniform_(-stdv_2, stdv_2) 300 | self.question_embeddings = nn.Embedding(len(vocab['question_idx_to_token']), embedding_dim_1+embedding_dim_2) 301 | self.question_embeddings.weight.data = torch.cat([question_embeddings_1.weight.data, 302 | question_embeddings_2.weight.data],dim=-1) 303 | 304 | self.func = ConvFunc(module_dim, module_kernel_size) 305 | 306 | elif use_module == 'residual': 307 | embedding_dim_1 = module_dim + (module_dim*module_dim*module_kernel_size*module_kernel_size) 308 | embedding_dim_2 = module_dim + (2*module_dim*module_dim) 309 | 310 | question_embeddings_a = nn.Embedding(len(vocab['question_idx_to_token']),embedding_dim_1) 311 | question_embeddings_b = nn.Embedding(len(vocab['question_idx_to_token']),embedding_dim_1) 312 | question_embeddings_2 = nn.Embedding(len(vocab['question_idx_to_token']),embedding_dim_2) 313 | 314 | stdv_1 = 1. / math.sqrt(module_dim*module_kernel_size*module_kernel_size) 315 | stdv_2 = 1. / math.sqrt(2*module_dim) 316 | 317 | question_embeddings_a.weight.data.uniform_(-stdv_1, stdv_1) 318 | question_embeddings_b.weight.data.uniform_(-stdv_1, stdv_1) 319 | question_embeddings_2.weight.data.uniform_(-stdv_2, stdv_2) 320 | self.question_embeddings = nn.Embedding(len(vocab['question_idx_to_token']), 2*embedding_dim_1+embedding_dim_2) 321 | self.question_embeddings.weight.data = torch.cat([question_embeddings_a.weight.data, question_embeddings_b.weight.data, 322 | question_embeddings_2.weight.data],dim=-1) 323 | self.func = ResidualFunc(module_dim, module_kernel_size) 324 | 325 | else: 326 | self.question_embeddings = nn.Embedding(len(vocab['question_idx_to_token']), module_dim) 327 | self.func = FindModule(module_dim, module_kernel_size) 328 | 329 | 330 | # stem for processing the image into a 3D tensor 331 | self.stem = build_stem(feature_dim[0], stem_dim, module_dim, 332 | num_layers=stem_num_layers, 333 | subsample_layers=stem_subsample_layers, 334 | kernel_size=stem_kernel_size, 335 | padding=stem_padding, 336 | with_batchnorm=stem_batchnorm) 337 | 338 | tmp = self.stem(Variable(torch.zeros([1, feature_dim[0], feature_dim[1], feature_dim[2]]))) 339 | module_H = tmp.size(2) 340 | module_W = tmp.size(3) 341 | num_answers = len(vocab['answer_idx_to_token']) 342 | self.classifier = build_classifier(module_dim, module_H, module_W, num_answers, 343 | classifier_fc_layers, 344 | classifier_proj_dim, 345 | classifier_downsample, 346 | with_batchnorm=classifier_batchnorm) 347 | 348 | self.model_type = model_type 349 | self.use_module = use_module 350 | p = model_bernoulli 351 | tree_odds = -numpy.log((1 - p) / p) 352 | self.tree_odds = nn.Parameter(torch.Tensor([tree_odds])) 353 | 354 | 355 | def forward_hard(self, image, question): 356 | question = self.question_embeddings(question) 357 | stemmed_img = self.stem(image).unsqueeze(1) # B x 1 x C x H x W 358 | 359 | chain_tau_0, chain_tau_1 = _chain_tau() 360 | chain_tau_0 = chain_tau_0.to(device) 361 | chain_tau_1 = chain_tau_1.to(device) 362 | h_chain = _shnmn_func(question, stemmed_img, 363 | self.num_modules, self.alpha, 364 | Variable(chain_tau_0), Variable(chain_tau_1), self.func) 365 | h_final_chain = h_chain[:, -1, :, :, :] 366 | tree_tau_0, tree_tau_1 = _tree_tau() 367 | tree_tau_0 = tree_tau_0.to(device) 368 | tree_tau_1 = tree_tau_1.to(device) 369 | h_tree = _shnmn_func(question, stemmed_img, 370 | self.num_modules, self.alpha, 371 | Variable(tree_tau_0), Variable(tree_tau_1), self.func) 372 | h_final_tree = h_tree[:, -1, :, :, :] 373 | 374 | p_tree = torch.sigmoid(self.tree_odds[0]) 375 | self.tree_scores = self.classifier(h_final_tree) 376 | self.chain_scores = self.classifier(h_final_chain) 377 | output_probs_tree = F.softmax(self.tree_scores, dim=1) 378 | output_probs_chain = F.softmax(self.chain_scores, dim=1) 379 | probs_mixture = p_tree * output_probs_tree + (1.0 - p_tree) * output_probs_chain 380 | eps = 1e-6 381 | probs_mixture = (1 - eps) * probs_mixture + eps 382 | return torch.log(probs_mixture) 383 | 384 | 385 | def forward_soft(self, image, question): 386 | question = self.question_embeddings(question) 387 | stemmed_img = self.stem(image).unsqueeze(1) # B x 1 x C x H x W 388 | 389 | self.h = _shnmn_func(question, stemmed_img, self.num_modules, 390 | self.alpha, self.tau_0, self.tau_1, self.func) 391 | h_final = self.h[:, -1, :, :, :] 392 | return self.classifier(h_final) 393 | 394 | def forward(self, image, question): 395 | if self.model_type == 'hard': 396 | return self.forward_hard(image, question) 397 | else: 398 | return self.forward_soft(image, question) 399 | -------------------------------------------------------------------------------- /vr/models/simple_module_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import math 11 | import torch 12 | import sys 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | import torchvision.models 17 | 18 | from vr.models.layers import init_modules, ResidualBlock, SimpleVisualBlock, GlobalAveragePool, Flatten 19 | from vr.models.layers import build_classifier, build_stem, ConcatBlock 20 | import vr.programs 21 | 22 | from torch.nn.init import kaiming_normal, kaiming_uniform, xavier_uniform, xavier_normal, constant 23 | 24 | from vr.models.filmed_net import FiLM, FiLMedResBlock, ConcatFiLMedResBlock,coord_map 25 | from functools import partial 26 | 27 | # helper functions 28 | 29 | # === Definition of modules for NMN === # 30 | def shape_module(shape): 31 | return "Shape[{}]".format(shape) 32 | 33 | def binary_shape_module(shape): 34 | return "Shape2[{}]".format(shape) 35 | 36 | def relation_module(relation): 37 | return "Relate[{}]".format(relation) 38 | 39 | def unary_relation_module(relation): 40 | return "Relate1[{}]".format(relation) 41 | 42 | 43 | def forward_chain(image_tensor, vocab, function_modules, item_list, film_params): 44 | gammas, betas, coords = None, None, None 45 | if film_params is not None: 46 | gammas, betas, coords = film_params 47 | 48 | h_cur = image_tensor 49 | for input_ in item_list: 50 | h_next = [] 51 | for j in range(input_.shape[0]): 52 | 53 | if gammas is not None: 54 | item_idx = int(input_[j]) 55 | mod = function_modules['film'] 56 | h_next.append(mod(h_cur[[j]], gammas[:, item_idx, :], betas[:, item_idx, :], coords)) 57 | else: 58 | module_name = vocab['program_idx_to_token'][int(input_[j])] 59 | mod = function_modules[module_name] 60 | h_next.append(mod(h_cur[[j]])) 61 | 62 | h_cur = torch.cat(h_next) 63 | 64 | return h_cur 65 | 66 | 67 | def forward_chain1(image, question, stem, vocab, function_modules, binary_function_modules, film_params=None): 68 | lhs = question[:, 0] 69 | rhs = question[:, 2] 70 | rel = question[:, 1] 71 | 72 | item_list = [lhs, rel, rhs] 73 | return forward_chain(stem(image), vocab, function_modules, item_list, film_params) 74 | 75 | 76 | def forward_chain2(image, question, stem, vocab, function_modules, binary_function_modules, film_params=None): 77 | lhs = question[:, 0] 78 | rhs = question[:, 2] 79 | rel = question[:, 1] 80 | 81 | item_list = [lhs, rhs, rel] 82 | return forward_chain(stem(image), vocab, function_modules, item_list, film_params) 83 | 84 | 85 | def forward_chain3(image, question, stem, vocab, function_modules, binary_function_modules, film_params=None): 86 | lhs = question[:, 0] 87 | rhs = question[:, 2] 88 | rel = question[:, 1] 89 | 90 | item_list = [rel, lhs, rhs] 91 | return forward_chain(stem(image), vocab, function_modules, item_list, film_params) 92 | 93 | 94 | def forward_tree(image, question, stem, vocab, unary_function_modules, binary_function_modules, film_params=None): 95 | h_cur = stem(image) 96 | h_out = [] 97 | 98 | gammas, betas, coords = None, None, None 99 | if film_params is not None: 100 | gammas, betas, coords = film_params 101 | 102 | lhs = question[:, 0] 103 | rhs = question[:, 2] 104 | rel = question[:, 1] 105 | 106 | for j in range(question.shape[0]): 107 | 108 | lhs_idx = int(lhs[j]) 109 | rel_idx = int(rel[j]) 110 | rhs_idx = int(rhs[j]) 111 | 112 | lhs = shape_module(vocab['question_idx_to_token'][lhs_idx]) 113 | rel = relation_module(vocab['question_idx_to_token'][rel_idx]) 114 | rhs = shape_module(vocab['question_idx_to_token'][rhs_idx]) 115 | 116 | if gammas is not None: 117 | rel_lhs = unary_function_modules['film'](h_cur[[j]], gammas[:, lhs_idx, :], betas[:, lhs_idx, :], coords) 118 | rel_rhs = unary_function_modules['film'](h_cur[[j]], gammas[:, rhs_idx, :], betas[:, rhs_idx, :], coords) 119 | 120 | h_out.append(binary_function_modules['film']([rel_lhs, rel_rhs], gammas[:, rel_idx, :], betas[:, rel_idx, :], coords )) 121 | else: 122 | rel_lhs = unary_function_modules[lhs](h_cur[[j]]) 123 | rel_rhs = unary_function_modules[rhs](h_cur[[j]]) 124 | 125 | h_out.append(binary_function_modules[rel](rel_lhs, rel_rhs)) 126 | 127 | h_out = torch.cat(h_out) 128 | return h_out 129 | 130 | 131 | FUNC_DICT = {'chain1' : forward_chain1, 'chain2' : forward_chain2, 'chain3' : forward_chain3, 'tree' : forward_tree} 132 | 133 | 134 | class SimpleModuleNet(nn.Module): 135 | def __init__(self, vocab, feature_dim, 136 | stem_num_layers, 137 | stem_batchnorm, 138 | stem_subsample_layers, 139 | stem_kernel_size, 140 | stem_stride, 141 | stem_padding, 142 | stem_dim, 143 | module_dim, 144 | module_kernel_size, 145 | module_input_proj, 146 | forward_func, 147 | use_color, 148 | module_residual=True, 149 | module_batchnorm=False, 150 | classifier_proj_dim=512, 151 | classifier_downsample='maxpool2', 152 | classifier_fc_layers=(1024,), 153 | classifier_batchnorm=False, 154 | classifier_dropout=0, 155 | use_film=False, 156 | verbose=True): 157 | super().__init__() 158 | 159 | self.module_dim = module_dim 160 | self.func = FUNC_DICT[forward_func] 161 | self.use_color = use_color 162 | 163 | self.stem = build_stem(feature_dim[0], stem_dim, module_dim, 164 | num_layers=stem_num_layers, 165 | subsample_layers=stem_subsample_layers, 166 | kernel_size=stem_kernel_size, 167 | padding=stem_padding, 168 | with_batchnorm=stem_batchnorm) 169 | tmp = self.stem(Variable(torch.zeros([1, feature_dim[0], feature_dim[1], feature_dim[2]]))) 170 | module_H = tmp.size(2) 171 | module_W = tmp.size(3) 172 | 173 | self.coords = coord_map((module_H, module_W)).unsqueeze(0) 174 | 175 | if verbose: 176 | print('Here is my stem:') 177 | print(self.stem) 178 | 179 | num_answers = len(vocab['answer_idx_to_token']) 180 | self.classifier = build_classifier(module_dim, module_H, module_W, num_answers, 181 | classifier_fc_layers, 182 | classifier_proj_dim, 183 | classifier_downsample, 184 | with_batchnorm=classifier_batchnorm, 185 | dropout=classifier_dropout) 186 | if verbose: 187 | print('Here is my classifier:') 188 | print(self.classifier) 189 | 190 | self.unary_function_modules = {} 191 | self.binary_function_modules = {} 192 | self.vocab = vocab 193 | self.use_film = use_film 194 | 195 | if self.use_film: 196 | unary_mod = FiLMedResBlock(module_dim, with_residual=module_residual, 197 | with_intermediate_batchnorm=False, with_batchnorm=False, 198 | with_cond=[True, True], 199 | num_extra_channels=2, # was 2 for original film, 200 | extra_channel_freq=1, 201 | with_input_proj=module_input_proj, 202 | num_cond_maps=0, 203 | kernel_size=module_kernel_size, 204 | batchnorm_affine=False, 205 | num_layers=1, 206 | condition_method='bn-film', 207 | debug_every=float('inf')) 208 | binary_mod = ConcatFiLMedResBlock(2, module_dim, with_residual=module_residual, 209 | with_intermediate_batchnorm=False, with_batchnorm=False, 210 | with_cond=[True, True], 211 | num_extra_channels=2, #was 2 for original film, 212 | extra_channel_freq=1, 213 | with_input_proj=module_input_proj, 214 | num_cond_maps=0, 215 | kernel_size=module_kernel_size, 216 | batchnorm_affine=False, 217 | num_layers=1, 218 | condition_method='bn-film', 219 | debug_every=float('inf')) 220 | 221 | self.unary_function_modules['film'] = unary_mod 222 | self.binary_function_modules['film'] = binary_mod 223 | self.add_module('film_unary', unary_mod) 224 | self.add_module('film_binary', binary_mod) 225 | 226 | 227 | else: 228 | for fn_str in vocab['program_token_to_idx']: 229 | arity = self.vocab['program_token_arity'][fn_str] 230 | if arity == 2 and forward_func == 'tree': 231 | binary_mod = ConcatBlock( 232 | module_dim, 233 | kernel_size=module_kernel_size, 234 | with_residual=module_residual, 235 | with_batchnorm=module_batchnorm, 236 | use_simple=False) 237 | 238 | self.add_module(fn_str, binary_mod) 239 | self.binary_function_modules[fn_str] = binary_mod 240 | 241 | else: 242 | mod = ResidualBlock( 243 | module_dim, 244 | kernel_size=module_kernel_size, 245 | with_residual=module_residual, 246 | with_batchnorm=module_batchnorm) 247 | 248 | self.add_module(fn_str, mod) 249 | self.unary_function_modules[fn_str] = mod 250 | 251 | self.declare_film_coefficients() 252 | 253 | def declare_film_coefficients(self): 254 | num_coeff = 1+len(self.vocab['question_idx_to_token']) 255 | if self.use_film: 256 | self.gammas = nn.Parameter(torch.Tensor(1, num_coeff, self.module_dim)) 257 | xavier_uniform(self.gammas) 258 | self.betas = nn.Parameter(torch.Tensor(1, num_coeff, self.module_dim)) 259 | xavier_uniform(self.betas) 260 | 261 | else: 262 | self.gammas = None 263 | self.betas = None 264 | 265 | def forward(self, image, question): 266 | return self.classifier(self.func(image, question, self.stem, self.vocab, self.unary_function_modules, self.binary_function_modules, [self.gammas, self.betas, self.coords])) 267 | -------------------------------------------------------------------------------- /vr/plotting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from matplotlib import pyplot 4 | import pandas 5 | import scipy.stats as stats 6 | 7 | 8 | def load_log(root, file_, data_train, data_val, args): 9 | slurmid = file_[:-8] 10 | path = os.path.join(root, file_) 11 | log = json.load(open(path)) 12 | 13 | args[root][slurmid] = log['args'] 14 | 15 | for t, train_loss in zip(log['train_losses_ts'], log['train_losses']): 16 | data_train['root'].append(root) 17 | data_train['slurmid'].append(slurmid) 18 | data_train['step'].append(t) 19 | data_train['train_loss'].append(train_loss) 20 | 21 | assert len(log['val_accs_ts']) == len(log['val_accs']) 22 | assert len(log['val_accs_ts']) == len(log['train_accs']) 23 | for t, val_acc, train_acc in zip(log['val_accs_ts'], log['val_accs'], log['train_accs']): 24 | data_val['root'].append(root) 25 | data_val['slurmid'].append(slurmid) 26 | data_val['step'].append(t) 27 | data_val['val_acc'].append(val_acc) 28 | data_val['train_acc'].append(train_acc) 29 | 30 | 31 | def load_logs(root, data_train, data_val, args): 32 | for root, dirs, files in os.walk(root): 33 | for file_ in files: 34 | if file_.endswith('pt.json'): 35 | load_log(root, file_, data_train, data_val, args) 36 | 37 | 38 | def plot_average(df, train_quantity='train_acc', val_quantity='val_acc', window=1, plot_interval=False): 39 | for root, df_root in df.groupby('root'): 40 | min_progress = min([df_slurmid['step'].max() for _, df_slurmid in df_root.groupby('slurmid')]) 41 | df_root = df_root[df_root['step'] <= min_progress] 42 | df_agg = df_root.groupby(['step']).agg(['mean', 'std']) 43 | 44 | # Plot train 45 | train_values = df_agg[train_quantity]['mean'] 46 | train_values = train_values.rolling(window).mean() 47 | train_lines = pyplot.plot(df_agg.index, 48 | train_values, 49 | label=root + ' train', 50 | linestyle='dotted') 51 | 52 | # Plot validation 53 | n_seeds = len(df_root['slurmid'].unique()) 54 | if val_quantity: 55 | val_values = df_agg[val_quantity]['mean'] 56 | val_std = df_agg[val_quantity]['std'] 57 | val_values = val_values.rolling(window).mean() 58 | val_std = val_std.rolling(window).mean() 59 | width = val_std * stats.t.ppf(0.975, n_seeds - 1) / (n_seeds ** 0.5) 60 | pyplot.plot(df_agg.index, 61 | val_values, 62 | label=root + " val", 63 | color=train_lines[0].get_color()) 64 | if plot_interval: 65 | pyplot.fill_between(df_agg.index, 66 | val_values - width, val_values + width, 67 | color=train_lines[0].get_color(), 68 | alpha=0.5) 69 | 70 | # Count number of successes 71 | n_train_successes = 0 72 | n_val_successes = 0 73 | for slurmid, df_slurmid in df_root.groupby('slurmid'): 74 | slurmid_values = df_slurmid[train_quantity].rolling(window).mean() 75 | if slurmid_values.iloc[-1] > 0.99: 76 | n_train_successes += 1 77 | if val_quantity: 78 | slurmid_values = df_slurmid[val_quantity].rolling(window).mean() 79 | if slurmid_values.iloc[-1] > 0.99: 80 | n_val_successes += 1 81 | success_report = "{} out of {}".format(n_train_successes, n_seeds) 82 | 83 | # Print 84 | to_print = ["{} ({} steps)".format(root, str(min_progress)), 85 | success_report, "({:.1f})".format(100 * train_values.iloc[-1])] 86 | if val_quantity: 87 | to_print.append("{} out of {}".format(n_val_successes, n_seeds)) 88 | to_print.append("({:.1f}+-{:.1f})".format(100 * val_values.iloc[-1], 100 * width.iloc[-1])) 89 | print(*to_print) 90 | 91 | pyplot.legend() 92 | 93 | 94 | def plot_all_runs(df, train_quantity='train_acc', val_quantity='val_acc', color=None, window=1): 95 | kwargs = {} 96 | if color: 97 | kwargs['color'] = color 98 | for (root, slurmid), df_run in df.groupby(['root', 'slurmid']): 99 | path = root + ' ' + slurmid 100 | train_lines = pyplot.plot(df_run['step'], 101 | df_run[train_quantity].rolling(window).mean(), 102 | label=path + ' train', 103 | 104 | linestyle='dotted', 105 | **kwargs) 106 | if val_quantity: 107 | pyplot.plot(df_run['step'], 108 | df_run[val_quantity].rolling(window).mean(), 109 | label=path + ' val', 110 | color=train_lines[0].get_color()) 111 | to_print = [path, df_run['step'].iloc[-1], df_run[train_quantity].iloc[-1]] 112 | if val_quantity: 113 | to_print.append(df_run[val_quantity].iloc[-1].mean()) 114 | print(*to_print) 115 | -------------------------------------------------------------------------------- /vr/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | """ 11 | Utilities for preprocessing sequence data. 12 | 13 | Special tokens that are in all dictionaries: 14 | 15 | : Extra parts of the sequence that we should ignore 16 | : Goes at the start of a sequence 17 | : Goes at the end of a sequence, before tokens 18 | : Out-of-vocabulary words 19 | """ 20 | 21 | SPECIAL_TOKENS = { 22 | '': 0, 23 | '': 1, 24 | '': 2, 25 | '': 3, 26 | } 27 | 28 | 29 | def tokenize(s, delim=' ', 30 | add_start_token=True, add_end_token=True, 31 | punct_to_keep=None, punct_to_remove=None): 32 | """ 33 | Tokenize a sequence, converting a string s into a list of (string) tokens by 34 | splitting on the specified delimiter. Optionally keep or remove certain 35 | punctuation marks and add start and end tokens. 36 | """ 37 | if punct_to_keep is not None: 38 | for p in punct_to_keep: 39 | s = s.replace(p, '%s%s' % (delim, p)) 40 | 41 | if punct_to_remove is not None: 42 | for p in punct_to_remove: 43 | s = s.replace(p, '') 44 | 45 | tokens = s.split(delim) 46 | if add_start_token: 47 | tokens.insert(0, '') 48 | if add_end_token: 49 | tokens.append('') 50 | return tokens 51 | 52 | 53 | def build_vocab(sequences, min_token_count=1, delim=' ', 54 | punct_to_keep=None, punct_to_remove=None): 55 | token_to_count = {} 56 | tokenize_kwargs = { 57 | 'delim': delim, 58 | 'punct_to_keep': punct_to_keep, 59 | 'punct_to_remove': punct_to_remove, 60 | } 61 | for seq in sequences: 62 | seq_tokens = tokenize(seq, delim=delim, punct_to_keep=punct_to_keep, 63 | punct_to_remove=punct_to_remove, 64 | add_start_token=False, add_end_token=False) 65 | for token in seq_tokens: 66 | if token not in token_to_count: 67 | token_to_count[token] = 0 68 | token_to_count[token] += 1 69 | 70 | token_to_idx = {} 71 | for token, idx in SPECIAL_TOKENS.items(): 72 | token_to_idx[token] = idx 73 | for token, count in sorted(token_to_count.items()): 74 | if count >= min_token_count: 75 | token_to_idx[token] = len(token_to_idx) 76 | 77 | return token_to_idx 78 | 79 | 80 | def encode(seq_tokens, token_to_idx, allow_unk=False): 81 | seq_idx = [] 82 | for token in seq_tokens: 83 | if token not in token_to_idx: 84 | if allow_unk: 85 | token = '' 86 | else: 87 | raise KeyError('Token "%s" not in vocab' % token) 88 | seq_idx.append(token_to_idx[token]) 89 | return seq_idx 90 | 91 | 92 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True): 93 | tokens = [] 94 | for idx in seq_idx: 95 | tokens.append(idx_to_token[idx]) 96 | if stop_at_end and tokens[-1] == '': 97 | break 98 | if delim is None: 99 | return tokens 100 | else: 101 | return delim.join(tokens) 102 | -------------------------------------------------------------------------------- /vr/programs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | """ 11 | Utilities for working with and converting between the various data structures 12 | used to represent programs. 13 | """ 14 | 15 | 16 | class ProgramConverter(object): 17 | 18 | def __init__(self, vocab=None): 19 | """ 20 | `vocab` is necessary only for prefix_to_list, cause in this case 21 | we need to know the arity of the tokens. 22 | """ 23 | self._vocab = vocab 24 | 25 | def is_chain(self, program_list): 26 | visited = [False for fn in program_list] 27 | cur_idx = len(program_list) - 1 28 | while True: 29 | visited[cur_idx] = True 30 | inputs = program_list[cur_idx]['inputs'] 31 | if len(inputs) == 0: 32 | break 33 | elif len(inputs) == 1: 34 | cur_idx = inputs[0] 35 | elif len(inputs) > 1: 36 | return False 37 | return all(visited) 38 | 39 | def list_to_tree(self, program_list): 40 | def build_subtree(cur): 41 | return { 42 | 'function': cur['function'], 43 | 'value_inputs': [x for x in cur['value_inputs']], 44 | 'inputs': [build_subtree(program_list[i]) for i in cur['inputs']], 45 | } 46 | return build_subtree(program_list[-1]) 47 | 48 | def tree_to_prefix(self, program_tree): 49 | output = [] 50 | def helper(cur): 51 | output.append({ 52 | 'function': cur['function'], 53 | 'value_inputs': [x for x in cur['value_inputs']], 54 | }) 55 | for node in cur['inputs']: 56 | helper(node) 57 | helper(program_tree) 58 | return output 59 | 60 | def list_to_prefix(self, program_list): 61 | return self.tree_to_prefix(self.list_to_tree(program_list)) 62 | 63 | def tree_to_postfix(self, program_tree): 64 | output = [] 65 | def helper(cur): 66 | for node in cur['inputs']: 67 | helper(node) 68 | output.append({ 69 | 'function': cur['function'], 70 | 'value_inputs': [x for x in cur['value_inputs']], 71 | }) 72 | helper(program_tree) 73 | return output 74 | 75 | def tree_to_list(self, program_tree): 76 | # First count nodes 77 | def count_nodes(cur): 78 | return 1 + sum(count_nodes(x) for x in cur['inputs']) 79 | num_nodes = count_nodes(program_tree) 80 | output = [None] * num_nodes 81 | def helper(cur, idx): 82 | output[idx] = { 83 | 'function': cur['function'], 84 | 'value_inputs': [x for x in cur['value_inputs']], 85 | 'inputs': [], 86 | } 87 | next_idx = idx - 1 88 | for node in reversed(cur['inputs']): 89 | output[idx]['inputs'].insert(0, next_idx) 90 | next_idx = helper(node, next_idx) 91 | return next_idx 92 | helper(program_tree, num_nodes - 1) 93 | return output 94 | 95 | 96 | def prefix_to_tree(self, program_prefix): 97 | program_prefix = [x for x in program_prefix] 98 | def helper(): 99 | cur = program_prefix.pop(0) 100 | return { 101 | 'function': cur['function'], 102 | 'value_inputs': [x for x in cur['value_inputs']], 103 | 'inputs': [helper() for _ in range(self.get_num_inputs(cur))], 104 | } 105 | return helper() 106 | 107 | 108 | def prefix_to_list(self, program_prefix): 109 | return self.tree_to_list(self.prefix_to_tree(program_prefix)) 110 | 111 | 112 | def list_to_postfix(self, program_list): 113 | return self.tree_to_postfix(self.list_to_tree(program_list)) 114 | 115 | 116 | def postfix_to_tree(self, program_postfix): 117 | program_postfix = [x for x in program_postfix] 118 | def helper(): 119 | cur = program_postfix.pop() 120 | return { 121 | 'function': cur['function'], 122 | 'value_inputs': [x for x in cur['value_inputs']], 123 | 'inputs': [helper() for _ in range(self, self.get_num_inputs(cur))][::-1], 124 | } 125 | return helper() 126 | 127 | 128 | def postfix_to_list(self, program_postfix): 129 | return self.tree_to_list(self, self.postfix_to_tree(program_postfix)) 130 | 131 | def get_num_inputs(self, f): 132 | f = function_to_str(f) 133 | # This is a litle hacky; it would be better to look up from metadata.json 134 | # if type(f) is str: 135 | # f = str_to_function(f) 136 | # name = f['function'] 137 | return self._vocab['program_token_arity'][f] 138 | 139 | 140 | def function_to_str(f): 141 | value_str = '' 142 | if f['value_inputs']: 143 | value_str = '[%s]' % ','.join(f['value_inputs']) 144 | return '%s%s' % (f['function'], value_str) 145 | 146 | def str_to_function(s): 147 | if '[' not in s: 148 | return { 149 | 'function': s, 150 | 'value_inputs': [], 151 | } 152 | name, value_str = s.replace(']', '').split('[') 153 | return { 154 | 'function': name, 155 | 'value_inputs': value_str.split(','), 156 | } 157 | 158 | def list_to_str(program_list): 159 | return ' '.join(function_to_str(f) for f in program_list) 160 | -------------------------------------------------------------------------------- /vr/treeGenerator.py: -------------------------------------------------------------------------------- 1 | 2 | class TreeGenerator: 3 | def __init__(self): 4 | pass 5 | 6 | def gen(self, tree_type='complete_binary'): 7 | if tree_type.startswith('complete_binary'): 8 | depth = tree_type[len('complete_binary'):] 9 | if depth == '': depth = 3 10 | else: depth = int(depth) 11 | return completeBinaryTree(depth) 12 | elif tree_type.startswith('chainTree'): 13 | depth = tree_type[len('chainTree'):] 14 | if depth == '': depth = 8 15 | else: depth = int(depth) 16 | return chainTree(depth) 17 | elif tree_type.startswith('pairChainTree'): 18 | depth = tree_type[len('pairChainTree'):] 19 | if depth == '': depth = 8 20 | else: depth = int(depth) 21 | return pairChainTree(depth) 22 | else: 23 | raise(NotImplemented) 24 | 25 | def genHeap(self, tree_type='complete_binary'): 26 | if tree_type.startswith('complete_binary'): 27 | depth = tree_type[len('complete_binary'):] 28 | if depth == '': depth = 3 29 | else: depth = int(depth) 30 | return heapCompleteBinaryTree(depth) 31 | elif tree_type.startswith('pairChainTree'): 32 | depth = tree_type[len('pairChainTree'):] 33 | if depth == '': depth = 7 34 | else: depth = int(depth) 35 | return heapPairChainTree(depth) 36 | else: 37 | raise(NotImplemented) 38 | 39 | def heapCompleteBinaryTree(depth=3): 40 | childrens = [] 41 | num = 2 ** depth 42 | for i in range(num-1): 43 | childrens.append([i*2+1,i*2+2]) 44 | for _ in range(num): 45 | childrens.append([]) 46 | return childrens 47 | 48 | def heapPairChainTree(depth=8): 49 | children = [[1,2]] 50 | for i in range(3,2*depth+1): 51 | children.append([i]) 52 | children.append([]) 53 | children.append([]) 54 | return children 55 | 56 | def completeBinaryTree(depth=3): 57 | arities = [] 58 | def gen(idepth=0): 59 | if idepth == depth: 60 | arities.append(0) 61 | return 62 | else: 63 | arities.append(2) 64 | gen(idepth+1) 65 | gen(idepth+1) 66 | gen(0) 67 | return arities 68 | 69 | def chainTree(depth=8): 70 | if depth == 0: return [0] 71 | arities = [] 72 | while (len(arities) < depth-1): arities.append(1) 73 | arities.append(0) 74 | return arities 75 | 76 | def pairChainTree(depth=8): 77 | if depth < 2: raise Exception('Depth has to be at least 2') 78 | half = int(depth / 2) 79 | above = [1] * half 80 | above[-1] = 2 81 | below = [1] * half 82 | below[-1] = 0 83 | return above + below + below 84 | -------------------------------------------------------------------------------- /vr/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2019-present, Mila 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import inspect 11 | import json 12 | import torch 13 | 14 | from vr.models import (ModuleNet, 15 | SHNMN, 16 | Seq2Seq, 17 | Seq2SeqAtt, 18 | LstmModel, 19 | CnnLstmModel, 20 | CnnLstmSaModel, 21 | FiLMedNet, 22 | FiLMGen, 23 | MAC, 24 | RelationNet) 25 | 26 | def invert_dict(d): 27 | return {v: k for k, v in d.items()} 28 | 29 | 30 | def load_vocab(path): 31 | with open(path, 'r') as f: 32 | vocab = json.load(f) 33 | vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx']) 34 | vocab['program_idx_to_token'] = invert_dict(vocab['program_token_to_idx']) 35 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 36 | # Sanity check: make sure , , and are consistent 37 | assert vocab['question_token_to_idx'][''] == 0 38 | assert vocab['question_token_to_idx'][''] == 1 39 | assert vocab['question_token_to_idx'][''] == 2 40 | assert vocab['program_token_to_idx'][''] == 0 41 | assert vocab['program_token_to_idx'][''] == 1 42 | assert vocab['program_token_to_idx'][''] == 2 43 | return vocab 44 | 45 | 46 | def load_cpu(path): 47 | """ 48 | Loads a torch checkpoint, remapping all Tensors to CPU 49 | """ 50 | return torch.load(path, map_location=lambda storage, loc: storage) 51 | 52 | 53 | def load_program_generator(path): 54 | checkpoint = load_cpu(path) 55 | model_type = checkpoint['args']['model_type'] 56 | kwargs = checkpoint['program_generator_kwargs'] 57 | state = checkpoint['program_generator_state'] 58 | if model_type in ['FiLM', 'MAC', 'RelNet']: 59 | kwargs = get_updated_args(kwargs, FiLMGen) 60 | model = FiLMGen(**kwargs) 61 | elif model_type == 'PG+EE': 62 | if kwargs.rnn_attention: 63 | model = Seq2SeqAtt(**kwargs) 64 | else: 65 | model = Seq2Seq(**kwargs) 66 | else: 67 | model = None 68 | if model is not None: 69 | model.load_state_dict(state) 70 | return model, kwargs 71 | 72 | 73 | def load_execution_engine(path, verbose=True): 74 | checkpoint = load_cpu(path) 75 | model_type = checkpoint['args']['model_type'] 76 | kwargs = checkpoint['execution_engine_kwargs'] 77 | state = checkpoint['execution_engine_state'] 78 | kwargs['verbose'] = verbose 79 | if model_type == 'FiLM': 80 | kwargs = get_updated_args(kwargs, FiLMedNet) 81 | model = FiLMedNet(**kwargs) 82 | elif model_type == 'EE': 83 | model = ModuleNet(**kwargs) 84 | elif model_type == 'MAC': 85 | kwargs.setdefault('write_unit', 'original') 86 | kwargs.setdefault('read_connect', 'last') 87 | kwargs.setdefault('noisy_controls', False) 88 | kwargs.pop('sharing_params_patterns', None) 89 | model = MAC(**kwargs) 90 | elif model_type == 'RelNet': 91 | model = RelationNet(**kwargs) 92 | elif model_type == 'SHNMN': 93 | model = SHNMN(**kwargs) 94 | else: 95 | raise ValueError() 96 | cur_state = model.state_dict() 97 | model.load_state_dict(state) 98 | return model, kwargs 99 | 100 | 101 | def load_baseline(path): 102 | model_cls_dict = { 103 | 'LSTM': LstmModel, 104 | 'CNN+LSTM': CnnLstmModel, 105 | 'CNN+LSTM+SA': CnnLstmSaModel, 106 | } 107 | checkpoint = load_cpu(path) 108 | baseline_type = checkpoint['baseline_type'] 109 | kwargs = checkpoint['baseline_kwargs'] 110 | state = checkpoint['baseline_state'] 111 | 112 | model = model_cls_dict[baseline_type](**kwargs) 113 | model.load_state_dict(state) 114 | return model, kwargs 115 | 116 | 117 | def get_updated_args(kwargs, object_class): 118 | """ 119 | Returns kwargs with renamed args or arg valuesand deleted, deprecated, unused args. 120 | Useful for loading older, trained models. 121 | If using this function is neccessary, use immediately before initializing object. 122 | """ 123 | # Update arg values 124 | for arg in arg_value_updates: 125 | if arg in kwargs and kwargs[arg] in arg_value_updates[arg]: 126 | kwargs[arg] = arg_value_updates[arg][kwargs[arg]] 127 | 128 | # Delete deprecated, unused args 129 | valid_args = inspect.getargspec(object_class.__init__)[0] 130 | new_kwargs = {valid_arg: kwargs[valid_arg] for valid_arg in valid_args if valid_arg in kwargs} 131 | return new_kwargs 132 | 133 | class EMA(): 134 | def __init__(self, mu): 135 | self.mu = mu 136 | self.shadow = {} 137 | 138 | def register(self, cat, name, val): 139 | self.shadow[cat + '-' + name] = val.clone() 140 | 141 | def __call__(self, cat, name, x): 142 | name = cat + '-' + name 143 | assert name in self.shadow 144 | new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name] 145 | self.shadow[name] = new_average.clone() 146 | return new_average 147 | 148 | arg_value_updates = { 149 | 'condition_method': { 150 | 'block-input-fac': 'block-input-film', 151 | 'block-output-fac': 'block-output-film', 152 | 'cbn': 'bn-film', 153 | 'conv-fac': 'conv-film', 154 | 'relu-fac': 'relu-film', 155 | }, 156 | 'module_input_proj': { 157 | True: 1, 158 | }, 159 | } 160 | --------------------------------------------------------------------------------