├── img ├── FiLM-layer.png ├── FiLM-model.png ├── stats │ ├── Betas.png │ ├── Gammas.png │ ├── Beta SDs.png │ ├── Gamma SDs.png │ ├── Beta Means.png │ ├── Gamma Means.png │ ├── Betas: Layer 1.png │ ├── Betas: Layer 2.png │ ├── Betas: Layer 3.png │ ├── Betas: Layer 4.png │ ├── Gammas: Layer 1.png │ ├── Gammas: Layer 2.png │ ├── Gammas: Layer 3.png │ └── Gammas: Layer 4.png ├── CLEVR_val_000010.png ├── CLEVR_val_000011.png ├── CLEVR_val_000012.png ├── CLEVR_val_000013.png ├── CLEVR_val_000014.png ├── CLEVR_val_000015.png ├── CLEVR_val_000016.png ├── CLEVR_val_000017.png ├── CLEVR_val_000018.png ├── CLEVR_val_000019.png └── best-model-curves.png ├── requirements.txt ├── vr ├── __init__.py ├── models │ ├── __init__.py │ ├── layers.py │ ├── module_net.py │ ├── film_gen.py │ ├── baselines.py │ ├── seq2seq.py │ └── filmed_net.py ├── embedding.py ├── preprocess.py ├── utils.py ├── programs.py └── data.py ├── scripts ├── train │ ├── film.sh │ ├── film_pixels.sh │ ├── film_cogent.sh │ └── film_humans.sh ├── preprocess_human.sh ├── extract_features.py ├── preprocess_questions.py ├── run_model.py └── train_model.py ├── CLEVR_eval_with_q_type.py ├── .gitignore ├── README.md └── LICENSE /img/FiLM-layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/FiLM-layer.png -------------------------------------------------------------------------------- /img/FiLM-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/FiLM-model.png -------------------------------------------------------------------------------- /img/stats/Betas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Betas.png -------------------------------------------------------------------------------- /img/stats/Gammas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gammas.png -------------------------------------------------------------------------------- /img/stats/Beta SDs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Beta SDs.png -------------------------------------------------------------------------------- /img/stats/Gamma SDs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gamma SDs.png -------------------------------------------------------------------------------- /img/CLEVR_val_000010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000010.png -------------------------------------------------------------------------------- /img/CLEVR_val_000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000011.png -------------------------------------------------------------------------------- /img/CLEVR_val_000012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000012.png -------------------------------------------------------------------------------- /img/CLEVR_val_000013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000013.png -------------------------------------------------------------------------------- /img/CLEVR_val_000014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000014.png -------------------------------------------------------------------------------- /img/CLEVR_val_000015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000015.png -------------------------------------------------------------------------------- /img/CLEVR_val_000016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000016.png -------------------------------------------------------------------------------- /img/CLEVR_val_000017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000017.png -------------------------------------------------------------------------------- /img/CLEVR_val_000018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000018.png -------------------------------------------------------------------------------- /img/CLEVR_val_000019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/CLEVR_val_000019.png -------------------------------------------------------------------------------- /img/best-model-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/best-model-curves.png -------------------------------------------------------------------------------- /img/stats/Beta Means.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Beta Means.png -------------------------------------------------------------------------------- /img/stats/Gamma Means.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gamma Means.png -------------------------------------------------------------------------------- /img/stats/Betas: Layer 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Betas: Layer 1.png -------------------------------------------------------------------------------- /img/stats/Betas: Layer 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Betas: Layer 2.png -------------------------------------------------------------------------------- /img/stats/Betas: Layer 3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Betas: Layer 3.png -------------------------------------------------------------------------------- /img/stats/Betas: Layer 4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Betas: Layer 4.png -------------------------------------------------------------------------------- /img/stats/Gammas: Layer 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gammas: Layer 1.png -------------------------------------------------------------------------------- /img/stats/Gammas: Layer 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gammas: Layer 2.png -------------------------------------------------------------------------------- /img/stats/Gammas: Layer 3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gammas: Layer 3.png -------------------------------------------------------------------------------- /img/stats/Gammas: Layer 4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanjperez/film/HEAD/img/stats/Gammas: Layer 4.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | http://download.pytorch.org/whl/cu80/torch-0.1.11.post5-cp35-cp35m-linux_x86_64.whl 2 | numpy 3 | Pillow 4 | scipy 5 | torchvision 6 | h5py 7 | tqdm 8 | -------------------------------------------------------------------------------- /vr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /vr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from vr.models.module_net import ModuleNet 10 | from vr.models.filmed_net import FiLMedNet 11 | from vr.models.seq2seq import Seq2Seq 12 | from vr.models.film_gen import FiLMGen 13 | from vr.models.baselines import LstmModel, CnnLstmModel, CnnLstmSaModel 14 | -------------------------------------------------------------------------------- /scripts/train/film.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | checkpoint_path="data/film.pt" 4 | log_path="data/film.log" 5 | python scripts/train_model.py \ 6 | --checkpoint_path $checkpoint_path \ 7 | --model_type FiLM \ 8 | --num_iterations 20000000 \ 9 | --print_verbose_every 20000000 \ 10 | --checkpoint_every 11000 \ 11 | --record_loss_every 100 \ 12 | --num_val_samples 149991 \ 13 | --optimizer Adam \ 14 | --learning_rate 3e-4 \ 15 | --batch_size 64 \ 16 | --use_coords 1 \ 17 | --module_stem_batchnorm 1 \ 18 | --module_stem_num_layers 1 \ 19 | --module_batchnorm 1 \ 20 | --classifier_batchnorm 1 \ 21 | --bidirectional 0 \ 22 | --decoder_type linear \ 23 | --encoder_type gru \ 24 | --weight_decay 1e-5 \ 25 | --rnn_num_layers 1 \ 26 | --rnn_wordvec_dim 200 \ 27 | --rnn_hidden_dim 4096 \ 28 | --rnn_output_batchnorm 0 \ 29 | --classifier_downsample maxpoolfull \ 30 | --classifier_proj_dim 512 \ 31 | --classifier_fc_dims 1024 \ 32 | --module_input_proj 1 \ 33 | --module_residual 1 \ 34 | --module_dim 128 \ 35 | --module_dropout 0e-2 \ 36 | --module_stem_kernel_size 3 \ 37 | --module_kernel_size 3 \ 38 | --module_batchnorm_affine 0 \ 39 | --module_num_layers 1 \ 40 | --num_modules 4 \ 41 | --condition_pattern 1,1,1,1 \ 42 | --gamma_option linear \ 43 | --gamma_baseline 1 \ 44 | --use_gamma 1 \ 45 | --use_beta 1 \ 46 | --condition_method bn-film \ 47 | --program_generator_parameter_efficient 1 \ 48 | | tee $log_path 49 | -------------------------------------------------------------------------------- /scripts/preprocess_human.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | unk_threshold=${@} 8 | data_dir="data" 9 | if [ ! -d "$data_dir/human_preprocessed" ]; then mkdir "$data_dir/human_preprocessed"; fi 10 | 11 | python scripts/preprocess_questions.py \ 12 | --input_questions_json "$data_dir/CLEVR-Humans/CLEVR-Humans-train.json" \ 13 | --input_vocab_json "$data_dir/vocab.json" \ 14 | --output_h5_file "$data_dir/human_preprocessed/train_human_questions_ut$unk_threshold.h5" \ 15 | --output_vocab_json "$data_dir/human_preprocessed/human_vocab_ut$unk_threshold.json" \ 16 | --expand_vocab 1 \ 17 | --unk_threshold $unk_threshold \ 18 | --encode_unk 1 \ 19 | 20 | python scripts/preprocess_questions.py \ 21 | --input_questions_json "$data_dir/CLEVR-Humans/CLEVR-Humans-val.json" \ 22 | --input_vocab_json "$data_dir/human_preprocessed/human_vocab_ut$unk_threshold.json" \ 23 | --output_h5_file "$data_dir/human_preprocessed/val_human_questions_ut$unk_threshold.h5" \ 24 | --encode_unk 1 25 | 26 | python scripts/preprocess_questions.py \ 27 | --input_questions_json "$data_dir/CLEVR-Humans/CLEVR-Humans-test.json" \ 28 | --input_vocab_json "$data_dir/human_preprocessed/human_vocab_ut$unk_threshold.json" \ 29 | --output_h5_file "$data_dir/human_preprocessed/test_human_questions_ut$unk_threshold.h5" \ 30 | --encode_unk 1 31 | -------------------------------------------------------------------------------- /scripts/train/film_pixels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | checkpoint_path="data/film_pixels.pt" 4 | log_path="data/film_pixels.log" 5 | python scripts/train_model.py \ 6 | --checkpoint_path $checkpoint_path \ 7 | --model_type FiLM \ 8 | --num_iterations 20000000 \ 9 | --print_verbose_every 20000000 \ 10 | --checkpoint_every 11000 \ 11 | --record_loss_every 100 \ 12 | --num_val_samples 149991 \ 13 | --optimizer Adam \ 14 | --learning_rate 3e-4 \ 15 | --batch_size 64 \ 16 | --use_coords 1 \ 17 | --module_batchnorm 1 \ 18 | --classifier_batchnorm 1 \ 19 | --bidirectional 0 \ 20 | --decoder_type linear \ 21 | --encoder_type gru \ 22 | --weight_decay 1e-5 \ 23 | --rnn_num_layers 1 \ 24 | --rnn_wordvec_dim 200 \ 25 | --rnn_hidden_dim 4096 \ 26 | --rnn_output_batchnorm 0 \ 27 | --classifier_downsample maxpoolfull \ 28 | --classifier_proj_dim 512 \ 29 | --classifier_fc_dims 1024 \ 30 | --module_input_proj 1 \ 31 | --module_residual 1 \ 32 | --module_dim 128 \ 33 | --module_dropout 0e-2 \ 34 | --module_kernel_size 3 \ 35 | --module_batchnorm_affine 0 \ 36 | --module_num_layers 1 \ 37 | --num_modules 4 \ 38 | --condition_pattern 1,1,1,1 \ 39 | --gamma_option linear \ 40 | --gamma_baseline 1 \ 41 | --use_gamma 1 \ 42 | --use_beta 1 \ 43 | --train_features_h5 data/train_raw.h5 \ 44 | --val_features_h5 data/val_raw.h5 \ 45 | --feature_dim 3,224,224 \ 46 | --module_stem_batchnorm 1 \ 47 | --module_stem_num_layers 4 \ 48 | --module_stem_kernel_size 4 \ 49 | --module_stem_stride 2 \ 50 | --module_stem_padding 1 \ 51 | | tee $log_path 52 | -------------------------------------------------------------------------------- /CLEVR_eval_with_q_type.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import json 4 | from collections import defaultdict 5 | import numpy as np 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--questions_file', required=True) 10 | parser.add_argument('--answers_file', required=True) 11 | 12 | 13 | def main(args): 14 | # Load true answers from questions file 15 | true_answers = [] 16 | with open(args.questions_file, 'r') as f: 17 | questions = json.load(f)['questions'] 18 | for q in questions: 19 | true_answers.append(q['answer']) 20 | 21 | correct_by_q_type = defaultdict(list) 22 | 23 | # Load predicted answers 24 | predicted_answers = [] 25 | with open(args.answers_file, 'r') as f: 26 | for line in f: 27 | predicted_answers.append(line.strip()) 28 | 29 | num_true, num_pred = len(true_answers), len(predicted_answers) 30 | assert num_true == num_pred, 'Expected %d answers but got %d' % ( 31 | num_true, num_pred) 32 | 33 | for i, (true_answer, predicted_answer) in enumerate(zip(true_answers, predicted_answers)): 34 | correct = 1 if true_answer == predicted_answer else 0 35 | correct_by_q_type['Overall'].append(correct) 36 | q_type = questions[i]['program'][-1]['function'] 37 | correct_by_q_type[q_type].append(correct) 38 | 39 | for q_type, vals in sorted(correct_by_q_type.items()): 40 | vals = np.asarray(vals) 41 | print(q_type, '%d / %d = %.2f' % (vals.sum(), vals.shape[0], 100.0 * vals.mean())) 42 | 43 | 44 | if __name__ == '__main__': 45 | args = parser.parse_args() 46 | main(args) -------------------------------------------------------------------------------- /scripts/train/film_cogent.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | checkpoint_path="data/film_cogent.pt" 4 | log_path="data/film_cogent.log" 5 | python scripts/train_model.py \ 6 | --checkpoint_path $checkpoint_path \ 7 | --model_type FiLM \ 8 | --num_iterations 20000000 \ 9 | --print_verbose_every 20000000 \ 10 | --checkpoint_every 11000 \ 11 | --record_loss_every 100 \ 12 | --optimizer Adam \ 13 | --learning_rate 3e-4 \ 14 | --batch_size 64 \ 15 | --use_coords 1 \ 16 | --module_stem_batchnorm 1 \ 17 | --module_stem_num_layers 1 \ 18 | --module_batchnorm 1 \ 19 | --classifier_batchnorm 1 \ 20 | --bidirectional 0 \ 21 | --decoder_type linear \ 22 | --encoder_type gru \ 23 | --weight_decay 1e-5 \ 24 | --rnn_num_layers 1 \ 25 | --rnn_wordvec_dim 200 \ 26 | --rnn_hidden_dim 4096 \ 27 | --rnn_output_batchnorm 0 \ 28 | --classifier_downsample maxpoolfull \ 29 | --classifier_proj_dim 512 \ 30 | --classifier_fc_dims 1024 \ 31 | --module_input_proj 1 \ 32 | --module_residual 1 \ 33 | --module_dim 128 \ 34 | --module_dropout 0e-2 \ 35 | --module_stem_kernel_size 3 \ 36 | --module_kernel_size 3 \ 37 | --module_batchnorm_affine 0 \ 38 | --module_num_layers 1 \ 39 | --num_modules 4 \ 40 | --condition_pattern 1,1,1,1 \ 41 | --gamma_option linear \ 42 | --gamma_baseline 1 \ 43 | --use_gamma 1 \ 44 | --use_beta 1 \ 45 | --condition_method bn-film \ 46 | --program_generator_parameter_efficient 1 \ 47 | --vocab_json data/vocabA.json \ 48 | --train_features_h5 data/trainA_features.h5 \ 49 | --train_question_h5 data/trainA_questions.h5 \ 50 | --val_features_h5 data/valA_features.h5 \ 51 | --val_question_h5 data/valA_questions.h5 \ 52 | | tee $log_path 53 | -------------------------------------------------------------------------------- /vr/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utilities for dealing with embeddings. 9 | """ 10 | 11 | 12 | def convert_pretrained_wordvecs(vocab, word2vec): 13 | N = len(vocab['question_idx_to_token']) 14 | D = word2vec['vecs'].size(1) 15 | embed = torch.nn.Embedding(N, D) 16 | print(type(embed.weight)) 17 | word2vec_word_to_idx = {w: i for i, w in enumerate(word2vec['words'])} 18 | print(type(word2vec['vecs'])) 19 | for idx, word in vocab['question_idx_to_token'].items(): 20 | word2vec_idx = word2vec_word_to_idx.get(word, None) 21 | if word2vec_idx is not None: 22 | embed.weight.data[idx] = word2vec['vecs'][word2vec_idx] 23 | return embed 24 | 25 | 26 | def expand_embedding_vocab(embed, token_to_idx, word2vec=None, std=0.01): 27 | old_weight = embed.weight.data 28 | old_N, D = old_weight.size() 29 | new_N = 1 + max(idx for idx in token_to_idx.values()) 30 | new_weight = old_weight.new(new_N, D).normal_().mul_(std) 31 | new_weight[:old_N].copy_(old_weight) 32 | 33 | if word2vec is not None: 34 | num_found = 0 35 | assert D == word2vec['vecs'].size(1), 'Word vector dimension mismatch' 36 | word2vec_token_to_idx = {w: i for i, w in enumerate(word2vec['words'])} 37 | for token, idx in token_to_idx.items(): 38 | word2vec_idx = word2vec_token_to_idx.get(token, None) 39 | if idx >= old_N and word2vec_idx is not None: 40 | vec = word2vec['vecs'][word2vec_idx] 41 | new_weight[idx].copy_(vec) 42 | num_found += 1 43 | embed.num_embeddings = new_N 44 | embed.weight.data = new_weight 45 | return embed 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data files 2 | data 3 | 4 | # Experiment files 5 | exp 6 | scripts/dev 7 | 8 | # Image files 9 | img/cst 10 | 11 | # Editor files 12 | *.DS_Store 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # dotenv 96 | .env 97 | 98 | # virtualenv 99 | .venv 100 | venv/ 101 | ENV/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /scripts/train/film_humans.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | unk_threshold=5 3 | checkpoint_path="data/film_humans.pt" 4 | log_path="data/film_humans.log" 5 | python scripts/train_model.py \ 6 | --checkpoint_path $checkpoint_path \ 7 | --model_type FiLM \ 8 | --num_iterations 20000000 \ 9 | --print_verbose_every 20000000 \ 10 | --checkpoint_every 278 \ 11 | --record_loss_every 278 \ 12 | --num_val_samples 149991 \ 13 | --batch_size 64 \ 14 | --use_coords 1 \ 15 | --module_stem_batchnorm 1 \ 16 | --module_stem_num_layers 1 \ 17 | --module_batchnorm 1 \ 18 | --classifier_batchnorm 1 \ 19 | --bidirectional 0 \ 20 | --decoder_type linear \ 21 | --encoder_type gru \ 22 | --rnn_num_layers 1 \ 23 | --rnn_wordvec_dim 200 \ 24 | --rnn_hidden_dim 4096 \ 25 | --rnn_output_batchnorm 0 \ 26 | --classifier_downsample maxpoolfull \ 27 | --classifier_proj_dim 512 \ 28 | --classifier_fc_dims 1024 \ 29 | --module_input_proj 1 \ 30 | --module_residual 1 \ 31 | --module_dim 128 \ 32 | --module_dropout 0e-2 \ 33 | --module_stem_kernel_size 3 \ 34 | --module_kernel_size 3 \ 35 | --module_batchnorm_affine 0 \ 36 | --module_num_layers 1 \ 37 | --num_modules 4 \ 38 | --condition_pattern 1,1,1,1 \ 39 | --gamma_option linear \ 40 | --gamma_baseline 1 \ 41 | --use_gamma 1 \ 42 | --use_beta 1 \ 43 | --program_generator_start_from models/best.pt \ 44 | --execution_engine_start_from models/best.pt \ 45 | --optimizer Adam \ 46 | --learning_rate 3e-4 \ 47 | --weight_decay 1e-5 \ 48 | --train_program_generator 1 \ 49 | --train_execution_engine 0 \ 50 | --set_execution_engine_eval 0 \ 51 | --train_question_h5 "data/human_preprocessed/train_human_questions_ut$unk_threshold.h5" \ 52 | --val_question_h5 "data/human_preprocessed/val_human_questions_ut$unk_threshold.h5" \ 53 | --vocab_json "data/human_preprocessed/human_vocab_ut$unk_threshold.json" \ 54 | | tee $log_path 55 | -------------------------------------------------------------------------------- /vr/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Utilities for preprocessing sequence data. 11 | 12 | Special tokens that are in all dictionaries: 13 | 14 | : Extra parts of the sequence that we should ignore 15 | : Goes at the start of a sequence 16 | : Goes at the end of a sequence, before tokens 17 | : Out-of-vocabulary words 18 | """ 19 | 20 | SPECIAL_TOKENS = { 21 | '': 0, 22 | '': 1, 23 | '': 2, 24 | '': 3, 25 | } 26 | 27 | 28 | def tokenize(s, delim=' ', 29 | add_start_token=True, add_end_token=True, 30 | punct_to_keep=None, punct_to_remove=None): 31 | """ 32 | Tokenize a sequence, converting a string s into a list of (string) tokens by 33 | splitting on the specified delimiter. Optionally keep or remove certain 34 | punctuation marks and add start and end tokens. 35 | """ 36 | if punct_to_keep is not None: 37 | for p in punct_to_keep: 38 | s = s.replace(p, '%s%s' % (delim, p)) 39 | 40 | if punct_to_remove is not None: 41 | for p in punct_to_remove: 42 | s = s.replace(p, '') 43 | 44 | tokens = s.split(delim) 45 | if add_start_token: 46 | tokens.insert(0, '') 47 | if add_end_token: 48 | tokens.append('') 49 | return tokens 50 | 51 | 52 | def build_vocab(sequences, min_token_count=1, delim=' ', 53 | punct_to_keep=None, punct_to_remove=None): 54 | token_to_count = {} 55 | tokenize_kwargs = { 56 | 'delim': delim, 57 | 'punct_to_keep': punct_to_keep, 58 | 'punct_to_remove': punct_to_remove, 59 | } 60 | for seq in sequences: 61 | seq_tokens = tokenize(seq, delim=delim, punct_to_keep=punct_to_keep, 62 | punct_to_remove=punct_to_remove, 63 | add_start_token=False, add_end_token=False) 64 | for token in seq_tokens: 65 | if token not in token_to_count: 66 | token_to_count[token] = 0 67 | token_to_count[token] += 1 68 | 69 | token_to_idx = {} 70 | for token, idx in SPECIAL_TOKENS.items(): 71 | token_to_idx[token] = idx 72 | for token, count in sorted(token_to_count.items()): 73 | if count >= min_token_count: 74 | token_to_idx[token] = len(token_to_idx) 75 | 76 | return token_to_idx 77 | 78 | 79 | def encode(seq_tokens, token_to_idx, allow_unk=False): 80 | seq_idx = [] 81 | for token in seq_tokens: 82 | if token not in token_to_idx: 83 | if allow_unk: 84 | token = '' 85 | else: 86 | raise KeyError('Token "%s" not in vocab' % token) 87 | seq_idx.append(token_to_idx[token]) 88 | return seq_idx 89 | 90 | 91 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True): 92 | tokens = [] 93 | for idx in seq_idx: 94 | tokens.append(idx_to_token[idx]) 95 | if stop_at_end and tokens[-1] == '': 96 | break 97 | if delim is None: 98 | return tokens 99 | else: 100 | return delim.join(tokens) 101 | -------------------------------------------------------------------------------- /scripts/extract_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse, os, json 8 | import h5py 9 | import numpy as np 10 | from scipy.misc import imread, imresize 11 | from tqdm import tqdm 12 | 13 | import torch 14 | import torchvision 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--input_image_dir', required=True) 19 | parser.add_argument('--max_images', default=None, type=int) 20 | parser.add_argument('--output_h5_file', required=True) 21 | 22 | parser.add_argument('--image_height', default=224, type=int) 23 | parser.add_argument('--image_width', default=224, type=int) 24 | 25 | parser.add_argument('--model', default='resnet101') 26 | parser.add_argument('--model_stage', default=3, type=int) 27 | parser.add_argument('--batch_size', default=128, type=int) 28 | 29 | 30 | def build_model(args): 31 | if args.model.lower() == 'none': 32 | return None 33 | if not hasattr(torchvision.models, args.model): 34 | raise ValueError('Invalid model "%s"' % args.model) 35 | if not 'resnet' in args.model: 36 | raise ValueError('Feature extraction only supports ResNets') 37 | cnn = getattr(torchvision.models, args.model)(pretrained=True) 38 | layers = [ 39 | cnn.conv1, 40 | cnn.bn1, 41 | cnn.relu, 42 | cnn.maxpool, 43 | ] 44 | for i in range(args.model_stage): 45 | name = 'layer%d' % (i + 1) 46 | layers.append(getattr(cnn, name)) 47 | model = torch.nn.Sequential(*layers) 48 | model.cuda() 49 | model.eval() 50 | return model 51 | 52 | 53 | def run_batch(cur_batch, model): 54 | if model is None: 55 | image_batch = np.concatenate(cur_batch, 0).astype(np.float32) 56 | return image_batch / 255. # Scale pixel values to [0, 1] 57 | 58 | mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) 59 | std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) 60 | 61 | image_batch = np.concatenate(cur_batch, 0).astype(np.float32) 62 | image_batch = (image_batch / 255.0 - mean) / std 63 | image_batch = torch.FloatTensor(image_batch).cuda() 64 | image_batch = torch.autograd.Variable(image_batch, volatile=True) 65 | 66 | feats = model(image_batch) 67 | feats = feats.data.cpu().clone().numpy() 68 | 69 | return feats 70 | 71 | 72 | def main(args): 73 | input_paths = [] 74 | idx_set = set() 75 | for fn in os.listdir(args.input_image_dir): 76 | if not fn.endswith('.png'): continue 77 | idx = int(os.path.splitext(fn)[0].split('_')[-1]) 78 | input_paths.append((os.path.join(args.input_image_dir, fn), idx)) 79 | idx_set.add(idx) 80 | input_paths.sort(key=lambda x: x[1]) 81 | assert len(idx_set) == len(input_paths) 82 | assert min(idx_set) == 0 and max(idx_set) == len(idx_set) - 1 83 | if args.max_images is not None: 84 | input_paths = input_paths[:args.max_images] 85 | print(input_paths[0]) 86 | print(input_paths[-1]) 87 | 88 | model = build_model(args) 89 | 90 | img_size = (args.image_height, args.image_width) 91 | with h5py.File(args.output_h5_file, 'w') as f: 92 | feat_dset = None 93 | i0 = 0 94 | cur_batch = [] 95 | for i, (path, idx) in tqdm(enumerate(input_paths)): 96 | img = imread(path, mode='RGB') 97 | img = imresize(img, img_size, interp='bicubic') 98 | img = img.transpose(2, 0, 1)[None] 99 | cur_batch.append(img) 100 | if len(cur_batch) == args.batch_size: 101 | feats = run_batch(cur_batch, model) 102 | if feat_dset is None: 103 | N = len(input_paths) 104 | _, C, H, W = feats.shape 105 | feat_dset = f.create_dataset('features', (N, C, H, W), 106 | dtype=np.float32) 107 | i1 = i0 + len(cur_batch) 108 | feat_dset[i0:i1] = feats 109 | i0 = i1 110 | cur_batch = [] 111 | if len(cur_batch) > 0: 112 | feats = run_batch(cur_batch, model) 113 | i1 = i0 + len(cur_batch) 114 | feat_dset[i0:i1] = feats 115 | return 116 | 117 | 118 | if __name__ == '__main__': 119 | args = parser.parse_args() 120 | main(args) 121 | -------------------------------------------------------------------------------- /vr/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import inspect 10 | import ipdb as pdb 11 | import json 12 | import torch 13 | 14 | from vr.models import ModuleNet, Seq2Seq, LstmModel, CnnLstmModel, CnnLstmSaModel 15 | from vr.models import FiLMedNet 16 | from vr.models import FiLMGen 17 | 18 | def invert_dict(d): 19 | return {v: k for k, v in d.items()} 20 | 21 | 22 | def load_vocab(path): 23 | with open(path, 'r') as f: 24 | vocab = json.load(f) 25 | vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx']) 26 | vocab['program_idx_to_token'] = invert_dict(vocab['program_token_to_idx']) 27 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 28 | # Sanity check: make sure , , and are consistent 29 | assert vocab['question_token_to_idx'][''] == 0 30 | assert vocab['question_token_to_idx'][''] == 1 31 | assert vocab['question_token_to_idx'][''] == 2 32 | assert vocab['program_token_to_idx'][''] == 0 33 | assert vocab['program_token_to_idx'][''] == 1 34 | assert vocab['program_token_to_idx'][''] == 2 35 | return vocab 36 | 37 | 38 | def load_cpu(path): 39 | """ 40 | Loads a torch checkpoint, remapping all Tensors to CPU 41 | """ 42 | return torch.load(path, map_location=lambda storage, loc: storage) 43 | 44 | 45 | def load_program_generator(path, model_type='PG+EE'): 46 | checkpoint = load_cpu(path) 47 | kwargs = checkpoint['program_generator_kwargs'] 48 | state = checkpoint['program_generator_state'] 49 | if model_type == 'FiLM': 50 | print('Loading FiLMGen from ' + path) 51 | kwargs = get_updated_args(kwargs, FiLMGen) 52 | model = FiLMGen(**kwargs) 53 | else: 54 | print('Loading PG from ' + path) 55 | model = Seq2Seq(**kwargs) 56 | model.load_state_dict(state) 57 | return model, kwargs 58 | 59 | 60 | def load_execution_engine(path, verbose=True, model_type='PG+EE'): 61 | checkpoint = load_cpu(path) 62 | kwargs = checkpoint['execution_engine_kwargs'] 63 | state = checkpoint['execution_engine_state'] 64 | kwargs['verbose'] = verbose 65 | if model_type == 'FiLM': 66 | print('Loading FiLMedNet from ' + path) 67 | kwargs = get_updated_args(kwargs, FiLMedNet) 68 | model = FiLMedNet(**kwargs) 69 | else: 70 | print('Loading EE from ' + path) 71 | model = ModuleNet(**kwargs) 72 | cur_state = model.state_dict() 73 | model.load_state_dict(state) 74 | return model, kwargs 75 | 76 | 77 | def load_baseline(path): 78 | model_cls_dict = { 79 | 'LSTM': LstmModel, 80 | 'CNN+LSTM': CnnLstmModel, 81 | 'CNN+LSTM+SA': CnnLstmSaModel, 82 | } 83 | checkpoint = load_cpu(path) 84 | baseline_type = checkpoint['baseline_type'] 85 | kwargs = checkpoint['baseline_kwargs'] 86 | state = checkpoint['baseline_state'] 87 | 88 | model = model_cls_dict[baseline_type](**kwargs) 89 | model.load_state_dict(state) 90 | return model, kwargs 91 | 92 | 93 | def get_updated_args(kwargs, object_class): 94 | """ 95 | Returns kwargs with renamed args or arg valuesand deleted, deprecated, unused args. 96 | Useful for loading older, trained models. 97 | If using this function is neccessary, use immediately before initializing object. 98 | """ 99 | # Update arg values 100 | for arg in arg_value_updates: 101 | if arg in kwargs and kwargs[arg] in arg_value_updates[arg]: 102 | kwargs[arg] = arg_value_updates[arg][kwargs[arg]] 103 | 104 | # Delete deprecated, unused args 105 | valid_args = inspect.getargspec(object_class.__init__)[0] 106 | new_kwargs = {valid_arg: kwargs[valid_arg] for valid_arg in valid_args if valid_arg in kwargs} 107 | return new_kwargs 108 | 109 | 110 | arg_value_updates = { 111 | 'condition_method': { 112 | 'block-input-fac': 'block-input-film', 113 | 'block-output-fac': 'block-output-film', 114 | 'cbn': 'bn-film', 115 | 'conv-fac': 'conv-film', 116 | 'relu-fac': 'relu-film', 117 | }, 118 | 'module_input_proj': { 119 | True: 1, 120 | }, 121 | } 122 | -------------------------------------------------------------------------------- /vr/programs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Utilities for working with and converting between the various data structures 11 | used to represent programs. 12 | """ 13 | 14 | 15 | def is_chain(program_list): 16 | visited = [False for fn in program_list] 17 | cur_idx = len(program_list) - 1 18 | while True: 19 | visited[cur_idx] = True 20 | inputs = program_list[cur_idx]['inputs'] 21 | if len(inputs) == 0: 22 | break 23 | elif len(inputs) == 1: 24 | cur_idx = inputs[0] 25 | elif len(inputs) > 1: 26 | return False 27 | return all(visited) 28 | 29 | 30 | def list_to_tree(program_list): 31 | def build_subtree(cur): 32 | return { 33 | 'function': cur['function'], 34 | 'value_inputs': [x for x in cur['value_inputs']], 35 | 'inputs': [build_subtree(program_list[i]) for i in cur['inputs']], 36 | } 37 | return build_subtree(program_list[-1]) 38 | 39 | 40 | def tree_to_prefix(program_tree): 41 | output = [] 42 | def helper(cur): 43 | output.append({ 44 | 'function': cur['function'], 45 | 'value_inputs': [x for x in cur['value_inputs']], 46 | }) 47 | for node in cur['inputs']: 48 | helper(node) 49 | helper(program_tree) 50 | return output 51 | 52 | 53 | def list_to_prefix(program_list): 54 | return tree_to_prefix(list_to_tree(program_list)) 55 | 56 | 57 | def tree_to_postfix(program_tree): 58 | output = [] 59 | def helper(cur): 60 | for node in cur['inputs']: 61 | helper(node) 62 | output.append({ 63 | 'function': cur['function'], 64 | 'value_inputs': [x for x in cur['value_inputs']], 65 | }) 66 | helper(program_tree) 67 | return output 68 | 69 | 70 | def tree_to_list(program_tree): 71 | # First count nodes 72 | def count_nodes(cur): 73 | return 1 + sum(count_nodes(x) for x in cur['inputs']) 74 | num_nodes = count_nodes(program_tree) 75 | output = [None] * num_nodes 76 | def helper(cur, idx): 77 | output[idx] = { 78 | 'function': cur['function'], 79 | 'value_inputs': [x for x in cur['value_inputs']], 80 | 'inputs': [], 81 | } 82 | next_idx = idx - 1 83 | for node in reversed(cur['inputs']): 84 | output[idx]['inputs'].insert(0, next_idx) 85 | next_idx = helper(node, next_idx) 86 | return next_idx 87 | helper(program_tree, num_nodes - 1) 88 | return output 89 | 90 | 91 | def prefix_to_tree(program_prefix): 92 | program_prefix = [x for x in program_prefix] 93 | def helper(): 94 | cur = program_prefix.pop(0) 95 | return { 96 | 'function': cur['function'], 97 | 'value_inputs': [x for x in cur['value_inputs']], 98 | 'inputs': [helper() for _ in range(get_num_inputs(cur))], 99 | } 100 | return helper() 101 | 102 | 103 | def prefix_to_list(program_prefix): 104 | return tree_to_list(prefix_to_tree(program_prefix)) 105 | 106 | 107 | def list_to_postfix(program_list): 108 | return tree_to_postfix(list_to_tree(program_list)) 109 | 110 | 111 | def postfix_to_tree(program_postfix): 112 | program_postfix = [x for x in program_postfix] 113 | def helper(): 114 | cur = program_postfix.pop() 115 | return { 116 | 'function': cur['function'], 117 | 'value_inputs': [x for x in cur['value_inputs']], 118 | 'inputs': [helper() for _ in range(get_num_inputs(cur))][::-1], 119 | } 120 | return helper() 121 | 122 | 123 | def postfix_to_list(program_postfix): 124 | return tree_to_list(postfix_to_tree(program_postfix)) 125 | 126 | 127 | def function_to_str(f): 128 | value_str = '' 129 | if f['value_inputs']: 130 | value_str = '[%s]' % ','.join(f['value_inputs']) 131 | return '%s%s' % (f['function'], value_str) 132 | 133 | 134 | def str_to_function(s): 135 | if '[' not in s: 136 | return { 137 | 'function': s, 138 | 'value_inputs': [], 139 | } 140 | name, value_str = s.replace(']', '').split('[') 141 | return { 142 | 'function': name, 143 | 'value_inputs': value_str.split(','), 144 | } 145 | 146 | 147 | def list_to_str(program_list): 148 | return ' '.join(function_to_str(f) for f in program_list) 149 | 150 | 151 | def get_num_inputs(f): 152 | # This is a litle hacky; it would be better to look up from metadata.json 153 | if type(f) is str: 154 | f = str_to_function(f) 155 | name = f['function'] 156 | if name == 'scene': 157 | return 0 158 | if 'equal' in name or name in ['union', 'intersect', 'less_than', 'greater_than']: 159 | return 2 160 | return 1 161 | -------------------------------------------------------------------------------- /vr/models/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import math 10 | import ipdb as pdb 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.init import kaiming_normal, kaiming_uniform 15 | 16 | 17 | class ResidualBlock(nn.Module): 18 | def __init__(self, in_dim, out_dim=None, with_residual=True, with_batchnorm=True): 19 | if out_dim is None: 20 | out_dim = in_dim 21 | super(ResidualBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 23 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) 24 | self.with_batchnorm = with_batchnorm 25 | if with_batchnorm: 26 | self.bn1 = nn.BatchNorm2d(out_dim) 27 | self.bn2 = nn.BatchNorm2d(out_dim) 28 | self.with_residual = with_residual 29 | if in_dim == out_dim or not with_residual: 30 | self.proj = None 31 | else: 32 | self.proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 33 | 34 | def forward(self, x): 35 | if self.with_batchnorm: 36 | out = F.relu(self.bn1(self.conv1(x))) 37 | out = self.bn2(self.conv2(out)) 38 | else: 39 | out = self.conv2(F.relu(self.conv1(x))) 40 | res = x if self.proj is None else self.proj(x) 41 | if self.with_residual: 42 | out = F.relu(res + out) 43 | else: 44 | out = F.relu(out) 45 | return out 46 | 47 | 48 | class ConcatBlock(nn.Module): 49 | def __init__(self, dim, with_residual=True, with_batchnorm=True): 50 | super(ConcatBlock, self).__init__() 51 | self.proj = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0) 52 | self.res_block = ResidualBlock(dim, with_residual=with_residual, 53 | with_batchnorm=with_batchnorm) 54 | 55 | def forward(self, x, y): 56 | out = torch.cat([x, y], 1) # Concatentate along depth 57 | out = F.relu(self.proj(out)) 58 | out = self.res_block(out) 59 | return out 60 | 61 | 62 | class GlobalAveragePool(nn.Module): 63 | def forward(self, x): 64 | N, C = x.size(0), x.size(1) 65 | return x.view(N, C, -1).mean(2).squeeze(2) 66 | 67 | 68 | class Flatten(nn.Module): 69 | def forward(self, x): 70 | return x.view(x.size(0), -1) 71 | 72 | 73 | def build_stem(feature_dim, module_dim, num_layers=2, with_batchnorm=True, 74 | kernel_size=3, stride=1, padding=None): 75 | layers = [] 76 | prev_dim = feature_dim 77 | if padding is None: # Calculate default padding when None provided 78 | if kernel_size % 2 == 0: 79 | raise(NotImplementedError) 80 | padding = kernel_size // 2 81 | for i in range(num_layers): 82 | layers.append(nn.Conv2d(prev_dim, module_dim, kernel_size=kernel_size, stride=stride, 83 | padding=padding)) 84 | if with_batchnorm: 85 | layers.append(nn.BatchNorm2d(module_dim)) 86 | layers.append(nn.ReLU(inplace=True)) 87 | prev_dim = module_dim 88 | return nn.Sequential(*layers) 89 | 90 | 91 | def build_classifier(module_C, module_H, module_W, num_answers, 92 | fc_dims=[], proj_dim=None, downsample='maxpool2', 93 | with_batchnorm=True, dropout=0): 94 | layers = [] 95 | prev_dim = module_C * module_H * module_W 96 | if proj_dim is not None and proj_dim > 0: 97 | layers.append(nn.Conv2d(module_C, proj_dim, kernel_size=1)) 98 | if with_batchnorm: 99 | layers.append(nn.BatchNorm2d(proj_dim)) 100 | layers.append(nn.ReLU(inplace=True)) 101 | prev_dim = proj_dim * module_H * module_W 102 | if 'maxpool' in downsample or 'avgpool' in downsample: 103 | pool = nn.MaxPool2d if 'maxpool' in downsample else nn.AvgPool2d 104 | if 'full' in downsample: 105 | if module_H != module_W: 106 | assert(NotImplementedError) 107 | pool_size = module_H 108 | else: 109 | pool_size = int(downsample[-1]) 110 | # Note: Potentially sub-optimal padding for non-perfectly aligned pooling 111 | padding = 0 if ((module_H % pool_size == 0) and (module_W % pool_size == 0)) else 1 112 | layers.append(pool(kernel_size=pool_size, stride=pool_size, padding=padding)) 113 | prev_dim = proj_dim * math.ceil(module_H / pool_size) * math.ceil(module_W / pool_size) 114 | if downsample == 'aggressive': 115 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 116 | layers.append(nn.AvgPool2d(kernel_size=module_H // 2, stride=module_W // 2)) 117 | prev_dim = proj_dim 118 | fc_dims = [] # No FC layers here 119 | layers.append(Flatten()) 120 | for next_dim in fc_dims: 121 | layers.append(nn.Linear(prev_dim, next_dim)) 122 | if with_batchnorm: 123 | layers.append(nn.BatchNorm1d(next_dim)) 124 | layers.append(nn.ReLU(inplace=True)) 125 | if dropout > 0: 126 | layers.append(nn.Dropout(p=dropout)) 127 | prev_dim = next_dim 128 | layers.append(nn.Linear(prev_dim, num_answers)) 129 | return nn.Sequential(*layers) 130 | 131 | 132 | def init_modules(modules, init='uniform'): 133 | if init.lower() == 'normal': 134 | init_params = kaiming_normal 135 | elif init.lower() == 'uniform': 136 | init_params = kaiming_uniform 137 | else: 138 | return 139 | for m in modules: 140 | if isinstance(m, (nn.Conv2d, nn.Linear)): 141 | init_params(m.weight) 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FiLM: Visual Reasoning with a General Conditioning Layer 2 | 3 | ## Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, Aaron Courville 4 | 5 | This code implements a Feature-wise Linear Modulation approach to Visual Reasoning - answering multi-step questions on images. This codebase reproduces results from the AAAI 2018 paper "FiLM: Visual Reasoning with a General Conditioning Layer" (citation [here](https://github.com/ethanjperez/film#film)), which extends prior work "Learning Visual Reasoning Without Strong Priors" presented at ICML's MLSLP workshop. Please see the [retrospective paper](https://ml-retrospectives.github.io/neurips2019/accepted_retrospectives/2019/film/) (citation [here](https://github.com/ethanjperez/film#retrospective-for-film)) for an honest reflection on FiLM after the work that followed, including when to (and not to) use FiLM and tips-and-tricks for effectively training a network with FiLM layers. 6 | 7 | ### Code Outline 8 | 9 | This code is a fork from the code for "Inferring and Executing Programs for Visual Reasoning" available [here](https://github.com/facebookresearch/clevr-iep). 10 | 11 | Our FiLM Generator is located in [vr/models/film_gen.py](https://github.com/ethanjperez/film/blob/master/vr/models/film_gen.py), and our FiLMed Network and FiLM layer implementation is located in [vr/models/filmed_net.py](https://github.com/ethanjperez/film/blob/master/vr/models/filmed_net.py). 12 | 13 | We inserted a new model mode "FiLM" which integrates into forked code for [CLEVR baselines](https://arxiv.org/abs/1612.06890) and the [Program Generator + Execution Engine model](https://arxiv.org/abs/1705.03633). Throughout the code, for our model, our FiLM Generator acts in place of the "program generator" which generates the FiLM parameters for an the FiLMed Network, i.e. "execution engine." In some sense, FiLM parameters can vaguely be thought of as a "soft program" of sorts, but we use this denotation in the code to integrate better with the forked models. 14 | 15 | ### Setup and Training 16 | 17 | Because of this integration, setup instructions for the FiLM model are nearly the same as for "Inferring and Executing Programs for Visual Reasoning." We will post more detailed instructions on how to use our code in particular soon for more step-by-step guidance. For now, the guidelines below should give substantial direction to those interested. 18 | 19 | First, follow the virtual environment setup [instructions](https://github.com/facebookresearch/clevr-iep#setup). 20 | 21 | Second, follow the CLEVR data preprocessing [instructions](https://github.com/facebookresearch/clevr-iep/blob/master/TRAINING.md#preprocessing-clevr). 22 | 23 | Lastly, model training details are similar at a high level (though adapted for FiLM and our repo) to [these](https://github.com/facebookresearch/clevr-iep/blob/master/TRAINING.md#training-on-clevr) for the Program Generator + Execution Engine model, though our model only uses one step of training, rather than a 3-step training procedure. 24 | 25 | The below script has the hyperparameters and settings to reproduce FiLM CLEVR results: 26 | ```bash 27 | sh scripts/train/film.sh 28 | ``` 29 | 30 | 31 | For CLEVR-Humans, data preprocessing instructions are [here](https://github.com/facebookresearch/clevr-iep/blob/master/TRAINING.md#preprocessing-clevr-humans). 32 | The below script has the hyperparameters and settings to reproduce FiLM CLEVR-Humans results: 33 | ```bash 34 | sh scripts/train/film_humans.sh 35 | ``` 36 | 37 | 38 | Training a CLEVR-CoGenT model is very similar to training a normal CLEVR model. Training a model from pixels requires modifying the preprocessing with scripts included in the repo to preprocess pixels. The scripts to reproduce our results are also located in the scripts/train/ folder. 39 | 40 | We tried to not break existing models from the CLEVR codebase with our modifications, but we haven't tested their code after our changes. We recommend using using the CLEVR and "Inferring and Executing Programs for Visual Reasoning" code directly. 41 | 42 | Training a solid FiLM CLEVR model should only take ~12 hours on a good GPU (See training curves in the paper appendix). 43 | 44 | ### Running models 45 | 46 | We added an interactive command line tool for use with the below command/script. It's actually super enjoyable to play around with trained models. It's great for gaining intuition around what various trained models have or have not learned and how they tackle reasoning questions. 47 | ```bash 48 | python run_model.py --program_generator --execution_engine 49 | ``` 50 | 51 | By default, the command runs on [this CLEVR image](https://github.com/ethanjperez/film/blob/master/img/CLEVR_val_000017.png) in our repo, but you may modify which image to use via command line flag to test on any CLEVR image. 52 | 53 | CLEVR vocab is enforced by default, but for CLEVR-Humans models, for example, you may append the command line flag option '--enforce_clevr_vocab 0' to ask any string of characters you please. 54 | 55 | In addition, one easier way to try out zero-shot with FiLM is to run a trained model with run_model.py, but with the implemented debug command line flag on so you can manipulate the FiLM parameters modulating the FiLMed network during the forward computation. For example, '--debug_every -1' will stop the program after the model generates FiLM parameters but before the FiLMed network carries out its forward pass using FiLM layers. 56 | 57 | Thanks for stopping by, and we hope you enjoy playing around with FiLM! 58 | 59 | ### Bibtex 60 | 61 | #### FiLM 62 | ```bash 63 | @InProceedings{perez2018film, 64 | title={FiLM: Visual Reasoning with a General Conditioning Layer}, 65 | author={Ethan Perez and Florian Strub and Harm de Vries and Vincent Dumoulin and Aaron C. Courville}, 66 | booktitle={AAAI}, 67 | year={2018} 68 | } 69 | ``` 70 | 71 | #### Retrospective for FiLM 72 | ```bash 73 | @misc{perez2019retrospective, 74 | author = {Perez, Ethan}, 75 | title = {{Retroespective for: "FiLM: Visual Reasoning with a General Conditioning Layer"}}, 76 | year = {2019}, 77 | howpublished = {\url{https://ml-retrospectives.github.io/published_retrospectives/2019/film/}}, 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /scripts/preprocess_questions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import sys 10 | import os 11 | sys.path.insert(0, os.path.abspath('.')) 12 | 13 | import argparse 14 | 15 | import json 16 | import os 17 | 18 | import h5py 19 | import numpy as np 20 | 21 | import vr.programs 22 | from vr.preprocess import tokenize, encode, build_vocab 23 | 24 | 25 | """ 26 | Preprocessing script for CLEVR question files. 27 | """ 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--mode', default='prefix', 32 | choices=['chain', 'prefix', 'postfix']) 33 | parser.add_argument('--input_questions_json', required=True) 34 | parser.add_argument('--input_vocab_json', default='') 35 | parser.add_argument('--expand_vocab', default=0, type=int) 36 | parser.add_argument('--unk_threshold', default=1, type=int) 37 | parser.add_argument('--encode_unk', default=0, type=int) 38 | 39 | parser.add_argument('--output_h5_file', required=True) 40 | parser.add_argument('--output_vocab_json', default='') 41 | 42 | 43 | def program_to_str(program, mode): 44 | if mode == 'chain': 45 | if not vr.programs.is_chain(program): 46 | return None 47 | return vr.programs.list_to_str(program) 48 | elif mode == 'prefix': 49 | program_prefix = vr.programs.list_to_prefix(program) 50 | return vr.programs.list_to_str(program_prefix) 51 | elif mode == 'postfix': 52 | program_postfix = vr.programs.list_to_postfix(program) 53 | return vr.programs.list_to_str(program_postfix) 54 | return None 55 | 56 | 57 | def main(args): 58 | if (args.input_vocab_json == '') and (args.output_vocab_json == ''): 59 | print('Must give one of --input_vocab_json or --output_vocab_json') 60 | return 61 | 62 | print('Loading data') 63 | with open(args.input_questions_json, 'r') as f: 64 | questions = json.load(f)['questions'] 65 | 66 | # Either create the vocab or load it from disk 67 | if args.input_vocab_json == '' or args.expand_vocab == 1: 68 | print('Building vocab') 69 | if 'answer' in questions[0]: 70 | answer_token_to_idx = build_vocab( 71 | (q['answer'] for q in questions) 72 | ) 73 | question_token_to_idx = build_vocab( 74 | (q['question'] for q in questions), 75 | min_token_count=args.unk_threshold, 76 | punct_to_keep=[';', ','], punct_to_remove=['?', '.'] 77 | ) 78 | all_program_strs = [] 79 | for q in questions: 80 | if 'program' not in q: continue 81 | program_str = program_to_str(q['program'], args.mode) 82 | if program_str is not None: 83 | all_program_strs.append(program_str) 84 | program_token_to_idx = build_vocab(all_program_strs) 85 | vocab = { 86 | 'question_token_to_idx': question_token_to_idx, 87 | 'program_token_to_idx': program_token_to_idx, 88 | 'answer_token_to_idx': answer_token_to_idx, 89 | } 90 | 91 | if args.input_vocab_json != '': 92 | print('Loading vocab') 93 | if args.expand_vocab == 1: 94 | new_vocab = vocab 95 | with open(args.input_vocab_json, 'r') as f: 96 | vocab = json.load(f) 97 | if args.expand_vocab == 1: 98 | num_new_words = 0 99 | for word in new_vocab['question_token_to_idx']: 100 | if word not in vocab['question_token_to_idx']: 101 | print('Found new word %s' % word) 102 | idx = len(vocab['question_token_to_idx']) 103 | vocab['question_token_to_idx'][word] = idx 104 | num_new_words += 1 105 | print('Found %d new words' % num_new_words) 106 | 107 | if args.output_vocab_json != '': 108 | with open(args.output_vocab_json, 'w') as f: 109 | json.dump(vocab, f) 110 | 111 | # Encode all questions and programs 112 | print('Encoding data') 113 | questions_encoded = [] 114 | programs_encoded = [] 115 | question_families = [] 116 | orig_idxs = [] 117 | image_idxs = [] 118 | answers = [] 119 | types = [] 120 | for orig_idx, q in enumerate(questions): 121 | question = q['question'] 122 | if 'program' in q: 123 | types += [q['program'][-1]['function']] 124 | 125 | orig_idxs.append(orig_idx) 126 | image_idxs.append(q['image_index']) 127 | if 'question_family_index' in q: 128 | question_families.append(q['question_family_index']) 129 | question_tokens = tokenize(question, 130 | punct_to_keep=[';', ','], 131 | punct_to_remove=['?', '.']) 132 | question_encoded = encode(question_tokens, 133 | vocab['question_token_to_idx'], 134 | allow_unk=args.encode_unk == 1) 135 | questions_encoded.append(question_encoded) 136 | 137 | if 'program' in q: 138 | program = q['program'] 139 | program_str = program_to_str(program, args.mode) 140 | program_tokens = tokenize(program_str) 141 | program_encoded = encode(program_tokens, vocab['program_token_to_idx']) 142 | programs_encoded.append(program_encoded) 143 | 144 | if 'answer' in q: 145 | answers.append(vocab['answer_token_to_idx'][q['answer']]) 146 | 147 | # Pad encoded questions and programs 148 | max_question_length = max(len(x) for x in questions_encoded) 149 | for qe in questions_encoded: 150 | while len(qe) < max_question_length: 151 | qe.append(vocab['question_token_to_idx']['']) 152 | 153 | if len(programs_encoded) > 0: 154 | max_program_length = max(len(x) for x in programs_encoded) 155 | for pe in programs_encoded: 156 | while len(pe) < max_program_length: 157 | pe.append(vocab['program_token_to_idx']['']) 158 | 159 | # Create h5 file 160 | print('Writing output') 161 | questions_encoded = np.asarray(questions_encoded, dtype=np.int32) 162 | programs_encoded = np.asarray(programs_encoded, dtype=np.int32) 163 | print(questions_encoded.shape) 164 | print(programs_encoded.shape) 165 | 166 | mapping = {} 167 | for i, t in enumerate(set(types)): 168 | mapping[t] = i 169 | 170 | print(mapping) 171 | 172 | types_coded = [] 173 | for t in types: 174 | types_coded += [mapping[t]] 175 | 176 | with h5py.File(args.output_h5_file, 'w') as f: 177 | f.create_dataset('questions', data=questions_encoded) 178 | f.create_dataset('image_idxs', data=np.asarray(image_idxs)) 179 | f.create_dataset('orig_idxs', data=np.asarray(orig_idxs)) 180 | 181 | if len(programs_encoded) > 0: 182 | f.create_dataset('programs', data=programs_encoded) 183 | if len(question_families) > 0: 184 | f.create_dataset('question_families', data=np.asarray(question_families)) 185 | if len(answers) > 0: 186 | f.create_dataset('answers', data=np.asarray(answers)) 187 | if len(types) > 0: 188 | f.create_dataset('types', data=np.asarray(types_coded)) 189 | 190 | 191 | if __name__ == '__main__': 192 | args = parser.parse_args() 193 | main(args) 194 | -------------------------------------------------------------------------------- /vr/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import numpy as np 10 | import h5py 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | from torch.utils.data.dataloader import default_collate 14 | 15 | import vr.programs 16 | 17 | 18 | def _dataset_to_tensor(dset, mask=None): 19 | arr = np.asarray(dset, dtype=np.int64) 20 | if mask is not None: 21 | arr = arr[mask] 22 | tensor = torch.LongTensor(arr) 23 | return tensor 24 | 25 | 26 | class ClevrDataset(Dataset): 27 | def __init__(self, question_h5, feature_h5, vocab, mode='prefix', 28 | image_h5=None, max_samples=None, question_families=None, 29 | image_idx_start_from=None): 30 | mode_choices = ['prefix', 'postfix'] 31 | if mode not in mode_choices: 32 | raise ValueError('Invalid mode "%s"' % mode) 33 | self.image_h5 = image_h5 34 | self.vocab = vocab 35 | self.feature_h5 = feature_h5 36 | self.mode = mode 37 | self.max_samples = max_samples 38 | 39 | mask = None 40 | if question_families is not None: 41 | # Use only the specified families 42 | all_families = np.asarray(question_h5['question_families']) 43 | N = all_families.shape[0] 44 | print(question_families) 45 | target_families = np.asarray(question_families)[:, None] 46 | mask = (all_families == target_families).any(axis=0) 47 | if image_idx_start_from is not None: 48 | all_image_idxs = np.asarray(question_h5['image_idxs']) 49 | mask = all_image_idxs >= image_idx_start_from 50 | 51 | # Data from the question file is small, so read it all into memory 52 | print('Reading question data into memory') 53 | self.all_types = None 54 | if 'types' in question_h5: 55 | self.all_types = _dataset_to_tensor(question_h5['types'], mask) 56 | self.all_question_families = None 57 | if 'question_families' in question_h5: 58 | self.all_question_families = _dataset_to_tensor(question_h5['question_families'], mask) 59 | self.all_questions = _dataset_to_tensor(question_h5['questions'], mask) 60 | self.all_image_idxs = _dataset_to_tensor(question_h5['image_idxs'], mask) 61 | self.all_programs = None 62 | if 'programs' in question_h5: 63 | self.all_programs = _dataset_to_tensor(question_h5['programs'], mask) 64 | self.all_answers = None 65 | if 'answers' in question_h5: 66 | self.all_answers = _dataset_to_tensor(question_h5['answers'], mask) 67 | 68 | def __getitem__(self, index): 69 | if self.all_question_families is not None: 70 | question_family = self.all_question_families[index] 71 | q_type = None if self.all_types is None else self.all_types[index] 72 | question = self.all_questions[index] 73 | image_idx = self.all_image_idxs[index] 74 | answer = None 75 | if self.all_answers is not None: 76 | answer = self.all_answers[index] 77 | program_seq = None 78 | if self.all_programs is not None: 79 | program_seq = self.all_programs[index] 80 | 81 | image = None 82 | if self.image_h5 is not None: 83 | image = self.image_h5['images'][image_idx] 84 | image = torch.FloatTensor(np.asarray(image, dtype=np.float32)) 85 | 86 | feats = self.feature_h5['features'][image_idx] 87 | feats = torch.FloatTensor(np.asarray(feats, dtype=np.float32)) 88 | 89 | program_json = None 90 | if program_seq is not None: 91 | program_json_seq = [] 92 | for fn_idx in program_seq: 93 | fn_str = self.vocab['program_idx_to_token'][fn_idx] 94 | if fn_str == '' or fn_str == '': continue 95 | fn = vr.programs.str_to_function(fn_str) 96 | program_json_seq.append(fn) 97 | if self.mode == 'prefix': 98 | program_json = vr.programs.prefix_to_list(program_json_seq) 99 | elif self.mode == 'postfix': 100 | program_json = vr.programs.postfix_to_list(program_json_seq) 101 | 102 | if q_type is None: 103 | return (question, image, feats, answer, program_seq, program_json) 104 | return ([question, q_type], image, feats, answer, program_seq, program_json) 105 | 106 | def __len__(self): 107 | if self.max_samples is None: 108 | return self.all_questions.size(0) 109 | else: 110 | return min(self.max_samples, self.all_questions.size(0)) 111 | 112 | 113 | class ClevrDataLoader(DataLoader): 114 | def __init__(self, **kwargs): 115 | if 'question_h5' not in kwargs: 116 | raise ValueError('Must give question_h5') 117 | if 'feature_h5' not in kwargs: 118 | raise ValueError('Must give feature_h5') 119 | if 'vocab' not in kwargs: 120 | raise ValueError('Must give vocab') 121 | 122 | feature_h5_path = kwargs.pop('feature_h5') 123 | print('Reading features from', feature_h5_path) 124 | self.feature_h5 = h5py.File(feature_h5_path, 'r') 125 | 126 | self.image_h5 = None 127 | if 'image_h5' in kwargs: 128 | image_h5_path = kwargs.pop('image_h5') 129 | print('Reading images from ', image_h5_path) 130 | self.image_h5 = h5py.File(image_h5_path, 'r') 131 | 132 | vocab = kwargs.pop('vocab') 133 | mode = kwargs.pop('mode', 'prefix') 134 | 135 | question_families = kwargs.pop('question_families', None) 136 | max_samples = kwargs.pop('max_samples', None) 137 | question_h5_path = kwargs.pop('question_h5') 138 | image_idx_start_from = kwargs.pop('image_idx_start_from', None) 139 | print('Reading questions from ', question_h5_path) 140 | with h5py.File(question_h5_path, 'r') as question_h5: 141 | self.dataset = ClevrDataset(question_h5, self.feature_h5, vocab, mode, 142 | image_h5=self.image_h5, 143 | max_samples=max_samples, 144 | question_families=question_families, 145 | image_idx_start_from=image_idx_start_from) 146 | kwargs['collate_fn'] = clevr_collate 147 | super(ClevrDataLoader, self).__init__(self.dataset, **kwargs) 148 | 149 | def close(self): 150 | if self.image_h5 is not None: 151 | self.image_h5.close() 152 | if self.feature_h5 is not None: 153 | self.feature_h5.close() 154 | 155 | def __enter__(self): 156 | return self 157 | 158 | def __exit__(self, exc_type, exc_value, traceback): 159 | self.close() 160 | 161 | 162 | def clevr_collate(batch): 163 | transposed = list(zip(*batch)) 164 | question_batch = default_collate(transposed[0]) 165 | image_batch = transposed[1] 166 | if any(img is not None for img in image_batch): 167 | image_batch = default_collate(image_batch) 168 | feat_batch = transposed[2] 169 | if any(f is not None for f in feat_batch): 170 | feat_batch = default_collate(feat_batch) 171 | answer_batch = transposed[3] 172 | if transposed[3][0] is not None: 173 | answer_batch = default_collate(transposed[3]) 174 | program_seq_batch = transposed[4] 175 | if transposed[4][0] is not None: 176 | program_seq_batch = default_collate(transposed[4]) 177 | program_struct_batch = transposed[5] 178 | return [question_batch, image_batch, feat_batch, answer_batch, program_seq_batch, program_struct_batch] 179 | -------------------------------------------------------------------------------- /vr/models/module_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | import torchvision.models 15 | 16 | from vr.models.layers import init_modules, ResidualBlock, GlobalAveragePool, Flatten 17 | from vr.models.layers import build_classifier, build_stem, ConcatBlock 18 | import vr.programs 19 | 20 | 21 | class ModuleNet(nn.Module): 22 | def __init__(self, vocab, feature_dim=(1024, 14, 14), 23 | stem_num_layers=2, 24 | stem_batchnorm=False, 25 | module_dim=128, 26 | module_residual=True, 27 | module_batchnorm=False, 28 | classifier_proj_dim=512, 29 | classifier_downsample='maxpool2', 30 | classifier_fc_layers=(1024,), 31 | classifier_batchnorm=False, 32 | classifier_dropout=0, 33 | verbose=True): 34 | super(ModuleNet, self).__init__() 35 | 36 | 37 | self.stem = build_stem(feature_dim[0], module_dim, 38 | num_layers=stem_num_layers, 39 | with_batchnorm=stem_batchnorm) 40 | if verbose: 41 | print('Here is my stem:') 42 | print(self.stem) 43 | 44 | num_answers = len(vocab['answer_idx_to_token']) 45 | module_H, module_W = feature_dim[1], feature_dim[2] 46 | self.classifier = build_classifier(module_dim, module_H, module_W, num_answers, 47 | classifier_fc_layers, 48 | classifier_proj_dim, 49 | classifier_downsample, 50 | with_batchnorm=classifier_batchnorm, 51 | dropout=classifier_dropout) 52 | if verbose: 53 | print('Here is my classifier:') 54 | print(self.classifier) 55 | self.stem_times = [] 56 | self.module_times = [] 57 | self.classifier_times = [] 58 | self.timing = False 59 | 60 | self.function_modules = {} 61 | self.function_modules_num_inputs = {} 62 | self.vocab = vocab 63 | for fn_str in vocab['program_token_to_idx']: 64 | num_inputs = vr.programs.get_num_inputs(fn_str) 65 | self.function_modules_num_inputs[fn_str] = num_inputs 66 | if fn_str == 'scene' or num_inputs == 1: 67 | mod = ResidualBlock(module_dim, 68 | with_residual=module_residual, 69 | with_batchnorm=module_batchnorm) 70 | elif num_inputs == 2: 71 | mod = ConcatBlock(module_dim, 72 | with_residual=module_residual, 73 | with_batchnorm=module_batchnorm) 74 | self.add_module(fn_str, mod) 75 | self.function_modules[fn_str] = mod 76 | 77 | self.save_module_outputs = False 78 | 79 | def expand_answer_vocab(self, answer_to_idx, std=0.01, init_b=-50): 80 | # TODO: This is really gross, dipping into private internals of Sequential 81 | final_linear_key = str(len(self.classifier._modules) - 1) 82 | final_linear = self.classifier._modules[final_linear_key] 83 | 84 | old_weight = final_linear.weight.data 85 | old_bias = final_linear.bias.data 86 | old_N, D = old_weight.size() 87 | new_N = 1 + max(answer_to_idx.values()) 88 | new_weight = old_weight.new(new_N, D).normal_().mul_(std) 89 | new_bias = old_bias.new(new_N).fill_(init_b) 90 | new_weight[:old_N].copy_(old_weight) 91 | new_bias[:old_N].copy_(old_bias) 92 | 93 | final_linear.weight.data = new_weight 94 | final_linear.bias.data = new_bias 95 | 96 | def _forward_modules_json(self, feats, program): 97 | def gen_hook(i, j): 98 | def hook(grad): 99 | self.all_module_grad_outputs[i][j] = grad.data.cpu().clone() 100 | return hook 101 | 102 | self.all_module_outputs = [] 103 | self.all_module_grad_outputs = [] 104 | # We can't easily handle minibatching of modules, so just do a loop 105 | N = feats.size(0) 106 | final_module_outputs = [] 107 | for i in range(N): 108 | if self.save_module_outputs: 109 | self.all_module_outputs.append([]) 110 | self.all_module_grad_outputs.append([None] * len(program[i])) 111 | module_outputs = [] 112 | for j, f in enumerate(program[i]): 113 | f_str = vr.programs.function_to_str(f) 114 | module = self.function_modules[f_str] 115 | if f_str == 'scene': 116 | module_inputs = [feats[i:i+1]] 117 | else: 118 | module_inputs = [module_outputs[j] for j in f['inputs']] 119 | module_outputs.append(module(*module_inputs)) 120 | if self.save_module_outputs: 121 | self.all_module_outputs[-1].append(module_outputs[-1].data.cpu().clone()) 122 | module_outputs[-1].register_hook(gen_hook(i, j)) 123 | final_module_outputs.append(module_outputs[-1]) 124 | final_module_outputs = torch.cat(final_module_outputs, 0) 125 | return final_module_outputs 126 | 127 | def _forward_modules_ints_helper(self, feats, program, i, j): 128 | used_fn_j = True 129 | if j < program.size(1): 130 | fn_idx = program.data[i, j] 131 | fn_str = self.vocab['program_idx_to_token'][fn_idx] 132 | else: 133 | used_fn_j = False 134 | fn_str = 'scene' 135 | if fn_str == '': 136 | used_fn_j = False 137 | fn_str = 'scene' 138 | elif fn_str == '': 139 | used_fn_j = False 140 | return self._forward_modules_ints_helper(feats, program, i, j + 1) 141 | if used_fn_j: 142 | self.used_fns[i, j] = 1 143 | j += 1 144 | module = self.function_modules[fn_str] 145 | if fn_str == 'scene': 146 | module_inputs = [feats[i:i+1]] 147 | else: 148 | num_inputs = self.function_modules_num_inputs[fn_str] 149 | module_inputs = [] 150 | while len(module_inputs) < num_inputs: 151 | cur_input, j = self._forward_modules_ints_helper(feats, program, i, j) 152 | module_inputs.append(cur_input) 153 | module_output = module(*module_inputs) 154 | return module_output, j 155 | 156 | def _forward_modules_ints(self, feats, program): 157 | """ 158 | feats: FloatTensor of shape (N, C, H, W) giving features for each image 159 | program: LongTensor of shape (N, L) giving a prefix-encoded program for 160 | each image. 161 | """ 162 | N = feats.size(0) 163 | final_module_outputs = [] 164 | self.used_fns = torch.Tensor(program.size()).fill_(0) 165 | for i in range(N): 166 | cur_output, _ = self._forward_modules_ints_helper(feats, program, i, 0) 167 | final_module_outputs.append(cur_output) 168 | self.used_fns = self.used_fns.type_as(program.data).float() 169 | final_module_outputs = torch.cat(final_module_outputs, 0) 170 | return final_module_outputs 171 | 172 | def forward(self, x, program): 173 | N = x.size(0) 174 | assert N == len(program) 175 | 176 | feats = self.stem(x) 177 | 178 | if type(program) is list or type(program) is tuple: 179 | final_module_outputs = self._forward_modules_json(feats, program) 180 | elif type(program) is Variable and program.dim() == 2: 181 | final_module_outputs = self._forward_modules_ints(feats, program) 182 | elif torch.is_tensor(program) and program.dim() == 3: 183 | final_module_outputs = self._forward_modules_probs(feats, program) 184 | else: 185 | raise ValueError('Unrecognized program format') 186 | 187 | # After running modules for each input, concatenat the outputs from the 188 | # final module and run the classifier. 189 | out = self.classifier(final_module_outputs) 190 | return out 191 | -------------------------------------------------------------------------------- /vr/models/film_gen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import ipdb as pdb 4 | import torch 5 | import torch.cuda 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | from vr.embedding import expand_embedding_vocab 11 | from vr.models.layers import init_modules 12 | 13 | 14 | class FiLMGen(nn.Module): 15 | def __init__(self, 16 | null_token=0, 17 | start_token=1, 18 | end_token=2, 19 | encoder_embed=None, 20 | encoder_vocab_size=100, 21 | decoder_vocab_size=100, 22 | wordvec_dim=200, 23 | hidden_dim=512, 24 | rnn_num_layers=1, 25 | rnn_dropout=0, 26 | output_batchnorm=False, 27 | bidirectional=False, 28 | encoder_type='gru', 29 | decoder_type='linear', 30 | gamma_option='linear', 31 | gamma_baseline=1, 32 | num_modules=4, 33 | module_num_layers=1, 34 | module_dim=128, 35 | parameter_efficient=False, 36 | debug_every=float('inf'), 37 | ): 38 | super(FiLMGen, self).__init__() 39 | self.encoder_type = encoder_type 40 | self.decoder_type = decoder_type 41 | self.output_batchnorm = output_batchnorm 42 | self.bidirectional = bidirectional 43 | self.num_dir = 2 if self.bidirectional else 1 44 | self.gamma_option = gamma_option 45 | self.gamma_baseline = gamma_baseline 46 | self.num_modules = num_modules 47 | self.module_num_layers = module_num_layers 48 | self.module_dim = module_dim 49 | self.debug_every = debug_every 50 | self.NULL = null_token 51 | self.START = start_token 52 | self.END = end_token 53 | if self.bidirectional: 54 | if decoder_type != 'linear': 55 | raise(NotImplementedError) 56 | hidden_dim = (int) (hidden_dim / self.num_dir) 57 | 58 | self.func_list = { 59 | 'linear': None, 60 | 'sigmoid': F.sigmoid, 61 | 'tanh': F.tanh, 62 | 'exp': torch.exp, 63 | } 64 | 65 | self.cond_feat_size = 2 * self.module_dim * self.module_num_layers # FiLM params per ResBlock 66 | if not parameter_efficient: # parameter_efficient=False only used to load older trained models 67 | self.cond_feat_size = 4 * self.module_dim + 2 * self.num_modules 68 | 69 | self.encoder_embed = nn.Embedding(encoder_vocab_size, wordvec_dim) 70 | self.encoder_rnn = init_rnn(self.encoder_type, wordvec_dim, hidden_dim, rnn_num_layers, 71 | dropout=rnn_dropout, bidirectional=self.bidirectional) 72 | self.decoder_rnn = init_rnn(self.decoder_type, hidden_dim, hidden_dim, rnn_num_layers, 73 | dropout=rnn_dropout, bidirectional=self.bidirectional) 74 | self.decoder_linear = nn.Linear( 75 | hidden_dim * self.num_dir, self.num_modules * self.cond_feat_size) 76 | if self.output_batchnorm: 77 | self.output_bn = nn.BatchNorm1d(self.cond_feat_size, affine=True) 78 | 79 | init_modules(self.modules()) 80 | 81 | def expand_encoder_vocab(self, token_to_idx, word2vec=None, std=0.01): 82 | expand_embedding_vocab(self.encoder_embed, token_to_idx, 83 | word2vec=word2vec, std=std) 84 | 85 | def get_dims(self, x=None): 86 | V_in = self.encoder_embed.num_embeddings 87 | V_out = self.cond_feat_size 88 | D = self.encoder_embed.embedding_dim 89 | H = self.encoder_rnn.hidden_size 90 | H_full = self.encoder_rnn.hidden_size * self.num_dir 91 | L = self.encoder_rnn.num_layers * self.num_dir 92 | 93 | N = x.size(0) if x is not None else None 94 | T_in = x.size(1) if x is not None else None 95 | T_out = self.num_modules 96 | return V_in, V_out, D, H, H_full, L, N, T_in, T_out 97 | 98 | def before_rnn(self, x, replace=0): 99 | N, T = x.size() 100 | idx = torch.LongTensor(N).fill_(T - 1) 101 | 102 | # Find the last non-null element in each sequence. 103 | x_cpu = x.cpu() 104 | for i in range(N): 105 | for t in range(T - 1): 106 | if x_cpu.data[i, t] != self.NULL and x_cpu.data[i, t + 1] == self.NULL: 107 | idx[i] = t 108 | break 109 | idx = idx.type_as(x.data) 110 | x[x.data == self.NULL] = replace 111 | return x, Variable(idx) 112 | 113 | def encoder(self, x): 114 | V_in, V_out, D, H, H_full, L, N, T_in, T_out = self.get_dims(x=x) 115 | x, idx = self.before_rnn(x) # Tokenized word sequences (questions), end index 116 | embed = self.encoder_embed(x) 117 | h0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 118 | 119 | if self.encoder_type == 'lstm': 120 | c0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 121 | out, _ = self.encoder_rnn(embed, (h0, c0)) 122 | elif self.encoder_type == 'gru': 123 | out, _ = self.encoder_rnn(embed, h0) 124 | 125 | # Pull out the hidden state for the last non-null value in each input 126 | idx = idx.view(N, 1, 1).expand(N, 1, H_full) 127 | return out.gather(1, idx).view(N, H_full) 128 | 129 | def decoder(self, encoded, dims, h0=None, c0=None): 130 | V_in, V_out, D, H, H_full, L, N, T_in, T_out = dims 131 | 132 | if self.decoder_type == 'linear': 133 | # (N x H) x (H x T_out*V_out) -> (N x T_out*V_out) -> N x T_out x V_out 134 | return self.decoder_linear(encoded).view(N, T_out, V_out), (None, None) 135 | 136 | encoded_repeat = encoded.view(N, 1, H).expand(N, T_out, H) 137 | if not h0: 138 | h0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 139 | 140 | if self.decoder_type == 'lstm': 141 | if not c0: 142 | c0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 143 | rnn_output, (ht, ct) = self.decoder_rnn(encoded_repeat, (h0, c0)) 144 | elif self.decoder_type == 'gru': 145 | ct = None 146 | rnn_output, ht = self.decoder_rnn(encoded_repeat, h0) 147 | 148 | rnn_output_2d = rnn_output.contiguous().view(N * T_out, H) 149 | linear_output = self.decoder_linear(rnn_output_2d) 150 | if self.output_batchnorm: 151 | linear_output = self.output_bn(linear_output) 152 | output_shaped = linear_output.view(N, T_out, V_out) 153 | return output_shaped, (ht, ct) 154 | 155 | def forward(self, x): 156 | if self.debug_every <= -2: 157 | pdb.set_trace() 158 | encoded = self.encoder(x) 159 | film_pre_mod, _ = self.decoder(encoded, self.get_dims(x=x)) 160 | film = self.modify_output(film_pre_mod, gamma_option=self.gamma_option, 161 | gamma_shift=self.gamma_baseline) 162 | return film 163 | 164 | def modify_output(self, out, gamma_option='linear', gamma_scale=1, gamma_shift=0, 165 | beta_option='linear', beta_scale=1, beta_shift=0): 166 | gamma_func = self.func_list[gamma_option] 167 | beta_func = self.func_list[beta_option] 168 | 169 | gs = [] 170 | bs = [] 171 | for i in range(self.module_num_layers): 172 | gs.append(slice(i * (2 * self.module_dim), i * (2 * self.module_dim) + self.module_dim)) 173 | bs.append(slice(i * (2 * self.module_dim) + self.module_dim, (i + 1) * (2 * self.module_dim))) 174 | 175 | if gamma_func is not None: 176 | for i in range(self.module_num_layers): 177 | out[:,:,gs[i]] = gamma_func(out[:,:,gs[i]]) 178 | if gamma_scale != 1: 179 | for i in range(self.module_num_layers): 180 | out[:,:,gs[i]] = out[:,:,gs[i]] * gamma_scale 181 | if gamma_shift != 0: 182 | for i in range(self.module_num_layers): 183 | out[:,:,gs[i]] = out[:,:,gs[i]] + gamma_shift 184 | if beta_func is not None: 185 | for i in range(self.module_num_layers): 186 | out[:,:,bs[i]] = beta_func(out[:,:,bs[i]]) 187 | out[:,:,b2] = beta_func(out[:,:,b2]) 188 | if beta_scale != 1: 189 | for i in range(self.module_num_layers): 190 | out[:,:,bs[i]] = out[:,:,bs[i]] * beta_scale 191 | if beta_shift != 0: 192 | for i in range(self.module_num_layers): 193 | out[:,:,bs[i]] = out[:,:,bs[i]] + beta_shift 194 | return out 195 | 196 | def init_rnn(rnn_type, hidden_dim1, hidden_dim2, rnn_num_layers, 197 | dropout=0, bidirectional=False): 198 | if rnn_type == 'gru': 199 | return nn.GRU(hidden_dim1, hidden_dim2, rnn_num_layers, dropout=dropout, 200 | batch_first=True, bidirectional=bidirectional) 201 | elif rnn_type == 'lstm': 202 | return nn.LSTM(hidden_dim1, hidden_dim2, rnn_num_layers, dropout=dropout, 203 | batch_first=True, bidirectional=bidirectional) 204 | elif rnn_type == 'linear': 205 | return None 206 | else: 207 | print('RNN type ' + str(rnn_type) + ' not yet implemented.') 208 | raise(NotImplementedError) 209 | -------------------------------------------------------------------------------- /vr/models/baselines.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | 14 | from vr.models.layers import init_modules, ResidualBlock 15 | from vr.embedding import expand_embedding_vocab 16 | 17 | 18 | class StackedAttention(nn.Module): 19 | def __init__(self, input_dim, hidden_dim): 20 | super(StackedAttention, self).__init__() 21 | self.Wv = nn.Conv2d(input_dim, hidden_dim, kernel_size=1, padding=0) 22 | self.Wu = nn.Linear(input_dim, hidden_dim) 23 | self.Wp = nn.Conv2d(hidden_dim, 1, kernel_size=1, padding=0) 24 | self.hidden_dim = hidden_dim 25 | self.attention_maps = None 26 | init_modules(self.modules(), init='normal') 27 | 28 | def forward(self, v, u): 29 | """ 30 | Input: 31 | - v: N x D x H x W 32 | - u: N x D 33 | 34 | Returns: 35 | - next_u: N x D 36 | """ 37 | N, K = v.size(0), self.hidden_dim 38 | D, H, W = v.size(1), v.size(2), v.size(3) 39 | v_proj = self.Wv(v) # N x K x H x W 40 | u_proj = self.Wu(u) # N x K 41 | u_proj_expand = u_proj.view(N, K, 1, 1).expand(N, K, H, W) 42 | h = F.tanh(v_proj + u_proj_expand) 43 | p = F.softmax(self.Wp(h).view(N, H * W)).view(N, 1, H, W) 44 | self.attention_maps = p.data.clone() 45 | 46 | v_tilde = (p.expand_as(v) * v).sum(2).sum(3).view(N, D) 47 | next_u = u + v_tilde 48 | return next_u 49 | 50 | 51 | class LstmEncoder(nn.Module): 52 | def __init__(self, token_to_idx, wordvec_dim=300, 53 | rnn_dim=256, rnn_num_layers=2, rnn_dropout=0): 54 | super(LstmEncoder, self).__init__() 55 | self.token_to_idx = token_to_idx 56 | self.NULL = token_to_idx[''] 57 | self.START = token_to_idx[''] 58 | self.END = token_to_idx[''] 59 | 60 | self.embed = nn.Embedding(len(token_to_idx), wordvec_dim) 61 | self.rnn = nn.LSTM(wordvec_dim, rnn_dim, rnn_num_layers, 62 | dropout=rnn_dropout, batch_first=True) 63 | 64 | def expand_vocab(self, token_to_idx, word2vec=None, std=0.01): 65 | expand_embedding_vocab(self.embed, token_to_idx, 66 | word2vec=word2vec, std=std) 67 | 68 | def forward(self, x): 69 | N, T = x.size() 70 | idx = torch.LongTensor(N).fill_(T - 1) 71 | 72 | # Find the last non-null element in each sequence 73 | x_cpu = x.data.cpu() 74 | for i in range(N): 75 | for t in range(T - 1): 76 | if x_cpu[i, t] != self.NULL and x_cpu[i, t + 1] == self.NULL: 77 | idx[i] = t 78 | break 79 | idx = idx.type_as(x.data).long() 80 | idx = Variable(idx, requires_grad=False) 81 | 82 | hs, _ = self.rnn(self.embed(x)) 83 | idx = idx.view(N, 1, 1).expand(N, 1, hs.size(2)) 84 | H = hs.size(2) 85 | return hs.gather(1, idx).view(N, H) 86 | 87 | 88 | def build_cnn(feat_dim=(1024, 14, 14), 89 | res_block_dim=128, 90 | num_res_blocks=0, 91 | proj_dim=512, 92 | pooling='maxpool2'): 93 | C, H, W = feat_dim 94 | layers = [] 95 | if num_res_blocks > 0: 96 | layers.append(nn.Conv2d(C, res_block_dim, kernel_size=3, padding=1)) 97 | layers.append(nn.ReLU(inplace=True)) 98 | C = res_block_dim 99 | for _ in range(num_res_blocks): 100 | layers.append(ResidualBlock(C)) 101 | if proj_dim > 0: 102 | layers.append(nn.Conv2d(C, proj_dim, kernel_size=1, padding=0)) 103 | layers.append(nn.ReLU(inplace=True)) 104 | C = proj_dim 105 | if pooling == 'maxpool2': 106 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) 107 | H, W = H // 2, W // 2 108 | return nn.Sequential(*layers), (C, H, W) 109 | 110 | 111 | def build_mlp(input_dim, hidden_dims, output_dim, 112 | use_batchnorm=False, dropout=0): 113 | layers = [] 114 | D = input_dim 115 | if dropout > 0: 116 | layers.append(nn.Dropout(p=dropout)) 117 | if use_batchnorm: 118 | layers.append(nn.BatchNorm1d(input_dim)) 119 | for dim in hidden_dims: 120 | layers.append(nn.Linear(D, dim)) 121 | if use_batchnorm: 122 | layers.append(nn.BatchNorm1d(dim)) 123 | if dropout > 0: 124 | layers.append(nn.Dropout(p=dropout)) 125 | layers.append(nn.ReLU(inplace=True)) 126 | D = dim 127 | layers.append(nn.Linear(D, output_dim)) 128 | return nn.Sequential(*layers) 129 | 130 | 131 | class LstmModel(nn.Module): 132 | def __init__(self, vocab, 133 | rnn_wordvec_dim=300, rnn_dim=256, rnn_num_layers=2, rnn_dropout=0, 134 | fc_use_batchnorm=False, fc_dropout=0, fc_dims=(1024,)): 135 | super(LstmModel, self).__init__() 136 | rnn_kwargs = { 137 | 'token_to_idx': vocab['question_token_to_idx'], 138 | 'wordvec_dim': rnn_wordvec_dim, 139 | 'rnn_dim': rnn_dim, 140 | 'rnn_num_layers': rnn_num_layers, 141 | 'rnn_dropout': rnn_dropout, 142 | } 143 | self.rnn = LstmEncoder(**rnn_kwargs) 144 | 145 | classifier_kwargs = { 146 | 'input_dim': rnn_dim, 147 | 'hidden_dims': fc_dims, 148 | 'output_dim': len(vocab['answer_token_to_idx']), 149 | 'use_batchnorm': fc_use_batchnorm, 150 | 'dropout': fc_dropout, 151 | } 152 | self.classifier = build_mlp(**classifier_kwargs) 153 | 154 | def forward(self, questions, feats): 155 | q_feats = self.rnn(questions) 156 | scores = self.classifier(q_feats) 157 | return scores 158 | 159 | 160 | class CnnLstmModel(nn.Module): 161 | def __init__(self, vocab, 162 | rnn_wordvec_dim=300, rnn_dim=256, rnn_num_layers=2, rnn_dropout=0, 163 | cnn_feat_dim=(1024,14,14), 164 | cnn_res_block_dim=128, cnn_num_res_blocks=0, 165 | cnn_proj_dim=512, cnn_pooling='maxpool2', 166 | fc_dims=(1024,), fc_use_batchnorm=False, fc_dropout=0): 167 | super(CnnLstmModel, self).__init__() 168 | rnn_kwargs = { 169 | 'token_to_idx': vocab['question_token_to_idx'], 170 | 'wordvec_dim': rnn_wordvec_dim, 171 | 'rnn_dim': rnn_dim, 172 | 'rnn_num_layers': rnn_num_layers, 173 | 'rnn_dropout': rnn_dropout, 174 | } 175 | self.rnn = LstmEncoder(**rnn_kwargs) 176 | 177 | cnn_kwargs = { 178 | 'feat_dim': cnn_feat_dim, 179 | 'res_block_dim': cnn_res_block_dim, 180 | 'num_res_blocks': cnn_num_res_blocks, 181 | 'proj_dim': cnn_proj_dim, 182 | 'pooling': cnn_pooling, 183 | } 184 | self.cnn, (C, H, W) = build_cnn(**cnn_kwargs) 185 | 186 | classifier_kwargs = { 187 | 'input_dim': C * H * W + rnn_dim, 188 | 'hidden_dims': fc_dims, 189 | 'output_dim': len(vocab['answer_token_to_idx']), 190 | 'use_batchnorm': fc_use_batchnorm, 191 | 'dropout': fc_dropout, 192 | } 193 | self.classifier = build_mlp(**classifier_kwargs) 194 | 195 | def forward(self, questions, feats): 196 | N = questions.size(0) 197 | assert N == feats.size(0) 198 | q_feats = self.rnn(questions) 199 | img_feats = self.cnn(feats) 200 | cat_feats = torch.cat([q_feats, img_feats.view(N, -1)], 1) 201 | scores = self.classifier(cat_feats) 202 | return scores 203 | 204 | 205 | class CnnLstmSaModel(nn.Module): 206 | def __init__(self, vocab, 207 | rnn_wordvec_dim=300, rnn_dim=256, rnn_num_layers=2, rnn_dropout=0, 208 | cnn_feat_dim=(1024,14,14), 209 | stacked_attn_dim=512, num_stacked_attn=2, 210 | fc_use_batchnorm=False, fc_dropout=0, fc_dims=(1024,)): 211 | super(CnnLstmSaModel, self).__init__() 212 | rnn_kwargs = { 213 | 'token_to_idx': vocab['question_token_to_idx'], 214 | 'wordvec_dim': rnn_wordvec_dim, 215 | 'rnn_dim': rnn_dim, 216 | 'rnn_num_layers': rnn_num_layers, 217 | 'rnn_dropout': rnn_dropout, 218 | } 219 | self.rnn = LstmEncoder(**rnn_kwargs) 220 | 221 | C, H, W = cnn_feat_dim 222 | self.image_proj = nn.Conv2d(C, rnn_dim, kernel_size=1, padding=0) 223 | self.stacked_attns = [] 224 | for i in range(num_stacked_attn): 225 | sa = StackedAttention(rnn_dim, stacked_attn_dim) 226 | self.stacked_attns.append(sa) 227 | self.add_module('stacked-attn-%d' % i, sa) 228 | 229 | classifier_args = { 230 | 'input_dim': rnn_dim, 231 | 'hidden_dims': fc_dims, 232 | 'output_dim': len(vocab['answer_token_to_idx']), 233 | 'use_batchnorm': fc_use_batchnorm, 234 | 'dropout': fc_dropout, 235 | } 236 | self.classifier = build_mlp(**classifier_args) 237 | init_modules(self.modules(), init='normal') 238 | 239 | def forward(self, questions, feats): 240 | u = self.rnn(questions) 241 | v = self.image_proj(feats) 242 | 243 | for sa in self.stacked_attns: 244 | u = sa(v, u) 245 | 246 | scores = self.classifier(u) 247 | return scores 248 | -------------------------------------------------------------------------------- /vr/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | import torch.cuda 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | 15 | from vr.embedding import expand_embedding_vocab 16 | 17 | class Seq2Seq(nn.Module): 18 | def __init__(self, 19 | encoder_vocab_size=100, 20 | decoder_vocab_size=100, 21 | wordvec_dim=300, 22 | hidden_dim=256, 23 | rnn_num_layers=2, 24 | rnn_dropout=0, 25 | null_token=0, 26 | start_token=1, 27 | end_token=2, 28 | encoder_embed=None 29 | ): 30 | super(Seq2Seq, self).__init__() 31 | self.encoder_embed = nn.Embedding(encoder_vocab_size, wordvec_dim) 32 | self.encoder_rnn = nn.LSTM(wordvec_dim, hidden_dim, rnn_num_layers, 33 | dropout=rnn_dropout, batch_first=True) 34 | self.decoder_embed = nn.Embedding(decoder_vocab_size, wordvec_dim) 35 | self.decoder_rnn = nn.LSTM(wordvec_dim + hidden_dim, hidden_dim, rnn_num_layers, 36 | dropout=rnn_dropout, batch_first=True) 37 | self.decoder_rnn_new = nn.LSTM(hidden_dim, hidden_dim, rnn_num_layers, 38 | dropout=rnn_dropout, batch_first=True) 39 | self.decoder_linear = nn.Linear(hidden_dim, decoder_vocab_size) 40 | self.NULL = null_token 41 | self.START = start_token 42 | self.END = end_token 43 | self.multinomial_outputs = None 44 | 45 | def expand_encoder_vocab(self, token_to_idx, word2vec=None, std=0.01): 46 | expand_embedding_vocab(self.encoder_embed, token_to_idx, 47 | word2vec=word2vec, std=std) 48 | 49 | def get_dims(self, x=None, y=None): 50 | V_in = self.encoder_embed.num_embeddings 51 | V_out = self.decoder_embed.num_embeddings 52 | D = self.encoder_embed.embedding_dim 53 | H = self.encoder_rnn.hidden_size 54 | L = self.encoder_rnn.num_layers 55 | 56 | N = x.size(0) if x is not None else None 57 | N = y.size(0) if N is None and y is not None else N 58 | T_in = x.size(1) if x is not None else None 59 | T_out = y.size(1) if y is not None else None 60 | return V_in, V_out, D, H, L, N, T_in, T_out 61 | 62 | def before_rnn(self, x, replace=0): 63 | # TODO: Use PackedSequence instead of manually plucking out the last 64 | # non-NULL entry of each sequence; it is cleaner and more efficient. 65 | N, T = x.size() 66 | idx = torch.LongTensor(N).fill_(T - 1) 67 | 68 | # Find the last non-null element in each sequence. Is there a clean 69 | # way to do this? 70 | x_cpu = x.cpu() 71 | for i in range(N): 72 | for t in range(T - 1): 73 | if x_cpu.data[i, t] != self.NULL and x_cpu.data[i, t + 1] == self.NULL: 74 | idx[i] = t 75 | break 76 | idx = idx.type_as(x.data) 77 | x[x.data == self.NULL] = replace 78 | return x, Variable(idx) 79 | 80 | def encoder(self, x): 81 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(x=x) 82 | x, idx = self.before_rnn(x) 83 | embed = self.encoder_embed(x) 84 | h0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 85 | c0 = Variable(torch.zeros(L, N, H).type_as(embed.data)) 86 | 87 | out, _ = self.encoder_rnn(embed, (h0, c0)) 88 | 89 | # Pull out the hidden state for the last non-null value in each input 90 | idx = idx.view(N, 1, 1).expand(N, 1, H) 91 | return out.gather(1, idx).view(N, H) 92 | 93 | def decoder(self, encoded, y, h0=None, c0=None): 94 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(y=y) 95 | 96 | if T_out > 1: 97 | y, _ = self.before_rnn(y) 98 | y_embed = self.decoder_embed(y) 99 | encoded_repeat = encoded.view(N, 1, H).expand(N, T_out, H) 100 | rnn_input = torch.cat([encoded_repeat, y_embed], 2) 101 | if not h0: 102 | h0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 103 | if not c0: 104 | c0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 105 | rnn_output, (ht, ct) = self.decoder_rnn(rnn_input, (h0, c0)) 106 | 107 | rnn_output_2d = rnn_output.contiguous().view(N * T_out, H) 108 | output_logprobs = self.decoder_linear(rnn_output_2d).view(N, T_out, V_out) 109 | 110 | return output_logprobs, ht, ct 111 | 112 | def compute_loss(self, output_logprobs, y): 113 | """ 114 | Compute loss. We assume that the first element of the output sequence y is 115 | a start token, and that each element of y is left-aligned and right-padded 116 | with self.NULL out to T_out. We want the output_logprobs to predict the 117 | sequence y, shifted by one timestep so that y[0] is fed to the network and 118 | then y[1] is predicted. We also don't want to compute loss for padded 119 | timesteps. 120 | 121 | Inputs: 122 | - output_logprobs: Variable of shape (N, T_out, V_out) 123 | - y: LongTensor Variable of shape (N, T_out) 124 | """ 125 | self.multinomial_outputs = None 126 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(y=y) 127 | mask = y.data != self.NULL 128 | y_mask = Variable(torch.Tensor(N, T_out).fill_(0).type_as(mask)) 129 | y_mask[:, 1:] = mask[:, 1:] 130 | y_masked = y[y_mask] 131 | out_mask = Variable(torch.Tensor(N, T_out).fill_(0).type_as(mask)) 132 | out_mask[:, :-1] = mask[:, 1:] 133 | out_mask = out_mask.view(N, T_out, 1).expand(N, T_out, V_out) 134 | out_masked = output_logprobs[out_mask].view(-1, V_out) 135 | loss = F.cross_entropy(out_masked, y_masked) 136 | return loss 137 | 138 | def forward(self, x, y): 139 | encoded = self.encoder(x) 140 | 141 | V_in, V_out, D, H, L, N, T_in, T_out = self.get_dims(x=x) 142 | T_out = 15 143 | encoded_repeat = encoded.view(N, 1, H).expand(N, T_out, H) 144 | h0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 145 | c0 = Variable(torch.zeros(L, N, H).type_as(encoded.data)) 146 | rnn_output, (ht, ct) = self.decoder_rnn_new(encoded_repeat, (h0, c0)) 147 | 148 | output_logprobs, _, _ = self.decoder(encoded, y) 149 | loss = self.compute_loss(output_logprobs, y) 150 | return loss 151 | 152 | def sample(self, x, max_length=50): 153 | # TODO: Handle sampling for minibatch inputs 154 | # TODO: Beam search? 155 | self.multinomial_outputs = None 156 | assert x.size(0) == 1, "Sampling minibatches not implemented" 157 | encoded = self.encoder(x) 158 | y = [self.START] 159 | h0, c0 = None, None 160 | while True: 161 | cur_y = Variable(torch.LongTensor([y[-1]]).type_as(x.data).view(1, 1)) 162 | logprobs, h0, c0 = self.decoder(encoded, cur_y, h0=h0, c0=c0) 163 | _, next_y = logprobs.data.max(2) 164 | y.append(next_y[0, 0, 0]) 165 | if len(y) >= max_length or y[-1] == self.END: 166 | break 167 | return y 168 | 169 | def reinforce_sample(self, x, max_length=30, temperature=1.0, argmax=False): 170 | N, T = x.size(0), max_length 171 | encoded = self.encoder(x) 172 | y = torch.LongTensor(N, T).fill_(self.NULL) 173 | done = torch.ByteTensor(N).fill_(0) 174 | cur_input = Variable(x.data.new(N, 1).fill_(self.START)) 175 | h, c = None, None 176 | self.multinomial_outputs = [] 177 | self.multinomial_probs = [] 178 | for t in range(T): 179 | # logprobs is N x 1 x V 180 | logprobs, h, c = self.decoder(encoded, cur_input, h0=h, c0=c) 181 | logprobs = logprobs / temperature 182 | probs = F.softmax(logprobs.view(N, -1)) # Now N x V 183 | if argmax: 184 | _, cur_output = probs.max(1) 185 | else: 186 | cur_output = probs.multinomial() # Now N x 1 187 | self.multinomial_outputs.append(cur_output) 188 | self.multinomial_probs.append(probs) 189 | cur_output_data = cur_output.data.cpu() 190 | not_done = logical_not(done) 191 | y[:, t][not_done] = cur_output_data[not_done] 192 | done = logical_or(done, cur_output_data.cpu() == self.END) 193 | cur_input = cur_output 194 | if done.sum() == N: 195 | break 196 | return Variable(y.type_as(x.data)) 197 | 198 | def reinforce_backward(self, reward, output_mask=None): 199 | """ 200 | If output_mask is not None, then it should be a FloatTensor of shape (N, T) 201 | giving a multiplier to the output. 202 | """ 203 | assert self.multinomial_outputs is not None, 'Must call reinforce_sample first' 204 | grad_output = [] 205 | 206 | def gen_hook(mask): 207 | def hook(grad): 208 | return grad * mask.contiguous().view(-1, 1).expand_as(grad) 209 | return hook 210 | 211 | if output_mask is not None: 212 | for t, probs in enumerate(self.multinomial_probs): 213 | mask = Variable(output_mask[:, t]) 214 | probs.register_hook(gen_hook(mask)) 215 | 216 | for sampled_output in self.multinomial_outputs: 217 | sampled_output.reinforce(reward) 218 | grad_output.append(None) 219 | torch.autograd.backward(self.multinomial_outputs, grad_output, retain_variables=True) 220 | 221 | 222 | def logical_and(x, y): 223 | return x * y 224 | 225 | def logical_or(x, y): 226 | return (x + y).clamp_(0, 1) 227 | 228 | def logical_not(x): 229 | return x == 0 230 | -------------------------------------------------------------------------------- /vr/models/filmed_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | import ipdb as pdb 5 | import pprint 6 | from termcolor import colored 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torchvision.models 12 | 13 | from vr.models.layers import init_modules, GlobalAveragePool, Flatten 14 | from vr.models.layers import build_classifier, build_stem 15 | import vr.programs 16 | 17 | 18 | class FiLM(nn.Module): 19 | """ 20 | A Feature-wise Linear Modulation Layer from 21 | 'FiLM: Visual Reasoning with a General Conditioning Layer' 22 | """ 23 | def forward(self, x, gammas, betas): 24 | gammas = gammas.unsqueeze(2).unsqueeze(3).expand_as(x) 25 | betas = betas.unsqueeze(2).unsqueeze(3).expand_as(x) 26 | return (gammas * x) + betas 27 | 28 | 29 | class FiLMedNet(nn.Module): 30 | def __init__(self, vocab, feature_dim=(1024, 14, 14), 31 | stem_num_layers=2, 32 | stem_batchnorm=False, 33 | stem_kernel_size=3, 34 | stem_stride=1, 35 | stem_padding=None, 36 | num_modules=4, 37 | module_num_layers=1, 38 | module_dim=128, 39 | module_residual=True, 40 | module_batchnorm=False, 41 | module_batchnorm_affine=False, 42 | module_dropout=0, 43 | module_input_proj=1, 44 | module_kernel_size=3, 45 | classifier_proj_dim=512, 46 | classifier_downsample='maxpool2', 47 | classifier_fc_layers=(1024,), 48 | classifier_batchnorm=False, 49 | classifier_dropout=0, 50 | condition_method='bn-film', 51 | condition_pattern=[], 52 | use_gamma=True, 53 | use_beta=True, 54 | use_coords=1, 55 | debug_every=float('inf'), 56 | print_verbose_every=float('inf'), 57 | verbose=True, 58 | ): 59 | super(FiLMedNet, self).__init__() 60 | 61 | num_answers = len(vocab['answer_idx_to_token']) 62 | 63 | self.stem_times = [] 64 | self.module_times = [] 65 | self.classifier_times = [] 66 | self.timing = False 67 | 68 | self.num_modules = num_modules 69 | self.module_num_layers = module_num_layers 70 | self.module_batchnorm = module_batchnorm 71 | self.module_dim = module_dim 72 | self.condition_method = condition_method 73 | self.use_gamma = use_gamma 74 | self.use_beta = use_beta 75 | self.use_coords_freq = use_coords 76 | self.debug_every = debug_every 77 | self.print_verbose_every = print_verbose_every 78 | 79 | # Initialize helper variables 80 | self.stem_use_coords = (stem_stride == 1) and (self.use_coords_freq > 0) 81 | self.condition_pattern = condition_pattern 82 | if len(condition_pattern) == 0: 83 | self.condition_pattern = [] 84 | for i in range(self.module_num_layers * self.num_modules): 85 | self.condition_pattern.append(self.condition_method != 'concat') 86 | else: 87 | self.condition_pattern = [i > 0 for i in self.condition_pattern] 88 | self.extra_channel_freq = self.use_coords_freq 89 | self.block = FiLMedResBlock 90 | self.num_cond_maps = 2 * self.module_dim if self.condition_method == 'concat' else 0 91 | self.fwd_count = 0 92 | self.num_extra_channels = 2 if self.use_coords_freq > 0 else 0 93 | if self.debug_every <= -1: 94 | self.print_verbose_every = 1 95 | module_H = feature_dim[1] // (stem_stride ** stem_num_layers) # Rough calc: work for main cases 96 | module_W = feature_dim[2] // (stem_stride ** stem_num_layers) # Rough calc: work for main cases 97 | self.coords = coord_map((module_H, module_W)) 98 | self.default_weight = Variable(torch.ones(1, 1, self.module_dim)).type(torch.cuda.FloatTensor) 99 | self.default_bias = Variable(torch.zeros(1, 1, self.module_dim)).type(torch.cuda.FloatTensor) 100 | 101 | # Initialize stem 102 | stem_feature_dim = feature_dim[0] + self.stem_use_coords * self.num_extra_channels 103 | self.stem = build_stem(stem_feature_dim, module_dim, 104 | num_layers=stem_num_layers, with_batchnorm=stem_batchnorm, 105 | kernel_size=stem_kernel_size, stride=stem_stride, padding=stem_padding) 106 | 107 | # Initialize FiLMed network body 108 | self.function_modules = {} 109 | self.vocab = vocab 110 | for fn_num in range(self.num_modules): 111 | with_cond = self.condition_pattern[self.module_num_layers * fn_num: 112 | self.module_num_layers * (fn_num + 1)] 113 | mod = self.block(module_dim, with_residual=module_residual, with_batchnorm=module_batchnorm, 114 | with_cond=with_cond, 115 | dropout=module_dropout, 116 | num_extra_channels=self.num_extra_channels, 117 | extra_channel_freq=self.extra_channel_freq, 118 | with_input_proj=module_input_proj, 119 | num_cond_maps=self.num_cond_maps, 120 | kernel_size=module_kernel_size, 121 | batchnorm_affine=module_batchnorm_affine, 122 | num_layers=self.module_num_layers, 123 | condition_method=condition_method, 124 | debug_every=self.debug_every) 125 | self.add_module(str(fn_num), mod) 126 | self.function_modules[fn_num] = mod 127 | 128 | # Initialize output classifier 129 | self.classifier = build_classifier(module_dim + self.num_extra_channels, module_H, module_W, 130 | num_answers, classifier_fc_layers, classifier_proj_dim, 131 | classifier_downsample, with_batchnorm=classifier_batchnorm, 132 | dropout=classifier_dropout) 133 | 134 | init_modules(self.modules()) 135 | 136 | def forward(self, x, film, save_activations=False): 137 | # Initialize forward pass and externally viewable activations 138 | self.fwd_count += 1 139 | if save_activations: 140 | self.feats = None 141 | self.module_outputs = [] 142 | self.cf_input = None 143 | 144 | if self.debug_every <= -2: 145 | pdb.set_trace() 146 | 147 | # Prepare FiLM layers 148 | gammas = None 149 | betas = None 150 | if self.condition_method == 'concat': 151 | # Use parameters usually used to condition via FiLM instead to condition via concatenation 152 | cond_params = film[:,:,:2*self.module_dim] 153 | cond_maps = cond_params.unsqueeze(3).unsqueeze(4).expand(cond_params.size() + x.size()[-2:]) 154 | else: 155 | gammas, betas = torch.split(film[:,:,:2*self.module_dim], self.module_dim, dim=-1) 156 | if not self.use_gamma: 157 | gammas = self.default_weight.expand_as(gammas) 158 | if not self.use_beta: 159 | betas = self.default_bias.expand_as(betas) 160 | 161 | # Propagate up image features CNN 162 | batch_coords = None 163 | if self.use_coords_freq > 0: 164 | batch_coords = self.coords.unsqueeze(0).expand(torch.Size((x.size(0), *self.coords.size()))) 165 | if self.stem_use_coords: 166 | x = torch.cat([x, batch_coords], 1) 167 | feats = self.stem(x) 168 | if save_activations: 169 | self.feats = feats 170 | N, _, H, W = feats.size() 171 | 172 | # Propagate up the network from low-to-high numbered blocks 173 | module_inputs = Variable(torch.zeros(feats.size()).unsqueeze(1).expand( 174 | N, self.num_modules, self.module_dim, H, W)).type(torch.cuda.FloatTensor) 175 | module_inputs[:,0] = feats 176 | for fn_num in range(self.num_modules): 177 | if self.condition_method == 'concat': 178 | layer_output = self.function_modules[fn_num](module_inputs[:,fn_num], 179 | extra_channels=batch_coords, cond_maps=cond_maps[:,fn_num]) 180 | else: 181 | layer_output = self.function_modules[fn_num](module_inputs[:,fn_num], 182 | gammas[:,fn_num,:], betas[:,fn_num,:], batch_coords) 183 | 184 | # Store for future computation 185 | if save_activations: 186 | self.module_outputs.append(layer_output) 187 | if fn_num == (self.num_modules - 1): 188 | final_module_output = layer_output 189 | else: 190 | module_inputs_updated = module_inputs.clone() 191 | module_inputs_updated[:,fn_num+1] = module_inputs_updated[:,fn_num+1] + layer_output 192 | module_inputs = module_inputs_updated 193 | 194 | if self.debug_every <= -2: 195 | pdb.set_trace() 196 | 197 | # Run the final classifier over the resultant, post-modulated features. 198 | if self.use_coords_freq > 0: 199 | final_module_output = torch.cat([final_module_output, batch_coords], 1) 200 | if save_activations: 201 | self.cf_input = final_module_output 202 | out = self.classifier(final_module_output) 203 | 204 | if ((self.fwd_count % self.debug_every) == 0) or (self.debug_every <= -1): 205 | pdb.set_trace() 206 | return out 207 | 208 | 209 | class FiLMedResBlock(nn.Module): 210 | def __init__(self, in_dim, out_dim=None, with_residual=True, with_batchnorm=True, 211 | with_cond=[False], dropout=0, num_extra_channels=0, extra_channel_freq=1, 212 | with_input_proj=0, num_cond_maps=0, kernel_size=3, batchnorm_affine=False, 213 | num_layers=1, condition_method='bn-film', debug_every=float('inf')): 214 | if out_dim is None: 215 | out_dim = in_dim 216 | super(FiLMedResBlock, self).__init__() 217 | self.with_residual = with_residual 218 | self.with_batchnorm = with_batchnorm 219 | self.with_cond = with_cond 220 | self.dropout = dropout 221 | self.extra_channel_freq = 0 if num_extra_channels == 0 else extra_channel_freq 222 | self.with_input_proj = with_input_proj # Kernel size of input projection 223 | self.num_cond_maps = num_cond_maps 224 | self.kernel_size = kernel_size 225 | self.batchnorm_affine = batchnorm_affine 226 | self.num_layers = num_layers 227 | self.condition_method = condition_method 228 | self.debug_every = debug_every 229 | 230 | if self.with_input_proj % 2 == 0: 231 | raise(NotImplementedError) 232 | if self.kernel_size % 2 == 0: 233 | raise(NotImplementedError) 234 | if self.num_layers >= 2: 235 | raise(NotImplementedError) 236 | 237 | if self.condition_method == 'block-input-film' and self.with_cond[0]: 238 | self.film = FiLM() 239 | if self.with_input_proj: 240 | self.input_proj = nn.Conv2d(in_dim + (num_extra_channels if self.extra_channel_freq >= 1 else 0), 241 | in_dim, kernel_size=self.with_input_proj, padding=self.with_input_proj // 2) 242 | 243 | self.conv1 = nn.Conv2d(in_dim + self.num_cond_maps + 244 | (num_extra_channels if self.extra_channel_freq >= 2 else 0), 245 | out_dim, kernel_size=self.kernel_size, 246 | padding=self.kernel_size // 2) 247 | if self.condition_method == 'conv-film' and self.with_cond[0]: 248 | self.film = FiLM() 249 | if self.with_batchnorm: 250 | self.bn1 = nn.BatchNorm2d(out_dim, affine=((not self.with_cond[0]) or self.batchnorm_affine)) 251 | if self.condition_method == 'bn-film' and self.with_cond[0]: 252 | self.film = FiLM() 253 | if dropout > 0: 254 | self.drop = nn.Dropout2d(p=self.dropout) 255 | if ((self.condition_method == 'relu-film' or self.condition_method == 'block-output-film') 256 | and self.with_cond[0]): 257 | self.film = FiLM() 258 | 259 | init_modules(self.modules()) 260 | 261 | def forward(self, x, gammas=None, betas=None, extra_channels=None, cond_maps=None): 262 | if self.debug_every <= -2: 263 | pdb.set_trace() 264 | 265 | if self.condition_method == 'block-input-film' and self.with_cond[0]: 266 | x = self.film(x, gammas, betas) 267 | 268 | # ResBlock input projection 269 | if self.with_input_proj: 270 | if extra_channels is not None and self.extra_channel_freq >= 1: 271 | x = torch.cat([x, extra_channels], 1) 272 | x = F.relu(self.input_proj(x)) 273 | out = x 274 | 275 | # ResBlock body 276 | if cond_maps is not None: 277 | out = torch.cat([out, cond_maps], 1) 278 | if extra_channels is not None and self.extra_channel_freq >= 2: 279 | out = torch.cat([out, extra_channels], 1) 280 | out = self.conv1(out) 281 | if self.condition_method == 'conv-film' and self.with_cond[0]: 282 | out = self.film(out, gammas, betas) 283 | if self.with_batchnorm: 284 | out = self.bn1(out) 285 | if self.condition_method == 'bn-film' and self.with_cond[0]: 286 | out = self.film(out, gammas, betas) 287 | if self.dropout > 0: 288 | out = self.drop(out) 289 | out = F.relu(out) 290 | if self.condition_method == 'relu-film' and self.with_cond[0]: 291 | out = self.film(out, gammas, betas) 292 | 293 | # ResBlock remainder 294 | if self.with_residual: 295 | out = x + out 296 | if self.condition_method == 'block-output-film' and self.with_cond[0]: 297 | out = self.film(out, gammas, betas) 298 | return out 299 | 300 | 301 | def coord_map(shape, start=-1, end=1): 302 | """ 303 | Gives, a 2d shape tuple, returns two mxn coordinate maps, 304 | Ranging min-max in the x and y directions, respectively. 305 | """ 306 | m, n = shape 307 | x_coord_row = torch.linspace(start, end, steps=n).type(torch.cuda.FloatTensor) 308 | y_coord_row = torch.linspace(start, end, steps=m).type(torch.cuda.FloatTensor) 309 | x_coords = x_coord_row.unsqueeze(0).expand(torch.Size((m, n))).unsqueeze(0) 310 | y_coords = y_coord_row.unsqueeze(1).expand(torch.Size((m, n))).unsqueeze(0) 311 | return Variable(torch.cat([x_coords, y_coords], 0)) 312 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /scripts/run_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import ipdb as pdb 9 | import json 10 | import random 11 | import shutil 12 | from termcolor import colored 13 | import time 14 | from tqdm import tqdm 15 | import sys 16 | import os 17 | sys.path.insert(0, os.path.abspath('.')) 18 | 19 | import torch 20 | from torch.autograd import Variable 21 | import torch.nn.functional as F 22 | import torchvision 23 | import numpy as np 24 | import h5py 25 | from scipy.misc import imread, imresize, imsave 26 | 27 | import vr.utils as utils 28 | import vr.programs 29 | from vr.data import ClevrDataset, ClevrDataLoader 30 | from vr.preprocess import tokenize, encode 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--program_generator', default='models/best.pt') 35 | parser.add_argument('--execution_engine', default='models/best.pt') 36 | parser.add_argument('--baseline_model', default=None) 37 | parser.add_argument('--model_type', default='FiLM') 38 | parser.add_argument('--debug_every', default=float('inf'), type=float) 39 | parser.add_argument('--use_gpu', default=1, type=int) 40 | 41 | # For running on a preprocessed dataset 42 | parser.add_argument('--input_question_h5', default=None) 43 | parser.add_argument('--input_features_h5', default=None) 44 | 45 | # This will override the vocab stored in the checkpoint; 46 | # we need this to run CLEVR models on human data 47 | parser.add_argument('--vocab_json', default=None) 48 | 49 | # For running on a single example 50 | parser.add_argument('--question', default=None) 51 | parser.add_argument('--image', default='img/CLEVR_val_000017.png') 52 | parser.add_argument('--cnn_model', default='resnet101') 53 | parser.add_argument('--cnn_model_stage', default=3, type=int) 54 | parser.add_argument('--image_width', default=224, type=int) 55 | parser.add_argument('--image_height', default=224, type=int) 56 | parser.add_argument('--enforce_clevr_vocab', default=1, type=int) 57 | 58 | parser.add_argument('--batch_size', default=64, type=int) 59 | parser.add_argument('--num_samples', default=None, type=int) 60 | parser.add_argument('--num_last_words_shuffled', default=0, type=int) # -1 for all shuffled 61 | parser.add_argument('--family_split_file', default=None) 62 | 63 | parser.add_argument('--sample_argmax', type=int, default=1) 64 | parser.add_argument('--temperature', default=1.0, type=float) 65 | 66 | # FiLM models only 67 | parser.add_argument('--gamma_option', default='linear', 68 | choices=['linear', 'sigmoid', 'tanh', 'exp', 'relu', 'softplus']) 69 | parser.add_argument('--gamma_scale', default=1, type=float) 70 | parser.add_argument('--gamma_shift', default=0, type=float) 71 | parser.add_argument('--gammas_from', default=None) # Load gammas from file 72 | parser.add_argument('--beta_option', default='linear', 73 | choices=['linear', 'sigmoid', 'tanh', 'exp', 'relu', 'softplus']) 74 | parser.add_argument('--beta_scale', default=1, type=float) 75 | parser.add_argument('--beta_shift', default=0, type=float) 76 | parser.add_argument('--betas_from', default=None) # Load betas from file 77 | 78 | # If this is passed, then save all predictions to this file 79 | parser.add_argument('--output_h5', default=None) 80 | parser.add_argument('--output_preds', default=None) 81 | parser.add_argument('--output_viz_dir', default='img/') 82 | parser.add_argument('--output_program_stats_dir', default=None) 83 | 84 | grads = {} 85 | programs = {} # NOTE: Useful for zero-shot program manipulation when in debug mode 86 | 87 | def main(args): 88 | if args.debug_every <= 1: 89 | pdb.set_trace() 90 | model = None 91 | if args.baseline_model is not None: 92 | print('Loading baseline model from ', args.baseline_model) 93 | model, _ = utils.load_baseline(args.baseline_model) 94 | if args.vocab_json is not None: 95 | new_vocab = utils.load_vocab(args.vocab_json) 96 | model.rnn.expand_vocab(new_vocab['question_token_to_idx']) 97 | elif args.program_generator is not None and args.execution_engine is not None: 98 | pg, _ = utils.load_program_generator(args.program_generator, args.model_type) 99 | ee, _ = utils.load_execution_engine( 100 | args.execution_engine, verbose=False, model_type=args.model_type) 101 | if args.vocab_json is not None: 102 | new_vocab = utils.load_vocab(args.vocab_json) 103 | pg.expand_encoder_vocab(new_vocab['question_token_to_idx']) 104 | model = (pg, ee) 105 | else: 106 | print('Must give either --baseline_model or --program_generator and --execution_engine') 107 | return 108 | 109 | dtype = torch.FloatTensor 110 | if args.use_gpu == 1: 111 | dtype = torch.cuda.FloatTensor 112 | if args.question is not None and args.image is not None: 113 | run_single_example(args, model, dtype, args.question) 114 | # Interactive mode 115 | elif args.image is not None and args.input_question_h5 is None and args.input_features_h5 is None: 116 | feats_var = extract_image_features(args, dtype) 117 | print(colored('Ask me something!', 'cyan')) 118 | while True: 119 | # Get user question 120 | question_raw = input(">>> ") 121 | run_single_example(args, model, dtype, question_raw, feats_var) 122 | else: 123 | vocab = load_vocab(args) 124 | loader_kwargs = { 125 | 'question_h5': args.input_question_h5, 126 | 'feature_h5': args.input_features_h5, 127 | 'vocab': vocab, 128 | 'batch_size': args.batch_size, 129 | } 130 | if args.num_samples is not None and args.num_samples > 0: 131 | loader_kwargs['max_samples'] = args.num_samples 132 | if args.family_split_file is not None: 133 | with open(args.family_split_file, 'r') as f: 134 | loader_kwargs['question_families'] = json.load(f) 135 | with ClevrDataLoader(**loader_kwargs) as loader: 136 | run_batch(args, model, dtype, loader) 137 | 138 | 139 | def extract_image_features(args, dtype): 140 | # Build the CNN to use for feature extraction 141 | print('Extracting image features...') 142 | cnn = build_cnn(args, dtype) 143 | 144 | # Load and preprocess the image 145 | img_size = (args.image_height, args.image_width) 146 | img = imread(args.image, mode='RGB') 147 | img = imresize(img, img_size, interp='bicubic') 148 | img = img.transpose(2, 0, 1)[None] 149 | mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) 150 | std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) 151 | img = (img.astype(np.float32) / 255.0 - mean) / std 152 | 153 | # Use CNN to extract features for the image 154 | img_var = Variable(torch.FloatTensor(img).type(dtype), volatile=False, requires_grad=True) 155 | feats_var = cnn(img_var) 156 | return feats_var 157 | 158 | 159 | def run_single_example(args, model, dtype, question_raw, feats_var=None): 160 | interactive = feats_var is not None 161 | if not interactive: 162 | feats_var = extract_image_features(args, dtype) 163 | 164 | # Tokenize the question 165 | vocab = load_vocab(args) 166 | question_tokens = tokenize(question_raw, 167 | punct_to_keep=[';', ','], 168 | punct_to_remove=['?', '.']) 169 | if args.enforce_clevr_vocab == 1: 170 | for word in question_tokens: 171 | if word not in vocab['question_token_to_idx']: 172 | print(colored('No one taught me what "%s" means :( Try me again!' % (word), 'magenta')) 173 | return 174 | question_encoded = encode(question_tokens, 175 | vocab['question_token_to_idx'], 176 | allow_unk=True) 177 | question_encoded = torch.LongTensor(question_encoded).view(1, -1) 178 | question_encoded = question_encoded.type(dtype).long() 179 | question_var = Variable(question_encoded, volatile=False) 180 | 181 | # Run the model 182 | scores = None 183 | predicted_program = None 184 | if type(model) is tuple: 185 | pg, ee = model 186 | pg.type(dtype) 187 | pg.eval() 188 | ee.type(dtype) 189 | ee.eval() 190 | if args.model_type == 'FiLM': 191 | predicted_program = pg(question_var) 192 | else: 193 | predicted_program = pg.reinforce_sample( 194 | question_var, 195 | temperature=args.temperature, 196 | argmax=(args.sample_argmax == 1)) 197 | programs[question_raw] = predicted_program 198 | if args.debug_every <= -1: 199 | pdb.set_trace() 200 | scores = ee(feats_var, predicted_program, save_activations=True) 201 | else: 202 | model.type(dtype) 203 | scores = model(question_var, feats_var) 204 | 205 | # Print results 206 | predicted_probs = scores.data.cpu() 207 | _, predicted_answer_idx = predicted_probs[0].max(dim=0) 208 | predicted_probs = F.softmax(Variable(predicted_probs[0])).data 209 | predicted_answer = vocab['answer_idx_to_token'][predicted_answer_idx[0]] 210 | 211 | answers_to_probs = {} 212 | for i in range(len(vocab['answer_idx_to_token'])): 213 | answers_to_probs[vocab['answer_idx_to_token'][i]] = predicted_probs[i] 214 | answers_to_probs_sorted = sorted(answers_to_probs.items(), key=lambda x: x[1]) 215 | answers_to_probs_sorted.reverse() 216 | for i in range(len(answers_to_probs_sorted)): 217 | if answers_to_probs_sorted[i][1] >= 1e-3 and args.debug_every < float('inf'): 218 | print("%s: %.1f%%" % (answers_to_probs_sorted[i][0].capitalize(), 219 | 100 * answers_to_probs_sorted[i][1])) 220 | 221 | if not interactive: 222 | print(colored('Question: "%s"' % question_raw, 'cyan')) 223 | print(colored(str(predicted_answer).capitalize(), 'magenta')) 224 | 225 | if interactive: 226 | return 227 | 228 | # Visualize Gradients w.r.t. output 229 | cf_conv = ee.classifier[0](ee.cf_input) 230 | cf_bn = ee.classifier[1](cf_conv) 231 | pre_pool = ee.classifier[2](cf_bn) 232 | pooled = ee.classifier[3](pre_pool) 233 | 234 | pre_pool_max_per_c = pre_pool.max(2)[0].max(3)[0].expand_as(pre_pool) 235 | pre_pool_masked = (pre_pool_max_per_c == pre_pool).float() * pre_pool 236 | pool_feat_locs = (pre_pool_masked > 0).float().sum(1) 237 | if args.debug_every <= 1: 238 | pdb.set_trace() 239 | 240 | if args.output_viz_dir != 'NA': 241 | viz_dir = args.output_viz_dir + question_raw + ' ' + predicted_answer 242 | if not os.path.isdir(viz_dir): 243 | os.mkdir(viz_dir) 244 | args.viz_dir = viz_dir 245 | print('Saving visualizations to ' + args.viz_dir) 246 | 247 | # Backprop w.r.t. sum of output scores - What affected prediction most? 248 | ee.feats.register_hook(save_grad('stem')) 249 | for i in range(ee.num_modules): 250 | ee.module_outputs[i].register_hook(save_grad('m' + str(i))) 251 | scores_sum = scores.sum() 252 | scores_sum.backward() 253 | 254 | # Visualizations! 255 | visualize(feats_var, args, 'resnet101') 256 | visualize(ee.feats, args, 'conv-stem') 257 | visualize(grads['stem'], args, 'grad-conv-stem') 258 | for i in range(ee.num_modules): 259 | visualize(ee.module_outputs[i], args, 'resblock' + str(i)) 260 | visualize(grads['m' + str(i)], args, 'grad-resblock' + str(i)) 261 | visualize(pre_pool, args, 'pre-pool') 262 | visualize(pool_feat_locs, args, 'pool-feature-locations') 263 | 264 | if (predicted_program is not None) and (args.model_type != 'FiLM'): 265 | print() 266 | print('Predicted program:') 267 | program = predicted_program.data.cpu()[0] 268 | num_inputs = 1 269 | for fn_idx in program: 270 | fn_str = vocab['program_idx_to_token'][fn_idx] 271 | num_inputs += vr.programs.get_num_inputs(fn_str) - 1 272 | print(fn_str) 273 | if num_inputs == 0: 274 | break 275 | 276 | 277 | def run_our_model_batch(args, pg, ee, loader, dtype): 278 | pg.type(dtype) 279 | pg.eval() 280 | ee.type(dtype) 281 | ee.eval() 282 | 283 | all_scores, all_programs = [], [] 284 | all_probs = [] 285 | all_preds = [] 286 | num_correct, num_samples = 0, 0 287 | 288 | loaded_gammas = None 289 | loaded_betas = None 290 | if args.gammas_from: 291 | print('Loading ') 292 | loaded_gammas = torch.load(args.gammas_from) 293 | if args.betas_from: 294 | print('Betas loaded!') 295 | loaded_betas = torch.load(args.betas_from) 296 | 297 | q_types = [] 298 | film_params = [] 299 | 300 | if args.num_last_words_shuffled == -1: 301 | print('All words of each question shuffled.') 302 | elif args.num_last_words_shuffled > 0: 303 | print('Last %d words of each question shuffled.' % args.num_last_words_shuffled) 304 | start = time.time() 305 | for batch in tqdm(loader): 306 | assert(not pg.training) 307 | assert(not ee.training) 308 | questions, images, feats, answers, programs, program_lists = batch 309 | 310 | if args.num_last_words_shuffled != 0: 311 | for i, question in enumerate(questions): 312 | # Search for token to find question length 313 | q_end = get_index(question.numpy().tolist(), index=2, default=len(question)) 314 | if args.num_last_words_shuffled > 0: 315 | q_end -= args.num_last_words_shuffled # Leave last few words unshuffled 316 | if q_end < 2: 317 | q_end = 2 318 | question = question[1:q_end] 319 | random.shuffle(question) 320 | questions[i][1:q_end] = question 321 | 322 | if isinstance(questions, list): 323 | questions_var = Variable(questions[0].type(dtype).long(), volatile=True) 324 | q_types += [questions[1].cpu().numpy()] 325 | else: 326 | questions_var = Variable(questions.type(dtype).long(), volatile=True) 327 | feats_var = Variable(feats.type(dtype), volatile=True) 328 | if args.model_type == 'FiLM': 329 | programs_pred = pg(questions_var) 330 | # Examine effect of various conditioning modifications at test time! 331 | programs_pred = pg.modify_output(programs_pred, gamma_option=args.gamma_option, 332 | gamma_scale=args.gamma_scale, gamma_shift=args.gamma_shift, 333 | beta_option=args.beta_option, beta_scale=args.beta_scale, 334 | beta_shift=args.beta_shift) 335 | if args.gammas_from: 336 | programs_pred[:,:,:pg.module_dim] = loaded_gammas.expand_as( 337 | programs_pred[:,:,:pg.module_dim]) 338 | if args.betas_from: 339 | programs_pred[:,:,pg.module_dim:2*pg.module_dim] = loaded_betas.expand_as( 340 | programs_pred[:,:,pg.module_dim:2*pg.module_dim]) 341 | else: 342 | programs_pred = pg.reinforce_sample( 343 | questions_var, 344 | temperature=args.temperature, 345 | argmax=(args.sample_argmax == 1)) 346 | 347 | film_params += [programs_pred.cpu().data.numpy()] 348 | scores = ee(feats_var, programs_pred, save_activations=True) 349 | probs = F.softmax(scores) 350 | 351 | _, preds = scores.data.cpu().max(1) 352 | all_programs.append(programs_pred.data.cpu().clone()) 353 | all_scores.append(scores.data.cpu().clone()) 354 | all_probs.append(probs.data.cpu().clone()) 355 | all_preds.append(preds.cpu().clone()) 356 | if answers[0] is not None: 357 | num_correct += (preds == answers).sum() 358 | num_samples += preds.size(0) 359 | 360 | acc = float(num_correct) / num_samples 361 | print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) 362 | print('%.2fs to evaluate' % (start - time.time())) 363 | all_programs = torch.cat(all_programs, 0) 364 | all_scores = torch.cat(all_scores, 0) 365 | all_probs = torch.cat(all_probs, 0) 366 | all_preds = torch.cat(all_preds, 0).squeeze().numpy() 367 | if args.output_h5 is not None: 368 | print('Writing output to "%s"' % args.output_h5) 369 | with h5py.File(args.output_h5, 'w') as fout: 370 | fout.create_dataset('scores', data=all_scores.numpy()) 371 | fout.create_dataset('probs', data=all_probs.numpy()) 372 | fout.create_dataset('predicted_programs', data=all_programs.numpy()) 373 | 374 | # Save FiLM params 375 | np.save('film_params', np.vstack(film_params)) 376 | if isinstance(questions, list): 377 | np.save('q_types', np.vstack(q_types)) 378 | 379 | # Save FiLM param stats 380 | if args.output_program_stats_dir: 381 | if not os.path.isdir(args.output_program_stats_dir): 382 | os.mkdir(args.output_program_stats_dir) 383 | gammas = all_programs[:,:,:pg.module_dim] 384 | betas = all_programs[:,:,pg.module_dim:2*pg.module_dim] 385 | gamma_means = gammas.mean(0) 386 | torch.save(gamma_means, os.path.join(args.output_program_stats_dir, 'gamma_means')) 387 | beta_means = betas.mean(0) 388 | torch.save(beta_means, os.path.join(args.output_program_stats_dir, 'beta_means')) 389 | gamma_medians = gammas.median(0)[0] 390 | torch.save(gamma_medians, os.path.join(args.output_program_stats_dir, 'gamma_medians')) 391 | beta_medians = betas.median(0)[0] 392 | torch.save(beta_medians, os.path.join(args.output_program_stats_dir, 'beta_medians')) 393 | 394 | # Note: Takes O(10GB) space 395 | torch.save(gammas, os.path.join(args.output_program_stats_dir, 'gammas')) 396 | torch.save(betas, os.path.join(args.output_program_stats_dir, 'betas')) 397 | 398 | if args.output_preds is not None: 399 | vocab = load_vocab(args) 400 | all_preds_strings = [] 401 | for i in range(len(all_preds)): 402 | all_preds_strings.append(vocab['answer_idx_to_token'][all_preds[i]]) 403 | save_to_file(all_preds_strings, args.output_preds) 404 | 405 | if args.debug_every <= 1: 406 | pdb.set_trace() 407 | return 408 | 409 | 410 | def visualize(features, args, file_name=None): 411 | """ 412 | Converts a 4d map of features to alpha attention weights, 413 | According to their 2-Norm across dimensions 0 and 1. 414 | Then saves the input RGB image as an RGBA image using an upsampling of this attention map. 415 | """ 416 | save_file = os.path.join(args.viz_dir, file_name) 417 | img_path = args.image 418 | 419 | # Scale map to [0, 1] 420 | f_map = (features ** 2).mean(0).mean(1).squeeze().sqrt() 421 | f_map_shifted = f_map - f_map.min().expand_as(f_map) 422 | f_map_scaled = f_map_shifted / f_map_shifted.max().expand_as(f_map_shifted) 423 | 424 | if save_file is None: 425 | print(f_map_scaled) 426 | else: 427 | # Read original image 428 | img = imread(img_path, mode='RGB') 429 | orig_img_size = img.shape 430 | 431 | # Convert to image format 432 | alpha = (255 * f_map_scaled).round() 433 | alpha4d = alpha.unsqueeze(0).unsqueeze(0) 434 | alpha_upsampled = torch.nn.functional.upsample_bilinear( 435 | alpha4d, size=torch.Size(orig_img_size)).squeeze(0).transpose(1, 0).transpose(1, 2) 436 | alpha_upsampled_np = alpha_upsampled.cpu().data.numpy() 437 | 438 | # Create and save visualization 439 | imga = np.concatenate([img, alpha_upsampled_np], axis=2) 440 | if save_file[-4:] != '.png': save_file += '.png' 441 | imsave(save_file, imga) 442 | 443 | return f_map_scaled 444 | 445 | 446 | def build_cnn(args, dtype): 447 | if not hasattr(torchvision.models, args.cnn_model): 448 | raise ValueError('Invalid model "%s"' % args.cnn_model) 449 | if not 'resnet' in args.cnn_model: 450 | raise ValueError('Feature extraction only supports ResNets') 451 | whole_cnn = getattr(torchvision.models, args.cnn_model)(pretrained=True) 452 | layers = [ 453 | whole_cnn.conv1, 454 | whole_cnn.bn1, 455 | whole_cnn.relu, 456 | whole_cnn.maxpool, 457 | ] 458 | for i in range(args.cnn_model_stage): 459 | name = 'layer%d' % (i + 1) 460 | layers.append(getattr(whole_cnn, name)) 461 | cnn = torch.nn.Sequential(*layers) 462 | cnn.type(dtype) 463 | cnn.eval() 464 | return cnn 465 | 466 | 467 | def run_batch(args, model, dtype, loader): 468 | if type(model) is tuple: 469 | pg, ee = model 470 | run_our_model_batch(args, pg, ee, loader, dtype) 471 | else: 472 | run_baseline_batch(args, model, loader, dtype) 473 | 474 | 475 | def run_baseline_batch(args, model, loader, dtype): 476 | model.type(dtype) 477 | model.eval() 478 | 479 | all_scores, all_probs = [], [] 480 | num_correct, num_samples = 0, 0 481 | for batch in loader: 482 | questions, images, feats, answers, programs, program_lists = batch 483 | 484 | questions_var = Variable(questions.type(dtype).long(), volatile=True) 485 | feats_var = Variable(feats.type(dtype), volatile=True) 486 | scores = model(questions_var, feats_var) 487 | probs = F.softmax(scores) 488 | 489 | _, preds = scores.data.cpu().max(1) 490 | all_scores.append(scores.data.cpu().clone()) 491 | all_probs.append(probs.data.cpu().clone()) 492 | 493 | num_correct += (preds == answers).sum() 494 | num_samples += preds.size(0) 495 | print('Ran %d samples' % num_samples) 496 | 497 | acc = float(num_correct) / num_samples 498 | print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc)) 499 | 500 | all_scores = torch.cat(all_scores, 0) 501 | all_probs = torch.cat(all_probs, 0) 502 | if args.output_h5 is not None: 503 | print('Writing output to %s' % args.output_h5) 504 | with h5py.File(args.output_h5, 'w') as fout: 505 | fout.create_dataset('scores', data=all_scores.numpy()) 506 | fout.create_dataset('probs', data=all_probs.numpy()) 507 | 508 | 509 | def load_vocab(args): 510 | path = None 511 | if args.baseline_model is not None: 512 | path = args.baseline_model 513 | elif args.program_generator is not None: 514 | path = args.program_generator 515 | elif args.execution_engine is not None: 516 | path = args.execution_engine 517 | return utils.load_cpu(path)['vocab'] 518 | 519 | 520 | def save_grad(name): 521 | def hook(grad): 522 | grads[name] = grad 523 | return hook 524 | 525 | 526 | def save_to_file(text, filename): 527 | with open(filename, mode='wt', encoding='utf-8') as myfile: 528 | myfile.write('\n'.join(text)) 529 | myfile.write('\n') 530 | 531 | 532 | def get_index(l, index, default=-1): 533 | try: 534 | return l.index(index) 535 | except ValueError: 536 | return default 537 | 538 | 539 | if __name__ == '__main__': 540 | args = parser.parse_args() 541 | main(args) 542 | -------------------------------------------------------------------------------- /scripts/train_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import sys 10 | import os 11 | sys.path.insert(0, os.path.abspath('.')) 12 | 13 | import argparse 14 | import ipdb as pdb 15 | import json 16 | import random 17 | import shutil 18 | from termcolor import colored 19 | import time 20 | 21 | import torch 22 | torch.backends.cudnn.enabled = True 23 | from torch.autograd import Variable 24 | import torch.nn.functional as F 25 | import numpy as np 26 | import h5py 27 | 28 | import vr.utils as utils 29 | import vr.preprocess 30 | from vr.data import ClevrDataset, ClevrDataLoader 31 | from vr.models import ModuleNet, Seq2Seq, LstmModel, CnnLstmModel, CnnLstmSaModel 32 | from vr.models import FiLMedNet 33 | from vr.models import FiLMGen 34 | 35 | parser = argparse.ArgumentParser() 36 | 37 | # Input data 38 | parser.add_argument('--train_question_h5', default='data/train_questions.h5') 39 | parser.add_argument('--train_features_h5', default='data/train_features.h5') 40 | parser.add_argument('--val_question_h5', default='data/val_questions.h5') 41 | parser.add_argument('--val_features_h5', default='data/val_features.h5') 42 | parser.add_argument('--feature_dim', default='1024,14,14') 43 | parser.add_argument('--vocab_json', default='data/vocab.json') 44 | 45 | parser.add_argument('--loader_num_workers', type=int, default=1) 46 | parser.add_argument('--use_local_copies', default=0, type=int) 47 | parser.add_argument('--cleanup_local_copies', default=1, type=int) 48 | 49 | parser.add_argument('--family_split_file', default=None) 50 | parser.add_argument('--num_train_samples', default=None, type=int) 51 | parser.add_argument('--num_val_samples', default=None, type=int) 52 | parser.add_argument('--shuffle_train_data', default=1, type=int) 53 | 54 | # What type of model to use and which parts to train 55 | parser.add_argument('--model_type', default='PG', 56 | choices=['FiLM', 'PG', 'EE', 'PG+EE', 'LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']) 57 | parser.add_argument('--train_program_generator', default=1, type=int) 58 | parser.add_argument('--train_execution_engine', default=1, type=int) 59 | parser.add_argument('--baseline_train_only_rnn', default=0, type=int) 60 | 61 | # Start from an existing checkpoint 62 | parser.add_argument('--program_generator_start_from', default=None) 63 | parser.add_argument('--execution_engine_start_from', default=None) 64 | parser.add_argument('--baseline_start_from', default=None) 65 | 66 | # RNN options 67 | parser.add_argument('--rnn_wordvec_dim', default=300, type=int) 68 | parser.add_argument('--rnn_hidden_dim', default=256, type=int) 69 | parser.add_argument('--rnn_num_layers', default=2, type=int) 70 | parser.add_argument('--rnn_dropout', default=0, type=float) 71 | 72 | # Module net / FiLMedNet options 73 | parser.add_argument('--module_stem_num_layers', default=2, type=int) 74 | parser.add_argument('--module_stem_batchnorm', default=0, type=int) 75 | parser.add_argument('--module_dim', default=128, type=int) 76 | parser.add_argument('--module_residual', default=1, type=int) 77 | parser.add_argument('--module_batchnorm', default=0, type=int) 78 | 79 | # FiLM only options 80 | parser.add_argument('--set_execution_engine_eval', default=0, type=int) 81 | parser.add_argument('--program_generator_parameter_efficient', default=1, type=int) 82 | parser.add_argument('--rnn_output_batchnorm', default=0, type=int) 83 | parser.add_argument('--bidirectional', default=0, type=int) 84 | parser.add_argument('--encoder_type', default='gru', type=str, 85 | choices=['linear', 'gru', 'lstm']) 86 | parser.add_argument('--decoder_type', default='linear', type=str, 87 | choices=['linear', 'gru', 'lstm']) 88 | parser.add_argument('--gamma_option', default='linear', 89 | choices=['linear', 'sigmoid', 'tanh', 'exp']) 90 | parser.add_argument('--gamma_baseline', default=1, type=float) 91 | parser.add_argument('--num_modules', default=4, type=int) 92 | parser.add_argument('--module_stem_kernel_size', default=3, type=int) 93 | parser.add_argument('--module_stem_stride', default=1, type=int) 94 | parser.add_argument('--module_stem_padding', default=None, type=int) 95 | parser.add_argument('--module_num_layers', default=1, type=int) # Only mnl=1 currently implemented 96 | parser.add_argument('--module_batchnorm_affine', default=0, type=int) # 1 overrides other factors 97 | parser.add_argument('--module_dropout', default=5e-2, type=float) 98 | parser.add_argument('--module_input_proj', default=1, type=int) # Inp conv kernel size (0 for None) 99 | parser.add_argument('--module_kernel_size', default=3, type=int) 100 | parser.add_argument('--condition_method', default='bn-film', type=str, 101 | choices=['block-input-film', 'block-output-film', 'bn-film', 'concat', 'conv-film', 'relu-film']) 102 | parser.add_argument('--condition_pattern', default='', type=str) # List of 0/1's (len = # FiLMs) 103 | parser.add_argument('--use_gamma', default=1, type=int) 104 | parser.add_argument('--use_beta', default=1, type=int) 105 | parser.add_argument('--use_coords', default=1, type=int) # 0: none, 1: low usage, 2: high usage 106 | parser.add_argument('--grad_clip', default=0, type=float) # <= 0 for no grad clipping 107 | parser.add_argument('--debug_every', default=float('inf'), type=float) # inf for no pdb 108 | parser.add_argument('--print_verbose_every', default=float('inf'), type=float) # inf for min print 109 | 110 | # CNN options (for baselines) 111 | parser.add_argument('--cnn_res_block_dim', default=128, type=int) 112 | parser.add_argument('--cnn_num_res_blocks', default=0, type=int) 113 | parser.add_argument('--cnn_proj_dim', default=512, type=int) 114 | parser.add_argument('--cnn_pooling', default='maxpool2', 115 | choices=['none', 'maxpool2']) 116 | 117 | # Stacked-Attention options 118 | parser.add_argument('--stacked_attn_dim', default=512, type=int) 119 | parser.add_argument('--num_stacked_attn', default=2, type=int) 120 | 121 | # Classifier options 122 | parser.add_argument('--classifier_proj_dim', default=512, type=int) 123 | parser.add_argument('--classifier_downsample', default='maxpool2', 124 | choices=['maxpool2', 'maxpool3', 'maxpool4', 'maxpool5', 'maxpool7', 'maxpoolfull', 'none', 125 | 'avgpool2', 'avgpool3', 'avgpool4', 'avgpool5', 'avgpool7', 'avgpoolfull', 'aggressive']) 126 | parser.add_argument('--classifier_fc_dims', default='1024') 127 | parser.add_argument('--classifier_batchnorm', default=0, type=int) 128 | parser.add_argument('--classifier_dropout', default=0, type=float) 129 | 130 | # Optimization options 131 | parser.add_argument('--batch_size', default=64, type=int) 132 | parser.add_argument('--num_iterations', default=100000, type=int) 133 | parser.add_argument('--optimizer', default='Adam', 134 | choices=['Adadelta', 'Adagrad', 'Adam', 'Adamax', 'ASGD', 'RMSprop', 'SGD']) 135 | parser.add_argument('--learning_rate', default=5e-4, type=float) 136 | parser.add_argument('--reward_decay', default=0.9, type=float) 137 | parser.add_argument('--weight_decay', default=0, type=float) 138 | 139 | # Output options 140 | parser.add_argument('--checkpoint_path', default='data/checkpoint.pt') 141 | parser.add_argument('--randomize_checkpoint_path', type=int, default=0) 142 | parser.add_argument('--avoid_checkpoint_override', default=0, type=int) 143 | parser.add_argument('--record_loss_every', default=1, type=int) 144 | parser.add_argument('--checkpoint_every', default=10000, type=int) 145 | parser.add_argument('--time', default=0, type=int) 146 | 147 | 148 | def main(args): 149 | if args.randomize_checkpoint_path == 1: 150 | name, ext = os.path.splitext(args.checkpoint_path) 151 | num = random.randint(1, 1000000) 152 | args.checkpoint_path = '%s_%06d%s' % (name, num, ext) 153 | print('Will save checkpoints to %s' % args.checkpoint_path) 154 | 155 | vocab = utils.load_vocab(args.vocab_json) 156 | 157 | if args.use_local_copies == 1: 158 | shutil.copy(args.train_question_h5, '/tmp/train_questions.h5') 159 | shutil.copy(args.train_features_h5, '/tmp/train_features.h5') 160 | shutil.copy(args.val_question_h5, '/tmp/val_questions.h5') 161 | shutil.copy(args.val_features_h5, '/tmp/val_features.h5') 162 | args.train_question_h5 = '/tmp/train_questions.h5' 163 | args.train_features_h5 = '/tmp/train_features.h5' 164 | args.val_question_h5 = '/tmp/val_questions.h5' 165 | args.val_features_h5 = '/tmp/val_features.h5' 166 | 167 | question_families = None 168 | if args.family_split_file is not None: 169 | with open(args.family_split_file, 'r') as f: 170 | question_families = json.load(f) 171 | 172 | train_loader_kwargs = { 173 | 'question_h5': args.train_question_h5, 174 | 'feature_h5': args.train_features_h5, 175 | 'vocab': vocab, 176 | 'batch_size': args.batch_size, 177 | 'shuffle': args.shuffle_train_data == 1, 178 | 'question_families': question_families, 179 | 'max_samples': args.num_train_samples, 180 | 'num_workers': args.loader_num_workers, 181 | } 182 | val_loader_kwargs = { 183 | 'question_h5': args.val_question_h5, 184 | 'feature_h5': args.val_features_h5, 185 | 'vocab': vocab, 186 | 'batch_size': args.batch_size, 187 | 'question_families': question_families, 188 | 'max_samples': args.num_val_samples, 189 | 'num_workers': args.loader_num_workers, 190 | } 191 | 192 | with ClevrDataLoader(**train_loader_kwargs) as train_loader, \ 193 | ClevrDataLoader(**val_loader_kwargs) as val_loader: 194 | train_loop(args, train_loader, val_loader) 195 | 196 | if args.use_local_copies == 1 and args.cleanup_local_copies == 1: 197 | os.remove('/tmp/train_questions.h5') 198 | os.remove('/tmp/train_features.h5') 199 | os.remove('/tmp/val_questions.h5') 200 | os.remove('/tmp/val_features.h5') 201 | 202 | 203 | def train_loop(args, train_loader, val_loader): 204 | vocab = utils.load_vocab(args.vocab_json) 205 | program_generator, pg_kwargs, pg_optimizer = None, None, None 206 | execution_engine, ee_kwargs, ee_optimizer = None, None, None 207 | baseline_model, baseline_kwargs, baseline_optimizer = None, None, None 208 | baseline_type = None 209 | 210 | pg_best_state, ee_best_state, baseline_best_state = None, None, None 211 | 212 | # Set up model 213 | optim_method = getattr(torch.optim, args.optimizer) 214 | if args.model_type in ['FiLM', 'PG', 'PG+EE']: 215 | program_generator, pg_kwargs = get_program_generator(args) 216 | pg_optimizer = optim_method(program_generator.parameters(), 217 | lr=args.learning_rate, 218 | weight_decay=args.weight_decay) 219 | print('Here is the conditioning network:') 220 | print(program_generator) 221 | if args.model_type in ['FiLM', 'EE', 'PG+EE']: 222 | execution_engine, ee_kwargs = get_execution_engine(args) 223 | ee_optimizer = optim_method(execution_engine.parameters(), 224 | lr=args.learning_rate, 225 | weight_decay=args.weight_decay) 226 | print('Here is the conditioned network:') 227 | print(execution_engine) 228 | if args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']: 229 | baseline_model, baseline_kwargs = get_baseline_model(args) 230 | params = baseline_model.parameters() 231 | if args.baseline_train_only_rnn == 1: 232 | params = baseline_model.rnn.parameters() 233 | baseline_optimizer = optim_method(params, 234 | lr=args.learning_rate, 235 | weight_decay=args.weight_decay) 236 | print('Here is the baseline model') 237 | print(baseline_model) 238 | baseline_type = args.model_type 239 | loss_fn = torch.nn.CrossEntropyLoss().cuda() 240 | 241 | stats = { 242 | 'train_losses': [], 'train_rewards': [], 'train_losses_ts': [], 243 | 'train_accs': [], 'val_accs': [], 'val_accs_ts': [], 244 | 'best_val_acc': -1, 'model_t': 0, 245 | } 246 | t, epoch, reward_moving_average = 0, 0, 0 247 | 248 | set_mode('train', [program_generator, execution_engine, baseline_model]) 249 | 250 | print('train_loader has %d samples' % len(train_loader.dataset)) 251 | print('val_loader has %d samples' % len(val_loader.dataset)) 252 | 253 | num_checkpoints = 0 254 | epoch_start_time = 0.0 255 | epoch_total_time = 0.0 256 | train_pass_total_time = 0.0 257 | val_pass_total_time = 0.0 258 | running_loss = 0.0 259 | while t < args.num_iterations: 260 | if (epoch > 0) and (args.time == 1): 261 | epoch_time = time.time() - epoch_start_time 262 | epoch_total_time += epoch_time 263 | print(colored('EPOCH PASS AVG TIME: ' + str(epoch_total_time / epoch), 'white')) 264 | print(colored('Epoch Pass Time : ' + str(epoch_time), 'white')) 265 | epoch_start_time = time.time() 266 | 267 | epoch += 1 268 | print('Starting epoch %d' % epoch) 269 | for batch in train_loader: 270 | t += 1 271 | questions, _, feats, answers, programs, _ = batch 272 | if isinstance(questions, list): 273 | questions = questions[0] 274 | questions_var = Variable(questions.cuda()) 275 | feats_var = Variable(feats.cuda()) 276 | answers_var = Variable(answers.cuda()) 277 | if programs[0] is not None: 278 | programs_var = Variable(programs.cuda()) 279 | 280 | reward = None 281 | if args.model_type == 'PG': 282 | # Train program generator with ground-truth programs 283 | pg_optimizer.zero_grad() 284 | loss = program_generator(questions_var, programs_var) 285 | loss.backward() 286 | pg_optimizer.step() 287 | elif args.model_type == 'EE': 288 | # Train execution engine with ground-truth programs 289 | ee_optimizer.zero_grad() 290 | scores = execution_engine(feats_var, programs_var) 291 | loss = loss_fn(scores, answers_var) 292 | loss.backward() 293 | ee_optimizer.step() 294 | elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']: 295 | baseline_optimizer.zero_grad() 296 | baseline_model.zero_grad() 297 | scores = baseline_model(questions_var, feats_var) 298 | loss = loss_fn(scores, answers_var) 299 | loss.backward() 300 | baseline_optimizer.step() 301 | elif args.model_type == 'PG+EE': 302 | programs_pred = program_generator.reinforce_sample(questions_var) 303 | scores = execution_engine(feats_var, programs_pred) 304 | 305 | loss = loss_fn(scores, answers_var) 306 | _, preds = scores.data.cpu().max(1) 307 | raw_reward = (preds == answers).float() 308 | reward_moving_average *= args.reward_decay 309 | reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean() 310 | centered_reward = raw_reward - reward_moving_average 311 | 312 | if args.train_execution_engine == 1: 313 | ee_optimizer.zero_grad() 314 | loss.backward() 315 | ee_optimizer.step() 316 | 317 | if args.train_program_generator == 1: 318 | pg_optimizer.zero_grad() 319 | program_generator.reinforce_backward(centered_reward.cuda()) 320 | pg_optimizer.step() 321 | elif args.model_type == 'FiLM': 322 | if args.set_execution_engine_eval == 1: 323 | set_mode('eval', [execution_engine]) 324 | programs_pred = program_generator(questions_var) 325 | scores = execution_engine(feats_var, programs_pred) 326 | loss = loss_fn(scores, answers_var) 327 | 328 | pg_optimizer.zero_grad() 329 | ee_optimizer.zero_grad() 330 | if args.debug_every <= -2: 331 | pdb.set_trace() 332 | loss.backward() 333 | if args.debug_every < float('inf'): 334 | check_grad_num_nans(execution_engine, 'FiLMedNet') 335 | check_grad_num_nans(program_generator, 'FiLMGen') 336 | 337 | if args.train_program_generator == 1: 338 | if args.grad_clip > 0: 339 | torch.nn.utils.clip_grad_norm(program_generator.parameters(), args.grad_clip) 340 | pg_optimizer.step() 341 | if args.train_execution_engine == 1: 342 | if args.grad_clip > 0: 343 | torch.nn.utils.clip_grad_norm(execution_engine.parameters(), args.grad_clip) 344 | ee_optimizer.step() 345 | 346 | if t % args.record_loss_every == 0: 347 | running_loss += loss.data[0] 348 | avg_loss = running_loss / args.record_loss_every 349 | print(t, avg_loss) 350 | stats['train_losses'].append(avg_loss) 351 | stats['train_losses_ts'].append(t) 352 | if reward is not None: 353 | stats['train_rewards'].append(reward) 354 | running_loss = 0.0 355 | else: 356 | running_loss += loss.data[0] 357 | 358 | if t % args.checkpoint_every == 0: 359 | num_checkpoints += 1 360 | print('Checking training accuracy ... ') 361 | start = time.time() 362 | train_acc = check_accuracy(args, program_generator, execution_engine, 363 | baseline_model, train_loader) 364 | if args.time == 1: 365 | train_pass_time = (time.time() - start) 366 | train_pass_total_time += train_pass_time 367 | print(colored('TRAIN PASS AVG TIME: ' + str(train_pass_total_time / num_checkpoints), 'red')) 368 | print(colored('Train Pass Time : ' + str(train_pass_time), 'red')) 369 | print('train accuracy is', train_acc) 370 | print('Checking validation accuracy ...') 371 | start = time.time() 372 | val_acc = check_accuracy(args, program_generator, execution_engine, 373 | baseline_model, val_loader) 374 | if args.time == 1: 375 | val_pass_time = (time.time() - start) 376 | val_pass_total_time += val_pass_time 377 | print(colored('VAL PASS AVG TIME: ' + str(val_pass_total_time / num_checkpoints), 'cyan')) 378 | print(colored('Val Pass Time : ' + str(val_pass_time), 'cyan')) 379 | print('val accuracy is ', val_acc) 380 | stats['train_accs'].append(train_acc) 381 | stats['val_accs'].append(val_acc) 382 | stats['val_accs_ts'].append(t) 383 | 384 | if val_acc > stats['best_val_acc']: 385 | stats['best_val_acc'] = val_acc 386 | stats['model_t'] = t 387 | best_pg_state = get_state(program_generator) 388 | best_ee_state = get_state(execution_engine) 389 | best_baseline_state = get_state(baseline_model) 390 | 391 | checkpoint = { 392 | 'args': args.__dict__, 393 | 'program_generator_kwargs': pg_kwargs, 394 | 'program_generator_state': best_pg_state, 395 | 'execution_engine_kwargs': ee_kwargs, 396 | 'execution_engine_state': best_ee_state, 397 | 'baseline_kwargs': baseline_kwargs, 398 | 'baseline_state': best_baseline_state, 399 | 'baseline_type': baseline_type, 400 | 'vocab': vocab 401 | } 402 | for k, v in stats.items(): 403 | checkpoint[k] = v 404 | print('Saving checkpoint to %s' % args.checkpoint_path) 405 | torch.save(checkpoint, args.checkpoint_path) 406 | del checkpoint['program_generator_state'] 407 | del checkpoint['execution_engine_state'] 408 | del checkpoint['baseline_state'] 409 | with open(args.checkpoint_path + '.json', 'w') as f: 410 | json.dump(checkpoint, f) 411 | 412 | if t == args.num_iterations: 413 | break 414 | 415 | 416 | def parse_int_list(s): 417 | if s == '': return () 418 | return tuple(int(n) for n in s.split(',')) 419 | 420 | 421 | def get_state(m): 422 | if m is None: 423 | return None 424 | state = {} 425 | for k, v in m.state_dict().items(): 426 | state[k] = v.clone() 427 | return state 428 | 429 | 430 | def get_program_generator(args): 431 | vocab = utils.load_vocab(args.vocab_json) 432 | if args.program_generator_start_from is not None: 433 | pg, kwargs = utils.load_program_generator( 434 | args.program_generator_start_from, model_type=args.model_type) 435 | cur_vocab_size = pg.encoder_embed.weight.size(0) 436 | if cur_vocab_size != len(vocab['question_token_to_idx']): 437 | print('Expanding vocabulary of program generator') 438 | pg.expand_encoder_vocab(vocab['question_token_to_idx']) 439 | kwargs['encoder_vocab_size'] = len(vocab['question_token_to_idx']) 440 | else: 441 | kwargs = { 442 | 'encoder_vocab_size': len(vocab['question_token_to_idx']), 443 | 'decoder_vocab_size': len(vocab['program_token_to_idx']), 444 | 'wordvec_dim': args.rnn_wordvec_dim, 445 | 'hidden_dim': args.rnn_hidden_dim, 446 | 'rnn_num_layers': args.rnn_num_layers, 447 | 'rnn_dropout': args.rnn_dropout, 448 | } 449 | if args.model_type == 'FiLM': 450 | kwargs['parameter_efficient'] = args.program_generator_parameter_efficient == 1 451 | kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 452 | kwargs['bidirectional'] = args.bidirectional == 1 453 | kwargs['encoder_type'] = args.encoder_type 454 | kwargs['decoder_type'] = args.decoder_type 455 | kwargs['gamma_option'] = args.gamma_option 456 | kwargs['gamma_baseline'] = args.gamma_baseline 457 | kwargs['num_modules'] = args.num_modules 458 | kwargs['module_num_layers'] = args.module_num_layers 459 | kwargs['module_dim'] = args.module_dim 460 | kwargs['debug_every'] = args.debug_every 461 | pg = FiLMGen(**kwargs) 462 | else: 463 | pg = Seq2Seq(**kwargs) 464 | pg.cuda() 465 | pg.train() 466 | return pg, kwargs 467 | 468 | 469 | def get_execution_engine(args): 470 | vocab = utils.load_vocab(args.vocab_json) 471 | if args.execution_engine_start_from is not None: 472 | ee, kwargs = utils.load_execution_engine( 473 | args.execution_engine_start_from, model_type=args.model_type) 474 | else: 475 | kwargs = { 476 | 'vocab': vocab, 477 | 'feature_dim': parse_int_list(args.feature_dim), 478 | 'stem_batchnorm': args.module_stem_batchnorm == 1, 479 | 'stem_num_layers': args.module_stem_num_layers, 480 | 'module_dim': args.module_dim, 481 | 'module_residual': args.module_residual == 1, 482 | 'module_batchnorm': args.module_batchnorm == 1, 483 | 'classifier_proj_dim': args.classifier_proj_dim, 484 | 'classifier_downsample': args.classifier_downsample, 485 | 'classifier_fc_layers': parse_int_list(args.classifier_fc_dims), 486 | 'classifier_batchnorm': args.classifier_batchnorm == 1, 487 | 'classifier_dropout': args.classifier_dropout, 488 | } 489 | if args.model_type == 'FiLM': 490 | kwargs['num_modules'] = args.num_modules 491 | kwargs['stem_kernel_size'] = args.module_stem_kernel_size 492 | kwargs['stem_stride'] = args.module_stem_stride 493 | kwargs['stem_padding'] = args.module_stem_padding 494 | kwargs['module_num_layers'] = args.module_num_layers 495 | kwargs['module_batchnorm_affine'] = args.module_batchnorm_affine == 1 496 | kwargs['module_dropout'] = args.module_dropout 497 | kwargs['module_input_proj'] = args.module_input_proj 498 | kwargs['module_kernel_size'] = args.module_kernel_size 499 | kwargs['use_gamma'] = args.use_gamma == 1 500 | kwargs['use_beta'] = args.use_beta == 1 501 | kwargs['use_coords'] = args.use_coords 502 | kwargs['debug_every'] = args.debug_every 503 | kwargs['print_verbose_every'] = args.print_verbose_every 504 | kwargs['condition_method'] = args.condition_method 505 | kwargs['condition_pattern'] = parse_int_list(args.condition_pattern) 506 | ee = FiLMedNet(**kwargs) 507 | else: 508 | ee = ModuleNet(**kwargs) 509 | ee.cuda() 510 | ee.train() 511 | return ee, kwargs 512 | 513 | 514 | def get_baseline_model(args): 515 | vocab = utils.load_vocab(args.vocab_json) 516 | if args.baseline_start_from is not None: 517 | model, kwargs = utils.load_baseline(args.baseline_start_from) 518 | elif args.model_type == 'LSTM': 519 | kwargs = { 520 | 'vocab': vocab, 521 | 'rnn_wordvec_dim': args.rnn_wordvec_dim, 522 | 'rnn_dim': args.rnn_hidden_dim, 523 | 'rnn_num_layers': args.rnn_num_layers, 524 | 'rnn_dropout': args.rnn_dropout, 525 | 'fc_dims': parse_int_list(args.classifier_fc_dims), 526 | 'fc_use_batchnorm': args.classifier_batchnorm == 1, 527 | 'fc_dropout': args.classifier_dropout, 528 | } 529 | model = LstmModel(**kwargs) 530 | elif args.model_type == 'CNN+LSTM': 531 | kwargs = { 532 | 'vocab': vocab, 533 | 'rnn_wordvec_dim': args.rnn_wordvec_dim, 534 | 'rnn_dim': args.rnn_hidden_dim, 535 | 'rnn_num_layers': args.rnn_num_layers, 536 | 'rnn_dropout': args.rnn_dropout, 537 | 'cnn_feat_dim': parse_int_list(args.feature_dim), 538 | 'cnn_num_res_blocks': args.cnn_num_res_blocks, 539 | 'cnn_res_block_dim': args.cnn_res_block_dim, 540 | 'cnn_proj_dim': args.cnn_proj_dim, 541 | 'cnn_pooling': args.cnn_pooling, 542 | 'fc_dims': parse_int_list(args.classifier_fc_dims), 543 | 'fc_use_batchnorm': args.classifier_batchnorm == 1, 544 | 'fc_dropout': args.classifier_dropout, 545 | } 546 | model = CnnLstmModel(**kwargs) 547 | elif args.model_type == 'CNN+LSTM+SA': 548 | kwargs = { 549 | 'vocab': vocab, 550 | 'rnn_wordvec_dim': args.rnn_wordvec_dim, 551 | 'rnn_dim': args.rnn_hidden_dim, 552 | 'rnn_num_layers': args.rnn_num_layers, 553 | 'rnn_dropout': args.rnn_dropout, 554 | 'cnn_feat_dim': parse_int_list(args.feature_dim), 555 | 'stacked_attn_dim': args.stacked_attn_dim, 556 | 'num_stacked_attn': args.num_stacked_attn, 557 | 'fc_dims': parse_int_list(args.classifier_fc_dims), 558 | 'fc_use_batchnorm': args.classifier_batchnorm == 1, 559 | 'fc_dropout': args.classifier_dropout, 560 | } 561 | model = CnnLstmSaModel(**kwargs) 562 | if model.rnn.token_to_idx != vocab['question_token_to_idx']: 563 | # Make sure new vocab is superset of old 564 | for k, v in model.rnn.token_to_idx.items(): 565 | assert k in vocab['question_token_to_idx'] 566 | assert vocab['question_token_to_idx'][k] == v 567 | for token, idx in vocab['question_token_to_idx'].items(): 568 | model.rnn.token_to_idx[token] = idx 569 | kwargs['vocab'] = vocab 570 | model.rnn.expand_vocab(vocab['question_token_to_idx']) 571 | model.cuda() 572 | model.train() 573 | return model, kwargs 574 | 575 | 576 | def set_mode(mode, models): 577 | assert mode in ['train', 'eval'] 578 | for m in models: 579 | if m is None: continue 580 | if mode == 'train': m.train() 581 | if mode == 'eval': m.eval() 582 | 583 | 584 | def check_accuracy(args, program_generator, execution_engine, baseline_model, loader): 585 | set_mode('eval', [program_generator, execution_engine, baseline_model]) 586 | num_correct, num_samples = 0, 0 587 | for batch in loader: 588 | questions, _, feats, answers, programs, _ = batch 589 | if isinstance(questions, list): 590 | questions = questions[0] 591 | 592 | questions_var = Variable(questions.cuda(), volatile=True) 593 | feats_var = Variable(feats.cuda(), volatile=True) 594 | answers_var = Variable(feats.cuda(), volatile=True) 595 | if programs[0] is not None: 596 | programs_var = Variable(programs.cuda(), volatile=True) 597 | 598 | scores = None # Use this for everything but PG 599 | if args.model_type == 'PG': 600 | vocab = utils.load_vocab(args.vocab_json) 601 | for i in range(questions.size(0)): 602 | program_pred = program_generator.sample(Variable(questions[i:i+1].cuda(), volatile=True)) 603 | program_pred_str = vr.preprocess.decode(program_pred, vocab['program_idx_to_token']) 604 | program_str = vr.preprocess.decode(programs[i], vocab['program_idx_to_token']) 605 | if program_pred_str == program_str: 606 | num_correct += 1 607 | num_samples += 1 608 | elif args.model_type == 'EE': 609 | scores = execution_engine(feats_var, programs_var) 610 | elif args.model_type == 'PG+EE': 611 | programs_pred = program_generator.reinforce_sample( 612 | questions_var, argmax=True) 613 | scores = execution_engine(feats_var, programs_pred) 614 | elif args.model_type == 'FiLM': 615 | programs_pred = program_generator(questions_var) 616 | scores = execution_engine(feats_var, programs_pred) 617 | elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']: 618 | scores = baseline_model(questions_var, feats_var) 619 | 620 | if scores is not None: 621 | _, preds = scores.data.cpu().max(1) 622 | num_correct += (preds == answers).sum() 623 | num_samples += preds.size(0) 624 | 625 | if args.num_val_samples is not None and num_samples >= args.num_val_samples: 626 | break 627 | 628 | set_mode('train', [program_generator, execution_engine, baseline_model]) 629 | acc = float(num_correct) / num_samples 630 | return acc 631 | 632 | def check_grad_num_nans(model, model_name='model'): 633 | grads = [p.grad for p in model.parameters() if p.grad is not None] 634 | num_nans = [np.sum(np.isnan(grad.data.cpu().numpy())) for grad in grads] 635 | nan_checks = [num_nan == 0 for num_nan in num_nans] 636 | if False in nan_checks: 637 | print('Nans in ' + model_name + ' gradient!') 638 | print(num_nans) 639 | pdb.set_trace() 640 | raise(Exception) 641 | 642 | if __name__ == '__main__': 643 | args = parser.parse_args() 644 | main(args) 645 | --------------------------------------------------------------------------------