├── logs └── .gitkeep ├── snap └── .gitkeep ├── connectivity └── .gitkeep ├── img_features └── .gitkeep ├── .gitignore ├── run ├── train_agent.bash └── test_agent.bash ├── LICENSE ├── r2r_src ├── vlnbert │ ├── vlnbert_init.py │ ├── vlnbert_OSCAR.py │ └── vlnbert_PREVALENT.py ├── model_OSCAR.py ├── param.py ├── model_PREVALENT.py ├── eval.py ├── train.py ├── env.py ├── utils.py └── agent.py ├── README.md └── recurrent-vln-bert.yml /logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /snap/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /connectivity/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img_features/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ftpignore 2 | .ftpconfig 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | connectivity/* 10 | !connectivity/.gitkeep 11 | 12 | data/R2R_test.json 13 | data/R2R_train.json 14 | data/R2R_val_seen.json 15 | data/R2R_val_unseen.json 16 | data/prevalent/* 17 | !data/prevalent/.gitkeep 18 | 19 | img_features/* 20 | !img_features/.gitkeep 21 | 22 | snap/* 23 | !snap/.gitkeep 24 | 25 | logs/* 26 | !logs/.gitkeep 27 | -------------------------------------------------------------------------------- /run/train_agent.bash: -------------------------------------------------------------------------------- 1 | flag="--ADAPT 2 | --vlnbert prevalent 3 | 4 | --aug data/prevalent/prevalent_aug.json 5 | --test_only 0 6 | 7 | --train auglistener 8 | 9 | --maxAction 15 10 | --batchSize 8 11 | --feedback sample 12 | --lr 1e-5 13 | --iters 300000 14 | --optim adamW 15 | 16 | --mlWeight 0.20 17 | --maxInput 80 18 | --angleFeatSize 128 19 | --featdropout 0.4 20 | --dropout 0.5" 21 | 22 | # w CLIP visual feature 23 | PREFIX=ADAPT_CLIP CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features clip 24 | 25 | # w/o CLIP visual feature 26 | #PREFIX=ADAPT CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features places365 -------------------------------------------------------------------------------- /run/test_agent.bash: -------------------------------------------------------------------------------- 1 | flag="--ADAPT 2 | --vlnbert prevalent 3 | 4 | --submit 0 5 | --test_only 0 6 | 7 | --train validlistener 8 | 9 | --maxAction 15 10 | --batchSize 8 11 | --feedback sample 12 | --lr 1e-5 13 | --iters 300000 14 | --optim adamW 15 | 16 | --mlWeight 0.20 17 | --maxInput 80 18 | --angleFeatSize 128 19 | --featdropout 0.4 20 | --dropout 0.5" 21 | 22 | # w CLIP visual feature 23 | PREFIX=ADAPT_CLIP_test CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features clip --load snap/ADAPT_CLIP/state_dict/best_val_unseen 24 | # w/o CLIP visual feature 25 | #PREFIX=ADAPT_test CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features places365 --load snap/ADAPT/state_dict/best_val_unseen -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Yicong Hong, Qi Wu, Yuankai Qi, 4 | Cristian Rodriguez-Opazo, Stephen Gould 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /r2r_src/vlnbert/vlnbert_init.py: -------------------------------------------------------------------------------- 1 | 2 | #from transformers.pytorch_transformers import (BertConfig, BertTokenizer) 3 | from pytorch_transformers import (BertConfig, BertTokenizer) 4 | 5 | def get_tokenizer(args): 6 | if args.vlnbert == 'oscar': 7 | tokenizer_class = BertTokenizer 8 | model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997' 9 | tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True) 10 | elif args.vlnbert == 'prevalent': 11 | tokenizer_class = BertTokenizer 12 | tokenizer = tokenizer_class.from_pretrained('bert-base-uncased') 13 | return tokenizer 14 | 15 | def get_vlnbert_models(args, config=None): 16 | config_class = BertConfig 17 | 18 | if args.vlnbert == 'oscar': 19 | from vlnbert.vlnbert_OSCAR import VLNBert 20 | model_class = VLNBert 21 | model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997' 22 | vis_config = config_class.from_pretrained(model_name_or_path, num_labels=2, finetuning_task='vln-r2r') 23 | 24 | vis_config.model_type = 'visual' 25 | vis_config.finetuning_task = 'vln-r2r' 26 | vis_config.hidden_dropout_prob = 0.3 27 | vis_config.hidden_size = 768 28 | vis_config.img_feature_dim = 2176 29 | vis_config.num_attention_heads = 12 30 | vis_config.num_hidden_layers = 12 31 | visual_model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=vis_config) 32 | 33 | elif args.vlnbert == 'prevalent': 34 | from vlnbert.vlnbert_PREVALENT import VLNBert 35 | model_class = VLNBert 36 | model_name_or_path = 'Prevalent/pretrained_model/pytorch_model.bin' 37 | vis_config = config_class.from_pretrained('bert-base-uncased') 38 | vis_config.img_feature_dim = 2176 39 | vis_config.img_feature_type = "" 40 | vis_config.vl_layers = 4 41 | vis_config.la_layers = 9 42 | 43 | visual_model = model_class.from_pretrained(model_name_or_path, config=vis_config) 44 | 45 | return visual_model 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADAPT 2 | 3 | PyTorch implementation of the paper ["ADAPT: Vision-Language Navigation with Modality-Aligned Action Prompts"](https://arxiv.org/abs/2205.15509) (CVPR 2022). 4 | 5 | ## Prerequisites 6 | 7 | ### Installation 8 | The environment installation of ADAPT follows that in [Recurrent-VLN-BERT](https://github.com/YicongHong/Recurrent-VLN-BERT). 9 |
10 | 1.Install the [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator). Notice that this code uses the [old version (v0.1)](https://github.com/peteanderson80/Matterport3DSimulator/tree/v0.1) of the simulator. 11 |
12 | 2. the versions of packages in the environment can be found [here](https://github.com/YicongHong/Recurrent-VLN-BERT/blob/main/recurrent-vln-bert.yml). 13 |
14 | 3. Install the [Pytorch-Transformers](https://github.com/huggingface/transformers) of [this version](https://github.com/huggingface/transformers/tree/067923d3267325f525f4e46f357360c191ba562e). 15 | 16 | ### Data Preparation 17 | Please follow the instructions below to prepare the data in directories: 18 |
19 | * MP3D navigability graphs: ```connectivity``` 20 | * Download the [connectivity maps](https://github.com/peteanderson80/Matterport3DSimulator/tree/master/connectivity). 21 | * MP3D image features: ```img_features``` 22 | * Download the [Scene features](https://www.dropbox.com/s/85tpa6tc3enl5ud/ResNet-152-places365.zip?dl=1) (ResNet-152-Places365). 23 | * R2R data added action prompts: ```data``` 24 | * Download the [R2R data](https://drive.google.com/file/d/1dvWONxBDfNiG420Ggttjkje5Qu0pFJ_d/view?usp=sharing). 25 | * Augmentation data added action prompts: ```data``` 26 | * Download the [augmentation data](https://drive.google.com/file/d/1C9Ckhr6XASDveGvnRvZ3oIzqTeU43JAq/view?usp=sharing). 27 | * text sub-prompt feature: ```data``` 28 | * Download the [text sub-prompt feature](https://drive.google.com/file/d/127XonQJ2hqriljfSm8J-RLhV_PYQYWJh/view?usp=sharing). 29 | 30 | ## R2R Navigation 31 | 32 | ### Two-phase Training 33 | At the first stage, run the following scripts until the performance is converged in Val Unseen:
34 | ``` 35 | PREFIX=baseline python r2r_src/train.py --vlnbert prevalent --aug data/prevalent/prevalent_aug.json --batchSize 16 --lr 1e-5 36 | ``` 37 | At the second stage, run the following scripts using the Best Val Unseen model at the first stage:
38 | ``` 39 | PREFIX=ADAPT python r2r_src/train.py --vlnbert prevalent --aug data/prevalent/prevalent_aug.json --batchSize 16 --lr 1e-6 --ADAPT --load snap/baseline/state_dict/best_val_unseen --finetune 40 | ``` 41 | 42 | ## Acknowledgements 43 | The implementation relies on resources from [Recurrent-VLN-BERT](https://github.com/YicongHong/Recurrent-VLN-BERT). We thank the original authors for their open-sourcing. 44 | -------------------------------------------------------------------------------- /r2r_src/model_OSCAR.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from param import args 5 | 6 | from vlnbert.vlnbert_init import get_vlnbert_models 7 | 8 | 9 | class VLNBERT(nn.Module): 10 | def __init__(self, feature_size=2048+128): 11 | super(VLNBERT, self).__init__() 12 | print('\nInitalizing the VLN-BERT model ...') 13 | self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT 14 | self.vln_bert.config.directions = 4 # a preset random number 15 | 16 | hidden_size = self.vln_bert.config.hidden_size 17 | layer_norm_eps = self.vln_bert.config.layer_norm_eps 18 | 19 | self.action_state_project = nn.Sequential( 20 | nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh()) 21 | self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) 22 | 23 | self.drop_env = nn.Dropout(p=args.featdropout) 24 | self.img_projection = nn.Linear(feature_size, hidden_size, bias=True) 25 | self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) 26 | 27 | def forward(self, mode, sentence, token_type_ids=None, 28 | attention_mask=None, lang_mask=None, vis_mask=None, 29 | position_ids=None, action_feats=None, pano_feats=None, cand_feats=None): 30 | 31 | if mode == 'language': 32 | encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids, 33 | token_type_ids=token_type_ids, attention_mask=attention_mask) 34 | 35 | return encoded_sentence 36 | 37 | elif mode == 'visual': 38 | 39 | state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1) 40 | state_with_action = self.action_state_project(state_action_embed) 41 | state_with_action = self.action_LayerNorm(state_with_action) 42 | state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1) 43 | 44 | cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size]) 45 | 46 | cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection 47 | cand_feats_embed = self.cand_LayerNorm(cand_feats_embed) 48 | 49 | # logit is the attention scores over the candidate features 50 | h_t, logit = self.vln_bert(mode, state_feats, 51 | attention_mask=attention_mask, img_feats=cand_feats_embed) 52 | 53 | return h_t, logit 54 | 55 | else: 56 | ModuleNotFoundError 57 | 58 | 59 | class BertLayerNorm(nn.Module): 60 | def __init__(self, hidden_size, eps=1e-12): 61 | """Construct a layernorm module in the TF style (epsilon inside the square root). 62 | """ 63 | super(BertLayerNorm, self).__init__() 64 | self.weight = nn.Parameter(torch.ones(hidden_size)) 65 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 66 | self.variance_epsilon = eps 67 | 68 | def forward(self, x): 69 | u = x.mean(-1, keepdim=True) 70 | s = (x - u).pow(2).mean(-1, keepdim=True) 71 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 72 | return self.weight * x + self.bias 73 | 74 | 75 | class Critic(nn.Module): 76 | def __init__(self): 77 | super(Critic, self).__init__() 78 | self.state2value = nn.Sequential( 79 | nn.Linear(768, 512), 80 | nn.ReLU(), 81 | nn.Dropout(args.dropout), 82 | nn.Linear(512, 1), 83 | ) 84 | 85 | def forward(self, state): 86 | return self.state2value(state).squeeze() 87 | -------------------------------------------------------------------------------- /r2r_src/param.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | class Param: 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser(description="") 8 | 9 | # General 10 | self.parser.add_argument('--test_only', type=int, default=0, help='fast mode for testing') 11 | 12 | self.parser.add_argument('--iters', type=int, default=300000, help='training iterations') 13 | self.parser.add_argument('--name', type=str, default='default', help='experiment id') 14 | self.parser.add_argument('--vlnbert', type=str, default='oscar', help='oscar or prevalent') 15 | self.parser.add_argument('--train', type=str, default='listener') 16 | self.parser.add_argument('--description', type=str, default='no description\n') 17 | 18 | # Data preparation 19 | self.parser.add_argument('--maxInput', type=int, default=80, help="max input instruction") 20 | self.parser.add_argument('--maxAction', type=int, default=15, help='Max Action sequence') 21 | self.parser.add_argument('--batchSize', type=int, default=8) 22 | self.parser.add_argument('--ignoreid', type=int, default=-100) 23 | self.parser.add_argument('--feature_size', type=int, default=2048) 24 | self.parser.add_argument("--loadOptim",action="store_const", default=False, const=True) 25 | 26 | # Load the model from 27 | self.parser.add_argument("--load", default=None, help='path of the trained model') 28 | 29 | # Augmented Paths from 30 | self.parser.add_argument("--aug", default=None) 31 | 32 | # Listener Model Config 33 | self.parser.add_argument("--zeroInit", dest='zero_init', action='store_const', default=False, const=True) 34 | self.parser.add_argument("--mlWeight", dest='ml_weight', type=float, default=0.20) 35 | self.parser.add_argument("--teacherWeight", dest='teacher_weight', type=float, default=1.) 36 | self.parser.add_argument("--features", type=str, default='places365') 37 | 38 | # Dropout Param 39 | self.parser.add_argument('--dropout', type=float, default=0.5) 40 | self.parser.add_argument('--featdropout', type=float, default=0.3) 41 | 42 | # Submision configuration 43 | self.parser.add_argument("--submit", type=int, default=0) 44 | 45 | # Training Configurations 46 | self.parser.add_argument('--optim', type=str, default='rms') # rms, adam 47 | self.parser.add_argument('--lr', type=float, default=0.00001, help="the learning rate") 48 | self.parser.add_argument('--decay', dest='weight_decay', type=float, default=0.) 49 | self.parser.add_argument('--feedback', type=str, default='sample', 50 | help='How to choose next position, one of ``teacher``, ``sample`` and ``argmax``') 51 | self.parser.add_argument('--teacher', type=str, default='final', 52 | help="How to get supervision. one of ``next`` and ``final`` ") 53 | self.parser.add_argument('--epsilon', type=float, default=0.1) 54 | 55 | # Model hyper params: 56 | self.parser.add_argument("--angleFeatSize", dest="angle_feat_size", type=int, default=4) 57 | 58 | # A2C 59 | self.parser.add_argument("--gamma", default=0.9, type=float) 60 | self.parser.add_argument("--normalize", dest="normalize_loss", default="total", type=str, help='batch or total') 61 | 62 | # ADAPT 63 | self.parser.add_argument("--ADAPT", action='store_const', default=False, const=True, help='use ADAPT model') 64 | self.parser.add_argument("--AlignLossWeight", dest='align_loss_weight', default=0.01, type=float, help="modality alignment loss weight") 65 | self.parser.add_argument("--ConsistencyLossWeight", dest='consistency_loss_weight', default=0.0001, type=float, help="sequential consistency loss weight") 66 | self.parser.add_argument('--prompt_set_size', type=int, default=60, help="prompt set size") 67 | self.parser.add_argument("--temperature", default=0.5, type=float, help="hyperparameter for contrastive loss") 68 | self.parser.add_argument("--finetune", action='store_const', default=False, const=True, help='if in finetune stage') 69 | 70 | self.args = self.parser.parse_args() 71 | 72 | if self.args.optim == 'rms': 73 | print("Optimizer: Using RMSProp") 74 | self.args.optimizer = torch.optim.RMSprop 75 | elif self.args.optim == 'adam': 76 | print("Optimizer: Using Adam") 77 | self.args.optimizer = torch.optim.Adam 78 | elif self.args.optim == 'adamW': 79 | print("Optimizer: Using AdamW") 80 | self.args.optimizer = torch.optim.AdamW 81 | elif self.args.optim == 'sgd': 82 | print("Optimizer: sgd") 83 | self.args.optimizer = torch.optim.SGD 84 | else: 85 | assert False 86 | 87 | param = Param() 88 | args = param.args 89 | 90 | args.description = args.name 91 | args.IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv' 92 | args.log_dir = 'snap/%s' % args.name 93 | 94 | if not os.path.exists(args.log_dir): 95 | os.makedirs(args.log_dir) 96 | DEBUG_FILE = open(os.path.join('snap', args.name, "debug.log"), 'w') 97 | -------------------------------------------------------------------------------- /recurrent-vln-bert.yml: -------------------------------------------------------------------------------- 1 | name: recurrent-vln-bert 2 | channels: 3 | - bioconda 4 | - menpo 5 | - pytorch 6 | - conda-forge 7 | - anaconda 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - bcolz=1.2.1=py36h04863e7_0 12 | - blas=1.0=mkl 13 | - blosc=1.16.3=hd408876_0 14 | - bokeh=2.0.1=py36_0 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2020.7.22=0 17 | - certifi=2020.6.20=py36_0 18 | - cffi=1.13.1=py36h2e261b9_0 19 | - click=7.1.1=py_0 20 | - cloudpickle=1.3.0=py_0 21 | - cmake=3.14.0=h52cb24c_0 22 | - cudatoolkit=10.2.89=hfd86e86_1 23 | - cytoolz=0.10.1=py36h7b6447c_0 24 | - dask=2.14.0=py_0 25 | - dask-core=2.14.0=py_0 26 | - distributed=2.14.0=py36_0 27 | - docutils=0.15.2=py36_0 28 | - expat=2.2.6=he6710b0_0 29 | - ffmpeg=4.0=hcdf2ecd_0 30 | - freetype=2.9.1=h8a8886c_1 31 | - fsspec=0.7.1=py_0 32 | - hdf5=1.10.2=hba1933b_1 33 | - heapdict=1.0.1=py_0 34 | - idna=2.10=py_0 35 | - intel-openmp=2019.4=243 36 | - jasper=2.0.14=h07fcdf6_1 37 | - jinja2=2.11.1=py_0 38 | - jsoncpp=1.8.4=hfd86e86_0 39 | - krb5=1.17.1=h173b8e3_0 40 | - libcurl=7.69.1=h20c2e04_0 41 | - libedit=3.1.20181209=hc058e9b_0 42 | - libgcc-ng=9.1.0=hdf63c60_0 43 | - libgfortran-ng=7.3.0=hdf63c60_0 44 | - libopencv=3.4.2=hb342d67_1 45 | - libopus=1.3=h7b6447c_0 46 | - libpng=1.6.37=hbc83047_0 47 | - libssh2=1.9.0=h1ba5d50_1 48 | - libstdcxx-ng=9.1.0=hdf63c60_0 49 | - libtiff=4.0.10=h2733197_2 50 | - libvpx=1.7.0=h439df22_0 51 | - locket=0.2.0=py36_1 52 | - lz4-c=1.8.1.2=h14c3975_0 53 | - lzo=2.10=h1bfc0ba_1 54 | - markupsafe=1.1.1=py36h7b6447c_0 55 | - mkl=2019.4=243 56 | - mkl-service=2.3.0=py36he904b0f_0 57 | - mkl_fft=1.0.14=py36ha843d7b_0 58 | - mkl_random=1.1.0=py36hd6b4f25_0 59 | - msgpack-python=1.0.0=py36hfd86e86_1 60 | - ncurses=6.1=he6710b0_1 61 | - ninja=1.9.0=py36hfd86e86_0 62 | - numexpr=2.7.1=py36h423224d_0 63 | - olefile=0.46=py36_0 64 | - opencv=3.4.2=py36h6fd60c2_1 65 | - openjdk=8.0.152=h7b6447c_3 66 | - openssl=1.1.1g=h7b6447c_0 67 | - packaging=20.3=py_0 68 | - pandas=1.1.1=py36he6710b0_0 69 | - partd=1.1.0=py_0 70 | - pillow=6.2.1=py36h34e0f95_0 71 | - pip=19.3.1=py36_0 72 | - psutil=5.7.0=py36h7b6447c_0 73 | - py-opencv=3.4.2=py36hb342d67_1 74 | - pybind11=2.4.2=py36hfd86e86_0 75 | - pycparser=2.19=py36_0 76 | - pyopenssl=19.1.0=py_1 77 | - pyparsing=2.4.7=py_0 78 | - pytables=3.4.4=py36ha205bf6_0 79 | - python=3.6.9=h265db76_0 80 | - python-dateutil=2.8.1=py_0 81 | - pytz=2020.1=py_0 82 | - pyyaml=5.3.1=py36h7b6447c_0 83 | - readline=7.0=h7b6447c_5 84 | - requests=2.24.0=py_0 85 | - rhash=1.3.8=h1ba5d50_0 86 | - setuptools=41.6.0=py36_0 87 | - six=1.15.0=py_0 88 | - sortedcontainers=2.1.0=py36_0 89 | - sqlite=3.30.1=h7b6447c_0 90 | - tblib=1.6.0=py_0 91 | - tk=8.6.8=hbc83047_0 92 | - toolz=0.10.0=py_0 93 | - tornado=6.0.4=py36h7b6447c_1 94 | - typing_extensions=3.7.4.1=py36_0 95 | - urllib3=1.25.10=py_0 96 | - wheel=0.33.6=py36_0 97 | - xz=5.2.4=h14c3975_4 98 | - yaml=0.1.7=h96e3832_1 99 | - zict=2.0.0=py_0 100 | - zlib=1.2.11=h7b6447c_3 101 | - zstd=1.3.7=h0b5b093_0 102 | - java-jdk=7.0.91=1 103 | - tqdm=4.7.2=py36_0 104 | - boto3=1.13.14=pyh9f0ad1d_0 105 | - botocore=1.16.14=pyh9f0ad1d_0 106 | - brotlipy=0.7.0=py36h8c4c3a4_1000 107 | - cairo=1.14.12=h80bd089_1005 108 | - chardet=3.0.4=py36h9f0ad1d_1006 109 | - cryptography=2.9.2=py36h45558ae_0 110 | - fontconfig=2.13.1=h2176d3f_1000 111 | - freeglut=3.0.0=hf484d3e_1005 112 | - gettext=0.19.8.1=h9745a5d_1001 113 | - glew=2.1.0=he1b5a44_0 114 | - glib=2.56.2=had28632_1001 115 | - graphite2=1.3.13=hf484d3e_1000 116 | - harfbuzz=1.9.0=he243708_1001 117 | - htop=2.2.0=hf8c457e_1000 118 | - icu=58.2=hf484d3e_1000 119 | - jmespath=0.10.0=pyh9f0ad1d_0 120 | - libblas=3.8.0=14_mkl 121 | - libcblas=3.8.0=14_mkl 122 | - libglu=9.0.0=hf484d3e_1000 123 | - libiconv=1.15=h14c3975_1004 124 | - liblapack=3.8.0=14_mkl 125 | - libuuid=2.32.1=h14c3975_1000 126 | - libxcb=1.13=h14c3975_1002 127 | - libxml2=2.9.8=h143f9aa_1005 128 | - numpy=1.18.4=py36h7314795_0 129 | - pcre=8.41=hf484d3e_1003 130 | - pixman=0.34.0=h14c3975_1003 131 | - pthread-stubs=0.4=h14c3975_1001 132 | - pysocks=1.7.1=py36h9f0ad1d_1 133 | - python_abi=3.6=1_cp36m 134 | - pytorch-pretrained-bert=0.6.2=py36_0 135 | - regex=2020.5.14=py36h8c4c3a4_0 136 | - s3transfer=0.3.3=py36h9f0ad1d_1 137 | - xorg-fixesproto=5.0=h14c3975_1002 138 | - xorg-inputproto=2.3.2=h14c3975_1002 139 | - xorg-kbproto=1.0.7=h14c3975_1002 140 | - xorg-libice=1.0.9=h14c3975_1004 141 | - xorg-libsm=1.2.3=h4937e3b_1000 142 | - xorg-libx11=1.6.9=h516909a_0 143 | - xorg-libxau=1.0.8=h14c3975_1006 144 | - xorg-libxdmcp=1.1.2=h14c3975_1007 145 | - xorg-libxext=1.3.3=h14c3975_1004 146 | - xorg-libxfixes=5.0.3=h14c3975_1004 147 | - xorg-libxi=1.7.9=h14c3975_1002 148 | - xorg-libxrender=0.9.10=h14c3975_1002 149 | - xorg-renderproto=0.11.1=h14c3975_1002 150 | - xorg-xextproto=7.3.0=h14c3975_1002 151 | - xorg-xproto=7.0.31=h14c3975_1007 152 | - jpeg=9b=h024ee3a_2 153 | - libffi=3.2.1=hd88cf55_4 154 | - snappy=1.1.7=hbae5bb6_3 155 | - osmesa=12.2.2.dev=0 156 | - pytorch=1.6.0=py3.6_cuda10.2.89_cudnn7.6.5_0 157 | - torchvision=0.7.0=py36_cu102 158 | -------------------------------------------------------------------------------- /r2r_src/model_PREVALENT.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | from param import args 8 | 9 | from vlnbert.vlnbert_init import get_vlnbert_models 10 | 11 | class VLNBERT(nn.Module): 12 | def __init__(self, feature_size=2048+128): 13 | super(VLNBERT, self).__init__() 14 | print('\nInitalizing the VLN-BERT model ...') 15 | 16 | self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT 17 | self.vln_bert.config.directions = 4 # a preset random number 18 | 19 | hidden_size = self.vln_bert.config.hidden_size 20 | layer_norm_eps = self.vln_bert.config.layer_norm_eps 21 | 22 | self.action_state_project = nn.Sequential( 23 | nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh()) 24 | self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) 25 | 26 | self.drop_env = nn.Dropout(p=args.featdropout) 27 | self.img_projection = nn.Linear(feature_size, hidden_size, bias=True) 28 | self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) 29 | 30 | self.vis_lang_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) 31 | self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True) 32 | self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) 33 | 34 | def forward(self, mode, sentence, token_type_ids=None, 35 | attention_mask=None, lang_mask=None, vis_mask=None, 36 | position_ids=None, action_feats=None, pano_feats=None, cand_feats=None, prompt_mask_set=None, 37 | txt_sub_prompt_set=None, img_sub_prompt_set=None): 38 | 39 | if mode == 'language': 40 | init_state, encoded_sentence = self.vln_bert(mode, sentence, attention_mask=attention_mask, lang_mask=lang_mask,) 41 | 42 | return init_state, encoded_sentence 43 | 44 | elif mode == 'visual': 45 | 46 | state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1) 47 | state_with_action = self.action_state_project(state_action_embed) 48 | state_with_action = self.action_LayerNorm(state_with_action) 49 | state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1) 50 | 51 | cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size]) 52 | if txt_sub_prompt_set is not None and img_sub_prompt_set is not None: 53 | h_t, logit, attended_language, attended_visual, attended_txt_prompt, attended_img_prompt, txt_prompt, img_prompt = self.vln_bert( 54 | mode, 55 | state_feats, 56 | attention_mask=attention_mask, 57 | lang_mask=lang_mask, 58 | vis_mask=vis_mask, 59 | img_feats=cand_feats, 60 | txt_sub_prompt_set=txt_sub_prompt_set, 61 | img_sub_prompt_set=img_sub_prompt_set, 62 | prompt_mask_set=prompt_mask_set 63 | ) 64 | else: 65 | # logit is the attention scores over the candidate features 66 | h_t, logit, attended_language, attended_visual = self.vln_bert(mode, state_feats, 67 | attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats) 68 | 69 | # update agent's state, unify history, language and vision by elementwise product 70 | vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual) 71 | state_output = torch.cat((h_t, vis_lang_feat), dim=-1) 72 | state_proj = self.state_proj(state_output) 73 | state_proj = self.state_LayerNorm(state_proj) 74 | 75 | if txt_sub_prompt_set is not None and img_sub_prompt_set is not None: 76 | return state_proj, logit, attended_language, attended_visual, attended_txt_prompt, attended_img_prompt, txt_prompt, img_prompt 77 | else: 78 | return state_proj, logit 79 | 80 | else: 81 | ModuleNotFoundError 82 | 83 | 84 | class BertLayerNorm(nn.Module): 85 | def __init__(self, hidden_size, eps=1e-12): 86 | """Construct a layernorm module in the TF style (epsilon inside the square root). 87 | """ 88 | super(BertLayerNorm, self).__init__() 89 | self.weight = nn.Parameter(torch.ones(hidden_size)) 90 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 91 | self.variance_epsilon = eps 92 | 93 | def forward(self, x): 94 | u = x.mean(-1, keepdim=True) 95 | s = (x - u).pow(2).mean(-1, keepdim=True) 96 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 97 | return self.weight * x + self.bias 98 | 99 | 100 | class Critic(nn.Module): 101 | def __init__(self): 102 | super(Critic, self).__init__() 103 | self.state2value = nn.Sequential( 104 | nn.Linear(768, 512), 105 | nn.ReLU(), 106 | nn.Dropout(args.dropout), 107 | nn.Linear(512, 1), 108 | ) 109 | 110 | def forward(self, state): 111 | return self.state2value(state).squeeze() 112 | -------------------------------------------------------------------------------- /r2r_src/eval.py: -------------------------------------------------------------------------------- 1 | ''' Evaluation of agent trajectories ''' 2 | 3 | import json 4 | import os 5 | import sys 6 | from collections import defaultdict 7 | import networkx as nx 8 | import numpy as np 9 | import pprint 10 | pp = pprint.PrettyPrinter(indent=4) 11 | 12 | from env import R2RBatch 13 | from utils import load_datasets, load_nav_graphs, ndtw_graphload, DTW 14 | from agent import BaseAgent 15 | 16 | 17 | class Evaluation(object): 18 | ''' Results submission format: [{'instr_id': string, 'trajectory':[(viewpoint_id, heading_rads, elevation_rads),] } ] ''' 19 | 20 | def __init__(self, splits, scans, tok): 21 | self.error_margin = 3.0 22 | self.splits = splits 23 | self.tok = tok 24 | self.gt = {} 25 | self.instr_ids = [] 26 | self.scans = [] 27 | for split in splits: 28 | for item in load_datasets([split]): 29 | if scans is not None and item['scan'] not in scans: 30 | continue 31 | self.gt[str(item['path_id'])] = item 32 | self.scans.append(item['scan']) 33 | self.instr_ids += ['%s_%d' % (item['path_id'], i) for i in range(len(item['instructions']))] 34 | self.scans = set(self.scans) 35 | self.instr_ids = set(self.instr_ids) 36 | self.graphs = load_nav_graphs(self.scans) 37 | self.distances = {} 38 | for scan,G in self.graphs.items(): # compute all shortest paths 39 | self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G)) 40 | 41 | def _get_nearest(self, scan, goal_id, path): 42 | near_id = path[0][0] 43 | near_d = self.distances[scan][near_id][goal_id] 44 | for item in path: 45 | d = self.distances[scan][item[0]][goal_id] 46 | if d < near_d: 47 | near_id = item[0] 48 | near_d = d 49 | return near_id 50 | 51 | def _score_item(self, instr_id, path): 52 | ''' Calculate error based on the final position in trajectory, and also 53 | the closest position (oracle stopping rule). 54 | The path contains [view_id, angle, vofv] ''' 55 | gt = self.gt[instr_id.split('_')[-2]] 56 | start = gt['path'][0] 57 | assert start == path[0][0], 'Result trajectories should include the start position' 58 | goal = gt['path'][-1] 59 | final_position = path[-1][0] # the first of [view_id, angle, vofv] 60 | nearest_position = self._get_nearest(gt['scan'], goal, path) 61 | self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal]) 62 | self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal]) 63 | self.scores['trajectory_steps'].append(len(path)-1) 64 | distance = 0 # length of the path in meters 65 | prev = path[0] 66 | for curr in path[1:]: 67 | distance += self.distances[gt['scan']][prev[0]][curr[0]] 68 | prev = curr 69 | self.scores['trajectory_lengths'].append(distance) 70 | self.scores['shortest_lengths'].append( 71 | self.distances[gt['scan']][start][goal] 72 | ) 73 | 74 | def score(self, output_file): 75 | ''' Evaluate each agent trajectory based on how close it got to the goal location ''' 76 | self.scores = defaultdict(list) 77 | instr_ids = set(self.instr_ids) 78 | if type(output_file) is str: 79 | with open(output_file) as f: 80 | results = json.load(f) 81 | else: 82 | results = output_file 83 | 84 | print('result length', len(results)) 85 | for item in results: 86 | # Check against expected ids 87 | if item['instr_id'] in instr_ids: 88 | instr_ids.remove(item['instr_id']) 89 | self._score_item(item['instr_id'], item['trajectory']) 90 | 91 | if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial) 92 | assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\ 93 | % (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file) 94 | assert len(self.scores['nav_errors']) == len(self.instr_ids) 95 | score_summary = { 96 | 'nav_error': np.average(self.scores['nav_errors']), 97 | 'oracle_error': np.average(self.scores['oracle_errors']), 98 | 'steps': np.average(self.scores['trajectory_steps']), 99 | 'lengths': np.average(self.scores['trajectory_lengths']) 100 | } 101 | num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin]) 102 | score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors'])) 103 | oracle_successes = len([i for i in self.scores['oracle_errors'] if i < self.error_margin]) 104 | score_summary['oracle_rate'] = float(oracle_successes)/float(len(self.scores['oracle_errors'])) 105 | 106 | spl = [float(error < self.error_margin) * l / max(l, p, 0.01) 107 | for error, p, l in 108 | zip(self.scores['nav_errors'], self.scores['trajectory_lengths'], self.scores['shortest_lengths']) 109 | ] 110 | score_summary['spl'] = np.average(spl) 111 | 112 | return score_summary, self.scores 113 | -------------------------------------------------------------------------------- /r2r_src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | import os 5 | import time 6 | import json 7 | import random 8 | import numpy as np 9 | from collections import defaultdict 10 | 11 | from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress 12 | import utils 13 | from env import R2RBatch 14 | from agent import Seq2SeqAgent 15 | from eval import Evaluation 16 | from param import args 17 | 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | from tensorboardX import SummaryWriter 21 | 22 | from vlnbert.vlnbert_init import get_tokenizer 23 | 24 | # log_dir = 'snap/%s' % args.name 25 | # if not os.path.exists(log_dir): 26 | # os.makedirs(log_dir) 27 | 28 | IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv' 29 | PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv' 30 | CLIP_FEATURES = 'img_features/CLIP-ViT-B-32-views.tsv' 31 | 32 | if args.features == 'imagenet': 33 | features = IMAGENET_FEATURES 34 | elif args.features == 'places365': 35 | features = PLACE365_FEATURES 36 | elif args.features == 'clip': 37 | features = CLIP_FEATURES 38 | 39 | prefix = os.environ.get('PREFIX') 40 | args.name = prefix 41 | log_dir = 'snap/%s' % prefix 42 | if not os.path.exists(log_dir): 43 | os.makedirs(log_dir) 44 | 45 | feedback_method = args.feedback # teacher or sample 46 | 47 | print(args); print('') 48 | 49 | 50 | ''' train the listener ''' 51 | def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None, feat_dict=None): 52 | writer = SummaryWriter(log_dir=log_dir) 53 | listner = Seq2SeqAgent(train_env, "", tok, args.maxAction, feat_dict) 54 | 55 | record_file = open('./logs/' + args.name + '.txt', 'a') 56 | record_file.write(str(args) + '\n\n') 57 | record_file.close() 58 | 59 | 60 | if args.load is not None: 61 | if args.aug is None: 62 | start_iter = listner.load(os.path.join(args.load)) 63 | print("\nLOAD the model from {}, iteration ".format(args.load, start_iter)) 64 | else: 65 | load_iter = listner.load(os.path.join(args.load)) 66 | print("\nLOAD the model from {}, iteration ".format(args.load, load_iter)) 67 | start_iter = 0 68 | start = time.time() 69 | print('\nListener training starts, start iteration: %s' % str(start_iter)) 70 | 71 | best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}} 72 | 73 | idx = 0 74 | while idx < start_iter + n_iters: 75 | if args.finetune: 76 | log_every = 100 77 | listner.logs = defaultdict(list) 78 | interval = min(log_every, n_iters-idx) 79 | iter = idx + interval 80 | 81 | # Train for log_every interval 82 | if aug_env is None: 83 | listner.env = train_env 84 | listner.train(interval, feedback=feedback_method) # Train interval iters 85 | else: 86 | jdx_length = len(range(interval // 2)) 87 | for jdx in range(interval // 2): 88 | # Train with GT data 89 | listner.env = train_env 90 | args.ml_weight = 0.2 91 | listner.train(1, feedback=feedback_method) 92 | 93 | # Train with Augmented data 94 | listner.env = aug_env 95 | args.ml_weight = 0.2 96 | listner.train(1, feedback=feedback_method) 97 | 98 | print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50) 99 | 100 | # Log the training stats to tensorboard 101 | total = max(sum(listner.logs['total']), 1) 102 | length = max(len(listner.logs['critic_loss']), 1) 103 | critic_loss = sum(listner.logs['critic_loss']) / total 104 | RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1) 105 | IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1) 106 | entropy = sum(listner.logs['entropy']) / total 107 | writer.add_scalar("loss/critic", critic_loss, idx) 108 | writer.add_scalar("policy_entropy", entropy, idx) 109 | writer.add_scalar("loss/RL_loss", RL_loss, idx) 110 | writer.add_scalar("loss/IL_loss", IL_loss, idx) 111 | writer.add_scalar("total_actions", total, idx) 112 | writer.add_scalar("max_length", length, idx) 113 | # print("total_actions", total, ", max_length", length) 114 | 115 | if args.ADAPT: 116 | consistency_loss = sum(listner.logs['consistency_loss']) / max(len(listner.logs['consistency_loss']), 1) 117 | writer.add_scalar("loss/consistency_loss", consistency_loss, idx) 118 | alignment_loss = sum(listner.logs['alignment_loss']) / max(len(listner.logs['alignment_loss']), 1) 119 | writer.add_scalar("loss/alignment_loss", alignment_loss, idx) 120 | 121 | # Run validation 122 | loss_str = "iter {}".format(iter) 123 | for env_name, (env, evaluator) in val_envs.items(): 124 | listner.env = env 125 | 126 | # Get validation distance from goal under test evaluation conditions 127 | listner.test(use_dropout=False, feedback='argmax', iters=None) 128 | result = listner.get_results() 129 | score_summary, _ = evaluator.score(result) 130 | loss_str += ", %s " % env_name 131 | for metric, val in score_summary.items(): 132 | if metric in ['spl']: 133 | writer.add_scalar("spl/%s" % env_name, val, idx) 134 | if env_name in best_val: 135 | if val > best_val[env_name]['spl']: 136 | best_val[env_name]['spl'] = val 137 | best_val[env_name]['update'] = True 138 | elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']): 139 | best_val[env_name]['spl'] = val 140 | best_val[env_name]['update'] = True 141 | loss_str += ', %s: %.4f' % (metric, val) 142 | 143 | record_file = open('./logs/' + args.name + '.txt', 'a') 144 | record_file.write(loss_str + '\n') 145 | record_file.close() 146 | 147 | for env_name in best_val: 148 | if best_val[env_name]['update']: 149 | best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) 150 | best_val[env_name]['update'] = False 151 | listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name))) 152 | else: 153 | listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict")) 154 | 155 | print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters), 156 | iter, float(iter)/n_iters*100, loss_str))) 157 | 158 | if iter % 1000 == 0: 159 | print("BEST RESULT TILL NOW") 160 | for env_name in best_val: 161 | print(env_name, best_val[env_name]['state']) 162 | 163 | record_file = open('./logs/' + args.name + '.txt', 'a') 164 | record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n') 165 | record_file.close() 166 | 167 | idx += interval 168 | 169 | listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx))) 170 | 171 | 172 | def valid(train_env, tok, val_envs={}, 173 | feat_dict=None): 174 | agent = Seq2SeqAgent(train_env, "", tok, args.maxAction, feat_dict) 175 | 176 | print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load)) 177 | 178 | for env_name, (env, evaluator) in val_envs.items(): 179 | agent.logs = defaultdict(list) 180 | agent.env = env 181 | 182 | iters = None 183 | agent.test(use_dropout=False, feedback='argmax', iters=iters) 184 | result = agent.get_results() 185 | 186 | if env_name != '': 187 | score_summary, _ = evaluator.score(result) 188 | loss_str = "Env name: %s" % env_name 189 | for metric,val in score_summary.items(): 190 | loss_str += ', %s: %.4f' % (metric, val) 191 | print(loss_str) 192 | 193 | if args.submit: 194 | json.dump( 195 | result, 196 | open(os.path.join(log_dir, "submit_%s.json" % env_name), 'w'), 197 | sort_keys=True, indent=4, separators=(',', ': ') 198 | ) 199 | 200 | def setup(): 201 | torch.manual_seed(1) 202 | torch.cuda.manual_seed(1) 203 | random.seed(0) 204 | np.random.seed(0) 205 | 206 | def train_val(test_only=False): 207 | ''' Train on the training set, and validate on seen and unseen splits. ''' 208 | setup() 209 | tok = get_tokenizer(args) 210 | 211 | feat_dict = read_img_features(features, test_only=test_only) 212 | if args.ADAPT and args.features == 'places365': 213 | feat_dict_clip = read_img_features(CLIP_FEATURES, test_only=test_only) 214 | 215 | if test_only: 216 | featurized_scans = None 217 | val_env_names = ['val_train_seen'] 218 | else: 219 | featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) 220 | val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] 221 | 222 | if args.ADAPT: 223 | val_env_names = ['val_seen', 'val_unseen'] 224 | 225 | train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) 226 | from collections import OrderedDict 227 | 228 | if args.submit: 229 | val_env_names.append('test') 230 | else: 231 | pass 232 | 233 | val_envs = OrderedDict( 234 | ((split, 235 | (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok), 236 | Evaluation([split], featurized_scans, tok)) 237 | ) 238 | for split in val_env_names 239 | ) 240 | ) 241 | 242 | if args.train == 'listener': 243 | train(train_env, tok, args.iters, val_envs=val_envs) 244 | elif args.train == 'validlistener': 245 | if args.ADAPT and args.features == 'places365': 246 | valid(train_env, tok, val_envs=val_envs, 247 | feat_dict=feat_dict_clip) 248 | else: 249 | valid(train_env, tok, val_envs=val_envs, 250 | feat_dict=feat_dict) 251 | else: 252 | assert False 253 | 254 | def train_val_augment(test_only=False): 255 | """ 256 | Train the listener with the augmented data 257 | """ 258 | setup() 259 | 260 | # Create a batch training environment that will also preprocess text 261 | tok_bert = get_tokenizer(args) 262 | 263 | # Load the env img features 264 | feat_dict = read_img_features(features, test_only=test_only) 265 | if args.ADAPT and args.features == 'places365': 266 | feat_dict_clip = read_img_features(CLIP_FEATURES, test_only=test_only) 267 | 268 | if test_only: 269 | featurized_scans = None 270 | val_env_names = ['val_train_seen'] 271 | else: 272 | featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) 273 | val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] 274 | 275 | if args.ADAPT: 276 | val_env_names = ['val_seen', 'val_unseen'] 277 | 278 | # Load the augmentation data 279 | aug_path = args.aug 280 | # Create the training environment 281 | train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert) 282 | aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug') 283 | 284 | # Setup the validation data 285 | val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert), 286 | Evaluation([split], featurized_scans, tok_bert)) 287 | for split in val_env_names} 288 | if args.ADAPT and args.features == 'places365': 289 | # Start training 290 | train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env, 291 | feat_dict=feat_dict_clip) 292 | else: 293 | # Start training 294 | train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env, 295 | feat_dict=feat_dict) 296 | 297 | 298 | if __name__ == "__main__": 299 | if args.train in ['listener', 'validlistener']: 300 | train_val(test_only=args.test_only) 301 | elif args.train == 'auglistener': 302 | train_val_augment(test_only=args.test_only) 303 | else: 304 | assert False 305 | -------------------------------------------------------------------------------- /r2r_src/vlnbert/vlnbert_OSCAR.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | import logging 4 | import math 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.nn import CrossEntropyLoss, MSELoss 9 | 10 | from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings, 11 | BertSelfAttention, BertAttention, BertEncoder, BertLayer, 12 | BertSelfOutput, BertIntermediate, BertOutput, 13 | BertPooler, BertLayerNorm, BertPreTrainedModel, 14 | BertPredictionHeadTransform) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | class CaptionBertSelfAttention(BertSelfAttention): 19 | """ 20 | Modified from BertSelfAttention to add support for output_hidden_states. 21 | """ 22 | def __init__(self, config): 23 | super(CaptionBertSelfAttention, self).__init__(config) 24 | self.config = config 25 | 26 | def forward(self, mode, hidden_states, attention_mask, head_mask=None, 27 | history_state=None): 28 | if history_state is not None: 29 | x_states = torch.cat([history_state, hidden_states], dim=1) 30 | mixed_query_layer = self.query(hidden_states) 31 | mixed_key_layer = self.key(x_states) 32 | mixed_value_layer = self.value(x_states) 33 | else: 34 | mixed_query_layer = self.query(hidden_states) 35 | mixed_key_layer = self.key(hidden_states) 36 | mixed_value_layer = self.value(hidden_states) 37 | 38 | if mode == 'visual': 39 | mixed_query_layer = mixed_query_layer[:, [0]+list(range(-self.config.directions, 0)), :] 40 | 41 | ''' language feature only provide Keys and Values ''' 42 | query_layer = self.transpose_for_scores(mixed_query_layer) 43 | key_layer = self.transpose_for_scores(mixed_key_layer) 44 | value_layer = self.transpose_for_scores(mixed_value_layer) 45 | 46 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 47 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 48 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 49 | attention_scores = attention_scores + attention_mask 50 | 51 | # Normalize the attention scores to probabilities. 52 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 53 | 54 | # This is actually dropping out entire tokens to attend to, which might 55 | # seem a bit unusual, but is taken from the original Transformer paper. 56 | attention_probs = self.dropout(attention_probs) 57 | 58 | # Mask heads if we want to 59 | if head_mask is not None: 60 | attention_probs = attention_probs * head_mask 61 | 62 | context_layer = torch.matmul(attention_probs, value_layer) 63 | 64 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 65 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 66 | context_layer = context_layer.view(*new_context_layer_shape) 67 | 68 | outputs = (context_layer, attention_scores) 69 | 70 | return outputs 71 | 72 | 73 | class CaptionBertAttention(BertAttention): 74 | """ 75 | Modified from BertAttention to add support for output_hidden_states. 76 | """ 77 | def __init__(self, config): 78 | super(CaptionBertAttention, self).__init__(config) 79 | self.self = CaptionBertSelfAttention(config) 80 | self.output = BertSelfOutput(config) 81 | self.config = config 82 | 83 | def forward(self, mode, input_tensor, attention_mask, head_mask=None, 84 | history_state=None): 85 | ''' transformer processing ''' 86 | self_outputs = self.self(mode, input_tensor, attention_mask, head_mask, history_state) 87 | 88 | ''' feed-forward network with residule ''' 89 | if mode == 'visual': 90 | attention_output = self.output(self_outputs[0], input_tensor[:, [0]+list(range(-self.config.directions, 0)), :]) 91 | if mode == 'language': 92 | attention_output = self.output(self_outputs[0], input_tensor) 93 | 94 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 95 | 96 | return outputs 97 | 98 | 99 | class CaptionBertLayer(BertLayer): 100 | """ 101 | Modified from BertLayer to add support for output_hidden_states. 102 | """ 103 | def __init__(self, config): 104 | super(CaptionBertLayer, self).__init__(config) 105 | self.attention = CaptionBertAttention(config) 106 | self.intermediate = BertIntermediate(config) 107 | self.output = BertOutput(config) 108 | 109 | def forward(self, mode, hidden_states, attention_mask, head_mask=None, 110 | history_state=None): 111 | 112 | attention_outputs = self.attention(mode, hidden_states, attention_mask, 113 | head_mask, history_state) 114 | 115 | ''' feed-forward network with residule ''' 116 | attention_output = attention_outputs[0] 117 | intermediate_output = self.intermediate(attention_output) 118 | layer_output = self.output(intermediate_output, attention_output) 119 | outputs = (layer_output,) + attention_outputs[1:] 120 | 121 | return outputs 122 | 123 | 124 | class CaptionBertEncoder(BertEncoder): 125 | """ 126 | Modified from BertEncoder to add support for output_hidden_states. 127 | """ 128 | def __init__(self, config): 129 | super(CaptionBertEncoder, self).__init__(config) 130 | self.output_attentions = config.output_attentions 131 | self.output_hidden_states = config.output_hidden_states 132 | # 12 Bert layers 133 | self.layer = nn.ModuleList([CaptionBertLayer(config) for _ in range(config.num_hidden_layers)]) 134 | self.config = config 135 | 136 | def forward(self, mode, hidden_states, attention_mask, head_mask=None, 137 | encoder_history_states=None): 138 | 139 | if mode == 'visual': 140 | for i, layer_module in enumerate(self.layer): 141 | history_state = None if encoder_history_states is None else encoder_history_states[i] 142 | 143 | layer_outputs = layer_module(mode, 144 | hidden_states, attention_mask, head_mask[i], 145 | history_state) 146 | 147 | concat_layer_outputs = torch.cat((layer_outputs[0][:,0:1,:], hidden_states[:,1:-self.config.directions,:], layer_outputs[0][:,1:self.config.directions+1,:]), 1) 148 | hidden_states = concat_layer_outputs 149 | 150 | if i == self.config.num_hidden_layers - 1: 151 | state_attention_score = layer_outputs[1][:, :, 0, :] 152 | lang_attention_score = layer_outputs[1][:, :, -self.config.directions:, 1:-self.config.directions] 153 | vis_attention_score = layer_outputs[1][:, :, :, :] 154 | 155 | outputs = (hidden_states, state_attention_score, lang_attention_score, vis_attention_score) 156 | 157 | elif mode == 'language': 158 | for i, layer_module in enumerate(self.layer): 159 | history_state = None if encoder_history_states is None else encoder_history_states[i] # default None 160 | 161 | layer_outputs = layer_module(mode, 162 | hidden_states, attention_mask, head_mask[i], 163 | history_state) 164 | hidden_states = layer_outputs[0] 165 | 166 | if i == self.config.num_hidden_layers - 1: 167 | slang_attention_score = layer_outputs[1] 168 | 169 | outputs = (hidden_states, slang_attention_score) 170 | 171 | return outputs 172 | 173 | 174 | class BertImgModel(BertPreTrainedModel): 175 | """ Expand from BertModel to handle image region features as input 176 | """ 177 | def __init__(self, config): 178 | super(BertImgModel, self).__init__(config) 179 | self.embeddings = BertEmbeddings(config) 180 | self.encoder = CaptionBertEncoder(config) 181 | self.pooler = BertPooler(config) 182 | 183 | self.img_dim = config.img_feature_dim 184 | logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim)) 185 | 186 | self.apply(self.init_weights) 187 | 188 | def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None, 189 | position_ids=None, img_feats=None): 190 | 191 | if attention_mask.dim() == 2: 192 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 193 | elif attention_mask.dim() == 3: 194 | extended_attention_mask = attention_mask.unsqueeze(1) 195 | else: 196 | raise NotImplementedError 197 | 198 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 199 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 200 | 201 | head_mask = [None] * self.config.num_hidden_layers 202 | 203 | if mode == 'visual': 204 | language_features = input_ids 205 | concat_embedding_output = torch.cat((language_features, img_feats), 1) 206 | elif mode == 'language': 207 | embedding_output = self.embeddings(input_ids, position_ids=position_ids, 208 | token_type_ids=token_type_ids) 209 | concat_embedding_output = embedding_output 210 | 211 | ''' pass to the Transformer layers ''' 212 | encoder_outputs = self.encoder(mode, concat_embedding_output, 213 | extended_attention_mask, head_mask=head_mask) 214 | 215 | sequence_output = encoder_outputs[0] 216 | pooled_output = self.pooler(sequence_output) # We "pool" the model by simply taking the hidden state corresponding to the first token 217 | 218 | # add hidden_states and attentions if they are here 219 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] 220 | 221 | return outputs 222 | 223 | 224 | class VLNBert(BertPreTrainedModel): 225 | """ 226 | Modified from BertForMultipleChoice to support oscar training. 227 | """ 228 | def __init__(self, config): 229 | super(VLNBert, self).__init__(config) 230 | self.config = config 231 | self.bert = BertImgModel(config) 232 | 233 | self.vis_lang_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 234 | self.state_proj = nn.Linear(config.hidden_size*2, config.hidden_size, bias=True) 235 | self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 236 | 237 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 238 | self.apply(self.init_weights) 239 | 240 | def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None, 241 | position_ids=None, img_feats=None): 242 | 243 | outputs = self.bert(mode, input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 244 | attention_mask=attention_mask, img_feats=img_feats) 245 | 246 | sequence_output = outputs[0] 247 | sequence_output = self.dropout(sequence_output) 248 | 249 | pooled_output = outputs[1] 250 | 251 | if mode == 'language': 252 | return sequence_output 253 | 254 | elif mode == 'visual': 255 | # attention scores with respect to agent's state 256 | language_attentions = outputs[2][:, :, 1:-self.config.directions] 257 | visual_attentions = outputs[2][:, :, -self.config.directions:] 258 | 259 | language_attention_scores = language_attentions.mean(dim=1) # mean over the 12 heads 260 | visual_attention_scores = visual_attentions.mean(dim=1) 261 | 262 | # weighted_feat 263 | language_attention_probs = nn.Softmax(dim=-1)(language_attention_scores.clone()).unsqueeze(-1) 264 | visual_attention_probs = nn.Softmax(dim=-1)(visual_attention_scores.clone()).unsqueeze(-1) 265 | 266 | language_seq = sequence_output[:, 1:-self.config.directions, :] 267 | visual_seq = sequence_output[:, -self.config.directions:, :] 268 | 269 | # residual weighting, final attention to weight the raw inputs 270 | attended_language = (language_attention_probs * input_ids[:, 1:, :]).sum(1) 271 | attended_visual = (visual_attention_probs * img_feats).sum(1) 272 | 273 | # update agent's state, unify history, language and vision by elementwise product 274 | vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual) 275 | state_output = torch.cat((pooled_output, vis_lang_feat), dim=-1) 276 | state_proj = self.state_proj(state_output) 277 | state_proj = self.state_LayerNorm(state_proj) 278 | 279 | return state_proj, visual_attention_scores 280 | -------------------------------------------------------------------------------- /r2r_src/env.py: -------------------------------------------------------------------------------- 1 | ''' Batched Room-to-Room navigation environment ''' 2 | 3 | import sys 4 | sys.path.append('buildpy36') 5 | sys.path.append('Matterport_Simulator/build/') 6 | import MatterSim 7 | import csv 8 | import numpy as np 9 | import math 10 | import base64 11 | import utils 12 | import json 13 | import os 14 | import random 15 | import networkx as nx 16 | from param import args 17 | 18 | from utils import load_datasets, load_nav_graphs, pad_instr_tokens 19 | 20 | csv.field_size_limit(sys.maxsize) 21 | 22 | 23 | class EnvBatch(): 24 | ''' A simple wrapper for a batch of MatterSim environments, 25 | using discretized viewpoints and pretrained features ''' 26 | 27 | def __init__(self, feature_store=None, batch_size=100): 28 | """ 29 | 1. Load pretrained image feature 30 | 2. Init the Simulator. 31 | :param feature_store: The name of file stored the feature. 32 | :param batch_size: Used to create the simulator list. 33 | """ 34 | if feature_store: 35 | if type(feature_store) is dict: # A silly way to avoid multiple reading 36 | self.features = feature_store 37 | self.image_w = 640 38 | self.image_h = 480 39 | self.vfov = 60 40 | self.feature_size = next(iter(self.features.values())).shape[-1] 41 | print('The feature size is %d' % self.feature_size) 42 | else: 43 | print(' Image features not provided - in testing mode') 44 | self.features = None 45 | self.image_w = 640 46 | self.image_h = 480 47 | self.vfov = 60 48 | self.sims = [] 49 | for i in range(batch_size): 50 | sim = MatterSim.Simulator() 51 | sim.setRenderingEnabled(False) 52 | sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians) 53 | sim.setCameraResolution(self.image_w, self.image_h) 54 | sim.setCameraVFOV(math.radians(self.vfov)) 55 | sim.init() 56 | self.sims.append(sim) 57 | 58 | def _make_id(self, scanId, viewpointId): 59 | return scanId + '_' + viewpointId 60 | 61 | def newEpisodes(self, scanIds, viewpointIds, headings): 62 | for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)): 63 | self.sims[i].newEpisode(scanId, viewpointId, heading, 0) 64 | 65 | def getStates(self): 66 | """ 67 | Get list of states augmented with precomputed image features. rgb field will be empty. 68 | Agent's current view [0-35] (set only when viewing angles are discretized) 69 | [0-11] looking down, [12-23] looking at horizon, [24-35] looking up 70 | :return: [ ((30, 2048), sim_state) ] * batch_size 71 | """ 72 | feature_states = [] 73 | for i, sim in enumerate(self.sims): 74 | state = sim.getState() 75 | 76 | long_id = self._make_id(state.scanId, state.location.viewpointId) 77 | if self.features: 78 | feature = self.features[long_id] 79 | feature_states.append((feature, state)) 80 | else: 81 | feature_states.append((None, state)) 82 | return feature_states 83 | 84 | def makeActions(self, actions): 85 | ''' Take an action using the full state dependent action interface (with batched input). 86 | Every action element should be an (index, heading, elevation) tuple. ''' 87 | for i, (index, heading, elevation) in enumerate(actions): 88 | self.sims[i].makeAction(index, heading, elevation) 89 | 90 | 91 | class R2RBatch(): 92 | ''' Implements the Room to Room navigation task, using discretized viewpoints and pretrained features ''' 93 | 94 | def __init__(self, feature_store, batch_size=100, seed=10, splits=['train'], tokenizer=None, 95 | name=None): 96 | self.env = EnvBatch(feature_store=feature_store, batch_size=batch_size) 97 | if feature_store: 98 | self.feature_size = self.env.feature_size 99 | else: 100 | self.feature_size = 2048 101 | self.data = [] 102 | if tokenizer: 103 | self.tok = tokenizer 104 | scans = [] 105 | for split in splits: 106 | for i_item, item in enumerate(load_datasets([split])): 107 | if args.test_only and i_item == 64: 108 | break 109 | if args.ADAPT: 110 | if "/" in split: 111 | try: 112 | new_item = dict(item) 113 | new_item['instr_id'] = item['path_id'] 114 | new_item['instructions'] = item['instructions'][0] 115 | 116 | new_item['instr_encoding'] = item['instr_enc'] 117 | 118 | new_item['img_sub_prompt'] = item['positive_trajectory'][str(new_item['instr_id'])] 119 | new_item['txt_sub_prompt'] = item['positive_obj_index'][str(new_item['instr_id'])] 120 | new_item['prompt_num'] = item['positive_trajectory_length'][ 121 | str(new_item['instr_id'])] 122 | if new_item['instr_encoding'] is not None: # Filter the wrong data 123 | self.data.append(new_item) 124 | scans.append(item['scan']) 125 | except: 126 | continue 127 | else: 128 | # Split multiple instructions into separate entries 129 | for j, instr in enumerate(item['instructions']): 130 | try: 131 | new_item = dict(item) 132 | new_item['instr_id'] = '%s_%d' % (item['path_id'], j) 133 | new_item['instructions'] = instr 134 | ''' BERT tokenizer ''' 135 | instr_tokens = tokenizer.tokenize(instr) 136 | padded_instr_tokens, num_words = pad_instr_tokens(instr_tokens, args.maxInput) 137 | new_item['instr_encoding'] = tokenizer.convert_tokens_to_ids(padded_instr_tokens) 138 | 139 | new_item['img_sub_prompt'] = item['positive_trajectory'][new_item['instr_id']] 140 | new_item['txt_sub_prompt'] = item['positive_obj_index'][ 141 | str(new_item['instr_id'])] 142 | new_item['prompt_num'] = item['positive_trajectory_length'][ 143 | new_item['instr_id']] 144 | if new_item['instr_encoding'] is not None: # Filter the wrong data 145 | self.data.append(new_item) 146 | scans.append(item['scan']) 147 | except: 148 | continue 149 | 150 | else: 151 | if "/" in split: 152 | try: 153 | new_item = dict(item) 154 | new_item['instr_id'] = item['path_id'] 155 | new_item['instructions'] = item['instructions'][0] 156 | new_item['instr_encoding'] = item['instr_enc'] 157 | if new_item['instr_encoding'] is not None: # Filter the wrong data 158 | self.data.append(new_item) 159 | scans.append(item['scan']) 160 | except: 161 | continue 162 | else: 163 | # Split multiple instructions into separate entries 164 | for j, instr in enumerate(item['instructions']): 165 | try: 166 | new_item = dict(item) 167 | new_item['instr_id'] = '%s_%d' % (item['path_id'], j) 168 | new_item['instructions'] = instr 169 | 170 | ''' BERT tokenizer ''' 171 | instr_tokens = tokenizer.tokenize(instr) 172 | padded_instr_tokens, num_words = pad_instr_tokens(instr_tokens, args.maxInput) 173 | new_item['instr_encoding'] = tokenizer.convert_tokens_to_ids(padded_instr_tokens) 174 | 175 | if new_item['instr_encoding'] is not None: # Filter the wrong data 176 | self.data.append(new_item) 177 | scans.append(item['scan']) 178 | except: 179 | continue 180 | 181 | if name is None: 182 | self.name = splits[0] if len(splits) > 0 else "FAKE" 183 | else: 184 | self.name = name 185 | 186 | self.scans = set(scans) 187 | self.splits = splits 188 | self.seed = seed 189 | random.seed(self.seed) 190 | random.shuffle(self.data) 191 | 192 | self.ix = 0 193 | self.batch_size = batch_size 194 | self._load_nav_graphs() 195 | 196 | self.angle_feature = utils.get_all_point_angle_feature() 197 | self.sim = utils.new_simulator() 198 | self.buffered_state_dict = {} 199 | 200 | # It means that the fake data is equals to data in the supervised setup 201 | self.fake_data = self.data 202 | print('R2RBatch loaded with %d instructions, using splits: %s' % (len(self.data), ",".join(splits))) 203 | 204 | def size(self): 205 | return len(self.data) 206 | 207 | def _load_nav_graphs(self): 208 | """ 209 | load graph from self.scan, 210 | Store the graph {scan_id: graph} in self.graphs 211 | Store the shortest path {scan_id: {view_id_x: {view_id_y: [path]} } } in self.paths 212 | Store the distances in self.distances. (Structure see above) 213 | Load connectivity graph for each scan, useful for reasoning about shortest paths 214 | :return: None 215 | """ 216 | print('Loading navigation graphs for %d scans' % len(self.scans)) 217 | self.graphs = load_nav_graphs(self.scans) 218 | self.paths = {} 219 | for scan, G in self.graphs.items(): # compute all shortest paths 220 | self.paths[scan] = dict(nx.all_pairs_dijkstra_path(G)) 221 | self.distances = {} 222 | for scan, G in self.graphs.items(): # compute all shortest paths 223 | self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G)) 224 | 225 | def _next_minibatch(self, tile_one=False, batch_size=None, **kwargs): 226 | """ 227 | Store the minibach in 'self.batch' 228 | :param tile_one: Tile the one into batch_size 229 | :return: None 230 | """ 231 | if batch_size is None: 232 | batch_size = self.batch_size 233 | if tile_one: 234 | batch = [self.data[self.ix]] * batch_size 235 | self.ix += 1 236 | if self.ix >= len(self.data): 237 | random.shuffle(self.data) 238 | self.ix -= len(self.data) 239 | else: 240 | batch = self.data[self.ix: self.ix+batch_size] 241 | if len(batch) < batch_size: 242 | random.shuffle(self.data) 243 | self.ix = batch_size - len(batch) 244 | batch += self.data[:self.ix] 245 | else: 246 | self.ix += batch_size 247 | self.batch = batch 248 | 249 | def reset_epoch(self, shuffle=False): 250 | ''' Reset the data index to beginning of epoch. Primarily for testing. 251 | You must still call reset() for a new episode. ''' 252 | if shuffle: 253 | random.shuffle(self.data) 254 | self.ix = 0 255 | 256 | def _shortest_path_action(self, state, goalViewpointId): 257 | ''' Determine next action on the shortest path to goal, for supervised training. ''' 258 | if state.location.viewpointId == goalViewpointId: 259 | return goalViewpointId # Just stop here 260 | path = self.paths[state.scanId][state.location.viewpointId][goalViewpointId] 261 | nextViewpointId = path[1] 262 | return nextViewpointId 263 | 264 | def make_candidate(self, feature, scanId, viewpointId, viewId): 265 | def _loc_distance(loc): 266 | return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2) 267 | base_heading = (viewId % 12) * math.radians(30) 268 | adj_dict = {} 269 | long_id = "%s_%s" % (scanId, viewpointId) 270 | if long_id not in self.buffered_state_dict: 271 | for ix in range(36): 272 | if ix == 0: 273 | self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30)) 274 | elif ix % 12 == 0: 275 | self.sim.makeAction(0, 1.0, 1.0) 276 | else: 277 | self.sim.makeAction(0, 1.0, 0) 278 | 279 | state = self.sim.getState() 280 | assert state.viewIndex == ix 281 | 282 | # Heading and elevation for the viewpoint center 283 | heading = state.heading - base_heading 284 | elevation = state.elevation 285 | 286 | visual_feat = feature[ix] 287 | 288 | # get adjacent locations 289 | for j, loc in enumerate(state.navigableLocations[1:]): 290 | # if a loc is visible from multiple view, use the closest 291 | # view (in angular distance) as its representation 292 | distance = _loc_distance(loc) 293 | 294 | # Heading and elevation for for the loc 295 | loc_heading = heading + loc.rel_heading 296 | loc_elevation = elevation + loc.rel_elevation 297 | angle_feat = utils.angle_feature(loc_heading, loc_elevation) 298 | if (loc.viewpointId not in adj_dict or 299 | distance < adj_dict[loc.viewpointId]['distance']): 300 | adj_dict[loc.viewpointId] = { 301 | 'heading': loc_heading, 302 | 'elevation': loc_elevation, 303 | "normalized_heading": state.heading + loc.rel_heading, 304 | 'scanId':scanId, 305 | 'viewpointId': loc.viewpointId, # Next viewpoint id 306 | 'pointId': ix, 307 | 'distance': distance, 308 | 'idx': j + 1, 309 | 'feature': np.concatenate((visual_feat, angle_feat), -1) 310 | } 311 | candidate = list(adj_dict.values()) 312 | self.buffered_state_dict[long_id] = [ 313 | {key: c[key] 314 | for key in 315 | ['normalized_heading', 'elevation', 'scanId', 'viewpointId', 316 | 'pointId', 'idx']} 317 | for c in candidate 318 | ] 319 | return candidate 320 | else: 321 | candidate = self.buffered_state_dict[long_id] 322 | candidate_new = [] 323 | for c in candidate: 324 | c_new = c.copy() 325 | ix = c_new['pointId'] 326 | normalized_heading = c_new['normalized_heading'] 327 | visual_feat = feature[ix] 328 | loc_heading = normalized_heading - base_heading 329 | c_new['heading'] = loc_heading 330 | angle_feat = utils.angle_feature(c_new['heading'], c_new['elevation']) 331 | c_new['feature'] = np.concatenate((visual_feat, angle_feat), -1) 332 | c_new.pop('normalized_heading') 333 | candidate_new.append(c_new) 334 | return candidate_new 335 | 336 | def _get_obs(self): 337 | obs = [] 338 | for i, (feature, state) in enumerate(self.env.getStates()): 339 | item = self.batch[i] 340 | base_view_id = state.viewIndex 341 | 342 | if feature is None: 343 | feature = np.zeros((36, 2048)) 344 | 345 | # Full features 346 | candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex) 347 | # [visual_feature, angle_feature] for views 348 | feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1) 349 | if args.ADAPT: 350 | obs.append({ 351 | 'instr_id': item['instr_id'], 352 | 'scan': state.scanId, 353 | 'viewpoint': state.location.viewpointId, 354 | 'viewIndex': state.viewIndex, 355 | 'heading': state.heading, 356 | 'elevation': state.elevation, 357 | 'feature': feature, 358 | 'candidate': candidate, 359 | 'navigableLocations': state.navigableLocations, 360 | 'instructions': item['instructions'], 361 | 'teacher': self._shortest_path_action(state, item['path'][-1]), 362 | 'gt_path': item['path'], 363 | 'path_id': item['path_id'], 364 | 'img_sub_prompt': item['img_sub_prompt'], 365 | 'txt_sub_prompt': item['txt_sub_prompt'], 366 | 'prompt_num': item['prompt_num'] 367 | }) 368 | else: 369 | obs.append({ 370 | 'instr_id' : item['instr_id'], 371 | 'scan' : state.scanId, 372 | 'viewpoint' : state.location.viewpointId, 373 | 'viewIndex' : state.viewIndex, 374 | 'heading' : state.heading, 375 | 'elevation' : state.elevation, 376 | 'feature' : feature, 377 | 'candidate': candidate, 378 | 'navigableLocations' : state.navigableLocations, 379 | 'instructions' : item['instructions'], 380 | 'teacher' : self._shortest_path_action(state, item['path'][-1]), 381 | 'gt_path' : item['path'], 382 | 'path_id' : item['path_id'] 383 | }) 384 | if 'instr_encoding' in item: 385 | obs[-1]['instr_encoding'] = item['instr_encoding'] 386 | # A2C reward. The negative distance between the state and the final state 387 | obs[-1]['distance'] = self.distances[state.scanId][state.location.viewpointId][item['path'][-1]] 388 | return obs 389 | 390 | def reset(self, batch=None, inject=False, **kwargs): 391 | ''' Load a new minibatch / episodes. ''' 392 | if batch is None: # Allow the user to explicitly define the batch 393 | self._next_minibatch(**kwargs) 394 | else: 395 | if inject: # Inject the batch into the next minibatch 396 | self._next_minibatch(**kwargs) 397 | self.batch[:len(batch)] = batch 398 | else: # Else set the batch to the current batch 399 | self.batch = batch 400 | scanIds = [item['scan'] for item in self.batch] 401 | viewpointIds = [item['path'][0] for item in self.batch] 402 | headings = [item['heading'] for item in self.batch] 403 | self.env.newEpisodes(scanIds, viewpointIds, headings) 404 | return self._get_obs() 405 | 406 | def step(self, actions): 407 | ''' Take action (same interface as makeActions) ''' 408 | self.env.makeActions(actions) 409 | return self._get_obs() 410 | 411 | def get_statistics(self): 412 | stats = {} 413 | length = 0 414 | path = 0 415 | for datum in self.data: 416 | length += len(self.tok.split_sentence(datum['instructions'])) 417 | path += self.distances[datum['scan']][datum['path'][0]][datum['path'][-1]] 418 | stats['length'] = length / len(self.data) 419 | stats['path'] = path / len(self.data) 420 | return stats 421 | -------------------------------------------------------------------------------- /r2r_src/vlnbert/vlnbert_PREVALENT.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | 4 | import json 5 | import logging 6 | import math 7 | import os 8 | import sys 9 | from io import open 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import CrossEntropyLoss, MSELoss 14 | 15 | #from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig 16 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig 17 | import pdb 18 | from param import args 19 | import torch.nn.functional as F 20 | import numpy as np 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | def gelu(x): 25 | """Implementation of the gelu activation function. 26 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 27 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 28 | Also see https://arxiv.org/abs/1606.08415 29 | """ 30 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 31 | 32 | 33 | def swish(x): 34 | return x * torch.sigmoid(x) 35 | 36 | 37 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 38 | 39 | 40 | try: 41 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 42 | except (ImportError, AttributeError) as e: 43 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") 44 | BertLayerNorm = torch.nn.LayerNorm 45 | 46 | class BertEmbeddings(nn.Module): 47 | """Construct the embeddings from word, position and token_type embeddings. 48 | """ 49 | def __init__(self, config): 50 | super(BertEmbeddings, self).__init__() 51 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 52 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 53 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 54 | 55 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 56 | # any TensorFlow checkpoint file 57 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 58 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 59 | 60 | def forward(self, input_ids, token_type_ids=None, position_ids=None): 61 | seq_length = input_ids.size(1) 62 | if position_ids is None: 63 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 64 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 65 | if token_type_ids is None: 66 | token_type_ids = torch.zeros_like(input_ids) 67 | 68 | words_embeddings = self.word_embeddings(input_ids) 69 | position_embeddings = self.position_embeddings(position_ids) 70 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 71 | 72 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 73 | embeddings = self.LayerNorm(embeddings) 74 | embeddings = self.dropout(embeddings) 75 | return embeddings 76 | 77 | 78 | class BertSelfAttention(nn.Module): 79 | def __init__(self, config): 80 | super(BertSelfAttention, self).__init__() 81 | if config.hidden_size % config.num_attention_heads != 0: 82 | raise ValueError( 83 | "The hidden size (%d) is not a multiple of the number of attention " 84 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 85 | self.output_attentions = True 86 | 87 | self.num_attention_heads = config.num_attention_heads 88 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 89 | self.all_head_size = self.num_attention_heads * self.attention_head_size 90 | 91 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 92 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 93 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 94 | 95 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 96 | 97 | def transpose_for_scores(self, x): 98 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 99 | x = x.view(*new_x_shape) 100 | return x.permute(0, 2, 1, 3) 101 | 102 | def forward(self, hidden_states, attention_mask, head_mask=None): 103 | mixed_query_layer = self.query(hidden_states) 104 | mixed_key_layer = self.key(hidden_states) 105 | mixed_value_layer = self.value(hidden_states) 106 | 107 | query_layer = self.transpose_for_scores(mixed_query_layer) 108 | key_layer = self.transpose_for_scores(mixed_key_layer) 109 | value_layer = self.transpose_for_scores(mixed_value_layer) 110 | 111 | # Take the dot product between "query" and "key" to get the raw attention scores. 112 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 113 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 114 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 115 | attention_scores = attention_scores + attention_mask 116 | 117 | # Normalize the attention scores to probabilities. 118 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 119 | 120 | # This is actually dropping out entire tokens to attend to, which might 121 | # seem a bit unusual, but is taken from the original Transformer paper. 122 | attention_probs = self.dropout(attention_probs) 123 | 124 | # Mask heads if we want to 125 | if head_mask is not None: 126 | attention_probs = attention_probs * head_mask 127 | 128 | context_layer = torch.matmul(attention_probs, value_layer) 129 | 130 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 131 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 132 | context_layer = context_layer.view(*new_context_layer_shape) 133 | 134 | outputs = (context_layer, attention_scores) if self.output_attentions else (context_layer,) 135 | return outputs 136 | 137 | 138 | class BertSelfOutput(nn.Module): 139 | def __init__(self, config): 140 | super(BertSelfOutput, self).__init__() 141 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 142 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 143 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 144 | 145 | def forward(self, hidden_states, input_tensor): 146 | hidden_states = self.dense(hidden_states) 147 | hidden_states = self.dropout(hidden_states) 148 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 149 | return hidden_states 150 | 151 | 152 | class BertAttention(nn.Module): 153 | def __init__(self, config): 154 | super(BertAttention, self).__init__() 155 | self.self = BertSelfAttention(config) 156 | self.output = BertSelfOutput(config) 157 | 158 | def forward(self, input_tensor, attention_mask, head_mask=None): 159 | self_outputs = self.self(input_tensor, attention_mask, head_mask) 160 | attention_output = self.output(self_outputs[0], input_tensor) 161 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 162 | return outputs 163 | 164 | 165 | class BertIntermediate(nn.Module): 166 | def __init__(self, config): 167 | super(BertIntermediate, self).__init__() 168 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 169 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 170 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 171 | else: 172 | self.intermediate_act_fn = config.hidden_act 173 | 174 | def forward(self, hidden_states): 175 | hidden_states = self.dense(hidden_states) 176 | hidden_states = self.intermediate_act_fn(hidden_states) 177 | return hidden_states 178 | 179 | 180 | class BertOutput(nn.Module): 181 | def __init__(self, config): 182 | super(BertOutput, self).__init__() 183 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 184 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 185 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 186 | 187 | def forward(self, hidden_states, input_tensor): 188 | hidden_states = self.dense(hidden_states) 189 | hidden_states = self.dropout(hidden_states) 190 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 191 | return hidden_states 192 | 193 | 194 | class BertLayer(nn.Module): 195 | def __init__(self, config): 196 | super(BertLayer, self).__init__() 197 | self.attention = BertAttention(config) 198 | self.intermediate = BertIntermediate(config) 199 | self.output = BertOutput(config) 200 | 201 | def forward(self, hidden_states, attention_mask, head_mask=None): 202 | attention_outputs = self.attention(hidden_states, attention_mask, head_mask) 203 | attention_output = attention_outputs[0] 204 | intermediate_output = self.intermediate(attention_output) 205 | layer_output = self.output(intermediate_output, attention_output) 206 | outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them 207 | return outputs 208 | 209 | 210 | class BertPooler(nn.Module): 211 | def __init__(self, config): 212 | super(BertPooler, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.activation = nn.Tanh() 215 | 216 | def forward(self, hidden_states): 217 | # We "pool" the model by simply taking the hidden state corresponding 218 | # to the first token. 219 | first_token_tensor = hidden_states[:, 0] 220 | pooled_output = self.dense(first_token_tensor) 221 | pooled_output = self.activation(pooled_output) 222 | return pooled_output 223 | 224 | 225 | class BertXAttention(nn.Module): 226 | def __init__(self, config, ctx_dim=None): 227 | super().__init__() 228 | self.att = BertOutAttention(config, ctx_dim=ctx_dim) 229 | self.output = BertSelfOutput(config) 230 | 231 | def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None): 232 | output, attention_scores = self.att(input_tensor, ctx_tensor, ctx_att_mask) 233 | attention_output = self.output(output, input_tensor) 234 | return attention_output, attention_scores 235 | 236 | 237 | class BertOutAttention(nn.Module): 238 | def __init__(self, config, ctx_dim=None): 239 | super().__init__() 240 | if config.hidden_size % config.num_attention_heads != 0: 241 | raise ValueError( 242 | "The hidden size (%d) is not a multiple of the number of attention " 243 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 244 | self.num_attention_heads = config.num_attention_heads 245 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 246 | self.all_head_size = self.num_attention_heads * self.attention_head_size 247 | 248 | # visual_dim = 2048 249 | if ctx_dim is None: 250 | ctx_dim =config.hidden_size 251 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 252 | self.key = nn.Linear(ctx_dim, self.all_head_size) 253 | self.value = nn.Linear(ctx_dim, self.all_head_size) 254 | 255 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 256 | 257 | def transpose_for_scores(self, x): 258 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 259 | x = x.view(*new_x_shape) 260 | return x.permute(0, 2, 1, 3) 261 | 262 | def forward(self, hidden_states, context, attention_mask=None): 263 | mixed_query_layer = self.query(hidden_states) 264 | mixed_key_layer = self.key(context) 265 | mixed_value_layer = self.value(context) 266 | 267 | query_layer = self.transpose_for_scores(mixed_query_layer) 268 | key_layer = self.transpose_for_scores(mixed_key_layer) 269 | value_layer = self.transpose_for_scores(mixed_value_layer) 270 | 271 | # Take the dot product between "query" and "key" to get the raw attention scores. 272 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 273 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 274 | 275 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 276 | if attention_mask is not None: 277 | attention_scores = attention_scores + attention_mask 278 | 279 | # Normalize the attention scores to probabilities. 280 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 281 | 282 | # This is actually dropping out entire tokens to attend to, which might 283 | # seem a bit unusual, but is taken from the original Transformer paper. 284 | attention_probs = self.dropout(attention_probs) 285 | 286 | context_layer = torch.matmul(attention_probs, value_layer) 287 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 288 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 289 | context_layer = context_layer.view(*new_context_layer_shape) 290 | return context_layer, attention_scores 291 | 292 | 293 | class LXRTXLayer(nn.Module): 294 | def __init__(self, config): 295 | super().__init__() 296 | self.config = config 297 | # Lang self-att and FFN layer 298 | self.lang_self_att = BertAttention(config) 299 | self.lang_inter = BertIntermediate(config) 300 | self.lang_output = BertOutput(config) 301 | # Visn self-att and FFN layer 302 | self.visn_self_att = BertAttention(config) 303 | self.visn_inter = BertIntermediate(config) 304 | self.visn_output = BertOutput(config) 305 | # The cross attention layer 306 | self.visual_attention = BertXAttention(config) 307 | 308 | def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask): 309 | ''' Cross Attention -- cross for vision not for language ''' 310 | visn_att_output, attention_scores = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask) 311 | return visn_att_output, attention_scores 312 | 313 | def self_att(self, visn_input, visn_attention_mask): 314 | ''' Self Attention -- on visual features with language clues ''' 315 | visn_att_output = self.visn_self_att(visn_input, visn_attention_mask) 316 | return visn_att_output 317 | 318 | def output_fc(self, visn_input): 319 | ''' Feed forward ''' 320 | visn_inter_output = self.visn_inter(visn_input) 321 | visn_output = self.visn_output(visn_inter_output, visn_input) 322 | return visn_output 323 | 324 | def forward(self, lang_feats, lang_attention_mask, 325 | visn_feats, visn_attention_mask, tdx): 326 | 327 | ''' visual self-attention with state ''' 328 | visn_att_output = torch.cat((lang_feats[:, 0:1, :], visn_feats), dim=1) 329 | state_vis_mask = torch.cat((lang_attention_mask[:,:,:,0:1], visn_attention_mask), dim=-1) 330 | 331 | ''' state and vision attend to language ''' 332 | visn_att_output, cross_attention_scores = self.cross_att(lang_feats[:, 1:, :], lang_attention_mask[:, :, :, 1:], visn_att_output, state_vis_mask) 333 | 334 | language_attention_scores = cross_attention_scores[:, :, 0, :] 335 | 336 | state_visn_att_output = self.self_att(visn_att_output, state_vis_mask) 337 | state_visn_output = self.output_fc(state_visn_att_output[0]) 338 | 339 | visn_att_output = state_visn_output[:, 1:, :] 340 | lang_att_output = torch.cat((state_visn_output[:, 0:1, :], lang_feats[:,1:,:]), dim=1) 341 | 342 | visual_attention_scores = state_visn_att_output[1][:, :, 0, 1:] 343 | 344 | return lang_att_output, visn_att_output, language_attention_scores, visual_attention_scores 345 | 346 | 347 | class VisionEncoder(nn.Module): 348 | def __init__(self, vision_size, config): 349 | super().__init__() 350 | feat_dim = vision_size 351 | 352 | # Object feature encoding 353 | self.visn_fc = nn.Linear(feat_dim, config.hidden_size) 354 | self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12) 355 | 356 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 357 | 358 | if args.features == 'clip': 359 | self.visn_fc_pre = nn.Linear(512 + 128, feat_dim) 360 | 361 | def forward(self, visn_input): 362 | feats = visn_input 363 | 364 | if args.features == 'clip': 365 | feats = self.visn_fc_pre(feats) 366 | 367 | x = self.visn_fc(feats) 368 | x = self.visn_layer_norm(x) 369 | 370 | output = self.dropout(x) 371 | return output 372 | 373 | 374 | class VLNBert(BertPreTrainedModel): 375 | def __init__(self, config): 376 | super(VLNBert, self).__init__(config) 377 | self.embeddings = BertEmbeddings(config) 378 | self.pooler = BertPooler(config) 379 | 380 | self.img_dim = config.img_feature_dim # 2176 381 | logger.info('VLNBert Image Dimension: {}'.format(self.img_dim)) 382 | self.img_feature_type = config.img_feature_type # '' 383 | self.vl_layers = config.vl_layers # 4 384 | self.la_layers = config.la_layers # 9 385 | self.lalayer = nn.ModuleList( 386 | [BertLayer(config) for _ in range(self.la_layers)]) 387 | self.addlayer = nn.ModuleList( 388 | [LXRTXLayer(config) for _ in range(self.vl_layers)]) 389 | self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config) 390 | 391 | if args.ADAPT: 392 | self.txt_enc = nn.Sequential(nn.Linear(512, 768), 393 | nn.Dropout(0.1)) 394 | self.img_enc = nn.Sequential(nn.Linear(512, 768), 395 | nn.Dropout(0.1)) 396 | self.action_enc = nn.Sequential(nn.Linear(768*2, 768), 397 | nn.Dropout(0.1)) 398 | 399 | self.apply(self.init_weights) 400 | 401 | def forward(self, mode, input_ids, token_type_ids=None, 402 | attention_mask=None, lang_mask=None, vis_mask=None, position_ids=None, head_mask=None, 403 | img_feats=None, prompt_mask_set=None, 404 | txt_sub_prompt_set=None, img_sub_prompt_set=None): 405 | 406 | attention_mask = lang_mask 407 | 408 | if token_type_ids is None: 409 | token_type_ids = torch.zeros_like(input_ids) 410 | 411 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 412 | 413 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 414 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 415 | 416 | head_mask = [None] * self.config.num_hidden_layers 417 | 418 | if mode == 'language': 419 | ''' LXMERT language branch (in VLN only perform this at initialization) ''' 420 | embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) 421 | text_embeds = embedding_output 422 | 423 | for layer_module in self.lalayer: 424 | temp_output = layer_module(text_embeds, extended_attention_mask) 425 | text_embeds = temp_output[0] 426 | 427 | sequence_output = text_embeds 428 | pooled_output = self.pooler(sequence_output) 429 | 430 | return pooled_output, sequence_output 431 | 432 | elif mode == 'visual': 433 | ''' LXMERT visual branch (no language processing during navigation) ''' 434 | text_embeds = input_ids 435 | 436 | text_mask = extended_attention_mask 437 | 438 | img_embedding_output = self.vision_encoder(img_feats) 439 | img_seq_len = img_feats.shape[1] 440 | batch_size = text_embeds.size(0) 441 | 442 | img_seq_mask = vis_mask 443 | 444 | extended_img_mask = img_seq_mask.unsqueeze(1).unsqueeze(2) 445 | extended_img_mask = extended_img_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 446 | extended_img_mask = (1.0 - extended_img_mask) * -10000.0 447 | img_mask = extended_img_mask 448 | 449 | lang_output = text_embeds 450 | visn_output = img_embedding_output 451 | 452 | if txt_sub_prompt_set is not None and img_sub_prompt_set is not None: 453 | extended_prompt_mask = prompt_mask_set.unsqueeze(1).unsqueeze(2) 454 | extended_prompt_mask = extended_prompt_mask.to(dtype=next(self.parameters()).dtype) 455 | extended_prompt_mask = (1.0 - extended_prompt_mask) * -10000.0 456 | prompt_mask = extended_prompt_mask.cuda() 457 | 458 | text_mask = torch.cat([text_mask, prompt_mask], dim=3) 459 | txt_sub_prompt_feat = self.txt_enc(txt_sub_prompt_set) 460 | img_sub_prompt_feat = self.img_enc(img_sub_prompt_set) 461 | prompt_feat = self.action_enc(torch.cat([txt_sub_prompt_feat, img_sub_prompt_feat], dim=-1)) 462 | lang_output = torch.cat([lang_output, prompt_feat], dim=1) 463 | 464 | 465 | for tdx, layer_module in enumerate(self.addlayer): 466 | lang_output, visn_output, language_attention_scores, visual_attention_scores = layer_module(lang_output, text_mask, visn_output, img_mask, tdx) 467 | 468 | if prompt_feat is not None and prompt_mask_set is not None: 469 | language_attention_scores = language_attention_scores[:, :, 470 | :language_attention_scores.size(2) - prompt_feat.size(1)] 471 | prompt_attention_scores = language_attention_scores[:, :, 472 | language_attention_scores.size(2) - prompt_feat.size(1):] 473 | 474 | prompt_scores = prompt_attention_scores.mean(dim=1) 475 | prompt_probs = nn.Softmax(dim=-1)(prompt_scores.clone()).unsqueeze(-1) 476 | attended_txt_sub_prompt_feat = (prompt_probs * txt_sub_prompt_feat).sum(1) 477 | attended_img_sub_prompt_feat = (prompt_probs * img_sub_prompt_feat).sum(1) 478 | 479 | 480 | sequence_output = lang_output 481 | pooled_output = self.pooler(sequence_output) 482 | 483 | language_state_scores = language_attention_scores.mean(dim=1) 484 | visual_action_scores = visual_attention_scores.mean(dim=1) 485 | 486 | # weighted_feat 487 | language_attention_probs = nn.Softmax(dim=-1)(language_state_scores.clone()).unsqueeze(-1) 488 | visual_attention_probs = nn.Softmax(dim=-1)(visual_action_scores.clone()).unsqueeze(-1) 489 | 490 | attended_language = (language_attention_probs * text_embeds[:, 1:, :]).sum(1) 491 | attended_visual = (visual_attention_probs * img_embedding_output).sum(1) 492 | 493 | if prompt_feat is not None and prompt_mask_set is not None: 494 | return pooled_output, visual_action_scores, attended_language, attended_visual, attended_txt_sub_prompt_feat, attended_img_sub_prompt_feat, txt_sub_prompt_feat, img_sub_prompt_feat 495 | else: 496 | return pooled_output, visual_action_scores, attended_language, attended_visual 497 | -------------------------------------------------------------------------------- /r2r_src/utils.py: -------------------------------------------------------------------------------- 1 | ''' Utils for io, language, connectivity graphs etc ''' 2 | 3 | import os 4 | import sys 5 | import re 6 | sys.path.append('Matterport_Simulator/build/') 7 | import MatterSim 8 | import string 9 | import json 10 | import time 11 | import math 12 | from collections import Counter, defaultdict 13 | import numpy as np 14 | import networkx as nx 15 | from param import args 16 | from numpy.linalg import norm 17 | 18 | 19 | # padding, unknown word, end of sentence 20 | base_vocab = ['', '', ''] 21 | padding_idx = base_vocab.index('') 22 | 23 | def load_nav_graphs(scans): 24 | ''' Load connectivity graph for each scan ''' 25 | 26 | def distance(pose1, pose2): 27 | ''' Euclidean distance between two graph poses ''' 28 | return ((pose1['pose'][3]-pose2['pose'][3])**2\ 29 | + (pose1['pose'][7]-pose2['pose'][7])**2\ 30 | + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5 31 | 32 | graphs = {} 33 | for scan in scans: 34 | with open('connectivity/%s_connectivity.json' % scan) as f: 35 | G = nx.Graph() 36 | positions = {} 37 | data = json.load(f) 38 | for i,item in enumerate(data): 39 | if item['included']: 40 | for j,conn in enumerate(item['unobstructed']): 41 | if conn and data[j]['included']: 42 | positions[item['image_id']] = np.array([item['pose'][3], 43 | item['pose'][7], item['pose'][11]]); 44 | assert data[j]['unobstructed'][i], 'Graph should be undirected' 45 | G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j])) 46 | nx.set_node_attributes(G, values=positions, name='position') 47 | graphs[scan] = G 48 | return graphs 49 | 50 | 51 | def load_datasets(splits): 52 | """ 53 | 54 | :param splits: A list of split. 55 | if the split is "something@5000", it will use a random 5000 data from the data 56 | :return: 57 | """ 58 | import random 59 | data = [] 60 | old_state = random.getstate() 61 | for split in splits: 62 | # It only needs some part of the dataset? 63 | components = split.split("@") 64 | number = -1 65 | if len(components) > 1: 66 | split, number = components[0], int(components[1]) 67 | 68 | # Load Json 69 | # if split in ['train', 'val_seen', 'val_unseen', 'test', 70 | # 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check 71 | if args.ADAPT: 72 | if "/" not in split: 73 | with open('data/R2R_%s_ADAPT.json' % split) as f: 74 | new_data = json.load(f) 75 | else: 76 | print('\nLoading prevalent data for pretraining...') 77 | with open('data/R2R_aug_ADAPT.json') as f: 78 | new_data = json.load(f) 79 | else: 80 | if "/" not in split: 81 | with open('data/R2R_%s.json' % split) as f: 82 | new_data = json.load(f) 83 | else: 84 | print('\nLoading prevalent data for pretraining...') 85 | with open(split) as f: 86 | new_data = json.load(f) 87 | 88 | # Partition 89 | if number > 0: 90 | random.seed(0) # Make the data deterministic, additive 91 | random.shuffle(new_data) 92 | new_data = new_data[:number] 93 | 94 | # Join 95 | data += new_data 96 | random.setstate(old_state) # Recover the state of the random generator 97 | return data 98 | 99 | 100 | def pad_instr_tokens(instr_tokens, maxlength=20): 101 | 102 | if len(instr_tokens) <= 2: #assert len(raw_instr_tokens) > 2 103 | return None 104 | 105 | if len(instr_tokens) > maxlength - 2: # -2 for [CLS] and [SEP] 106 | instr_tokens = instr_tokens[:(maxlength-2)] 107 | 108 | instr_tokens = ['[CLS]'] + instr_tokens + ['[SEP]'] 109 | num_words = len(instr_tokens) # - 1 # include [SEP] 110 | instr_tokens += ['[PAD]'] * (maxlength-len(instr_tokens)) 111 | 112 | assert len(instr_tokens) == maxlength 113 | 114 | return instr_tokens, num_words 115 | 116 | 117 | class Tokenizer(object): 118 | ''' Class to tokenize and encode a sentence. ''' 119 | SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') # Split on any non-alphanumeric character 120 | 121 | def __init__(self, vocab=None, encoding_length=20): 122 | self.encoding_length = encoding_length 123 | self.vocab = vocab 124 | self.word_to_index = {} 125 | self.index_to_word = {} 126 | if vocab: 127 | for i,word in enumerate(vocab): 128 | self.word_to_index[word] = i 129 | new_w2i = defaultdict(lambda: self.word_to_index['']) 130 | new_w2i.update(self.word_to_index) 131 | self.word_to_index = new_w2i 132 | for key, value in self.word_to_index.items(): 133 | self.index_to_word[value] = key 134 | old = self.vocab_size() 135 | self.add_word('') 136 | assert self.vocab_size() == old+1 137 | print("OLD_VOCAB_SIZE", old) 138 | print("VOCAB_SIZE", self.vocab_size()) 139 | print("VOACB", len(vocab)) 140 | 141 | def finalize(self): 142 | """ 143 | This is used for debug 144 | """ 145 | self.word_to_index = dict(self.word_to_index) # To avoid using mis-typing tokens 146 | 147 | def add_word(self, word): 148 | assert word not in self.word_to_index 149 | self.word_to_index[word] = self.vocab_size() # vocab_size() is the 150 | self.index_to_word[self.vocab_size()] = word 151 | 152 | @staticmethod 153 | def split_sentence(sentence): 154 | ''' Break sentence into a list of words and punctuation ''' 155 | toks = [] 156 | for word in [s.strip().lower() for s in Tokenizer.SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0]: 157 | # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..' 158 | if all(c in string.punctuation for c in word) and not all(c in '.' for c in word): 159 | toks += list(word) 160 | else: 161 | toks.append(word) 162 | return toks 163 | 164 | def vocab_size(self): 165 | return len(self.index_to_word) 166 | 167 | def encode_sentence(self, sentence, max_length=None): 168 | if max_length is None: 169 | max_length = self.encoding_length 170 | if len(self.word_to_index) == 0: 171 | sys.exit('Tokenizer has no vocab') 172 | 173 | encoding = [self.word_to_index['']] 174 | for word in self.split_sentence(sentence): 175 | encoding.append(self.word_to_index[word]) # Default Dict 176 | encoding.append(self.word_to_index['']) 177 | 178 | if len(encoding) <= 2: 179 | return None 180 | #assert len(encoding) > 2 181 | 182 | if len(encoding) < max_length: 183 | encoding += [self.word_to_index['']] * (max_length-len(encoding)) # Padding 184 | elif len(encoding) > max_length: 185 | encoding[max_length - 1] = self.word_to_index[''] # Cut the length with EOS 186 | 187 | return np.array(encoding[:max_length]) 188 | 189 | def decode_sentence(self, encoding, length=None): 190 | sentence = [] 191 | if length is not None: 192 | encoding = encoding[:length] 193 | for ix in encoding: 194 | if ix == self.word_to_index['']: 195 | break 196 | else: 197 | sentence.append(self.index_to_word[ix]) 198 | return " ".join(sentence) 199 | 200 | def shrink(self, inst): 201 | """ 202 | :param inst: The id inst 203 | :return: Remove the potential and 204 | If no return empty list 205 | """ 206 | if len(inst) == 0: 207 | return inst 208 | end = np.argmax(np.array(inst) == self.word_to_index['']) # If no , return empty string 209 | if len(inst) > 1 and inst[0] == self.word_to_index['']: 210 | start = 1 211 | else: 212 | start = 0 213 | # print(inst, start, end) 214 | return inst[start: end] 215 | 216 | 217 | def build_vocab(splits=['train'], min_count=5, start_vocab=base_vocab): 218 | ''' Build a vocab, starting with base vocab containing a few useful tokens. ''' 219 | count = Counter() 220 | t = Tokenizer() 221 | data = load_datasets(splits) 222 | for item in data: 223 | for instr in item['instructions']: 224 | count.update(t.split_sentence(instr)) 225 | vocab = list(start_vocab) 226 | for word,num in count.most_common(): 227 | if num >= min_count: 228 | vocab.append(word) 229 | else: 230 | break 231 | return vocab 232 | 233 | 234 | def write_vocab(vocab, path): 235 | print('Writing vocab of size %d to %s' % (len(vocab),path)) 236 | with open(path, 'w') as f: 237 | for word in vocab: 238 | f.write("%s\n" % word) 239 | 240 | 241 | def read_vocab(path): 242 | with open(path) as f: 243 | vocab = [word.strip() for word in f.readlines()] 244 | return vocab 245 | 246 | 247 | def asMinutes(s): 248 | m = math.floor(s / 60) 249 | s -= m * 60 250 | return '%dm %ds' % (m, s) 251 | 252 | 253 | def timeSince(since, percent): 254 | now = time.time() 255 | s = now - since 256 | es = s / (percent) 257 | rs = es - s 258 | return '%s (- %s)' % (asMinutes(s), asMinutes(rs)) 259 | 260 | def read_img_features(feature_store, test_only=False): 261 | import csv 262 | import base64 263 | from tqdm import tqdm 264 | 265 | print("Start loading the image feature ... (~50 seconds)") 266 | start = time.time() 267 | 268 | if "detectfeat" in args.features: 269 | views = int(args.features[10:]) 270 | else: 271 | views = 36 272 | 273 | args.views = views 274 | 275 | tsv_fieldnames = ['scanId', 'viewpointId', 'image_w', 'image_h', 'vfov', 'features'] 276 | 277 | if not test_only: 278 | features = {} 279 | with open(feature_store, "r") as tsv_in_file: # Open the tsv file. 280 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=tsv_fieldnames) 281 | for item in reader: 282 | long_id = item['scanId'] + "_" + item['viewpointId'] 283 | features[long_id] = np.frombuffer(base64.decodestring(item['features'].encode('ascii')), 284 | dtype=np.float32).reshape((views, -1)) # Feature of long_id is (36, 2048) 285 | else: 286 | features = None 287 | 288 | print("Finish Loading the image feature from %s in %0.4f seconds" % (feature_store, time.time() - start)) 289 | return features 290 | 291 | def read_candidates(candidates_store): 292 | import csv 293 | import base64 294 | from collections import defaultdict 295 | print("Start loading the candidate feature") 296 | 297 | start = time.time() 298 | 299 | TSV_FIELDNAMES = ['scanId', 'viewpointId', 'heading', 'elevation', 'next', 'pointId', 'idx', 'feature'] 300 | candidates = defaultdict(lambda: list()) 301 | items = 0 302 | with open(candidates_store, "r") as tsv_in_file: # Open the tsv file. 303 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=TSV_FIELDNAMES) 304 | for item in reader: 305 | long_id = item['scanId'] + "_" + item['viewpointId'] 306 | candidates[long_id].append( 307 | {'heading': float(item['heading']), 308 | 'elevation': float(item['elevation']), 309 | 'scanId': item['scanId'], 310 | 'viewpointId': item['next'], 311 | 'pointId': int(item['pointId']), 312 | 'idx': int(item['idx']) + 1, # Because a bug in the precompute code, here +1 is important 313 | 'feature': np.frombuffer( 314 | base64.decodestring(item['feature'].encode('ascii')), 315 | dtype=np.float32) 316 | } 317 | ) 318 | items += 1 319 | 320 | for long_id in candidates: 321 | assert (len(candidates[long_id])) != 0 322 | 323 | assert sum(len(candidate) for candidate in candidates.values()) == items 324 | 325 | # candidate = candidates[long_id] 326 | # print(candidate) 327 | print("Finish Loading the candidates from %s in %0.4f seconds" % (candidates_store, time.time() - start)) 328 | candidates = dict(candidates) 329 | return candidates 330 | 331 | def add_exploration(paths): 332 | explore = json.load(open("data/exploration.json", 'r')) 333 | inst2explore = {path['instr_id']: path['trajectory'] for path in explore} 334 | for path in paths: 335 | path['trajectory'] = inst2explore[path['instr_id']] + path['trajectory'] 336 | return paths 337 | 338 | def angle_feature(heading, elevation): 339 | 340 | import math 341 | # twopi = math.pi * 2 342 | # heading = (heading + twopi) % twopi # From 0 ~ 2pi 343 | # It will be the same 344 | return np.array([math.sin(heading), math.cos(heading), 345 | math.sin(elevation), math.cos(elevation)] * (args.angle_feat_size // 4), 346 | dtype=np.float32) 347 | 348 | def new_simulator(): 349 | import MatterSim 350 | # Simulator image parameters 351 | WIDTH = 640 352 | HEIGHT = 480 353 | VFOV = 60 354 | 355 | sim = MatterSim.Simulator() 356 | sim.setRenderingEnabled(False) 357 | sim.setCameraResolution(WIDTH, HEIGHT) 358 | sim.setCameraVFOV(math.radians(VFOV)) 359 | sim.setDiscretizedViewingAngles(True) 360 | sim.init() 361 | 362 | return sim 363 | 364 | def get_point_angle_feature(baseViewId=0): 365 | sim = new_simulator() 366 | 367 | feature = np.empty((36, args.angle_feat_size), np.float32) 368 | base_heading = (baseViewId % 12) * math.radians(30) 369 | for ix in range(36): 370 | if ix == 0: 371 | sim.newEpisode('ZMojNkEp431', '2f4d90acd4024c269fb0efe49a8ac540', 0, math.radians(-30)) 372 | elif ix % 12 == 0: 373 | sim.makeAction(0, 1.0, 1.0) 374 | else: 375 | sim.makeAction(0, 1.0, 0) 376 | 377 | state = sim.getState() 378 | assert state.viewIndex == ix 379 | 380 | heading = state.heading - base_heading 381 | 382 | feature[ix, :] = angle_feature(heading, state.elevation) 383 | return feature 384 | 385 | def get_all_point_angle_feature(): 386 | return [get_point_angle_feature(baseViewId) for baseViewId in range(36)] 387 | 388 | def add_idx(inst): 389 | toks = Tokenizer.split_sentence(inst) 390 | return " ".join([str(idx)+tok for idx, tok in enumerate(toks)]) 391 | 392 | import signal 393 | class GracefulKiller: 394 | kill_now = False 395 | def __init__(self): 396 | signal.signal(signal.SIGINT, self.exit_gracefully) 397 | signal.signal(signal.SIGTERM, self.exit_gracefully) 398 | 399 | def exit_gracefully(self,signum, frame): 400 | self.kill_now = True 401 | 402 | from collections import OrderedDict 403 | 404 | class Timer: 405 | def __init__(self): 406 | self.cul = OrderedDict() 407 | self.start = {} 408 | self.iter = 0 409 | 410 | def reset(self): 411 | self.cul = OrderedDict() 412 | self.start = {} 413 | self.iter = 0 414 | 415 | def tic(self, key): 416 | self.start[key] = time.time() 417 | 418 | def toc(self, key): 419 | delta = time.time() - self.start[key] 420 | if key not in self.cul: 421 | self.cul[key] = delta 422 | else: 423 | self.cul[key] += delta 424 | 425 | def step(self): 426 | self.iter += 1 427 | 428 | def show(self): 429 | total = sum(self.cul.values()) 430 | for key in self.cul: 431 | print("%s, total time %0.2f, avg time %0.2f, part of %0.2f" % 432 | (key, self.cul[key], self.cul[key]*1./self.iter, self.cul[key]*1./total)) 433 | print(total / self.iter) 434 | 435 | 436 | stop_word_list = [ 437 | ",", ".", "and", "?", "!" 438 | ] 439 | 440 | 441 | def stop_words_location(inst, mask=False): 442 | toks = Tokenizer.split_sentence(inst) 443 | sws = [i for i, tok in enumerate(toks) if tok in stop_word_list] # The index of the stop words 444 | if len(sws) == 0 or sws[-1] != (len(toks)-1): # Add the index of the last token 445 | sws.append(len(toks)-1) 446 | sws = [x for x, y in zip(sws[:-1], sws[1:]) if x+1 != y] + [sws[-1]] # Filter the adjacent stop word 447 | sws_mask = np.ones(len(toks), np.int32) # Create the mask 448 | sws_mask[sws] = 0 449 | return sws_mask if mask else sws 450 | 451 | def get_segments(inst, mask=False): 452 | toks = Tokenizer.split_sentence(inst) 453 | sws = [i for i, tok in enumerate(toks) if tok in stop_word_list] # The index of the stop words 454 | sws = [-1] + sws + [len(toks)] # Add the and positions 455 | segments = [toks[sws[i]+1:sws[i+1]] for i in range(len(sws)-1)] # Slice the segments from the tokens 456 | segments = list(filter(lambda x: len(x)>0, segments)) # remove the consecutive stop words 457 | return segments 458 | 459 | def clever_pad_sequence(sequences, batch_first=True, padding_value=0): 460 | max_size = sequences[0].size() 461 | max_len, trailing_dims = max_size[0], max_size[1:] 462 | max_len = max(seq.size()[0] for seq in sequences) 463 | if batch_first: 464 | out_dims = (len(sequences), max_len) + trailing_dims 465 | else: 466 | out_dims = (max_len, len(sequences)) + trailing_dims 467 | if padding_value is not None: 468 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 469 | for i, tensor in enumerate(sequences): 470 | length = tensor.size(0) 471 | # use index notation to prevent duplicate references to the tensor 472 | if batch_first: 473 | out_tensor[i, :length, ...] = tensor 474 | else: 475 | out_tensor[:length, i, ...] = tensor 476 | 477 | return out_tensor 478 | 479 | import torch 480 | def length2mask(length, size=None): 481 | batch_size = len(length) 482 | size = int(max(length)) if size is None else size 483 | mask = (torch.arange(size, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1) 484 | > (torch.LongTensor(length) - 1).unsqueeze(1)).cuda() 485 | return mask 486 | 487 | def average_length(path2inst): 488 | length = [] 489 | 490 | for name in path2inst: 491 | datum = path2inst[name] 492 | length.append(len(datum)) 493 | return sum(length) / len(length) 494 | 495 | def tile_batch(tensor, multiplier): 496 | _, *s = tensor.size() 497 | tensor = tensor.unsqueeze(1).expand(-1, multiplier, *(-1,) * len(s)).contiguous().view(-1, *s) 498 | return tensor 499 | 500 | def viewpoint_drop_mask(viewpoint, seed=None, drop_func=None): 501 | local_seed = hash(viewpoint) ^ seed 502 | torch.random.manual_seed(local_seed) 503 | drop_mask = drop_func(torch.ones(2048).cuda()) 504 | return drop_mask 505 | 506 | 507 | class FloydGraph: 508 | def __init__(self): 509 | self._dis = defaultdict(lambda :defaultdict(lambda: 95959595)) 510 | self._point = defaultdict(lambda :defaultdict(lambda: "")) 511 | self._visited = set() 512 | 513 | def distance(self, x, y): 514 | if x == y: 515 | return 0 516 | else: 517 | return self._dis[x][y] 518 | 519 | def add_edge(self, x, y, dis): 520 | if dis < self._dis[x][y]: 521 | self._dis[x][y] = dis 522 | self._dis[y][x] = dis 523 | self._point[x][y] = "" 524 | self._point[y][x] = "" 525 | 526 | def update(self, k): 527 | for x in self._dis: 528 | for y in self._dis: 529 | if x != y: 530 | if self._dis[x][k] + self._dis[k][y] < self._dis[x][y]: 531 | self._dis[x][y] = self._dis[x][k] + self._dis[k][y] 532 | self._dis[y][x] = self._dis[x][y] 533 | self._point[x][y] = k 534 | self._point[y][x] = k 535 | self._visited.add(k) 536 | 537 | def visited(self, k): 538 | return (k in self._visited) 539 | 540 | def path(self, x, y): 541 | """ 542 | :param x: start 543 | :param y: end 544 | :return: the path from x to y [v1, v2, ..., v_n, y] 545 | """ 546 | if x == y: 547 | return [] 548 | if self._point[x][y] == "": # Direct edge 549 | return [y] 550 | else: 551 | k = self._point[x][y] 552 | # print(x, y, k) 553 | # for x1 in (x, k, y): 554 | # for x2 in (x, k, y): 555 | # print(x1, x2, "%.4f" % self._dis[x1][x2]) 556 | return self.path(x, k) + self.path(k, y) 557 | 558 | def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_length=100): 559 | """ 560 | Call in a loop to create terminal progress bar 561 | @params: 562 | iteration - Required : current iteration (Int) 563 | total - Required : total iterations (Int) 564 | prefix - Optional : prefix string (Str) 565 | suffix - Optional : suffix string (Str) 566 | decimals - Optional : positive number of decimals in percent complete (Int) 567 | bar_length - Optional : character length of bar (Int) 568 | """ 569 | str_format = "{0:." + str(decimals) + "f}" 570 | percents = str_format.format(100 * (iteration / float(total))) 571 | filled_length = int(round(bar_length * iteration / float(total))) 572 | bar = '█' * filled_length + '-' * (bar_length - filled_length) 573 | 574 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 575 | 576 | if iteration == total: 577 | sys.stdout.write('\n') 578 | sys.stdout.flush() 579 | 580 | def ndtw_initialize(): 581 | ndtw_criterion = {} 582 | scan_gts_dir = 'data/id_paths.json' 583 | with open(scan_gts_dir) as f_: 584 | scan_gts = json.load(f_) 585 | all_scan_ids = [] 586 | for key in scan_gts: 587 | path_scan_id = scan_gts[key][0] 588 | # print('path_scan_id', path_scan_id) 589 | if path_scan_id not in all_scan_ids: 590 | all_scan_ids.append(path_scan_id) 591 | ndtw_graph = ndtw_graphload(path_scan_id) 592 | ndtw_criterion[path_scan_id] = DTW(ndtw_graph) 593 | return ndtw_criterion 594 | 595 | def ndtw_graphload(scan): 596 | """Loads a networkx graph for a given scan. 597 | Args: 598 | connections_file: A string with the path to the .json file with the 599 | connectivity information. 600 | Returns: 601 | A networkx graph. 602 | """ 603 | connections_file = 'connectivity/{}_connectivity.json'.format(scan) 604 | with open(connections_file) as f: 605 | lines = json.load(f) 606 | nodes = np.array([x['image_id'] for x in lines]) 607 | matrix = np.array([x['unobstructed'] for x in lines]) 608 | mask = np.array([x['included'] for x in lines]) 609 | 610 | matrix = matrix[mask][:, mask] 611 | nodes = nodes[mask] 612 | 613 | pos2d = {x['image_id']: np.array(x['pose'])[[3, 7]] for x in lines} 614 | pos3d = {x['image_id']: np.array(x['pose'])[[3, 7, 11]] for x in lines} 615 | 616 | graph = nx.from_numpy_matrix(matrix) 617 | graph = nx.relabel.relabel_nodes(graph, dict(enumerate(nodes))) 618 | nx.set_node_attributes(graph, pos2d, 'pos2d') 619 | nx.set_node_attributes(graph, pos3d, 'pos3d') 620 | 621 | weight2d = {(u, v): norm(pos2d[u] - pos2d[v]) for u, v in graph.edges} 622 | weight3d = {(u, v): norm(pos3d[u] - pos3d[v]) for u, v in graph.edges} 623 | nx.set_edge_attributes(graph, weight2d, 'weight2d') 624 | nx.set_edge_attributes(graph, weight3d, 'weight3d') 625 | 626 | return graph 627 | 628 | class DTW(object): 629 | """Dynamic Time Warping (DTW) evaluation metrics. 630 | Python doctest: 631 | >>> graph = nx.grid_graph([3, 4]) 632 | >>> prediction = [(0, 0), (1, 0), (2, 0), (3, 0)] 633 | >>> reference = [(0, 0), (1, 0), (2, 1), (3, 2)] 634 | >>> dtw = DTW(graph) 635 | >>> assert np.isclose(dtw(prediction, reference, 'dtw'), 3.0) 636 | >>> assert np.isclose(dtw(prediction, reference, 'ndtw'), 0.77880078307140488) 637 | >>> assert np.isclose(dtw(prediction, reference, 'sdtw'), 0.77880078307140488) 638 | >>> assert np.isclose(dtw(prediction[:2], reference, 'sdtw'), 0.0) 639 | """ 640 | 641 | def __init__(self, graph, weight='weight', threshold=3.0): 642 | """Initializes a DTW object. 643 | Args: 644 | graph: networkx graph for the environment. 645 | weight: networkx edge weight key (str). 646 | threshold: distance threshold $d_{th}$ (float). 647 | """ 648 | self.graph = graph 649 | self.weight = weight 650 | self.threshold = threshold 651 | self.distance = dict( 652 | nx.all_pairs_dijkstra_path_length(self.graph, weight=self.weight)) 653 | 654 | def __call__(self, prediction, reference, metric='sdtw'): 655 | """Computes DTW metrics. 656 | Args: 657 | prediction: list of nodes (str), path predicted by agent. 658 | reference: list of nodes (str), the ground truth path. 659 | metric: one of ['ndtw', 'sdtw', 'dtw']. 660 | Returns: 661 | the DTW between the prediction and reference path (float). 662 | """ 663 | assert metric in ['ndtw', 'sdtw', 'dtw'] 664 | 665 | dtw_matrix = np.inf * np.ones((len(prediction) + 1, len(reference) + 1)) 666 | dtw_matrix[0][0] = 0 667 | for i in range(1, len(prediction)+1): 668 | for j in range(1, len(reference)+1): 669 | best_previous_cost = min( 670 | dtw_matrix[i-1][j], dtw_matrix[i][j-1], dtw_matrix[i-1][j-1]) 671 | cost = self.distance[prediction[i-1]][reference[j-1]] 672 | dtw_matrix[i][j] = cost + best_previous_cost 673 | dtw = dtw_matrix[len(prediction)][len(reference)] 674 | 675 | if metric == 'dtw': 676 | return dtw 677 | 678 | ndtw = np.exp(-dtw/(self.threshold * len(reference))) 679 | if metric == 'ndtw': 680 | return ndtw 681 | 682 | success = self.distance[prediction[-1]][reference[-1]] <= self.threshold 683 | return success * ndtw 684 | -------------------------------------------------------------------------------- /r2r_src/agent.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import os 4 | import sys 5 | import numpy as np 6 | import random 7 | import math 8 | import time 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | from torch import optim 14 | import torch.nn.functional as F 15 | 16 | from env import R2RBatch 17 | import utils 18 | from utils import padding_idx, print_progress 19 | import model_OSCAR, model_PREVALENT 20 | import param 21 | from param import args 22 | from collections import defaultdict 23 | 24 | 25 | class BaseAgent(object): 26 | ''' Base class for an R2R agent to generate and save trajectories. ''' 27 | 28 | def __init__(self, env, results_path): 29 | self.env = env 30 | self.results_path = results_path 31 | random.seed(1) 32 | self.results = {} 33 | self.losses = [] # For learning agents 34 | 35 | def write_results(self): 36 | output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()] 37 | with open(self.results_path, 'w') as f: 38 | json.dump(output, f) 39 | 40 | def get_results(self): 41 | output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()] 42 | return output 43 | 44 | def rollout(self, **args): 45 | ''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] ''' 46 | raise NotImplementedError 47 | 48 | @staticmethod 49 | def get_agent(name): 50 | return globals()[name+"Agent"] 51 | 52 | def test(self, iters=None, **kwargs): 53 | self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch 54 | self.losses = [] 55 | self.results = {} 56 | # We rely on env showing the entire batch before repeating anything 57 | looped = False 58 | self.loss = 0 59 | if iters is not None: 60 | # For each time, it will run the first 'iters' iterations. (It was shuffled before) 61 | for i in range(iters): 62 | for traj in self.rollout(**kwargs): 63 | self.loss = 0 64 | self.results[traj['instr_id']] = traj['path'] 65 | else: # Do a full round 66 | while True: 67 | for traj in self.rollout(**kwargs): 68 | if traj['instr_id'] in self.results: 69 | looped = True 70 | else: 71 | self.loss = 0 72 | self.results[traj['instr_id']] = traj['path'] 73 | if looped: 74 | break 75 | 76 | 77 | class Seq2SeqAgent(BaseAgent): 78 | ''' An agent based on an LSTM seq2seq model with attention. ''' 79 | 80 | # For now, the agent can't pick which forward move to make - just the one in the middle 81 | env_actions = { 82 | 'left': (0,-1, 0), # left 83 | 'right': (0, 1, 0), # right 84 | 'up': (0, 0, 1), # up 85 | 'down': (0, 0,-1), # down 86 | 'forward': (1, 0, 0), # forward 87 | '': (0, 0, 0), # 88 | '': (0, 0, 0), # 89 | '': (0, 0, 0) # 90 | } 91 | 92 | def __init__(self, env, results_path, tok, episode_len=20, feat_dict=None): 93 | super(Seq2SeqAgent, self).__init__(env, results_path) 94 | self.tok = tok 95 | self.episode_len = episode_len 96 | self.feature_size = self.env.feature_size 97 | self.feat_dict = feat_dict 98 | 99 | # Models 100 | if args.vlnbert == 'oscar': 101 | self.vln_bert = model_OSCAR.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda() 102 | self.critic = model_OSCAR.Critic().cuda() 103 | elif args.vlnbert == 'prevalent': 104 | if args.features == 'clip': 105 | self.vln_bert = model_PREVALENT.VLNBERT(feature_size=2048 + args.angle_feat_size).cuda() 106 | else: 107 | self.vln_bert = model_PREVALENT.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda() 108 | self.critic = model_PREVALENT.Critic().cuda() 109 | self.models = (self.vln_bert, self.critic) 110 | 111 | # Optimizers 112 | self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr) 113 | self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr) 114 | self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer) 115 | 116 | # Evaluations 117 | self.losses = [] 118 | self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False) 119 | self.ndtw_criterion = utils.ndtw_initialize() 120 | 121 | if args.ADAPT: 122 | self.align_loss = nn.CrossEntropyLoss(ignore_index=-1000, size_average=True) 123 | self.consistency_loss = nn.MSELoss(size_average=True) 124 | 125 | txt_subprompt_feat_file = "data/R2R_txt_subprompt_feature.json" 126 | with open(txt_subprompt_feat_file) as f: 127 | self.txt_subprompt_feat = json.load(f) 128 | for k, _ in self.txt_subprompt_feat.items(): 129 | self.txt_subprompt_feat[k] = np.array(self.txt_subprompt_feat[k]) 130 | 131 | # Logs 132 | sys.stdout.flush() 133 | self.logs = defaultdict(list) 134 | 135 | def _sort_batch(self, obs): 136 | seq_tensor = np.array([ob['instr_encoding'] for ob in obs]) 137 | seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1) 138 | seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] 139 | 140 | seq_tensor = torch.from_numpy(seq_tensor) 141 | seq_lengths = torch.from_numpy(seq_lengths) 142 | 143 | # Sort sequences by lengths 144 | seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending 145 | sorted_tensor = seq_tensor[perm_idx] 146 | mask = (sorted_tensor != padding_idx) 147 | 148 | token_type_ids = torch.zeros_like(mask) 149 | 150 | return Variable(sorted_tensor, requires_grad=False).long().cuda(), \ 151 | mask.long().cuda(), token_type_ids.long().cuda(), \ 152 | list(seq_lengths), list(perm_idx) 153 | 154 | def _feature_variable(self, obs): 155 | ''' Extract precomputed features into variable. ''' 156 | features = np.empty((len(obs), args.views, self.feature_size + args.angle_feat_size), dtype=np.float32) 157 | for i, ob in enumerate(obs): 158 | features[i, :, :] = ob['feature'] # Image feat 159 | return Variable(torch.from_numpy(features), requires_grad=False).cuda() 160 | 161 | def _candidate_variable(self, obs): 162 | candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end 163 | candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) 164 | # Note: The candidate_feat at len(ob['candidate']) is the feature for the END 165 | # which is zero in my implementation 166 | for i, ob in enumerate(obs): 167 | for j, cc in enumerate(ob['candidate']): 168 | candidate_feat[i, j, :] = cc['feature'] 169 | 170 | return torch.from_numpy(candidate_feat).cuda(), candidate_leng 171 | 172 | def get_input_feat(self, obs): 173 | input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32) 174 | for i, ob in enumerate(obs): 175 | input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation']) 176 | input_a_t = torch.from_numpy(input_a_t).cuda() 177 | # f_t = self._feature_variable(obs) # Pano image features from obs 178 | candidate_feat, candidate_leng = self._candidate_variable(obs) 179 | 180 | return input_a_t, candidate_feat, candidate_leng 181 | 182 | def _teacher_action(self, obs, ended): 183 | """ 184 | Extract teacher actions into variable. 185 | :param obs: The observation. 186 | :param ended: Whether the action seq is ended 187 | :return: 188 | """ 189 | a = np.zeros(len(obs), dtype=np.int64) 190 | for i, ob in enumerate(obs): 191 | if ended[i]: # Just ignore this index 192 | a[i] = args.ignoreid 193 | else: 194 | for k, candidate in enumerate(ob['candidate']): 195 | if candidate['viewpointId'] == ob['teacher']: # Next view point 196 | a[i] = k 197 | break 198 | else: # Stop here 199 | assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" 200 | a[i] = len(ob['candidate']) 201 | return torch.from_numpy(a).cuda() 202 | 203 | def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None): 204 | """ 205 | Interface between Panoramic view and Egocentric view 206 | It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator 207 | """ 208 | def take_action(i, idx, name): 209 | if type(name) is int: # Go to the next view 210 | self.env.env.sims[idx].makeAction(name, 0, 0) 211 | else: # Adjust 212 | self.env.env.sims[idx].makeAction(*self.env_actions[name]) 213 | 214 | if perm_idx is None: 215 | perm_idx = range(len(perm_obs)) 216 | 217 | for i, idx in enumerate(perm_idx): 218 | action = a_t[i] 219 | if action != -1: # -1 is the action 220 | select_candidate = perm_obs[i]['candidate'][action] 221 | src_point = perm_obs[i]['viewIndex'] 222 | trg_point = select_candidate['pointId'] 223 | src_level = (src_point ) // 12 # The point idx started from 0 224 | trg_level = (trg_point ) // 12 225 | while src_level < trg_level: # Tune up 226 | take_action(i, idx, 'up') 227 | src_level += 1 228 | while src_level > trg_level: # Tune down 229 | take_action(i, idx, 'down') 230 | src_level -= 1 231 | while self.env.env.sims[idx].getState().viewIndex != trg_point: # Turn right until the target 232 | take_action(i, idx, 'right') 233 | assert select_candidate['viewpointId'] == \ 234 | self.env.env.sims[idx].getState().navigableLocations[select_candidate['idx']].viewpointId 235 | take_action(i, idx, select_candidate['idx']) 236 | 237 | state = self.env.env.sims[idx].getState() 238 | if traj is not None: 239 | traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation)) 240 | 241 | def rollout(self, train_ml=None, train_rl=True, reset=True): 242 | """ 243 | :param train_ml: The weight to train with maximum likelihood 244 | :param train_rl: whether use RL in training 245 | :param reset: Reset the environment 246 | 247 | :return: 248 | """ 249 | if self.feedback == 'teacher' or self.feedback == 'argmax': 250 | train_rl = False 251 | 252 | if reset: # Reset env 253 | obs = np.array(self.env.reset()) 254 | else: 255 | obs = np.array(self.env._get_obs()) 256 | 257 | batch_size = len(obs) 258 | 259 | # Language input 260 | sentence, language_attention_mask, token_type_ids, \ 261 | seq_lengths, perm_idx = self._sort_batch(obs) 262 | perm_obs = obs[perm_idx] 263 | 264 | if args.ADAPT: 265 | img_sub_prompt_set = torch.zeros((batch_size, args.prompt_set_size, 512)).cuda() 266 | txt_sub_prompt_set = torch.zeros((batch_size, args.prompt_set_size, 512)).cuda() 267 | prompt_mask_set = torch.zeros((batch_size, args.prompt_set_size)) 268 | 269 | for bs_ind in range(batch_size): 270 | img_sub_prompt = perm_obs[bs_ind]['img_sub_prompt'] 271 | txt_sub_prompt = perm_obs[bs_ind]['txt_sub_prompt'] 272 | for sub_prompt_ind in range(len(img_sub_prompt)): 273 | scan_viewpoint_view = img_sub_prompt[sub_prompt_ind].split('_') 274 | scan_viewpoint = scan_viewpoint_view[0] + '_' + scan_viewpoint_view[1] 275 | 276 | img_sub_prompt_set[bs_ind, sub_prompt_ind, :] = torch.from_numpy( 277 | self.feat_dict[scan_viewpoint][int(scan_viewpoint_view[2])]).cuda().float() 278 | txt_sub_prompt_set[bs_ind, sub_prompt_ind, :] = torch.from_numpy(self.txt_subprompt_feat[ 279 | txt_sub_prompt[ 280 | sub_prompt_ind].split( 281 | '_')[-2]][int(txt_sub_prompt[sub_prompt_ind].split('_')[-1])]).cuda().float() 282 | prompt_num = perm_obs[bs_ind]['prompt_num'] 283 | prompt_mask_set[bs_ind, :prompt_num] = 1 284 | 285 | ''' Language BERT ''' 286 | language_inputs = {'mode': 'language', 287 | 'sentence': sentence, 288 | 'attention_mask': language_attention_mask, 289 | 'lang_mask': language_attention_mask, 290 | 'token_type_ids': token_type_ids} 291 | if args.vlnbert == 'oscar': 292 | language_features = self.vln_bert(**language_inputs) 293 | elif args.vlnbert == 'prevalent': 294 | h_t, language_features = self.vln_bert(**language_inputs) 295 | 296 | # Record starting point 297 | traj = [{ 298 | 'instr_id': ob['instr_id'], 299 | 'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])], 300 | } for ob in perm_obs] 301 | 302 | # Init the reward shaping 303 | last_dist = np.zeros(batch_size, np.float32) 304 | last_ndtw = np.zeros(batch_size, np.float32) 305 | for i, ob in enumerate(perm_obs): # The init distance from the view point to the target 306 | last_dist[i] = ob['distance'] 307 | path_act = [vp[0] for vp in traj[i]['path']] 308 | last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw') 309 | 310 | # Initialization the tracking state 311 | ended = np.array([False] * batch_size) # Indices match permuation of the model, not env 312 | 313 | # Init the logs 314 | rewards = [] 315 | hidden_states = [] 316 | policy_log_probs = [] 317 | masks = [] 318 | entropys = [] 319 | ml_loss = 0. 320 | 321 | if args.ADAPT: 322 | align_loss = 0 323 | consistency_loss = 0 324 | 325 | for t in range(self.episode_len): 326 | 327 | input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) 328 | 329 | # the first [CLS] token, initialized by the language BERT, serves 330 | # as the agent's state passing through time steps 331 | if (t >= 1) or (args.vlnbert=='prevalent'): 332 | language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1) 333 | 334 | visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long() 335 | visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1) 336 | 337 | self.vln_bert.vln_bert.config.directions = max(candidate_leng) 338 | 339 | ''' Visual BERT ''' 340 | visual_inputs = {'mode': 'visual', 341 | 'sentence': language_features, 342 | 'attention_mask': visual_attention_mask, 343 | 'lang_mask': language_attention_mask, 344 | 'vis_mask': visual_temp_mask, 345 | 'token_type_ids': token_type_ids, 346 | 'action_feats': input_a_t, 347 | # 'pano_feats': f_t, 348 | 'cand_feats': candidate_feat} 349 | 350 | if args.ADAPT: 351 | visual_inputs = {'mode': 'visual', 352 | 'sentence': language_features, 353 | 'attention_mask': visual_attention_mask, 354 | 'lang_mask': language_attention_mask, 355 | 'vis_mask': visual_temp_mask, 356 | 'token_type_ids': token_type_ids, 357 | 'action_feats': input_a_t, 358 | # 'pano_feats': f_t, 359 | 'cand_feats': candidate_feat, 360 | 'txt_sub_prompt_set': txt_sub_prompt_set, 361 | 'img_sub_prompt_set': img_sub_prompt_set, 362 | 'prompt_mask_set': prompt_mask_set 363 | } 364 | 365 | h_t, logit, attended_lan, attended_vis, attended_txt_pro, attended_img_pro, txt_pro, img_pro = self.vln_bert( 366 | **visual_inputs) 367 | # alignment loss 368 | if self.feedback == 'teacher': 369 | for loss_ind in range(args.prompt_set_size): 370 | sim_labels = np.arange(args.batchSize) 371 | sim_labels = torch.from_numpy(sim_labels).long().cuda() 372 | for bs_ind in range(batch_size): 373 | if loss_ind + 1 > perm_obs[bs_ind]['prompt_num']: 374 | sim_labels[bs_ind] = -1000 375 | sim_matrix = torch.mm(txt_pro[:, loss_ind, :], img_pro[:, loss_ind, :].transpose(0, 1)) 376 | sim_matrix /= args.temperature 377 | loss_I = self.align_loss(sim_matrix, sim_labels) 378 | loss_V = self.align_loss(sim_matrix.transpose(0, 1), sim_labels) 379 | align_loss += (loss_I + loss_V) / 2 380 | 381 | # consistency loss 382 | consistency_loss += self.consistency_loss(attended_lan, attended_txt_pro) 383 | consistency_loss += self.consistency_loss(attended_vis, attended_img_pro) 384 | else: 385 | h_t, logit = self.vln_bert(**visual_inputs) 386 | hidden_states.append(h_t) 387 | 388 | # Mask outputs where agent can't move forward 389 | # Here the logit is [b, max_candidate] 390 | candidate_mask = utils.length2mask(candidate_leng) 391 | logit.masked_fill_(candidate_mask, -float('inf')) 392 | 393 | # Supervised training 394 | target = self._teacher_action(perm_obs, ended) 395 | ml_loss += self.criterion(logit, target) 396 | 397 | # Determine next model inputs 398 | if self.feedback == 'teacher': 399 | a_t = target # teacher forcing 400 | elif self.feedback == 'argmax': 401 | _, a_t = logit.max(1) # student forcing - argmax 402 | a_t = a_t.detach() 403 | log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here 404 | policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch 405 | elif self.feedback == 'sample': 406 | probs = F.softmax(logit, 1) # sampling an action from model 407 | c = torch.distributions.Categorical(probs) 408 | self.logs['entropy'].append(c.entropy().sum().item()) # For log 409 | entropys.append(c.entropy()) # For optimization 410 | a_t = c.sample().detach() 411 | policy_log_probs.append(c.log_prob(a_t)) 412 | else: 413 | print(self.feedback) 414 | sys.exit('Invalid feedback option') 415 | # Prepare environment action 416 | # NOTE: Env action is in the perm_obs space 417 | cpu_a_t = a_t.cpu().numpy() 418 | for i, next_id in enumerate(cpu_a_t): 419 | if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is 420 | cpu_a_t[i] = -1 # Change the and ignore action to -1 421 | 422 | # Make action and get the new state 423 | self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj) 424 | obs = np.array(self.env._get_obs()) 425 | perm_obs = obs[perm_idx] # Perm the obs for the resu 426 | 427 | if train_rl: 428 | # Calculate the mask and reward 429 | dist = np.zeros(batch_size, np.float32) 430 | ndtw_score = np.zeros(batch_size, np.float32) 431 | reward = np.zeros(batch_size, np.float32) 432 | mask = np.ones(batch_size, np.float32) 433 | for i, ob in enumerate(perm_obs): 434 | dist[i] = ob['distance'] 435 | path_act = [vp[0] for vp in traj[i]['path']] 436 | ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw') 437 | 438 | if ended[i]: 439 | reward[i] = 0.0 440 | mask[i] = 0.0 441 | else: 442 | action_idx = cpu_a_t[i] 443 | # Target reward 444 | if action_idx == -1: # If the action now is end 445 | if dist[i] < 3.0: # Correct 446 | reward[i] = 2.0 + ndtw_score[i] * 2.0 447 | else: # Incorrect 448 | reward[i] = -2.0 449 | else: # The action is not end 450 | # Path fidelity rewards (distance & nDTW) 451 | reward[i] = - (dist[i] - last_dist[i]) 452 | ndtw_reward = ndtw_score[i] - last_ndtw[i] 453 | if reward[i] > 0.0: # Quantification 454 | reward[i] = 1.0 + ndtw_reward 455 | elif reward[i] < 0.0: 456 | reward[i] = -1.0 + ndtw_reward 457 | else: 458 | raise NameError("The action doesn't change the move") 459 | # Miss the target penalty 460 | if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0): 461 | reward[i] -= (1.0 - last_dist[i]) * 2.0 462 | rewards.append(reward) 463 | masks.append(mask) 464 | last_dist[:] = dist 465 | last_ndtw[:] = ndtw_score 466 | 467 | # Update the finished actions 468 | # -1 means ended or ignored (already ended) 469 | ended[:] = np.logical_or(ended, (cpu_a_t == -1)) 470 | 471 | # Early exit if all ended 472 | if ended.all(): 473 | break 474 | 475 | if train_rl: 476 | # Last action in A2C 477 | input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) 478 | 479 | language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1) 480 | 481 | visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long() 482 | visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1) 483 | 484 | self.vln_bert.vln_bert.config.directions = max(candidate_leng) 485 | ''' Visual BERT ''' 486 | visual_inputs = {'mode': 'visual', 487 | 'sentence': language_features, 488 | 'attention_mask': visual_attention_mask, 489 | 'lang_mask': language_attention_mask, 490 | 'vis_mask': visual_temp_mask, 491 | 'token_type_ids': token_type_ids, 492 | 'action_feats': input_a_t, 493 | # 'pano_feats': f_t, 494 | 'cand_feats': candidate_feat} 495 | 496 | if args.ADAPT: 497 | visual_inputs = {'mode': 'visual', 498 | 'sentence': language_features, 499 | 'attention_mask': visual_attention_mask, 500 | 'lang_mask': language_attention_mask, 501 | 'vis_mask': visual_temp_mask, 502 | 'token_type_ids': token_type_ids, 503 | 'action_feats': input_a_t, 504 | # 'pano_feats': f_t, 505 | 'cand_feats': candidate_feat, 506 | 'txt_sub_prompt_set': txt_sub_prompt_set, 507 | 'img_sub_prompt_set': img_sub_prompt_set, 508 | 'prompt_mask_set': prompt_mask_set 509 | } 510 | last_h_, _, _, _, _, _, _, _ = self.vln_bert(**visual_inputs) 511 | else: 512 | last_h_, _ = self.vln_bert(**visual_inputs) 513 | 514 | rl_loss = 0. 515 | 516 | # NOW, A2C!!! 517 | # Calculate the final discounted reward 518 | last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety 519 | discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero 520 | for i in range(batch_size): 521 | if not ended[i]: # If the action is not ended, use the value function as the last reward 522 | discount_reward[i] = last_value__[i] 523 | 524 | length = len(rewards) 525 | total = 0 526 | for t in range(length-1, -1, -1): 527 | discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0 528 | mask_ = Variable(torch.from_numpy(masks[t]), requires_grad=False).cuda() 529 | clip_reward = discount_reward.copy() 530 | r_ = Variable(torch.from_numpy(clip_reward), requires_grad=False).cuda() 531 | v_ = self.critic(hidden_states[t]) 532 | a_ = (r_ - v_).detach() 533 | 534 | rl_loss += (-policy_log_probs[t] * a_ * mask_).sum() 535 | rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss 536 | if self.feedback == 'sample': 537 | rl_loss += (- 0.01 * entropys[t] * mask_).sum() 538 | self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item()) 539 | 540 | total = total + np.sum(masks[t]) 541 | self.logs['total'].append(total) 542 | 543 | # Normalize the loss function 544 | if args.normalize_loss == 'total': 545 | rl_loss /= total 546 | elif args.normalize_loss == 'batch': 547 | rl_loss /= batch_size 548 | else: 549 | assert args.normalize_loss == 'none' 550 | 551 | self.loss += rl_loss 552 | self.logs['RL_loss'].append(rl_loss.item()) 553 | 554 | if train_ml is not None: 555 | self.loss += ml_loss * train_ml / batch_size 556 | self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item()) 557 | 558 | if args.ADAPT and self.feedback != 'argmax': 559 | self.loss += consistency_loss * args.consistency_loss_weight 560 | self.logs['consistency_loss'].append((consistency_loss * args.consistency_loss_weight).item()) 561 | 562 | if self.feedback == 'teacher': 563 | self.loss += align_loss * args.align_loss_weight 564 | self.logs['alignment_loss'].append((align_loss * args.align_loss_weight).item()) 565 | 566 | 567 | if type(self.loss) is int: # For safety, it will be activated if no losses are added 568 | self.losses.append(0.) 569 | else: 570 | self.losses.append(self.loss.item() / self.episode_len) # This argument is useless. 571 | 572 | return traj 573 | 574 | def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): 575 | ''' Evaluate once on each instruction in the current environment ''' 576 | self.feedback = feedback 577 | if use_dropout: 578 | self.vln_bert.train() 579 | self.critic.train() 580 | else: 581 | self.vln_bert.eval() 582 | self.critic.eval() 583 | super(Seq2SeqAgent, self).test(iters) 584 | 585 | def zero_grad(self): 586 | self.loss = 0. 587 | self.losses = [] 588 | for model, optimizer in zip(self.models, self.optimizers): 589 | model.train() 590 | optimizer.zero_grad() 591 | 592 | def accumulate_gradient(self, feedback='teacher', **kwargs): 593 | if feedback == 'teacher': 594 | self.feedback = 'teacher' 595 | self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs) 596 | elif feedback == 'sample': 597 | self.feedback = 'teacher' 598 | self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs) 599 | self.feedback = 'sample' 600 | self.rollout(train_ml=None, train_rl=True, **kwargs) 601 | else: 602 | assert False 603 | 604 | def optim_step(self): 605 | self.loss.backward() 606 | 607 | torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.) 608 | 609 | self.vln_bert_optimizer.step() 610 | self.critic_optimizer.step() 611 | 612 | def train(self, n_iters, feedback='teacher', **kwargs): 613 | ''' Train for a given number of iterations ''' 614 | self.feedback = feedback 615 | 616 | self.vln_bert.train() 617 | self.critic.train() 618 | 619 | self.losses = [] 620 | for iter in range(1, n_iters + 1): 621 | 622 | self.vln_bert_optimizer.zero_grad() 623 | self.critic_optimizer.zero_grad() 624 | 625 | self.loss = 0 626 | 627 | if feedback == 'teacher': 628 | self.feedback = 'teacher' 629 | self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs) 630 | elif feedback == 'sample': # agents in IL and RL separately 631 | if args.ml_weight != 0: 632 | self.feedback = 'teacher' 633 | self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs) 634 | self.feedback = 'sample' 635 | self.rollout(train_ml=None, train_rl=True, **kwargs) 636 | else: 637 | assert False 638 | 639 | self.loss.backward() 640 | 641 | torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.) 642 | 643 | self.vln_bert_optimizer.step() 644 | self.critic_optimizer.step() 645 | 646 | if args.aug is None: 647 | print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50) 648 | 649 | def save(self, epoch, path): 650 | ''' Snapshot models ''' 651 | the_dir, _ = os.path.split(path) 652 | os.makedirs(the_dir, exist_ok=True) 653 | states = {} 654 | def create_state(name, model, optimizer): 655 | states[name] = { 656 | 'epoch': epoch + 1, 657 | 'state_dict': model.state_dict(), 658 | 'optimizer': optimizer.state_dict(), 659 | } 660 | all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer), 661 | ("critic", self.critic, self.critic_optimizer)] 662 | for param in all_tuple: 663 | create_state(*param) 664 | torch.save(states, path) 665 | 666 | def load(self, path): 667 | ''' Loads parameters (but not training state) ''' 668 | states = torch.load(path) 669 | 670 | def recover_state(name, model, optimizer): 671 | state = model.state_dict() 672 | model_keys = set(state.keys()) 673 | load_keys = set(states[name]['state_dict'].keys()) 674 | if model_keys != load_keys: 675 | print("NOTICE: DIFFERENT KEYS IN THE LISTEREN") 676 | state.update(states[name]['state_dict']) 677 | model.load_state_dict(state) 678 | if args.loadOptim: 679 | optimizer.load_state_dict(states[name]['optimizer']) 680 | all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer), 681 | ("critic", self.critic, self.critic_optimizer)] 682 | for param in all_tuple: 683 | recover_state(*param) 684 | return states['vln_bert']['epoch'] - 1 685 | --------------------------------------------------------------------------------