├── logs
└── .gitkeep
├── snap
└── .gitkeep
├── connectivity
└── .gitkeep
├── img_features
└── .gitkeep
├── .gitignore
├── run
├── train_agent.bash
└── test_agent.bash
├── LICENSE
├── r2r_src
├── vlnbert
│ ├── vlnbert_init.py
│ ├── vlnbert_OSCAR.py
│ └── vlnbert_PREVALENT.py
├── model_OSCAR.py
├── param.py
├── model_PREVALENT.py
├── eval.py
├── train.py
├── env.py
├── utils.py
└── agent.py
├── README.md
└── recurrent-vln-bert.yml
/logs/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/snap/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/connectivity/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/img_features/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ftpignore
2 | .ftpconfig
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | connectivity/*
10 | !connectivity/.gitkeep
11 |
12 | data/R2R_test.json
13 | data/R2R_train.json
14 | data/R2R_val_seen.json
15 | data/R2R_val_unseen.json
16 | data/prevalent/*
17 | !data/prevalent/.gitkeep
18 |
19 | img_features/*
20 | !img_features/.gitkeep
21 |
22 | snap/*
23 | !snap/.gitkeep
24 |
25 | logs/*
26 | !logs/.gitkeep
27 |
--------------------------------------------------------------------------------
/run/train_agent.bash:
--------------------------------------------------------------------------------
1 | flag="--ADAPT
2 | --vlnbert prevalent
3 |
4 | --aug data/prevalent/prevalent_aug.json
5 | --test_only 0
6 |
7 | --train auglistener
8 |
9 | --maxAction 15
10 | --batchSize 8
11 | --feedback sample
12 | --lr 1e-5
13 | --iters 300000
14 | --optim adamW
15 |
16 | --mlWeight 0.20
17 | --maxInput 80
18 | --angleFeatSize 128
19 | --featdropout 0.4
20 | --dropout 0.5"
21 |
22 | # w CLIP visual feature
23 | PREFIX=ADAPT_CLIP CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features clip
24 |
25 | # w/o CLIP visual feature
26 | #PREFIX=ADAPT CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features places365
--------------------------------------------------------------------------------
/run/test_agent.bash:
--------------------------------------------------------------------------------
1 | flag="--ADAPT
2 | --vlnbert prevalent
3 |
4 | --submit 0
5 | --test_only 0
6 |
7 | --train validlistener
8 |
9 | --maxAction 15
10 | --batchSize 8
11 | --feedback sample
12 | --lr 1e-5
13 | --iters 300000
14 | --optim adamW
15 |
16 | --mlWeight 0.20
17 | --maxInput 80
18 | --angleFeatSize 128
19 | --featdropout 0.4
20 | --dropout 0.5"
21 |
22 | # w CLIP visual feature
23 | PREFIX=ADAPT_CLIP_test CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features clip --load snap/ADAPT_CLIP/state_dict/best_val_unseen
24 | # w/o CLIP visual feature
25 | #PREFIX=ADAPT_test CUDA_VISIBLE_DEVICES=$1 python r2r_src/train.py $flag --features places365 --load snap/ADAPT/state_dict/best_val_unseen
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2021 Yicong Hong, Qi Wu, Yuankai Qi,
4 | Cristian Rodriguez-Opazo, Stephen Gould
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/r2r_src/vlnbert/vlnbert_init.py:
--------------------------------------------------------------------------------
1 |
2 | #from transformers.pytorch_transformers import (BertConfig, BertTokenizer)
3 | from pytorch_transformers import (BertConfig, BertTokenizer)
4 |
5 | def get_tokenizer(args):
6 | if args.vlnbert == 'oscar':
7 | tokenizer_class = BertTokenizer
8 | model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
9 | tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
10 | elif args.vlnbert == 'prevalent':
11 | tokenizer_class = BertTokenizer
12 | tokenizer = tokenizer_class.from_pretrained('bert-base-uncased')
13 | return tokenizer
14 |
15 | def get_vlnbert_models(args, config=None):
16 | config_class = BertConfig
17 |
18 | if args.vlnbert == 'oscar':
19 | from vlnbert.vlnbert_OSCAR import VLNBert
20 | model_class = VLNBert
21 | model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
22 | vis_config = config_class.from_pretrained(model_name_or_path, num_labels=2, finetuning_task='vln-r2r')
23 |
24 | vis_config.model_type = 'visual'
25 | vis_config.finetuning_task = 'vln-r2r'
26 | vis_config.hidden_dropout_prob = 0.3
27 | vis_config.hidden_size = 768
28 | vis_config.img_feature_dim = 2176
29 | vis_config.num_attention_heads = 12
30 | vis_config.num_hidden_layers = 12
31 | visual_model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=vis_config)
32 |
33 | elif args.vlnbert == 'prevalent':
34 | from vlnbert.vlnbert_PREVALENT import VLNBert
35 | model_class = VLNBert
36 | model_name_or_path = 'Prevalent/pretrained_model/pytorch_model.bin'
37 | vis_config = config_class.from_pretrained('bert-base-uncased')
38 | vis_config.img_feature_dim = 2176
39 | vis_config.img_feature_type = ""
40 | vis_config.vl_layers = 4
41 | vis_config.la_layers = 9
42 |
43 | visual_model = model_class.from_pretrained(model_name_or_path, config=vis_config)
44 |
45 | return visual_model
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ADAPT
2 |
3 | PyTorch implementation of the paper ["ADAPT: Vision-Language Navigation with Modality-Aligned Action Prompts"](https://arxiv.org/abs/2205.15509) (CVPR 2022).
4 |
5 | ## Prerequisites
6 |
7 | ### Installation
8 | The environment installation of ADAPT follows that in [Recurrent-VLN-BERT](https://github.com/YicongHong/Recurrent-VLN-BERT).
9 |
10 | 1.Install the [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator). Notice that this code uses the [old version (v0.1)](https://github.com/peteanderson80/Matterport3DSimulator/tree/v0.1) of the simulator.
11 |
12 | 2. the versions of packages in the environment can be found [here](https://github.com/YicongHong/Recurrent-VLN-BERT/blob/main/recurrent-vln-bert.yml).
13 |
14 | 3. Install the [Pytorch-Transformers](https://github.com/huggingface/transformers) of [this version](https://github.com/huggingface/transformers/tree/067923d3267325f525f4e46f357360c191ba562e).
15 |
16 | ### Data Preparation
17 | Please follow the instructions below to prepare the data in directories:
18 |
19 | * MP3D navigability graphs: ```connectivity```
20 | * Download the [connectivity maps](https://github.com/peteanderson80/Matterport3DSimulator/tree/master/connectivity).
21 | * MP3D image features: ```img_features```
22 | * Download the [Scene features](https://www.dropbox.com/s/85tpa6tc3enl5ud/ResNet-152-places365.zip?dl=1) (ResNet-152-Places365).
23 | * R2R data added action prompts: ```data```
24 | * Download the [R2R data](https://drive.google.com/file/d/1dvWONxBDfNiG420Ggttjkje5Qu0pFJ_d/view?usp=sharing).
25 | * Augmentation data added action prompts: ```data```
26 | * Download the [augmentation data](https://drive.google.com/file/d/1C9Ckhr6XASDveGvnRvZ3oIzqTeU43JAq/view?usp=sharing).
27 | * text sub-prompt feature: ```data```
28 | * Download the [text sub-prompt feature](https://drive.google.com/file/d/127XonQJ2hqriljfSm8J-RLhV_PYQYWJh/view?usp=sharing).
29 |
30 | ## R2R Navigation
31 |
32 | ### Two-phase Training
33 | At the first stage, run the following scripts until the performance is converged in Val Unseen:
34 | ```
35 | PREFIX=baseline python r2r_src/train.py --vlnbert prevalent --aug data/prevalent/prevalent_aug.json --batchSize 16 --lr 1e-5
36 | ```
37 | At the second stage, run the following scripts using the Best Val Unseen model at the first stage:
38 | ```
39 | PREFIX=ADAPT python r2r_src/train.py --vlnbert prevalent --aug data/prevalent/prevalent_aug.json --batchSize 16 --lr 1e-6 --ADAPT --load snap/baseline/state_dict/best_val_unseen --finetune
40 | ```
41 |
42 | ## Acknowledgements
43 | The implementation relies on resources from [Recurrent-VLN-BERT](https://github.com/YicongHong/Recurrent-VLN-BERT). We thank the original authors for their open-sourcing.
44 |
--------------------------------------------------------------------------------
/r2r_src/model_OSCAR.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from param import args
5 |
6 | from vlnbert.vlnbert_init import get_vlnbert_models
7 |
8 |
9 | class VLNBERT(nn.Module):
10 | def __init__(self, feature_size=2048+128):
11 | super(VLNBERT, self).__init__()
12 | print('\nInitalizing the VLN-BERT model ...')
13 | self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT
14 | self.vln_bert.config.directions = 4 # a preset random number
15 |
16 | hidden_size = self.vln_bert.config.hidden_size
17 | layer_norm_eps = self.vln_bert.config.layer_norm_eps
18 |
19 | self.action_state_project = nn.Sequential(
20 | nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
21 | self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
22 |
23 | self.drop_env = nn.Dropout(p=args.featdropout)
24 | self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
25 | self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
26 |
27 | def forward(self, mode, sentence, token_type_ids=None,
28 | attention_mask=None, lang_mask=None, vis_mask=None,
29 | position_ids=None, action_feats=None, pano_feats=None, cand_feats=None):
30 |
31 | if mode == 'language':
32 | encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids,
33 | token_type_ids=token_type_ids, attention_mask=attention_mask)
34 |
35 | return encoded_sentence
36 |
37 | elif mode == 'visual':
38 |
39 | state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
40 | state_with_action = self.action_state_project(state_action_embed)
41 | state_with_action = self.action_LayerNorm(state_with_action)
42 | state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
43 |
44 | cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
45 |
46 | cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection
47 | cand_feats_embed = self.cand_LayerNorm(cand_feats_embed)
48 |
49 | # logit is the attention scores over the candidate features
50 | h_t, logit = self.vln_bert(mode, state_feats,
51 | attention_mask=attention_mask, img_feats=cand_feats_embed)
52 |
53 | return h_t, logit
54 |
55 | else:
56 | ModuleNotFoundError
57 |
58 |
59 | class BertLayerNorm(nn.Module):
60 | def __init__(self, hidden_size, eps=1e-12):
61 | """Construct a layernorm module in the TF style (epsilon inside the square root).
62 | """
63 | super(BertLayerNorm, self).__init__()
64 | self.weight = nn.Parameter(torch.ones(hidden_size))
65 | self.bias = nn.Parameter(torch.zeros(hidden_size))
66 | self.variance_epsilon = eps
67 |
68 | def forward(self, x):
69 | u = x.mean(-1, keepdim=True)
70 | s = (x - u).pow(2).mean(-1, keepdim=True)
71 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
72 | return self.weight * x + self.bias
73 |
74 |
75 | class Critic(nn.Module):
76 | def __init__(self):
77 | super(Critic, self).__init__()
78 | self.state2value = nn.Sequential(
79 | nn.Linear(768, 512),
80 | nn.ReLU(),
81 | nn.Dropout(args.dropout),
82 | nn.Linear(512, 1),
83 | )
84 |
85 | def forward(self, state):
86 | return self.state2value(state).squeeze()
87 |
--------------------------------------------------------------------------------
/r2r_src/param.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 |
5 | class Param:
6 | def __init__(self):
7 | self.parser = argparse.ArgumentParser(description="")
8 |
9 | # General
10 | self.parser.add_argument('--test_only', type=int, default=0, help='fast mode for testing')
11 |
12 | self.parser.add_argument('--iters', type=int, default=300000, help='training iterations')
13 | self.parser.add_argument('--name', type=str, default='default', help='experiment id')
14 | self.parser.add_argument('--vlnbert', type=str, default='oscar', help='oscar or prevalent')
15 | self.parser.add_argument('--train', type=str, default='listener')
16 | self.parser.add_argument('--description', type=str, default='no description\n')
17 |
18 | # Data preparation
19 | self.parser.add_argument('--maxInput', type=int, default=80, help="max input instruction")
20 | self.parser.add_argument('--maxAction', type=int, default=15, help='Max Action sequence')
21 | self.parser.add_argument('--batchSize', type=int, default=8)
22 | self.parser.add_argument('--ignoreid', type=int, default=-100)
23 | self.parser.add_argument('--feature_size', type=int, default=2048)
24 | self.parser.add_argument("--loadOptim",action="store_const", default=False, const=True)
25 |
26 | # Load the model from
27 | self.parser.add_argument("--load", default=None, help='path of the trained model')
28 |
29 | # Augmented Paths from
30 | self.parser.add_argument("--aug", default=None)
31 |
32 | # Listener Model Config
33 | self.parser.add_argument("--zeroInit", dest='zero_init', action='store_const', default=False, const=True)
34 | self.parser.add_argument("--mlWeight", dest='ml_weight', type=float, default=0.20)
35 | self.parser.add_argument("--teacherWeight", dest='teacher_weight', type=float, default=1.)
36 | self.parser.add_argument("--features", type=str, default='places365')
37 |
38 | # Dropout Param
39 | self.parser.add_argument('--dropout', type=float, default=0.5)
40 | self.parser.add_argument('--featdropout', type=float, default=0.3)
41 |
42 | # Submision configuration
43 | self.parser.add_argument("--submit", type=int, default=0)
44 |
45 | # Training Configurations
46 | self.parser.add_argument('--optim', type=str, default='rms') # rms, adam
47 | self.parser.add_argument('--lr', type=float, default=0.00001, help="the learning rate")
48 | self.parser.add_argument('--decay', dest='weight_decay', type=float, default=0.)
49 | self.parser.add_argument('--feedback', type=str, default='sample',
50 | help='How to choose next position, one of ``teacher``, ``sample`` and ``argmax``')
51 | self.parser.add_argument('--teacher', type=str, default='final',
52 | help="How to get supervision. one of ``next`` and ``final`` ")
53 | self.parser.add_argument('--epsilon', type=float, default=0.1)
54 |
55 | # Model hyper params:
56 | self.parser.add_argument("--angleFeatSize", dest="angle_feat_size", type=int, default=4)
57 |
58 | # A2C
59 | self.parser.add_argument("--gamma", default=0.9, type=float)
60 | self.parser.add_argument("--normalize", dest="normalize_loss", default="total", type=str, help='batch or total')
61 |
62 | # ADAPT
63 | self.parser.add_argument("--ADAPT", action='store_const', default=False, const=True, help='use ADAPT model')
64 | self.parser.add_argument("--AlignLossWeight", dest='align_loss_weight', default=0.01, type=float, help="modality alignment loss weight")
65 | self.parser.add_argument("--ConsistencyLossWeight", dest='consistency_loss_weight', default=0.0001, type=float, help="sequential consistency loss weight")
66 | self.parser.add_argument('--prompt_set_size', type=int, default=60, help="prompt set size")
67 | self.parser.add_argument("--temperature", default=0.5, type=float, help="hyperparameter for contrastive loss")
68 | self.parser.add_argument("--finetune", action='store_const', default=False, const=True, help='if in finetune stage')
69 |
70 | self.args = self.parser.parse_args()
71 |
72 | if self.args.optim == 'rms':
73 | print("Optimizer: Using RMSProp")
74 | self.args.optimizer = torch.optim.RMSprop
75 | elif self.args.optim == 'adam':
76 | print("Optimizer: Using Adam")
77 | self.args.optimizer = torch.optim.Adam
78 | elif self.args.optim == 'adamW':
79 | print("Optimizer: Using AdamW")
80 | self.args.optimizer = torch.optim.AdamW
81 | elif self.args.optim == 'sgd':
82 | print("Optimizer: sgd")
83 | self.args.optimizer = torch.optim.SGD
84 | else:
85 | assert False
86 |
87 | param = Param()
88 | args = param.args
89 |
90 | args.description = args.name
91 | args.IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
92 | args.log_dir = 'snap/%s' % args.name
93 |
94 | if not os.path.exists(args.log_dir):
95 | os.makedirs(args.log_dir)
96 | DEBUG_FILE = open(os.path.join('snap', args.name, "debug.log"), 'w')
97 |
--------------------------------------------------------------------------------
/recurrent-vln-bert.yml:
--------------------------------------------------------------------------------
1 | name: recurrent-vln-bert
2 | channels:
3 | - bioconda
4 | - menpo
5 | - pytorch
6 | - conda-forge
7 | - anaconda
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=main
11 | - bcolz=1.2.1=py36h04863e7_0
12 | - blas=1.0=mkl
13 | - blosc=1.16.3=hd408876_0
14 | - bokeh=2.0.1=py36_0
15 | - bzip2=1.0.8=h7b6447c_0
16 | - ca-certificates=2020.7.22=0
17 | - certifi=2020.6.20=py36_0
18 | - cffi=1.13.1=py36h2e261b9_0
19 | - click=7.1.1=py_0
20 | - cloudpickle=1.3.0=py_0
21 | - cmake=3.14.0=h52cb24c_0
22 | - cudatoolkit=10.2.89=hfd86e86_1
23 | - cytoolz=0.10.1=py36h7b6447c_0
24 | - dask=2.14.0=py_0
25 | - dask-core=2.14.0=py_0
26 | - distributed=2.14.0=py36_0
27 | - docutils=0.15.2=py36_0
28 | - expat=2.2.6=he6710b0_0
29 | - ffmpeg=4.0=hcdf2ecd_0
30 | - freetype=2.9.1=h8a8886c_1
31 | - fsspec=0.7.1=py_0
32 | - hdf5=1.10.2=hba1933b_1
33 | - heapdict=1.0.1=py_0
34 | - idna=2.10=py_0
35 | - intel-openmp=2019.4=243
36 | - jasper=2.0.14=h07fcdf6_1
37 | - jinja2=2.11.1=py_0
38 | - jsoncpp=1.8.4=hfd86e86_0
39 | - krb5=1.17.1=h173b8e3_0
40 | - libcurl=7.69.1=h20c2e04_0
41 | - libedit=3.1.20181209=hc058e9b_0
42 | - libgcc-ng=9.1.0=hdf63c60_0
43 | - libgfortran-ng=7.3.0=hdf63c60_0
44 | - libopencv=3.4.2=hb342d67_1
45 | - libopus=1.3=h7b6447c_0
46 | - libpng=1.6.37=hbc83047_0
47 | - libssh2=1.9.0=h1ba5d50_1
48 | - libstdcxx-ng=9.1.0=hdf63c60_0
49 | - libtiff=4.0.10=h2733197_2
50 | - libvpx=1.7.0=h439df22_0
51 | - locket=0.2.0=py36_1
52 | - lz4-c=1.8.1.2=h14c3975_0
53 | - lzo=2.10=h1bfc0ba_1
54 | - markupsafe=1.1.1=py36h7b6447c_0
55 | - mkl=2019.4=243
56 | - mkl-service=2.3.0=py36he904b0f_0
57 | - mkl_fft=1.0.14=py36ha843d7b_0
58 | - mkl_random=1.1.0=py36hd6b4f25_0
59 | - msgpack-python=1.0.0=py36hfd86e86_1
60 | - ncurses=6.1=he6710b0_1
61 | - ninja=1.9.0=py36hfd86e86_0
62 | - numexpr=2.7.1=py36h423224d_0
63 | - olefile=0.46=py36_0
64 | - opencv=3.4.2=py36h6fd60c2_1
65 | - openjdk=8.0.152=h7b6447c_3
66 | - openssl=1.1.1g=h7b6447c_0
67 | - packaging=20.3=py_0
68 | - pandas=1.1.1=py36he6710b0_0
69 | - partd=1.1.0=py_0
70 | - pillow=6.2.1=py36h34e0f95_0
71 | - pip=19.3.1=py36_0
72 | - psutil=5.7.0=py36h7b6447c_0
73 | - py-opencv=3.4.2=py36hb342d67_1
74 | - pybind11=2.4.2=py36hfd86e86_0
75 | - pycparser=2.19=py36_0
76 | - pyopenssl=19.1.0=py_1
77 | - pyparsing=2.4.7=py_0
78 | - pytables=3.4.4=py36ha205bf6_0
79 | - python=3.6.9=h265db76_0
80 | - python-dateutil=2.8.1=py_0
81 | - pytz=2020.1=py_0
82 | - pyyaml=5.3.1=py36h7b6447c_0
83 | - readline=7.0=h7b6447c_5
84 | - requests=2.24.0=py_0
85 | - rhash=1.3.8=h1ba5d50_0
86 | - setuptools=41.6.0=py36_0
87 | - six=1.15.0=py_0
88 | - sortedcontainers=2.1.0=py36_0
89 | - sqlite=3.30.1=h7b6447c_0
90 | - tblib=1.6.0=py_0
91 | - tk=8.6.8=hbc83047_0
92 | - toolz=0.10.0=py_0
93 | - tornado=6.0.4=py36h7b6447c_1
94 | - typing_extensions=3.7.4.1=py36_0
95 | - urllib3=1.25.10=py_0
96 | - wheel=0.33.6=py36_0
97 | - xz=5.2.4=h14c3975_4
98 | - yaml=0.1.7=h96e3832_1
99 | - zict=2.0.0=py_0
100 | - zlib=1.2.11=h7b6447c_3
101 | - zstd=1.3.7=h0b5b093_0
102 | - java-jdk=7.0.91=1
103 | - tqdm=4.7.2=py36_0
104 | - boto3=1.13.14=pyh9f0ad1d_0
105 | - botocore=1.16.14=pyh9f0ad1d_0
106 | - brotlipy=0.7.0=py36h8c4c3a4_1000
107 | - cairo=1.14.12=h80bd089_1005
108 | - chardet=3.0.4=py36h9f0ad1d_1006
109 | - cryptography=2.9.2=py36h45558ae_0
110 | - fontconfig=2.13.1=h2176d3f_1000
111 | - freeglut=3.0.0=hf484d3e_1005
112 | - gettext=0.19.8.1=h9745a5d_1001
113 | - glew=2.1.0=he1b5a44_0
114 | - glib=2.56.2=had28632_1001
115 | - graphite2=1.3.13=hf484d3e_1000
116 | - harfbuzz=1.9.0=he243708_1001
117 | - htop=2.2.0=hf8c457e_1000
118 | - icu=58.2=hf484d3e_1000
119 | - jmespath=0.10.0=pyh9f0ad1d_0
120 | - libblas=3.8.0=14_mkl
121 | - libcblas=3.8.0=14_mkl
122 | - libglu=9.0.0=hf484d3e_1000
123 | - libiconv=1.15=h14c3975_1004
124 | - liblapack=3.8.0=14_mkl
125 | - libuuid=2.32.1=h14c3975_1000
126 | - libxcb=1.13=h14c3975_1002
127 | - libxml2=2.9.8=h143f9aa_1005
128 | - numpy=1.18.4=py36h7314795_0
129 | - pcre=8.41=hf484d3e_1003
130 | - pixman=0.34.0=h14c3975_1003
131 | - pthread-stubs=0.4=h14c3975_1001
132 | - pysocks=1.7.1=py36h9f0ad1d_1
133 | - python_abi=3.6=1_cp36m
134 | - pytorch-pretrained-bert=0.6.2=py36_0
135 | - regex=2020.5.14=py36h8c4c3a4_0
136 | - s3transfer=0.3.3=py36h9f0ad1d_1
137 | - xorg-fixesproto=5.0=h14c3975_1002
138 | - xorg-inputproto=2.3.2=h14c3975_1002
139 | - xorg-kbproto=1.0.7=h14c3975_1002
140 | - xorg-libice=1.0.9=h14c3975_1004
141 | - xorg-libsm=1.2.3=h4937e3b_1000
142 | - xorg-libx11=1.6.9=h516909a_0
143 | - xorg-libxau=1.0.8=h14c3975_1006
144 | - xorg-libxdmcp=1.1.2=h14c3975_1007
145 | - xorg-libxext=1.3.3=h14c3975_1004
146 | - xorg-libxfixes=5.0.3=h14c3975_1004
147 | - xorg-libxi=1.7.9=h14c3975_1002
148 | - xorg-libxrender=0.9.10=h14c3975_1002
149 | - xorg-renderproto=0.11.1=h14c3975_1002
150 | - xorg-xextproto=7.3.0=h14c3975_1002
151 | - xorg-xproto=7.0.31=h14c3975_1007
152 | - jpeg=9b=h024ee3a_2
153 | - libffi=3.2.1=hd88cf55_4
154 | - snappy=1.1.7=hbae5bb6_3
155 | - osmesa=12.2.2.dev=0
156 | - pytorch=1.6.0=py3.6_cuda10.2.89_cudnn7.6.5_0
157 | - torchvision=0.7.0=py36_cu102
158 |
--------------------------------------------------------------------------------
/r2r_src/model_PREVALENT.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7 | from param import args
8 |
9 | from vlnbert.vlnbert_init import get_vlnbert_models
10 |
11 | class VLNBERT(nn.Module):
12 | def __init__(self, feature_size=2048+128):
13 | super(VLNBERT, self).__init__()
14 | print('\nInitalizing the VLN-BERT model ...')
15 |
16 | self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT
17 | self.vln_bert.config.directions = 4 # a preset random number
18 |
19 | hidden_size = self.vln_bert.config.hidden_size
20 | layer_norm_eps = self.vln_bert.config.layer_norm_eps
21 |
22 | self.action_state_project = nn.Sequential(
23 | nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
24 | self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
25 |
26 | self.drop_env = nn.Dropout(p=args.featdropout)
27 | self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
28 | self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
29 |
30 | self.vis_lang_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
31 | self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True)
32 | self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
33 |
34 | def forward(self, mode, sentence, token_type_ids=None,
35 | attention_mask=None, lang_mask=None, vis_mask=None,
36 | position_ids=None, action_feats=None, pano_feats=None, cand_feats=None, prompt_mask_set=None,
37 | txt_sub_prompt_set=None, img_sub_prompt_set=None):
38 |
39 | if mode == 'language':
40 | init_state, encoded_sentence = self.vln_bert(mode, sentence, attention_mask=attention_mask, lang_mask=lang_mask,)
41 |
42 | return init_state, encoded_sentence
43 |
44 | elif mode == 'visual':
45 |
46 | state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
47 | state_with_action = self.action_state_project(state_action_embed)
48 | state_with_action = self.action_LayerNorm(state_with_action)
49 | state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
50 |
51 | cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
52 | if txt_sub_prompt_set is not None and img_sub_prompt_set is not None:
53 | h_t, logit, attended_language, attended_visual, attended_txt_prompt, attended_img_prompt, txt_prompt, img_prompt = self.vln_bert(
54 | mode,
55 | state_feats,
56 | attention_mask=attention_mask,
57 | lang_mask=lang_mask,
58 | vis_mask=vis_mask,
59 | img_feats=cand_feats,
60 | txt_sub_prompt_set=txt_sub_prompt_set,
61 | img_sub_prompt_set=img_sub_prompt_set,
62 | prompt_mask_set=prompt_mask_set
63 | )
64 | else:
65 | # logit is the attention scores over the candidate features
66 | h_t, logit, attended_language, attended_visual = self.vln_bert(mode, state_feats,
67 | attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats)
68 |
69 | # update agent's state, unify history, language and vision by elementwise product
70 | vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual)
71 | state_output = torch.cat((h_t, vis_lang_feat), dim=-1)
72 | state_proj = self.state_proj(state_output)
73 | state_proj = self.state_LayerNorm(state_proj)
74 |
75 | if txt_sub_prompt_set is not None and img_sub_prompt_set is not None:
76 | return state_proj, logit, attended_language, attended_visual, attended_txt_prompt, attended_img_prompt, txt_prompt, img_prompt
77 | else:
78 | return state_proj, logit
79 |
80 | else:
81 | ModuleNotFoundError
82 |
83 |
84 | class BertLayerNorm(nn.Module):
85 | def __init__(self, hidden_size, eps=1e-12):
86 | """Construct a layernorm module in the TF style (epsilon inside the square root).
87 | """
88 | super(BertLayerNorm, self).__init__()
89 | self.weight = nn.Parameter(torch.ones(hidden_size))
90 | self.bias = nn.Parameter(torch.zeros(hidden_size))
91 | self.variance_epsilon = eps
92 |
93 | def forward(self, x):
94 | u = x.mean(-1, keepdim=True)
95 | s = (x - u).pow(2).mean(-1, keepdim=True)
96 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
97 | return self.weight * x + self.bias
98 |
99 |
100 | class Critic(nn.Module):
101 | def __init__(self):
102 | super(Critic, self).__init__()
103 | self.state2value = nn.Sequential(
104 | nn.Linear(768, 512),
105 | nn.ReLU(),
106 | nn.Dropout(args.dropout),
107 | nn.Linear(512, 1),
108 | )
109 |
110 | def forward(self, state):
111 | return self.state2value(state).squeeze()
112 |
--------------------------------------------------------------------------------
/r2r_src/eval.py:
--------------------------------------------------------------------------------
1 | ''' Evaluation of agent trajectories '''
2 |
3 | import json
4 | import os
5 | import sys
6 | from collections import defaultdict
7 | import networkx as nx
8 | import numpy as np
9 | import pprint
10 | pp = pprint.PrettyPrinter(indent=4)
11 |
12 | from env import R2RBatch
13 | from utils import load_datasets, load_nav_graphs, ndtw_graphload, DTW
14 | from agent import BaseAgent
15 |
16 |
17 | class Evaluation(object):
18 | ''' Results submission format: [{'instr_id': string, 'trajectory':[(viewpoint_id, heading_rads, elevation_rads),] } ] '''
19 |
20 | def __init__(self, splits, scans, tok):
21 | self.error_margin = 3.0
22 | self.splits = splits
23 | self.tok = tok
24 | self.gt = {}
25 | self.instr_ids = []
26 | self.scans = []
27 | for split in splits:
28 | for item in load_datasets([split]):
29 | if scans is not None and item['scan'] not in scans:
30 | continue
31 | self.gt[str(item['path_id'])] = item
32 | self.scans.append(item['scan'])
33 | self.instr_ids += ['%s_%d' % (item['path_id'], i) for i in range(len(item['instructions']))]
34 | self.scans = set(self.scans)
35 | self.instr_ids = set(self.instr_ids)
36 | self.graphs = load_nav_graphs(self.scans)
37 | self.distances = {}
38 | for scan,G in self.graphs.items(): # compute all shortest paths
39 | self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
40 |
41 | def _get_nearest(self, scan, goal_id, path):
42 | near_id = path[0][0]
43 | near_d = self.distances[scan][near_id][goal_id]
44 | for item in path:
45 | d = self.distances[scan][item[0]][goal_id]
46 | if d < near_d:
47 | near_id = item[0]
48 | near_d = d
49 | return near_id
50 |
51 | def _score_item(self, instr_id, path):
52 | ''' Calculate error based on the final position in trajectory, and also
53 | the closest position (oracle stopping rule).
54 | The path contains [view_id, angle, vofv] '''
55 | gt = self.gt[instr_id.split('_')[-2]]
56 | start = gt['path'][0]
57 | assert start == path[0][0], 'Result trajectories should include the start position'
58 | goal = gt['path'][-1]
59 | final_position = path[-1][0] # the first of [view_id, angle, vofv]
60 | nearest_position = self._get_nearest(gt['scan'], goal, path)
61 | self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
62 | self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
63 | self.scores['trajectory_steps'].append(len(path)-1)
64 | distance = 0 # length of the path in meters
65 | prev = path[0]
66 | for curr in path[1:]:
67 | distance += self.distances[gt['scan']][prev[0]][curr[0]]
68 | prev = curr
69 | self.scores['trajectory_lengths'].append(distance)
70 | self.scores['shortest_lengths'].append(
71 | self.distances[gt['scan']][start][goal]
72 | )
73 |
74 | def score(self, output_file):
75 | ''' Evaluate each agent trajectory based on how close it got to the goal location '''
76 | self.scores = defaultdict(list)
77 | instr_ids = set(self.instr_ids)
78 | if type(output_file) is str:
79 | with open(output_file) as f:
80 | results = json.load(f)
81 | else:
82 | results = output_file
83 |
84 | print('result length', len(results))
85 | for item in results:
86 | # Check against expected ids
87 | if item['instr_id'] in instr_ids:
88 | instr_ids.remove(item['instr_id'])
89 | self._score_item(item['instr_id'], item['trajectory'])
90 |
91 | if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
92 | assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
93 | % (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file)
94 | assert len(self.scores['nav_errors']) == len(self.instr_ids)
95 | score_summary = {
96 | 'nav_error': np.average(self.scores['nav_errors']),
97 | 'oracle_error': np.average(self.scores['oracle_errors']),
98 | 'steps': np.average(self.scores['trajectory_steps']),
99 | 'lengths': np.average(self.scores['trajectory_lengths'])
100 | }
101 | num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
102 | score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
103 | oracle_successes = len([i for i in self.scores['oracle_errors'] if i < self.error_margin])
104 | score_summary['oracle_rate'] = float(oracle_successes)/float(len(self.scores['oracle_errors']))
105 |
106 | spl = [float(error < self.error_margin) * l / max(l, p, 0.01)
107 | for error, p, l in
108 | zip(self.scores['nav_errors'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
109 | ]
110 | score_summary['spl'] = np.average(spl)
111 |
112 | return score_summary, self.scores
113 |
--------------------------------------------------------------------------------
/r2r_src/train.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | import os
5 | import time
6 | import json
7 | import random
8 | import numpy as np
9 | from collections import defaultdict
10 |
11 | from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress
12 | import utils
13 | from env import R2RBatch
14 | from agent import Seq2SeqAgent
15 | from eval import Evaluation
16 | from param import args
17 |
18 | import warnings
19 | warnings.filterwarnings("ignore")
20 | from tensorboardX import SummaryWriter
21 |
22 | from vlnbert.vlnbert_init import get_tokenizer
23 |
24 | # log_dir = 'snap/%s' % args.name
25 | # if not os.path.exists(log_dir):
26 | # os.makedirs(log_dir)
27 |
28 | IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
29 | PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'
30 | CLIP_FEATURES = 'img_features/CLIP-ViT-B-32-views.tsv'
31 |
32 | if args.features == 'imagenet':
33 | features = IMAGENET_FEATURES
34 | elif args.features == 'places365':
35 | features = PLACE365_FEATURES
36 | elif args.features == 'clip':
37 | features = CLIP_FEATURES
38 |
39 | prefix = os.environ.get('PREFIX')
40 | args.name = prefix
41 | log_dir = 'snap/%s' % prefix
42 | if not os.path.exists(log_dir):
43 | os.makedirs(log_dir)
44 |
45 | feedback_method = args.feedback # teacher or sample
46 |
47 | print(args); print('')
48 |
49 |
50 | ''' train the listener '''
51 | def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None, feat_dict=None):
52 | writer = SummaryWriter(log_dir=log_dir)
53 | listner = Seq2SeqAgent(train_env, "", tok, args.maxAction, feat_dict)
54 |
55 | record_file = open('./logs/' + args.name + '.txt', 'a')
56 | record_file.write(str(args) + '\n\n')
57 | record_file.close()
58 |
59 |
60 | if args.load is not None:
61 | if args.aug is None:
62 | start_iter = listner.load(os.path.join(args.load))
63 | print("\nLOAD the model from {}, iteration ".format(args.load, start_iter))
64 | else:
65 | load_iter = listner.load(os.path.join(args.load))
66 | print("\nLOAD the model from {}, iteration ".format(args.load, load_iter))
67 | start_iter = 0
68 | start = time.time()
69 | print('\nListener training starts, start iteration: %s' % str(start_iter))
70 |
71 | best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}}
72 |
73 | idx = 0
74 | while idx < start_iter + n_iters:
75 | if args.finetune:
76 | log_every = 100
77 | listner.logs = defaultdict(list)
78 | interval = min(log_every, n_iters-idx)
79 | iter = idx + interval
80 |
81 | # Train for log_every interval
82 | if aug_env is None:
83 | listner.env = train_env
84 | listner.train(interval, feedback=feedback_method) # Train interval iters
85 | else:
86 | jdx_length = len(range(interval // 2))
87 | for jdx in range(interval // 2):
88 | # Train with GT data
89 | listner.env = train_env
90 | args.ml_weight = 0.2
91 | listner.train(1, feedback=feedback_method)
92 |
93 | # Train with Augmented data
94 | listner.env = aug_env
95 | args.ml_weight = 0.2
96 | listner.train(1, feedback=feedback_method)
97 |
98 | print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)
99 |
100 | # Log the training stats to tensorboard
101 | total = max(sum(listner.logs['total']), 1)
102 | length = max(len(listner.logs['critic_loss']), 1)
103 | critic_loss = sum(listner.logs['critic_loss']) / total
104 | RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
105 | IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
106 | entropy = sum(listner.logs['entropy']) / total
107 | writer.add_scalar("loss/critic", critic_loss, idx)
108 | writer.add_scalar("policy_entropy", entropy, idx)
109 | writer.add_scalar("loss/RL_loss", RL_loss, idx)
110 | writer.add_scalar("loss/IL_loss", IL_loss, idx)
111 | writer.add_scalar("total_actions", total, idx)
112 | writer.add_scalar("max_length", length, idx)
113 | # print("total_actions", total, ", max_length", length)
114 |
115 | if args.ADAPT:
116 | consistency_loss = sum(listner.logs['consistency_loss']) / max(len(listner.logs['consistency_loss']), 1)
117 | writer.add_scalar("loss/consistency_loss", consistency_loss, idx)
118 | alignment_loss = sum(listner.logs['alignment_loss']) / max(len(listner.logs['alignment_loss']), 1)
119 | writer.add_scalar("loss/alignment_loss", alignment_loss, idx)
120 |
121 | # Run validation
122 | loss_str = "iter {}".format(iter)
123 | for env_name, (env, evaluator) in val_envs.items():
124 | listner.env = env
125 |
126 | # Get validation distance from goal under test evaluation conditions
127 | listner.test(use_dropout=False, feedback='argmax', iters=None)
128 | result = listner.get_results()
129 | score_summary, _ = evaluator.score(result)
130 | loss_str += ", %s " % env_name
131 | for metric, val in score_summary.items():
132 | if metric in ['spl']:
133 | writer.add_scalar("spl/%s" % env_name, val, idx)
134 | if env_name in best_val:
135 | if val > best_val[env_name]['spl']:
136 | best_val[env_name]['spl'] = val
137 | best_val[env_name]['update'] = True
138 | elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
139 | best_val[env_name]['spl'] = val
140 | best_val[env_name]['update'] = True
141 | loss_str += ', %s: %.4f' % (metric, val)
142 |
143 | record_file = open('./logs/' + args.name + '.txt', 'a')
144 | record_file.write(loss_str + '\n')
145 | record_file.close()
146 |
147 | for env_name in best_val:
148 | if best_val[env_name]['update']:
149 | best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
150 | best_val[env_name]['update'] = False
151 | listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name)))
152 | else:
153 | listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict"))
154 |
155 | print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters),
156 | iter, float(iter)/n_iters*100, loss_str)))
157 |
158 | if iter % 1000 == 0:
159 | print("BEST RESULT TILL NOW")
160 | for env_name in best_val:
161 | print(env_name, best_val[env_name]['state'])
162 |
163 | record_file = open('./logs/' + args.name + '.txt', 'a')
164 | record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n')
165 | record_file.close()
166 |
167 | idx += interval
168 |
169 | listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx)))
170 |
171 |
172 | def valid(train_env, tok, val_envs={},
173 | feat_dict=None):
174 | agent = Seq2SeqAgent(train_env, "", tok, args.maxAction, feat_dict)
175 |
176 | print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load))
177 |
178 | for env_name, (env, evaluator) in val_envs.items():
179 | agent.logs = defaultdict(list)
180 | agent.env = env
181 |
182 | iters = None
183 | agent.test(use_dropout=False, feedback='argmax', iters=iters)
184 | result = agent.get_results()
185 |
186 | if env_name != '':
187 | score_summary, _ = evaluator.score(result)
188 | loss_str = "Env name: %s" % env_name
189 | for metric,val in score_summary.items():
190 | loss_str += ', %s: %.4f' % (metric, val)
191 | print(loss_str)
192 |
193 | if args.submit:
194 | json.dump(
195 | result,
196 | open(os.path.join(log_dir, "submit_%s.json" % env_name), 'w'),
197 | sort_keys=True, indent=4, separators=(',', ': ')
198 | )
199 |
200 | def setup():
201 | torch.manual_seed(1)
202 | torch.cuda.manual_seed(1)
203 | random.seed(0)
204 | np.random.seed(0)
205 |
206 | def train_val(test_only=False):
207 | ''' Train on the training set, and validate on seen and unseen splits. '''
208 | setup()
209 | tok = get_tokenizer(args)
210 |
211 | feat_dict = read_img_features(features, test_only=test_only)
212 | if args.ADAPT and args.features == 'places365':
213 | feat_dict_clip = read_img_features(CLIP_FEATURES, test_only=test_only)
214 |
215 | if test_only:
216 | featurized_scans = None
217 | val_env_names = ['val_train_seen']
218 | else:
219 | featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
220 | val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
221 |
222 | if args.ADAPT:
223 | val_env_names = ['val_seen', 'val_unseen']
224 |
225 | train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
226 | from collections import OrderedDict
227 |
228 | if args.submit:
229 | val_env_names.append('test')
230 | else:
231 | pass
232 |
233 | val_envs = OrderedDict(
234 | ((split,
235 | (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok),
236 | Evaluation([split], featurized_scans, tok))
237 | )
238 | for split in val_env_names
239 | )
240 | )
241 |
242 | if args.train == 'listener':
243 | train(train_env, tok, args.iters, val_envs=val_envs)
244 | elif args.train == 'validlistener':
245 | if args.ADAPT and args.features == 'places365':
246 | valid(train_env, tok, val_envs=val_envs,
247 | feat_dict=feat_dict_clip)
248 | else:
249 | valid(train_env, tok, val_envs=val_envs,
250 | feat_dict=feat_dict)
251 | else:
252 | assert False
253 |
254 | def train_val_augment(test_only=False):
255 | """
256 | Train the listener with the augmented data
257 | """
258 | setup()
259 |
260 | # Create a batch training environment that will also preprocess text
261 | tok_bert = get_tokenizer(args)
262 |
263 | # Load the env img features
264 | feat_dict = read_img_features(features, test_only=test_only)
265 | if args.ADAPT and args.features == 'places365':
266 | feat_dict_clip = read_img_features(CLIP_FEATURES, test_only=test_only)
267 |
268 | if test_only:
269 | featurized_scans = None
270 | val_env_names = ['val_train_seen']
271 | else:
272 | featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
273 | val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
274 |
275 | if args.ADAPT:
276 | val_env_names = ['val_seen', 'val_unseen']
277 |
278 | # Load the augmentation data
279 | aug_path = args.aug
280 | # Create the training environment
281 | train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert)
282 | aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug')
283 |
284 | # Setup the validation data
285 | val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert),
286 | Evaluation([split], featurized_scans, tok_bert))
287 | for split in val_env_names}
288 | if args.ADAPT and args.features == 'places365':
289 | # Start training
290 | train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env,
291 | feat_dict=feat_dict_clip)
292 | else:
293 | # Start training
294 | train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env,
295 | feat_dict=feat_dict)
296 |
297 |
298 | if __name__ == "__main__":
299 | if args.train in ['listener', 'validlistener']:
300 | train_val(test_only=args.test_only)
301 | elif args.train == 'auglistener':
302 | train_val_augment(test_only=args.test_only)
303 | else:
304 | assert False
305 |
--------------------------------------------------------------------------------
/r2r_src/vlnbert/vlnbert_OSCAR.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import, division, print_function, unicode_literals
3 | import logging
4 | import math
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.nn import CrossEntropyLoss, MSELoss
9 |
10 | from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
11 | BertSelfAttention, BertAttention, BertEncoder, BertLayer,
12 | BertSelfOutput, BertIntermediate, BertOutput,
13 | BertPooler, BertLayerNorm, BertPreTrainedModel,
14 | BertPredictionHeadTransform)
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | class CaptionBertSelfAttention(BertSelfAttention):
19 | """
20 | Modified from BertSelfAttention to add support for output_hidden_states.
21 | """
22 | def __init__(self, config):
23 | super(CaptionBertSelfAttention, self).__init__(config)
24 | self.config = config
25 |
26 | def forward(self, mode, hidden_states, attention_mask, head_mask=None,
27 | history_state=None):
28 | if history_state is not None:
29 | x_states = torch.cat([history_state, hidden_states], dim=1)
30 | mixed_query_layer = self.query(hidden_states)
31 | mixed_key_layer = self.key(x_states)
32 | mixed_value_layer = self.value(x_states)
33 | else:
34 | mixed_query_layer = self.query(hidden_states)
35 | mixed_key_layer = self.key(hidden_states)
36 | mixed_value_layer = self.value(hidden_states)
37 |
38 | if mode == 'visual':
39 | mixed_query_layer = mixed_query_layer[:, [0]+list(range(-self.config.directions, 0)), :]
40 |
41 | ''' language feature only provide Keys and Values '''
42 | query_layer = self.transpose_for_scores(mixed_query_layer)
43 | key_layer = self.transpose_for_scores(mixed_key_layer)
44 | value_layer = self.transpose_for_scores(mixed_value_layer)
45 |
46 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
47 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
48 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
49 | attention_scores = attention_scores + attention_mask
50 |
51 | # Normalize the attention scores to probabilities.
52 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
53 |
54 | # This is actually dropping out entire tokens to attend to, which might
55 | # seem a bit unusual, but is taken from the original Transformer paper.
56 | attention_probs = self.dropout(attention_probs)
57 |
58 | # Mask heads if we want to
59 | if head_mask is not None:
60 | attention_probs = attention_probs * head_mask
61 |
62 | context_layer = torch.matmul(attention_probs, value_layer)
63 |
64 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
65 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
66 | context_layer = context_layer.view(*new_context_layer_shape)
67 |
68 | outputs = (context_layer, attention_scores)
69 |
70 | return outputs
71 |
72 |
73 | class CaptionBertAttention(BertAttention):
74 | """
75 | Modified from BertAttention to add support for output_hidden_states.
76 | """
77 | def __init__(self, config):
78 | super(CaptionBertAttention, self).__init__(config)
79 | self.self = CaptionBertSelfAttention(config)
80 | self.output = BertSelfOutput(config)
81 | self.config = config
82 |
83 | def forward(self, mode, input_tensor, attention_mask, head_mask=None,
84 | history_state=None):
85 | ''' transformer processing '''
86 | self_outputs = self.self(mode, input_tensor, attention_mask, head_mask, history_state)
87 |
88 | ''' feed-forward network with residule '''
89 | if mode == 'visual':
90 | attention_output = self.output(self_outputs[0], input_tensor[:, [0]+list(range(-self.config.directions, 0)), :])
91 | if mode == 'language':
92 | attention_output = self.output(self_outputs[0], input_tensor)
93 |
94 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
95 |
96 | return outputs
97 |
98 |
99 | class CaptionBertLayer(BertLayer):
100 | """
101 | Modified from BertLayer to add support for output_hidden_states.
102 | """
103 | def __init__(self, config):
104 | super(CaptionBertLayer, self).__init__(config)
105 | self.attention = CaptionBertAttention(config)
106 | self.intermediate = BertIntermediate(config)
107 | self.output = BertOutput(config)
108 |
109 | def forward(self, mode, hidden_states, attention_mask, head_mask=None,
110 | history_state=None):
111 |
112 | attention_outputs = self.attention(mode, hidden_states, attention_mask,
113 | head_mask, history_state)
114 |
115 | ''' feed-forward network with residule '''
116 | attention_output = attention_outputs[0]
117 | intermediate_output = self.intermediate(attention_output)
118 | layer_output = self.output(intermediate_output, attention_output)
119 | outputs = (layer_output,) + attention_outputs[1:]
120 |
121 | return outputs
122 |
123 |
124 | class CaptionBertEncoder(BertEncoder):
125 | """
126 | Modified from BertEncoder to add support for output_hidden_states.
127 | """
128 | def __init__(self, config):
129 | super(CaptionBertEncoder, self).__init__(config)
130 | self.output_attentions = config.output_attentions
131 | self.output_hidden_states = config.output_hidden_states
132 | # 12 Bert layers
133 | self.layer = nn.ModuleList([CaptionBertLayer(config) for _ in range(config.num_hidden_layers)])
134 | self.config = config
135 |
136 | def forward(self, mode, hidden_states, attention_mask, head_mask=None,
137 | encoder_history_states=None):
138 |
139 | if mode == 'visual':
140 | for i, layer_module in enumerate(self.layer):
141 | history_state = None if encoder_history_states is None else encoder_history_states[i]
142 |
143 | layer_outputs = layer_module(mode,
144 | hidden_states, attention_mask, head_mask[i],
145 | history_state)
146 |
147 | concat_layer_outputs = torch.cat((layer_outputs[0][:,0:1,:], hidden_states[:,1:-self.config.directions,:], layer_outputs[0][:,1:self.config.directions+1,:]), 1)
148 | hidden_states = concat_layer_outputs
149 |
150 | if i == self.config.num_hidden_layers - 1:
151 | state_attention_score = layer_outputs[1][:, :, 0, :]
152 | lang_attention_score = layer_outputs[1][:, :, -self.config.directions:, 1:-self.config.directions]
153 | vis_attention_score = layer_outputs[1][:, :, :, :]
154 |
155 | outputs = (hidden_states, state_attention_score, lang_attention_score, vis_attention_score)
156 |
157 | elif mode == 'language':
158 | for i, layer_module in enumerate(self.layer):
159 | history_state = None if encoder_history_states is None else encoder_history_states[i] # default None
160 |
161 | layer_outputs = layer_module(mode,
162 | hidden_states, attention_mask, head_mask[i],
163 | history_state)
164 | hidden_states = layer_outputs[0]
165 |
166 | if i == self.config.num_hidden_layers - 1:
167 | slang_attention_score = layer_outputs[1]
168 |
169 | outputs = (hidden_states, slang_attention_score)
170 |
171 | return outputs
172 |
173 |
174 | class BertImgModel(BertPreTrainedModel):
175 | """ Expand from BertModel to handle image region features as input
176 | """
177 | def __init__(self, config):
178 | super(BertImgModel, self).__init__(config)
179 | self.embeddings = BertEmbeddings(config)
180 | self.encoder = CaptionBertEncoder(config)
181 | self.pooler = BertPooler(config)
182 |
183 | self.img_dim = config.img_feature_dim
184 | logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
185 |
186 | self.apply(self.init_weights)
187 |
188 | def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
189 | position_ids=None, img_feats=None):
190 |
191 | if attention_mask.dim() == 2:
192 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
193 | elif attention_mask.dim() == 3:
194 | extended_attention_mask = attention_mask.unsqueeze(1)
195 | else:
196 | raise NotImplementedError
197 |
198 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
199 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
200 |
201 | head_mask = [None] * self.config.num_hidden_layers
202 |
203 | if mode == 'visual':
204 | language_features = input_ids
205 | concat_embedding_output = torch.cat((language_features, img_feats), 1)
206 | elif mode == 'language':
207 | embedding_output = self.embeddings(input_ids, position_ids=position_ids,
208 | token_type_ids=token_type_ids)
209 | concat_embedding_output = embedding_output
210 |
211 | ''' pass to the Transformer layers '''
212 | encoder_outputs = self.encoder(mode, concat_embedding_output,
213 | extended_attention_mask, head_mask=head_mask)
214 |
215 | sequence_output = encoder_outputs[0]
216 | pooled_output = self.pooler(sequence_output) # We "pool" the model by simply taking the hidden state corresponding to the first token
217 |
218 | # add hidden_states and attentions if they are here
219 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
220 |
221 | return outputs
222 |
223 |
224 | class VLNBert(BertPreTrainedModel):
225 | """
226 | Modified from BertForMultipleChoice to support oscar training.
227 | """
228 | def __init__(self, config):
229 | super(VLNBert, self).__init__(config)
230 | self.config = config
231 | self.bert = BertImgModel(config)
232 |
233 | self.vis_lang_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
234 | self.state_proj = nn.Linear(config.hidden_size*2, config.hidden_size, bias=True)
235 | self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
236 |
237 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
238 | self.apply(self.init_weights)
239 |
240 | def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
241 | position_ids=None, img_feats=None):
242 |
243 | outputs = self.bert(mode, input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
244 | attention_mask=attention_mask, img_feats=img_feats)
245 |
246 | sequence_output = outputs[0]
247 | sequence_output = self.dropout(sequence_output)
248 |
249 | pooled_output = outputs[1]
250 |
251 | if mode == 'language':
252 | return sequence_output
253 |
254 | elif mode == 'visual':
255 | # attention scores with respect to agent's state
256 | language_attentions = outputs[2][:, :, 1:-self.config.directions]
257 | visual_attentions = outputs[2][:, :, -self.config.directions:]
258 |
259 | language_attention_scores = language_attentions.mean(dim=1) # mean over the 12 heads
260 | visual_attention_scores = visual_attentions.mean(dim=1)
261 |
262 | # weighted_feat
263 | language_attention_probs = nn.Softmax(dim=-1)(language_attention_scores.clone()).unsqueeze(-1)
264 | visual_attention_probs = nn.Softmax(dim=-1)(visual_attention_scores.clone()).unsqueeze(-1)
265 |
266 | language_seq = sequence_output[:, 1:-self.config.directions, :]
267 | visual_seq = sequence_output[:, -self.config.directions:, :]
268 |
269 | # residual weighting, final attention to weight the raw inputs
270 | attended_language = (language_attention_probs * input_ids[:, 1:, :]).sum(1)
271 | attended_visual = (visual_attention_probs * img_feats).sum(1)
272 |
273 | # update agent's state, unify history, language and vision by elementwise product
274 | vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual)
275 | state_output = torch.cat((pooled_output, vis_lang_feat), dim=-1)
276 | state_proj = self.state_proj(state_output)
277 | state_proj = self.state_LayerNorm(state_proj)
278 |
279 | return state_proj, visual_attention_scores
280 |
--------------------------------------------------------------------------------
/r2r_src/env.py:
--------------------------------------------------------------------------------
1 | ''' Batched Room-to-Room navigation environment '''
2 |
3 | import sys
4 | sys.path.append('buildpy36')
5 | sys.path.append('Matterport_Simulator/build/')
6 | import MatterSim
7 | import csv
8 | import numpy as np
9 | import math
10 | import base64
11 | import utils
12 | import json
13 | import os
14 | import random
15 | import networkx as nx
16 | from param import args
17 |
18 | from utils import load_datasets, load_nav_graphs, pad_instr_tokens
19 |
20 | csv.field_size_limit(sys.maxsize)
21 |
22 |
23 | class EnvBatch():
24 | ''' A simple wrapper for a batch of MatterSim environments,
25 | using discretized viewpoints and pretrained features '''
26 |
27 | def __init__(self, feature_store=None, batch_size=100):
28 | """
29 | 1. Load pretrained image feature
30 | 2. Init the Simulator.
31 | :param feature_store: The name of file stored the feature.
32 | :param batch_size: Used to create the simulator list.
33 | """
34 | if feature_store:
35 | if type(feature_store) is dict: # A silly way to avoid multiple reading
36 | self.features = feature_store
37 | self.image_w = 640
38 | self.image_h = 480
39 | self.vfov = 60
40 | self.feature_size = next(iter(self.features.values())).shape[-1]
41 | print('The feature size is %d' % self.feature_size)
42 | else:
43 | print(' Image features not provided - in testing mode')
44 | self.features = None
45 | self.image_w = 640
46 | self.image_h = 480
47 | self.vfov = 60
48 | self.sims = []
49 | for i in range(batch_size):
50 | sim = MatterSim.Simulator()
51 | sim.setRenderingEnabled(False)
52 | sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians)
53 | sim.setCameraResolution(self.image_w, self.image_h)
54 | sim.setCameraVFOV(math.radians(self.vfov))
55 | sim.init()
56 | self.sims.append(sim)
57 |
58 | def _make_id(self, scanId, viewpointId):
59 | return scanId + '_' + viewpointId
60 |
61 | def newEpisodes(self, scanIds, viewpointIds, headings):
62 | for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
63 | self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
64 |
65 | def getStates(self):
66 | """
67 | Get list of states augmented with precomputed image features. rgb field will be empty.
68 | Agent's current view [0-35] (set only when viewing angles are discretized)
69 | [0-11] looking down, [12-23] looking at horizon, [24-35] looking up
70 | :return: [ ((30, 2048), sim_state) ] * batch_size
71 | """
72 | feature_states = []
73 | for i, sim in enumerate(self.sims):
74 | state = sim.getState()
75 |
76 | long_id = self._make_id(state.scanId, state.location.viewpointId)
77 | if self.features:
78 | feature = self.features[long_id]
79 | feature_states.append((feature, state))
80 | else:
81 | feature_states.append((None, state))
82 | return feature_states
83 |
84 | def makeActions(self, actions):
85 | ''' Take an action using the full state dependent action interface (with batched input).
86 | Every action element should be an (index, heading, elevation) tuple. '''
87 | for i, (index, heading, elevation) in enumerate(actions):
88 | self.sims[i].makeAction(index, heading, elevation)
89 |
90 |
91 | class R2RBatch():
92 | ''' Implements the Room to Room navigation task, using discretized viewpoints and pretrained features '''
93 |
94 | def __init__(self, feature_store, batch_size=100, seed=10, splits=['train'], tokenizer=None,
95 | name=None):
96 | self.env = EnvBatch(feature_store=feature_store, batch_size=batch_size)
97 | if feature_store:
98 | self.feature_size = self.env.feature_size
99 | else:
100 | self.feature_size = 2048
101 | self.data = []
102 | if tokenizer:
103 | self.tok = tokenizer
104 | scans = []
105 | for split in splits:
106 | for i_item, item in enumerate(load_datasets([split])):
107 | if args.test_only and i_item == 64:
108 | break
109 | if args.ADAPT:
110 | if "/" in split:
111 | try:
112 | new_item = dict(item)
113 | new_item['instr_id'] = item['path_id']
114 | new_item['instructions'] = item['instructions'][0]
115 |
116 | new_item['instr_encoding'] = item['instr_enc']
117 |
118 | new_item['img_sub_prompt'] = item['positive_trajectory'][str(new_item['instr_id'])]
119 | new_item['txt_sub_prompt'] = item['positive_obj_index'][str(new_item['instr_id'])]
120 | new_item['prompt_num'] = item['positive_trajectory_length'][
121 | str(new_item['instr_id'])]
122 | if new_item['instr_encoding'] is not None: # Filter the wrong data
123 | self.data.append(new_item)
124 | scans.append(item['scan'])
125 | except:
126 | continue
127 | else:
128 | # Split multiple instructions into separate entries
129 | for j, instr in enumerate(item['instructions']):
130 | try:
131 | new_item = dict(item)
132 | new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
133 | new_item['instructions'] = instr
134 | ''' BERT tokenizer '''
135 | instr_tokens = tokenizer.tokenize(instr)
136 | padded_instr_tokens, num_words = pad_instr_tokens(instr_tokens, args.maxInput)
137 | new_item['instr_encoding'] = tokenizer.convert_tokens_to_ids(padded_instr_tokens)
138 |
139 | new_item['img_sub_prompt'] = item['positive_trajectory'][new_item['instr_id']]
140 | new_item['txt_sub_prompt'] = item['positive_obj_index'][
141 | str(new_item['instr_id'])]
142 | new_item['prompt_num'] = item['positive_trajectory_length'][
143 | new_item['instr_id']]
144 | if new_item['instr_encoding'] is not None: # Filter the wrong data
145 | self.data.append(new_item)
146 | scans.append(item['scan'])
147 | except:
148 | continue
149 |
150 | else:
151 | if "/" in split:
152 | try:
153 | new_item = dict(item)
154 | new_item['instr_id'] = item['path_id']
155 | new_item['instructions'] = item['instructions'][0]
156 | new_item['instr_encoding'] = item['instr_enc']
157 | if new_item['instr_encoding'] is not None: # Filter the wrong data
158 | self.data.append(new_item)
159 | scans.append(item['scan'])
160 | except:
161 | continue
162 | else:
163 | # Split multiple instructions into separate entries
164 | for j, instr in enumerate(item['instructions']):
165 | try:
166 | new_item = dict(item)
167 | new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
168 | new_item['instructions'] = instr
169 |
170 | ''' BERT tokenizer '''
171 | instr_tokens = tokenizer.tokenize(instr)
172 | padded_instr_tokens, num_words = pad_instr_tokens(instr_tokens, args.maxInput)
173 | new_item['instr_encoding'] = tokenizer.convert_tokens_to_ids(padded_instr_tokens)
174 |
175 | if new_item['instr_encoding'] is not None: # Filter the wrong data
176 | self.data.append(new_item)
177 | scans.append(item['scan'])
178 | except:
179 | continue
180 |
181 | if name is None:
182 | self.name = splits[0] if len(splits) > 0 else "FAKE"
183 | else:
184 | self.name = name
185 |
186 | self.scans = set(scans)
187 | self.splits = splits
188 | self.seed = seed
189 | random.seed(self.seed)
190 | random.shuffle(self.data)
191 |
192 | self.ix = 0
193 | self.batch_size = batch_size
194 | self._load_nav_graphs()
195 |
196 | self.angle_feature = utils.get_all_point_angle_feature()
197 | self.sim = utils.new_simulator()
198 | self.buffered_state_dict = {}
199 |
200 | # It means that the fake data is equals to data in the supervised setup
201 | self.fake_data = self.data
202 | print('R2RBatch loaded with %d instructions, using splits: %s' % (len(self.data), ",".join(splits)))
203 |
204 | def size(self):
205 | return len(self.data)
206 |
207 | def _load_nav_graphs(self):
208 | """
209 | load graph from self.scan,
210 | Store the graph {scan_id: graph} in self.graphs
211 | Store the shortest path {scan_id: {view_id_x: {view_id_y: [path]} } } in self.paths
212 | Store the distances in self.distances. (Structure see above)
213 | Load connectivity graph for each scan, useful for reasoning about shortest paths
214 | :return: None
215 | """
216 | print('Loading navigation graphs for %d scans' % len(self.scans))
217 | self.graphs = load_nav_graphs(self.scans)
218 | self.paths = {}
219 | for scan, G in self.graphs.items(): # compute all shortest paths
220 | self.paths[scan] = dict(nx.all_pairs_dijkstra_path(G))
221 | self.distances = {}
222 | for scan, G in self.graphs.items(): # compute all shortest paths
223 | self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
224 |
225 | def _next_minibatch(self, tile_one=False, batch_size=None, **kwargs):
226 | """
227 | Store the minibach in 'self.batch'
228 | :param tile_one: Tile the one into batch_size
229 | :return: None
230 | """
231 | if batch_size is None:
232 | batch_size = self.batch_size
233 | if tile_one:
234 | batch = [self.data[self.ix]] * batch_size
235 | self.ix += 1
236 | if self.ix >= len(self.data):
237 | random.shuffle(self.data)
238 | self.ix -= len(self.data)
239 | else:
240 | batch = self.data[self.ix: self.ix+batch_size]
241 | if len(batch) < batch_size:
242 | random.shuffle(self.data)
243 | self.ix = batch_size - len(batch)
244 | batch += self.data[:self.ix]
245 | else:
246 | self.ix += batch_size
247 | self.batch = batch
248 |
249 | def reset_epoch(self, shuffle=False):
250 | ''' Reset the data index to beginning of epoch. Primarily for testing.
251 | You must still call reset() for a new episode. '''
252 | if shuffle:
253 | random.shuffle(self.data)
254 | self.ix = 0
255 |
256 | def _shortest_path_action(self, state, goalViewpointId):
257 | ''' Determine next action on the shortest path to goal, for supervised training. '''
258 | if state.location.viewpointId == goalViewpointId:
259 | return goalViewpointId # Just stop here
260 | path = self.paths[state.scanId][state.location.viewpointId][goalViewpointId]
261 | nextViewpointId = path[1]
262 | return nextViewpointId
263 |
264 | def make_candidate(self, feature, scanId, viewpointId, viewId):
265 | def _loc_distance(loc):
266 | return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2)
267 | base_heading = (viewId % 12) * math.radians(30)
268 | adj_dict = {}
269 | long_id = "%s_%s" % (scanId, viewpointId)
270 | if long_id not in self.buffered_state_dict:
271 | for ix in range(36):
272 | if ix == 0:
273 | self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30))
274 | elif ix % 12 == 0:
275 | self.sim.makeAction(0, 1.0, 1.0)
276 | else:
277 | self.sim.makeAction(0, 1.0, 0)
278 |
279 | state = self.sim.getState()
280 | assert state.viewIndex == ix
281 |
282 | # Heading and elevation for the viewpoint center
283 | heading = state.heading - base_heading
284 | elevation = state.elevation
285 |
286 | visual_feat = feature[ix]
287 |
288 | # get adjacent locations
289 | for j, loc in enumerate(state.navigableLocations[1:]):
290 | # if a loc is visible from multiple view, use the closest
291 | # view (in angular distance) as its representation
292 | distance = _loc_distance(loc)
293 |
294 | # Heading and elevation for for the loc
295 | loc_heading = heading + loc.rel_heading
296 | loc_elevation = elevation + loc.rel_elevation
297 | angle_feat = utils.angle_feature(loc_heading, loc_elevation)
298 | if (loc.viewpointId not in adj_dict or
299 | distance < adj_dict[loc.viewpointId]['distance']):
300 | adj_dict[loc.viewpointId] = {
301 | 'heading': loc_heading,
302 | 'elevation': loc_elevation,
303 | "normalized_heading": state.heading + loc.rel_heading,
304 | 'scanId':scanId,
305 | 'viewpointId': loc.viewpointId, # Next viewpoint id
306 | 'pointId': ix,
307 | 'distance': distance,
308 | 'idx': j + 1,
309 | 'feature': np.concatenate((visual_feat, angle_feat), -1)
310 | }
311 | candidate = list(adj_dict.values())
312 | self.buffered_state_dict[long_id] = [
313 | {key: c[key]
314 | for key in
315 | ['normalized_heading', 'elevation', 'scanId', 'viewpointId',
316 | 'pointId', 'idx']}
317 | for c in candidate
318 | ]
319 | return candidate
320 | else:
321 | candidate = self.buffered_state_dict[long_id]
322 | candidate_new = []
323 | for c in candidate:
324 | c_new = c.copy()
325 | ix = c_new['pointId']
326 | normalized_heading = c_new['normalized_heading']
327 | visual_feat = feature[ix]
328 | loc_heading = normalized_heading - base_heading
329 | c_new['heading'] = loc_heading
330 | angle_feat = utils.angle_feature(c_new['heading'], c_new['elevation'])
331 | c_new['feature'] = np.concatenate((visual_feat, angle_feat), -1)
332 | c_new.pop('normalized_heading')
333 | candidate_new.append(c_new)
334 | return candidate_new
335 |
336 | def _get_obs(self):
337 | obs = []
338 | for i, (feature, state) in enumerate(self.env.getStates()):
339 | item = self.batch[i]
340 | base_view_id = state.viewIndex
341 |
342 | if feature is None:
343 | feature = np.zeros((36, 2048))
344 |
345 | # Full features
346 | candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)
347 | # [visual_feature, angle_feature] for views
348 | feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
349 | if args.ADAPT:
350 | obs.append({
351 | 'instr_id': item['instr_id'],
352 | 'scan': state.scanId,
353 | 'viewpoint': state.location.viewpointId,
354 | 'viewIndex': state.viewIndex,
355 | 'heading': state.heading,
356 | 'elevation': state.elevation,
357 | 'feature': feature,
358 | 'candidate': candidate,
359 | 'navigableLocations': state.navigableLocations,
360 | 'instructions': item['instructions'],
361 | 'teacher': self._shortest_path_action(state, item['path'][-1]),
362 | 'gt_path': item['path'],
363 | 'path_id': item['path_id'],
364 | 'img_sub_prompt': item['img_sub_prompt'],
365 | 'txt_sub_prompt': item['txt_sub_prompt'],
366 | 'prompt_num': item['prompt_num']
367 | })
368 | else:
369 | obs.append({
370 | 'instr_id' : item['instr_id'],
371 | 'scan' : state.scanId,
372 | 'viewpoint' : state.location.viewpointId,
373 | 'viewIndex' : state.viewIndex,
374 | 'heading' : state.heading,
375 | 'elevation' : state.elevation,
376 | 'feature' : feature,
377 | 'candidate': candidate,
378 | 'navigableLocations' : state.navigableLocations,
379 | 'instructions' : item['instructions'],
380 | 'teacher' : self._shortest_path_action(state, item['path'][-1]),
381 | 'gt_path' : item['path'],
382 | 'path_id' : item['path_id']
383 | })
384 | if 'instr_encoding' in item:
385 | obs[-1]['instr_encoding'] = item['instr_encoding']
386 | # A2C reward. The negative distance between the state and the final state
387 | obs[-1]['distance'] = self.distances[state.scanId][state.location.viewpointId][item['path'][-1]]
388 | return obs
389 |
390 | def reset(self, batch=None, inject=False, **kwargs):
391 | ''' Load a new minibatch / episodes. '''
392 | if batch is None: # Allow the user to explicitly define the batch
393 | self._next_minibatch(**kwargs)
394 | else:
395 | if inject: # Inject the batch into the next minibatch
396 | self._next_minibatch(**kwargs)
397 | self.batch[:len(batch)] = batch
398 | else: # Else set the batch to the current batch
399 | self.batch = batch
400 | scanIds = [item['scan'] for item in self.batch]
401 | viewpointIds = [item['path'][0] for item in self.batch]
402 | headings = [item['heading'] for item in self.batch]
403 | self.env.newEpisodes(scanIds, viewpointIds, headings)
404 | return self._get_obs()
405 |
406 | def step(self, actions):
407 | ''' Take action (same interface as makeActions) '''
408 | self.env.makeActions(actions)
409 | return self._get_obs()
410 |
411 | def get_statistics(self):
412 | stats = {}
413 | length = 0
414 | path = 0
415 | for datum in self.data:
416 | length += len(self.tok.split_sentence(datum['instructions']))
417 | path += self.distances[datum['scan']][datum['path'][0]][datum['path'][-1]]
418 | stats['length'] = length / len(self.data)
419 | stats['path'] = path / len(self.data)
420 | return stats
421 |
--------------------------------------------------------------------------------
/r2r_src/vlnbert/vlnbert_PREVALENT.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import, division, print_function, unicode_literals
3 |
4 | import json
5 | import logging
6 | import math
7 | import os
8 | import sys
9 | from io import open
10 |
11 | import torch
12 | from torch import nn
13 | from torch.nn import CrossEntropyLoss, MSELoss
14 |
15 | #from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
16 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
17 | import pdb
18 | from param import args
19 | import torch.nn.functional as F
20 | import numpy as np
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | def gelu(x):
25 | """Implementation of the gelu activation function.
26 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
27 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
28 | Also see https://arxiv.org/abs/1606.08415
29 | """
30 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
31 |
32 |
33 | def swish(x):
34 | return x * torch.sigmoid(x)
35 |
36 |
37 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
38 |
39 |
40 | try:
41 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
42 | except (ImportError, AttributeError) as e:
43 | logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
44 | BertLayerNorm = torch.nn.LayerNorm
45 |
46 | class BertEmbeddings(nn.Module):
47 | """Construct the embeddings from word, position and token_type embeddings.
48 | """
49 | def __init__(self, config):
50 | super(BertEmbeddings, self).__init__()
51 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
52 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
53 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
54 |
55 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
56 | # any TensorFlow checkpoint file
57 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
58 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
59 |
60 | def forward(self, input_ids, token_type_ids=None, position_ids=None):
61 | seq_length = input_ids.size(1)
62 | if position_ids is None:
63 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
64 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
65 | if token_type_ids is None:
66 | token_type_ids = torch.zeros_like(input_ids)
67 |
68 | words_embeddings = self.word_embeddings(input_ids)
69 | position_embeddings = self.position_embeddings(position_ids)
70 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
71 |
72 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
73 | embeddings = self.LayerNorm(embeddings)
74 | embeddings = self.dropout(embeddings)
75 | return embeddings
76 |
77 |
78 | class BertSelfAttention(nn.Module):
79 | def __init__(self, config):
80 | super(BertSelfAttention, self).__init__()
81 | if config.hidden_size % config.num_attention_heads != 0:
82 | raise ValueError(
83 | "The hidden size (%d) is not a multiple of the number of attention "
84 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
85 | self.output_attentions = True
86 |
87 | self.num_attention_heads = config.num_attention_heads
88 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
89 | self.all_head_size = self.num_attention_heads * self.attention_head_size
90 |
91 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
92 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
93 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
94 |
95 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
96 |
97 | def transpose_for_scores(self, x):
98 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
99 | x = x.view(*new_x_shape)
100 | return x.permute(0, 2, 1, 3)
101 |
102 | def forward(self, hidden_states, attention_mask, head_mask=None):
103 | mixed_query_layer = self.query(hidden_states)
104 | mixed_key_layer = self.key(hidden_states)
105 | mixed_value_layer = self.value(hidden_states)
106 |
107 | query_layer = self.transpose_for_scores(mixed_query_layer)
108 | key_layer = self.transpose_for_scores(mixed_key_layer)
109 | value_layer = self.transpose_for_scores(mixed_value_layer)
110 |
111 | # Take the dot product between "query" and "key" to get the raw attention scores.
112 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
113 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
114 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
115 | attention_scores = attention_scores + attention_mask
116 |
117 | # Normalize the attention scores to probabilities.
118 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
119 |
120 | # This is actually dropping out entire tokens to attend to, which might
121 | # seem a bit unusual, but is taken from the original Transformer paper.
122 | attention_probs = self.dropout(attention_probs)
123 |
124 | # Mask heads if we want to
125 | if head_mask is not None:
126 | attention_probs = attention_probs * head_mask
127 |
128 | context_layer = torch.matmul(attention_probs, value_layer)
129 |
130 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
131 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
132 | context_layer = context_layer.view(*new_context_layer_shape)
133 |
134 | outputs = (context_layer, attention_scores) if self.output_attentions else (context_layer,)
135 | return outputs
136 |
137 |
138 | class BertSelfOutput(nn.Module):
139 | def __init__(self, config):
140 | super(BertSelfOutput, self).__init__()
141 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
142 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
143 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
144 |
145 | def forward(self, hidden_states, input_tensor):
146 | hidden_states = self.dense(hidden_states)
147 | hidden_states = self.dropout(hidden_states)
148 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
149 | return hidden_states
150 |
151 |
152 | class BertAttention(nn.Module):
153 | def __init__(self, config):
154 | super(BertAttention, self).__init__()
155 | self.self = BertSelfAttention(config)
156 | self.output = BertSelfOutput(config)
157 |
158 | def forward(self, input_tensor, attention_mask, head_mask=None):
159 | self_outputs = self.self(input_tensor, attention_mask, head_mask)
160 | attention_output = self.output(self_outputs[0], input_tensor)
161 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
162 | return outputs
163 |
164 |
165 | class BertIntermediate(nn.Module):
166 | def __init__(self, config):
167 | super(BertIntermediate, self).__init__()
168 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
169 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
170 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
171 | else:
172 | self.intermediate_act_fn = config.hidden_act
173 |
174 | def forward(self, hidden_states):
175 | hidden_states = self.dense(hidden_states)
176 | hidden_states = self.intermediate_act_fn(hidden_states)
177 | return hidden_states
178 |
179 |
180 | class BertOutput(nn.Module):
181 | def __init__(self, config):
182 | super(BertOutput, self).__init__()
183 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
184 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
185 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
186 |
187 | def forward(self, hidden_states, input_tensor):
188 | hidden_states = self.dense(hidden_states)
189 | hidden_states = self.dropout(hidden_states)
190 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
191 | return hidden_states
192 |
193 |
194 | class BertLayer(nn.Module):
195 | def __init__(self, config):
196 | super(BertLayer, self).__init__()
197 | self.attention = BertAttention(config)
198 | self.intermediate = BertIntermediate(config)
199 | self.output = BertOutput(config)
200 |
201 | def forward(self, hidden_states, attention_mask, head_mask=None):
202 | attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
203 | attention_output = attention_outputs[0]
204 | intermediate_output = self.intermediate(attention_output)
205 | layer_output = self.output(intermediate_output, attention_output)
206 | outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
207 | return outputs
208 |
209 |
210 | class BertPooler(nn.Module):
211 | def __init__(self, config):
212 | super(BertPooler, self).__init__()
213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
214 | self.activation = nn.Tanh()
215 |
216 | def forward(self, hidden_states):
217 | # We "pool" the model by simply taking the hidden state corresponding
218 | # to the first token.
219 | first_token_tensor = hidden_states[:, 0]
220 | pooled_output = self.dense(first_token_tensor)
221 | pooled_output = self.activation(pooled_output)
222 | return pooled_output
223 |
224 |
225 | class BertXAttention(nn.Module):
226 | def __init__(self, config, ctx_dim=None):
227 | super().__init__()
228 | self.att = BertOutAttention(config, ctx_dim=ctx_dim)
229 | self.output = BertSelfOutput(config)
230 |
231 | def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
232 | output, attention_scores = self.att(input_tensor, ctx_tensor, ctx_att_mask)
233 | attention_output = self.output(output, input_tensor)
234 | return attention_output, attention_scores
235 |
236 |
237 | class BertOutAttention(nn.Module):
238 | def __init__(self, config, ctx_dim=None):
239 | super().__init__()
240 | if config.hidden_size % config.num_attention_heads != 0:
241 | raise ValueError(
242 | "The hidden size (%d) is not a multiple of the number of attention "
243 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
244 | self.num_attention_heads = config.num_attention_heads
245 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
246 | self.all_head_size = self.num_attention_heads * self.attention_head_size
247 |
248 | # visual_dim = 2048
249 | if ctx_dim is None:
250 | ctx_dim =config.hidden_size
251 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
252 | self.key = nn.Linear(ctx_dim, self.all_head_size)
253 | self.value = nn.Linear(ctx_dim, self.all_head_size)
254 |
255 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
256 |
257 | def transpose_for_scores(self, x):
258 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
259 | x = x.view(*new_x_shape)
260 | return x.permute(0, 2, 1, 3)
261 |
262 | def forward(self, hidden_states, context, attention_mask=None):
263 | mixed_query_layer = self.query(hidden_states)
264 | mixed_key_layer = self.key(context)
265 | mixed_value_layer = self.value(context)
266 |
267 | query_layer = self.transpose_for_scores(mixed_query_layer)
268 | key_layer = self.transpose_for_scores(mixed_key_layer)
269 | value_layer = self.transpose_for_scores(mixed_value_layer)
270 |
271 | # Take the dot product between "query" and "key" to get the raw attention scores.
272 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
273 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
274 |
275 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
276 | if attention_mask is not None:
277 | attention_scores = attention_scores + attention_mask
278 |
279 | # Normalize the attention scores to probabilities.
280 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
281 |
282 | # This is actually dropping out entire tokens to attend to, which might
283 | # seem a bit unusual, but is taken from the original Transformer paper.
284 | attention_probs = self.dropout(attention_probs)
285 |
286 | context_layer = torch.matmul(attention_probs, value_layer)
287 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
288 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
289 | context_layer = context_layer.view(*new_context_layer_shape)
290 | return context_layer, attention_scores
291 |
292 |
293 | class LXRTXLayer(nn.Module):
294 | def __init__(self, config):
295 | super().__init__()
296 | self.config = config
297 | # Lang self-att and FFN layer
298 | self.lang_self_att = BertAttention(config)
299 | self.lang_inter = BertIntermediate(config)
300 | self.lang_output = BertOutput(config)
301 | # Visn self-att and FFN layer
302 | self.visn_self_att = BertAttention(config)
303 | self.visn_inter = BertIntermediate(config)
304 | self.visn_output = BertOutput(config)
305 | # The cross attention layer
306 | self.visual_attention = BertXAttention(config)
307 |
308 | def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
309 | ''' Cross Attention -- cross for vision not for language '''
310 | visn_att_output, attention_scores = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
311 | return visn_att_output, attention_scores
312 |
313 | def self_att(self, visn_input, visn_attention_mask):
314 | ''' Self Attention -- on visual features with language clues '''
315 | visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
316 | return visn_att_output
317 |
318 | def output_fc(self, visn_input):
319 | ''' Feed forward '''
320 | visn_inter_output = self.visn_inter(visn_input)
321 | visn_output = self.visn_output(visn_inter_output, visn_input)
322 | return visn_output
323 |
324 | def forward(self, lang_feats, lang_attention_mask,
325 | visn_feats, visn_attention_mask, tdx):
326 |
327 | ''' visual self-attention with state '''
328 | visn_att_output = torch.cat((lang_feats[:, 0:1, :], visn_feats), dim=1)
329 | state_vis_mask = torch.cat((lang_attention_mask[:,:,:,0:1], visn_attention_mask), dim=-1)
330 |
331 | ''' state and vision attend to language '''
332 | visn_att_output, cross_attention_scores = self.cross_att(lang_feats[:, 1:, :], lang_attention_mask[:, :, :, 1:], visn_att_output, state_vis_mask)
333 |
334 | language_attention_scores = cross_attention_scores[:, :, 0, :]
335 |
336 | state_visn_att_output = self.self_att(visn_att_output, state_vis_mask)
337 | state_visn_output = self.output_fc(state_visn_att_output[0])
338 |
339 | visn_att_output = state_visn_output[:, 1:, :]
340 | lang_att_output = torch.cat((state_visn_output[:, 0:1, :], lang_feats[:,1:,:]), dim=1)
341 |
342 | visual_attention_scores = state_visn_att_output[1][:, :, 0, 1:]
343 |
344 | return lang_att_output, visn_att_output, language_attention_scores, visual_attention_scores
345 |
346 |
347 | class VisionEncoder(nn.Module):
348 | def __init__(self, vision_size, config):
349 | super().__init__()
350 | feat_dim = vision_size
351 |
352 | # Object feature encoding
353 | self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
354 | self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
355 |
356 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
357 |
358 | if args.features == 'clip':
359 | self.visn_fc_pre = nn.Linear(512 + 128, feat_dim)
360 |
361 | def forward(self, visn_input):
362 | feats = visn_input
363 |
364 | if args.features == 'clip':
365 | feats = self.visn_fc_pre(feats)
366 |
367 | x = self.visn_fc(feats)
368 | x = self.visn_layer_norm(x)
369 |
370 | output = self.dropout(x)
371 | return output
372 |
373 |
374 | class VLNBert(BertPreTrainedModel):
375 | def __init__(self, config):
376 | super(VLNBert, self).__init__(config)
377 | self.embeddings = BertEmbeddings(config)
378 | self.pooler = BertPooler(config)
379 |
380 | self.img_dim = config.img_feature_dim # 2176
381 | logger.info('VLNBert Image Dimension: {}'.format(self.img_dim))
382 | self.img_feature_type = config.img_feature_type # ''
383 | self.vl_layers = config.vl_layers # 4
384 | self.la_layers = config.la_layers # 9
385 | self.lalayer = nn.ModuleList(
386 | [BertLayer(config) for _ in range(self.la_layers)])
387 | self.addlayer = nn.ModuleList(
388 | [LXRTXLayer(config) for _ in range(self.vl_layers)])
389 | self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config)
390 |
391 | if args.ADAPT:
392 | self.txt_enc = nn.Sequential(nn.Linear(512, 768),
393 | nn.Dropout(0.1))
394 | self.img_enc = nn.Sequential(nn.Linear(512, 768),
395 | nn.Dropout(0.1))
396 | self.action_enc = nn.Sequential(nn.Linear(768*2, 768),
397 | nn.Dropout(0.1))
398 |
399 | self.apply(self.init_weights)
400 |
401 | def forward(self, mode, input_ids, token_type_ids=None,
402 | attention_mask=None, lang_mask=None, vis_mask=None, position_ids=None, head_mask=None,
403 | img_feats=None, prompt_mask_set=None,
404 | txt_sub_prompt_set=None, img_sub_prompt_set=None):
405 |
406 | attention_mask = lang_mask
407 |
408 | if token_type_ids is None:
409 | token_type_ids = torch.zeros_like(input_ids)
410 |
411 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
412 |
413 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
414 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
415 |
416 | head_mask = [None] * self.config.num_hidden_layers
417 |
418 | if mode == 'language':
419 | ''' LXMERT language branch (in VLN only perform this at initialization) '''
420 | embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
421 | text_embeds = embedding_output
422 |
423 | for layer_module in self.lalayer:
424 | temp_output = layer_module(text_embeds, extended_attention_mask)
425 | text_embeds = temp_output[0]
426 |
427 | sequence_output = text_embeds
428 | pooled_output = self.pooler(sequence_output)
429 |
430 | return pooled_output, sequence_output
431 |
432 | elif mode == 'visual':
433 | ''' LXMERT visual branch (no language processing during navigation) '''
434 | text_embeds = input_ids
435 |
436 | text_mask = extended_attention_mask
437 |
438 | img_embedding_output = self.vision_encoder(img_feats)
439 | img_seq_len = img_feats.shape[1]
440 | batch_size = text_embeds.size(0)
441 |
442 | img_seq_mask = vis_mask
443 |
444 | extended_img_mask = img_seq_mask.unsqueeze(1).unsqueeze(2)
445 | extended_img_mask = extended_img_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
446 | extended_img_mask = (1.0 - extended_img_mask) * -10000.0
447 | img_mask = extended_img_mask
448 |
449 | lang_output = text_embeds
450 | visn_output = img_embedding_output
451 |
452 | if txt_sub_prompt_set is not None and img_sub_prompt_set is not None:
453 | extended_prompt_mask = prompt_mask_set.unsqueeze(1).unsqueeze(2)
454 | extended_prompt_mask = extended_prompt_mask.to(dtype=next(self.parameters()).dtype)
455 | extended_prompt_mask = (1.0 - extended_prompt_mask) * -10000.0
456 | prompt_mask = extended_prompt_mask.cuda()
457 |
458 | text_mask = torch.cat([text_mask, prompt_mask], dim=3)
459 | txt_sub_prompt_feat = self.txt_enc(txt_sub_prompt_set)
460 | img_sub_prompt_feat = self.img_enc(img_sub_prompt_set)
461 | prompt_feat = self.action_enc(torch.cat([txt_sub_prompt_feat, img_sub_prompt_feat], dim=-1))
462 | lang_output = torch.cat([lang_output, prompt_feat], dim=1)
463 |
464 |
465 | for tdx, layer_module in enumerate(self.addlayer):
466 | lang_output, visn_output, language_attention_scores, visual_attention_scores = layer_module(lang_output, text_mask, visn_output, img_mask, tdx)
467 |
468 | if prompt_feat is not None and prompt_mask_set is not None:
469 | language_attention_scores = language_attention_scores[:, :,
470 | :language_attention_scores.size(2) - prompt_feat.size(1)]
471 | prompt_attention_scores = language_attention_scores[:, :,
472 | language_attention_scores.size(2) - prompt_feat.size(1):]
473 |
474 | prompt_scores = prompt_attention_scores.mean(dim=1)
475 | prompt_probs = nn.Softmax(dim=-1)(prompt_scores.clone()).unsqueeze(-1)
476 | attended_txt_sub_prompt_feat = (prompt_probs * txt_sub_prompt_feat).sum(1)
477 | attended_img_sub_prompt_feat = (prompt_probs * img_sub_prompt_feat).sum(1)
478 |
479 |
480 | sequence_output = lang_output
481 | pooled_output = self.pooler(sequence_output)
482 |
483 | language_state_scores = language_attention_scores.mean(dim=1)
484 | visual_action_scores = visual_attention_scores.mean(dim=1)
485 |
486 | # weighted_feat
487 | language_attention_probs = nn.Softmax(dim=-1)(language_state_scores.clone()).unsqueeze(-1)
488 | visual_attention_probs = nn.Softmax(dim=-1)(visual_action_scores.clone()).unsqueeze(-1)
489 |
490 | attended_language = (language_attention_probs * text_embeds[:, 1:, :]).sum(1)
491 | attended_visual = (visual_attention_probs * img_embedding_output).sum(1)
492 |
493 | if prompt_feat is not None and prompt_mask_set is not None:
494 | return pooled_output, visual_action_scores, attended_language, attended_visual, attended_txt_sub_prompt_feat, attended_img_sub_prompt_feat, txt_sub_prompt_feat, img_sub_prompt_feat
495 | else:
496 | return pooled_output, visual_action_scores, attended_language, attended_visual
497 |
--------------------------------------------------------------------------------
/r2r_src/utils.py:
--------------------------------------------------------------------------------
1 | ''' Utils for io, language, connectivity graphs etc '''
2 |
3 | import os
4 | import sys
5 | import re
6 | sys.path.append('Matterport_Simulator/build/')
7 | import MatterSim
8 | import string
9 | import json
10 | import time
11 | import math
12 | from collections import Counter, defaultdict
13 | import numpy as np
14 | import networkx as nx
15 | from param import args
16 | from numpy.linalg import norm
17 |
18 |
19 | # padding, unknown word, end of sentence
20 | base_vocab = ['', '', '']
21 | padding_idx = base_vocab.index('')
22 |
23 | def load_nav_graphs(scans):
24 | ''' Load connectivity graph for each scan '''
25 |
26 | def distance(pose1, pose2):
27 | ''' Euclidean distance between two graph poses '''
28 | return ((pose1['pose'][3]-pose2['pose'][3])**2\
29 | + (pose1['pose'][7]-pose2['pose'][7])**2\
30 | + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5
31 |
32 | graphs = {}
33 | for scan in scans:
34 | with open('connectivity/%s_connectivity.json' % scan) as f:
35 | G = nx.Graph()
36 | positions = {}
37 | data = json.load(f)
38 | for i,item in enumerate(data):
39 | if item['included']:
40 | for j,conn in enumerate(item['unobstructed']):
41 | if conn and data[j]['included']:
42 | positions[item['image_id']] = np.array([item['pose'][3],
43 | item['pose'][7], item['pose'][11]]);
44 | assert data[j]['unobstructed'][i], 'Graph should be undirected'
45 | G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
46 | nx.set_node_attributes(G, values=positions, name='position')
47 | graphs[scan] = G
48 | return graphs
49 |
50 |
51 | def load_datasets(splits):
52 | """
53 |
54 | :param splits: A list of split.
55 | if the split is "something@5000", it will use a random 5000 data from the data
56 | :return:
57 | """
58 | import random
59 | data = []
60 | old_state = random.getstate()
61 | for split in splits:
62 | # It only needs some part of the dataset?
63 | components = split.split("@")
64 | number = -1
65 | if len(components) > 1:
66 | split, number = components[0], int(components[1])
67 |
68 | # Load Json
69 | # if split in ['train', 'val_seen', 'val_unseen', 'test',
70 | # 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check
71 | if args.ADAPT:
72 | if "/" not in split:
73 | with open('data/R2R_%s_ADAPT.json' % split) as f:
74 | new_data = json.load(f)
75 | else:
76 | print('\nLoading prevalent data for pretraining...')
77 | with open('data/R2R_aug_ADAPT.json') as f:
78 | new_data = json.load(f)
79 | else:
80 | if "/" not in split:
81 | with open('data/R2R_%s.json' % split) as f:
82 | new_data = json.load(f)
83 | else:
84 | print('\nLoading prevalent data for pretraining...')
85 | with open(split) as f:
86 | new_data = json.load(f)
87 |
88 | # Partition
89 | if number > 0:
90 | random.seed(0) # Make the data deterministic, additive
91 | random.shuffle(new_data)
92 | new_data = new_data[:number]
93 |
94 | # Join
95 | data += new_data
96 | random.setstate(old_state) # Recover the state of the random generator
97 | return data
98 |
99 |
100 | def pad_instr_tokens(instr_tokens, maxlength=20):
101 |
102 | if len(instr_tokens) <= 2: #assert len(raw_instr_tokens) > 2
103 | return None
104 |
105 | if len(instr_tokens) > maxlength - 2: # -2 for [CLS] and [SEP]
106 | instr_tokens = instr_tokens[:(maxlength-2)]
107 |
108 | instr_tokens = ['[CLS]'] + instr_tokens + ['[SEP]']
109 | num_words = len(instr_tokens) # - 1 # include [SEP]
110 | instr_tokens += ['[PAD]'] * (maxlength-len(instr_tokens))
111 |
112 | assert len(instr_tokens) == maxlength
113 |
114 | return instr_tokens, num_words
115 |
116 |
117 | class Tokenizer(object):
118 | ''' Class to tokenize and encode a sentence. '''
119 | SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') # Split on any non-alphanumeric character
120 |
121 | def __init__(self, vocab=None, encoding_length=20):
122 | self.encoding_length = encoding_length
123 | self.vocab = vocab
124 | self.word_to_index = {}
125 | self.index_to_word = {}
126 | if vocab:
127 | for i,word in enumerate(vocab):
128 | self.word_to_index[word] = i
129 | new_w2i = defaultdict(lambda: self.word_to_index[''])
130 | new_w2i.update(self.word_to_index)
131 | self.word_to_index = new_w2i
132 | for key, value in self.word_to_index.items():
133 | self.index_to_word[value] = key
134 | old = self.vocab_size()
135 | self.add_word('')
136 | assert self.vocab_size() == old+1
137 | print("OLD_VOCAB_SIZE", old)
138 | print("VOCAB_SIZE", self.vocab_size())
139 | print("VOACB", len(vocab))
140 |
141 | def finalize(self):
142 | """
143 | This is used for debug
144 | """
145 | self.word_to_index = dict(self.word_to_index) # To avoid using mis-typing tokens
146 |
147 | def add_word(self, word):
148 | assert word not in self.word_to_index
149 | self.word_to_index[word] = self.vocab_size() # vocab_size() is the
150 | self.index_to_word[self.vocab_size()] = word
151 |
152 | @staticmethod
153 | def split_sentence(sentence):
154 | ''' Break sentence into a list of words and punctuation '''
155 | toks = []
156 | for word in [s.strip().lower() for s in Tokenizer.SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0]:
157 | # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
158 | if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
159 | toks += list(word)
160 | else:
161 | toks.append(word)
162 | return toks
163 |
164 | def vocab_size(self):
165 | return len(self.index_to_word)
166 |
167 | def encode_sentence(self, sentence, max_length=None):
168 | if max_length is None:
169 | max_length = self.encoding_length
170 | if len(self.word_to_index) == 0:
171 | sys.exit('Tokenizer has no vocab')
172 |
173 | encoding = [self.word_to_index['']]
174 | for word in self.split_sentence(sentence):
175 | encoding.append(self.word_to_index[word]) # Default Dict
176 | encoding.append(self.word_to_index[''])
177 |
178 | if len(encoding) <= 2:
179 | return None
180 | #assert len(encoding) > 2
181 |
182 | if len(encoding) < max_length:
183 | encoding += [self.word_to_index['']] * (max_length-len(encoding)) # Padding
184 | elif len(encoding) > max_length:
185 | encoding[max_length - 1] = self.word_to_index[''] # Cut the length with EOS
186 |
187 | return np.array(encoding[:max_length])
188 |
189 | def decode_sentence(self, encoding, length=None):
190 | sentence = []
191 | if length is not None:
192 | encoding = encoding[:length]
193 | for ix in encoding:
194 | if ix == self.word_to_index['']:
195 | break
196 | else:
197 | sentence.append(self.index_to_word[ix])
198 | return " ".join(sentence)
199 |
200 | def shrink(self, inst):
201 | """
202 | :param inst: The id inst
203 | :return: Remove the potential and
204 | If no return empty list
205 | """
206 | if len(inst) == 0:
207 | return inst
208 | end = np.argmax(np.array(inst) == self.word_to_index['']) # If no , return empty string
209 | if len(inst) > 1 and inst[0] == self.word_to_index['']:
210 | start = 1
211 | else:
212 | start = 0
213 | # print(inst, start, end)
214 | return inst[start: end]
215 |
216 |
217 | def build_vocab(splits=['train'], min_count=5, start_vocab=base_vocab):
218 | ''' Build a vocab, starting with base vocab containing a few useful tokens. '''
219 | count = Counter()
220 | t = Tokenizer()
221 | data = load_datasets(splits)
222 | for item in data:
223 | for instr in item['instructions']:
224 | count.update(t.split_sentence(instr))
225 | vocab = list(start_vocab)
226 | for word,num in count.most_common():
227 | if num >= min_count:
228 | vocab.append(word)
229 | else:
230 | break
231 | return vocab
232 |
233 |
234 | def write_vocab(vocab, path):
235 | print('Writing vocab of size %d to %s' % (len(vocab),path))
236 | with open(path, 'w') as f:
237 | for word in vocab:
238 | f.write("%s\n" % word)
239 |
240 |
241 | def read_vocab(path):
242 | with open(path) as f:
243 | vocab = [word.strip() for word in f.readlines()]
244 | return vocab
245 |
246 |
247 | def asMinutes(s):
248 | m = math.floor(s / 60)
249 | s -= m * 60
250 | return '%dm %ds' % (m, s)
251 |
252 |
253 | def timeSince(since, percent):
254 | now = time.time()
255 | s = now - since
256 | es = s / (percent)
257 | rs = es - s
258 | return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
259 |
260 | def read_img_features(feature_store, test_only=False):
261 | import csv
262 | import base64
263 | from tqdm import tqdm
264 |
265 | print("Start loading the image feature ... (~50 seconds)")
266 | start = time.time()
267 |
268 | if "detectfeat" in args.features:
269 | views = int(args.features[10:])
270 | else:
271 | views = 36
272 |
273 | args.views = views
274 |
275 | tsv_fieldnames = ['scanId', 'viewpointId', 'image_w', 'image_h', 'vfov', 'features']
276 |
277 | if not test_only:
278 | features = {}
279 | with open(feature_store, "r") as tsv_in_file: # Open the tsv file.
280 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=tsv_fieldnames)
281 | for item in reader:
282 | long_id = item['scanId'] + "_" + item['viewpointId']
283 | features[long_id] = np.frombuffer(base64.decodestring(item['features'].encode('ascii')),
284 | dtype=np.float32).reshape((views, -1)) # Feature of long_id is (36, 2048)
285 | else:
286 | features = None
287 |
288 | print("Finish Loading the image feature from %s in %0.4f seconds" % (feature_store, time.time() - start))
289 | return features
290 |
291 | def read_candidates(candidates_store):
292 | import csv
293 | import base64
294 | from collections import defaultdict
295 | print("Start loading the candidate feature")
296 |
297 | start = time.time()
298 |
299 | TSV_FIELDNAMES = ['scanId', 'viewpointId', 'heading', 'elevation', 'next', 'pointId', 'idx', 'feature']
300 | candidates = defaultdict(lambda: list())
301 | items = 0
302 | with open(candidates_store, "r") as tsv_in_file: # Open the tsv file.
303 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=TSV_FIELDNAMES)
304 | for item in reader:
305 | long_id = item['scanId'] + "_" + item['viewpointId']
306 | candidates[long_id].append(
307 | {'heading': float(item['heading']),
308 | 'elevation': float(item['elevation']),
309 | 'scanId': item['scanId'],
310 | 'viewpointId': item['next'],
311 | 'pointId': int(item['pointId']),
312 | 'idx': int(item['idx']) + 1, # Because a bug in the precompute code, here +1 is important
313 | 'feature': np.frombuffer(
314 | base64.decodestring(item['feature'].encode('ascii')),
315 | dtype=np.float32)
316 | }
317 | )
318 | items += 1
319 |
320 | for long_id in candidates:
321 | assert (len(candidates[long_id])) != 0
322 |
323 | assert sum(len(candidate) for candidate in candidates.values()) == items
324 |
325 | # candidate = candidates[long_id]
326 | # print(candidate)
327 | print("Finish Loading the candidates from %s in %0.4f seconds" % (candidates_store, time.time() - start))
328 | candidates = dict(candidates)
329 | return candidates
330 |
331 | def add_exploration(paths):
332 | explore = json.load(open("data/exploration.json", 'r'))
333 | inst2explore = {path['instr_id']: path['trajectory'] for path in explore}
334 | for path in paths:
335 | path['trajectory'] = inst2explore[path['instr_id']] + path['trajectory']
336 | return paths
337 |
338 | def angle_feature(heading, elevation):
339 |
340 | import math
341 | # twopi = math.pi * 2
342 | # heading = (heading + twopi) % twopi # From 0 ~ 2pi
343 | # It will be the same
344 | return np.array([math.sin(heading), math.cos(heading),
345 | math.sin(elevation), math.cos(elevation)] * (args.angle_feat_size // 4),
346 | dtype=np.float32)
347 |
348 | def new_simulator():
349 | import MatterSim
350 | # Simulator image parameters
351 | WIDTH = 640
352 | HEIGHT = 480
353 | VFOV = 60
354 |
355 | sim = MatterSim.Simulator()
356 | sim.setRenderingEnabled(False)
357 | sim.setCameraResolution(WIDTH, HEIGHT)
358 | sim.setCameraVFOV(math.radians(VFOV))
359 | sim.setDiscretizedViewingAngles(True)
360 | sim.init()
361 |
362 | return sim
363 |
364 | def get_point_angle_feature(baseViewId=0):
365 | sim = new_simulator()
366 |
367 | feature = np.empty((36, args.angle_feat_size), np.float32)
368 | base_heading = (baseViewId % 12) * math.radians(30)
369 | for ix in range(36):
370 | if ix == 0:
371 | sim.newEpisode('ZMojNkEp431', '2f4d90acd4024c269fb0efe49a8ac540', 0, math.radians(-30))
372 | elif ix % 12 == 0:
373 | sim.makeAction(0, 1.0, 1.0)
374 | else:
375 | sim.makeAction(0, 1.0, 0)
376 |
377 | state = sim.getState()
378 | assert state.viewIndex == ix
379 |
380 | heading = state.heading - base_heading
381 |
382 | feature[ix, :] = angle_feature(heading, state.elevation)
383 | return feature
384 |
385 | def get_all_point_angle_feature():
386 | return [get_point_angle_feature(baseViewId) for baseViewId in range(36)]
387 |
388 | def add_idx(inst):
389 | toks = Tokenizer.split_sentence(inst)
390 | return " ".join([str(idx)+tok for idx, tok in enumerate(toks)])
391 |
392 | import signal
393 | class GracefulKiller:
394 | kill_now = False
395 | def __init__(self):
396 | signal.signal(signal.SIGINT, self.exit_gracefully)
397 | signal.signal(signal.SIGTERM, self.exit_gracefully)
398 |
399 | def exit_gracefully(self,signum, frame):
400 | self.kill_now = True
401 |
402 | from collections import OrderedDict
403 |
404 | class Timer:
405 | def __init__(self):
406 | self.cul = OrderedDict()
407 | self.start = {}
408 | self.iter = 0
409 |
410 | def reset(self):
411 | self.cul = OrderedDict()
412 | self.start = {}
413 | self.iter = 0
414 |
415 | def tic(self, key):
416 | self.start[key] = time.time()
417 |
418 | def toc(self, key):
419 | delta = time.time() - self.start[key]
420 | if key not in self.cul:
421 | self.cul[key] = delta
422 | else:
423 | self.cul[key] += delta
424 |
425 | def step(self):
426 | self.iter += 1
427 |
428 | def show(self):
429 | total = sum(self.cul.values())
430 | for key in self.cul:
431 | print("%s, total time %0.2f, avg time %0.2f, part of %0.2f" %
432 | (key, self.cul[key], self.cul[key]*1./self.iter, self.cul[key]*1./total))
433 | print(total / self.iter)
434 |
435 |
436 | stop_word_list = [
437 | ",", ".", "and", "?", "!"
438 | ]
439 |
440 |
441 | def stop_words_location(inst, mask=False):
442 | toks = Tokenizer.split_sentence(inst)
443 | sws = [i for i, tok in enumerate(toks) if tok in stop_word_list] # The index of the stop words
444 | if len(sws) == 0 or sws[-1] != (len(toks)-1): # Add the index of the last token
445 | sws.append(len(toks)-1)
446 | sws = [x for x, y in zip(sws[:-1], sws[1:]) if x+1 != y] + [sws[-1]] # Filter the adjacent stop word
447 | sws_mask = np.ones(len(toks), np.int32) # Create the mask
448 | sws_mask[sws] = 0
449 | return sws_mask if mask else sws
450 |
451 | def get_segments(inst, mask=False):
452 | toks = Tokenizer.split_sentence(inst)
453 | sws = [i for i, tok in enumerate(toks) if tok in stop_word_list] # The index of the stop words
454 | sws = [-1] + sws + [len(toks)] # Add the and positions
455 | segments = [toks[sws[i]+1:sws[i+1]] for i in range(len(sws)-1)] # Slice the segments from the tokens
456 | segments = list(filter(lambda x: len(x)>0, segments)) # remove the consecutive stop words
457 | return segments
458 |
459 | def clever_pad_sequence(sequences, batch_first=True, padding_value=0):
460 | max_size = sequences[0].size()
461 | max_len, trailing_dims = max_size[0], max_size[1:]
462 | max_len = max(seq.size()[0] for seq in sequences)
463 | if batch_first:
464 | out_dims = (len(sequences), max_len) + trailing_dims
465 | else:
466 | out_dims = (max_len, len(sequences)) + trailing_dims
467 | if padding_value is not None:
468 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
469 | for i, tensor in enumerate(sequences):
470 | length = tensor.size(0)
471 | # use index notation to prevent duplicate references to the tensor
472 | if batch_first:
473 | out_tensor[i, :length, ...] = tensor
474 | else:
475 | out_tensor[:length, i, ...] = tensor
476 |
477 | return out_tensor
478 |
479 | import torch
480 | def length2mask(length, size=None):
481 | batch_size = len(length)
482 | size = int(max(length)) if size is None else size
483 | mask = (torch.arange(size, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1)
484 | > (torch.LongTensor(length) - 1).unsqueeze(1)).cuda()
485 | return mask
486 |
487 | def average_length(path2inst):
488 | length = []
489 |
490 | for name in path2inst:
491 | datum = path2inst[name]
492 | length.append(len(datum))
493 | return sum(length) / len(length)
494 |
495 | def tile_batch(tensor, multiplier):
496 | _, *s = tensor.size()
497 | tensor = tensor.unsqueeze(1).expand(-1, multiplier, *(-1,) * len(s)).contiguous().view(-1, *s)
498 | return tensor
499 |
500 | def viewpoint_drop_mask(viewpoint, seed=None, drop_func=None):
501 | local_seed = hash(viewpoint) ^ seed
502 | torch.random.manual_seed(local_seed)
503 | drop_mask = drop_func(torch.ones(2048).cuda())
504 | return drop_mask
505 |
506 |
507 | class FloydGraph:
508 | def __init__(self):
509 | self._dis = defaultdict(lambda :defaultdict(lambda: 95959595))
510 | self._point = defaultdict(lambda :defaultdict(lambda: ""))
511 | self._visited = set()
512 |
513 | def distance(self, x, y):
514 | if x == y:
515 | return 0
516 | else:
517 | return self._dis[x][y]
518 |
519 | def add_edge(self, x, y, dis):
520 | if dis < self._dis[x][y]:
521 | self._dis[x][y] = dis
522 | self._dis[y][x] = dis
523 | self._point[x][y] = ""
524 | self._point[y][x] = ""
525 |
526 | def update(self, k):
527 | for x in self._dis:
528 | for y in self._dis:
529 | if x != y:
530 | if self._dis[x][k] + self._dis[k][y] < self._dis[x][y]:
531 | self._dis[x][y] = self._dis[x][k] + self._dis[k][y]
532 | self._dis[y][x] = self._dis[x][y]
533 | self._point[x][y] = k
534 | self._point[y][x] = k
535 | self._visited.add(k)
536 |
537 | def visited(self, k):
538 | return (k in self._visited)
539 |
540 | def path(self, x, y):
541 | """
542 | :param x: start
543 | :param y: end
544 | :return: the path from x to y [v1, v2, ..., v_n, y]
545 | """
546 | if x == y:
547 | return []
548 | if self._point[x][y] == "": # Direct edge
549 | return [y]
550 | else:
551 | k = self._point[x][y]
552 | # print(x, y, k)
553 | # for x1 in (x, k, y):
554 | # for x2 in (x, k, y):
555 | # print(x1, x2, "%.4f" % self._dis[x1][x2])
556 | return self.path(x, k) + self.path(k, y)
557 |
558 | def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_length=100):
559 | """
560 | Call in a loop to create terminal progress bar
561 | @params:
562 | iteration - Required : current iteration (Int)
563 | total - Required : total iterations (Int)
564 | prefix - Optional : prefix string (Str)
565 | suffix - Optional : suffix string (Str)
566 | decimals - Optional : positive number of decimals in percent complete (Int)
567 | bar_length - Optional : character length of bar (Int)
568 | """
569 | str_format = "{0:." + str(decimals) + "f}"
570 | percents = str_format.format(100 * (iteration / float(total)))
571 | filled_length = int(round(bar_length * iteration / float(total)))
572 | bar = '█' * filled_length + '-' * (bar_length - filled_length)
573 |
574 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
575 |
576 | if iteration == total:
577 | sys.stdout.write('\n')
578 | sys.stdout.flush()
579 |
580 | def ndtw_initialize():
581 | ndtw_criterion = {}
582 | scan_gts_dir = 'data/id_paths.json'
583 | with open(scan_gts_dir) as f_:
584 | scan_gts = json.load(f_)
585 | all_scan_ids = []
586 | for key in scan_gts:
587 | path_scan_id = scan_gts[key][0]
588 | # print('path_scan_id', path_scan_id)
589 | if path_scan_id not in all_scan_ids:
590 | all_scan_ids.append(path_scan_id)
591 | ndtw_graph = ndtw_graphload(path_scan_id)
592 | ndtw_criterion[path_scan_id] = DTW(ndtw_graph)
593 | return ndtw_criterion
594 |
595 | def ndtw_graphload(scan):
596 | """Loads a networkx graph for a given scan.
597 | Args:
598 | connections_file: A string with the path to the .json file with the
599 | connectivity information.
600 | Returns:
601 | A networkx graph.
602 | """
603 | connections_file = 'connectivity/{}_connectivity.json'.format(scan)
604 | with open(connections_file) as f:
605 | lines = json.load(f)
606 | nodes = np.array([x['image_id'] for x in lines])
607 | matrix = np.array([x['unobstructed'] for x in lines])
608 | mask = np.array([x['included'] for x in lines])
609 |
610 | matrix = matrix[mask][:, mask]
611 | nodes = nodes[mask]
612 |
613 | pos2d = {x['image_id']: np.array(x['pose'])[[3, 7]] for x in lines}
614 | pos3d = {x['image_id']: np.array(x['pose'])[[3, 7, 11]] for x in lines}
615 |
616 | graph = nx.from_numpy_matrix(matrix)
617 | graph = nx.relabel.relabel_nodes(graph, dict(enumerate(nodes)))
618 | nx.set_node_attributes(graph, pos2d, 'pos2d')
619 | nx.set_node_attributes(graph, pos3d, 'pos3d')
620 |
621 | weight2d = {(u, v): norm(pos2d[u] - pos2d[v]) for u, v in graph.edges}
622 | weight3d = {(u, v): norm(pos3d[u] - pos3d[v]) for u, v in graph.edges}
623 | nx.set_edge_attributes(graph, weight2d, 'weight2d')
624 | nx.set_edge_attributes(graph, weight3d, 'weight3d')
625 |
626 | return graph
627 |
628 | class DTW(object):
629 | """Dynamic Time Warping (DTW) evaluation metrics.
630 | Python doctest:
631 | >>> graph = nx.grid_graph([3, 4])
632 | >>> prediction = [(0, 0), (1, 0), (2, 0), (3, 0)]
633 | >>> reference = [(0, 0), (1, 0), (2, 1), (3, 2)]
634 | >>> dtw = DTW(graph)
635 | >>> assert np.isclose(dtw(prediction, reference, 'dtw'), 3.0)
636 | >>> assert np.isclose(dtw(prediction, reference, 'ndtw'), 0.77880078307140488)
637 | >>> assert np.isclose(dtw(prediction, reference, 'sdtw'), 0.77880078307140488)
638 | >>> assert np.isclose(dtw(prediction[:2], reference, 'sdtw'), 0.0)
639 | """
640 |
641 | def __init__(self, graph, weight='weight', threshold=3.0):
642 | """Initializes a DTW object.
643 | Args:
644 | graph: networkx graph for the environment.
645 | weight: networkx edge weight key (str).
646 | threshold: distance threshold $d_{th}$ (float).
647 | """
648 | self.graph = graph
649 | self.weight = weight
650 | self.threshold = threshold
651 | self.distance = dict(
652 | nx.all_pairs_dijkstra_path_length(self.graph, weight=self.weight))
653 |
654 | def __call__(self, prediction, reference, metric='sdtw'):
655 | """Computes DTW metrics.
656 | Args:
657 | prediction: list of nodes (str), path predicted by agent.
658 | reference: list of nodes (str), the ground truth path.
659 | metric: one of ['ndtw', 'sdtw', 'dtw'].
660 | Returns:
661 | the DTW between the prediction and reference path (float).
662 | """
663 | assert metric in ['ndtw', 'sdtw', 'dtw']
664 |
665 | dtw_matrix = np.inf * np.ones((len(prediction) + 1, len(reference) + 1))
666 | dtw_matrix[0][0] = 0
667 | for i in range(1, len(prediction)+1):
668 | for j in range(1, len(reference)+1):
669 | best_previous_cost = min(
670 | dtw_matrix[i-1][j], dtw_matrix[i][j-1], dtw_matrix[i-1][j-1])
671 | cost = self.distance[prediction[i-1]][reference[j-1]]
672 | dtw_matrix[i][j] = cost + best_previous_cost
673 | dtw = dtw_matrix[len(prediction)][len(reference)]
674 |
675 | if metric == 'dtw':
676 | return dtw
677 |
678 | ndtw = np.exp(-dtw/(self.threshold * len(reference)))
679 | if metric == 'ndtw':
680 | return ndtw
681 |
682 | success = self.distance[prediction[-1]][reference[-1]] <= self.threshold
683 | return success * ndtw
684 |
--------------------------------------------------------------------------------
/r2r_src/agent.py:
--------------------------------------------------------------------------------
1 |
2 | import json
3 | import os
4 | import sys
5 | import numpy as np
6 | import random
7 | import math
8 | import time
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.autograd import Variable
13 | from torch import optim
14 | import torch.nn.functional as F
15 |
16 | from env import R2RBatch
17 | import utils
18 | from utils import padding_idx, print_progress
19 | import model_OSCAR, model_PREVALENT
20 | import param
21 | from param import args
22 | from collections import defaultdict
23 |
24 |
25 | class BaseAgent(object):
26 | ''' Base class for an R2R agent to generate and save trajectories. '''
27 |
28 | def __init__(self, env, results_path):
29 | self.env = env
30 | self.results_path = results_path
31 | random.seed(1)
32 | self.results = {}
33 | self.losses = [] # For learning agents
34 |
35 | def write_results(self):
36 | output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()]
37 | with open(self.results_path, 'w') as f:
38 | json.dump(output, f)
39 |
40 | def get_results(self):
41 | output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
42 | return output
43 |
44 | def rollout(self, **args):
45 | ''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
46 | raise NotImplementedError
47 |
48 | @staticmethod
49 | def get_agent(name):
50 | return globals()[name+"Agent"]
51 |
52 | def test(self, iters=None, **kwargs):
53 | self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
54 | self.losses = []
55 | self.results = {}
56 | # We rely on env showing the entire batch before repeating anything
57 | looped = False
58 | self.loss = 0
59 | if iters is not None:
60 | # For each time, it will run the first 'iters' iterations. (It was shuffled before)
61 | for i in range(iters):
62 | for traj in self.rollout(**kwargs):
63 | self.loss = 0
64 | self.results[traj['instr_id']] = traj['path']
65 | else: # Do a full round
66 | while True:
67 | for traj in self.rollout(**kwargs):
68 | if traj['instr_id'] in self.results:
69 | looped = True
70 | else:
71 | self.loss = 0
72 | self.results[traj['instr_id']] = traj['path']
73 | if looped:
74 | break
75 |
76 |
77 | class Seq2SeqAgent(BaseAgent):
78 | ''' An agent based on an LSTM seq2seq model with attention. '''
79 |
80 | # For now, the agent can't pick which forward move to make - just the one in the middle
81 | env_actions = {
82 | 'left': (0,-1, 0), # left
83 | 'right': (0, 1, 0), # right
84 | 'up': (0, 0, 1), # up
85 | 'down': (0, 0,-1), # down
86 | 'forward': (1, 0, 0), # forward
87 | '': (0, 0, 0), #
88 | '': (0, 0, 0), #
89 | '': (0, 0, 0) #
90 | }
91 |
92 | def __init__(self, env, results_path, tok, episode_len=20, feat_dict=None):
93 | super(Seq2SeqAgent, self).__init__(env, results_path)
94 | self.tok = tok
95 | self.episode_len = episode_len
96 | self.feature_size = self.env.feature_size
97 | self.feat_dict = feat_dict
98 |
99 | # Models
100 | if args.vlnbert == 'oscar':
101 | self.vln_bert = model_OSCAR.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
102 | self.critic = model_OSCAR.Critic().cuda()
103 | elif args.vlnbert == 'prevalent':
104 | if args.features == 'clip':
105 | self.vln_bert = model_PREVALENT.VLNBERT(feature_size=2048 + args.angle_feat_size).cuda()
106 | else:
107 | self.vln_bert = model_PREVALENT.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
108 | self.critic = model_PREVALENT.Critic().cuda()
109 | self.models = (self.vln_bert, self.critic)
110 |
111 | # Optimizers
112 | self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr)
113 | self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr)
114 | self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer)
115 |
116 | # Evaluations
117 | self.losses = []
118 | self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
119 | self.ndtw_criterion = utils.ndtw_initialize()
120 |
121 | if args.ADAPT:
122 | self.align_loss = nn.CrossEntropyLoss(ignore_index=-1000, size_average=True)
123 | self.consistency_loss = nn.MSELoss(size_average=True)
124 |
125 | txt_subprompt_feat_file = "data/R2R_txt_subprompt_feature.json"
126 | with open(txt_subprompt_feat_file) as f:
127 | self.txt_subprompt_feat = json.load(f)
128 | for k, _ in self.txt_subprompt_feat.items():
129 | self.txt_subprompt_feat[k] = np.array(self.txt_subprompt_feat[k])
130 |
131 | # Logs
132 | sys.stdout.flush()
133 | self.logs = defaultdict(list)
134 |
135 | def _sort_batch(self, obs):
136 | seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
137 | seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
138 | seq_lengths[seq_lengths == 0] = seq_tensor.shape[1]
139 |
140 | seq_tensor = torch.from_numpy(seq_tensor)
141 | seq_lengths = torch.from_numpy(seq_lengths)
142 |
143 | # Sort sequences by lengths
144 | seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending
145 | sorted_tensor = seq_tensor[perm_idx]
146 | mask = (sorted_tensor != padding_idx)
147 |
148 | token_type_ids = torch.zeros_like(mask)
149 |
150 | return Variable(sorted_tensor, requires_grad=False).long().cuda(), \
151 | mask.long().cuda(), token_type_ids.long().cuda(), \
152 | list(seq_lengths), list(perm_idx)
153 |
154 | def _feature_variable(self, obs):
155 | ''' Extract precomputed features into variable. '''
156 | features = np.empty((len(obs), args.views, self.feature_size + args.angle_feat_size), dtype=np.float32)
157 | for i, ob in enumerate(obs):
158 | features[i, :, :] = ob['feature'] # Image feat
159 | return Variable(torch.from_numpy(features), requires_grad=False).cuda()
160 |
161 | def _candidate_variable(self, obs):
162 | candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
163 | candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
164 | # Note: The candidate_feat at len(ob['candidate']) is the feature for the END
165 | # which is zero in my implementation
166 | for i, ob in enumerate(obs):
167 | for j, cc in enumerate(ob['candidate']):
168 | candidate_feat[i, j, :] = cc['feature']
169 |
170 | return torch.from_numpy(candidate_feat).cuda(), candidate_leng
171 |
172 | def get_input_feat(self, obs):
173 | input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32)
174 | for i, ob in enumerate(obs):
175 | input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation'])
176 | input_a_t = torch.from_numpy(input_a_t).cuda()
177 | # f_t = self._feature_variable(obs) # Pano image features from obs
178 | candidate_feat, candidate_leng = self._candidate_variable(obs)
179 |
180 | return input_a_t, candidate_feat, candidate_leng
181 |
182 | def _teacher_action(self, obs, ended):
183 | """
184 | Extract teacher actions into variable.
185 | :param obs: The observation.
186 | :param ended: Whether the action seq is ended
187 | :return:
188 | """
189 | a = np.zeros(len(obs), dtype=np.int64)
190 | for i, ob in enumerate(obs):
191 | if ended[i]: # Just ignore this index
192 | a[i] = args.ignoreid
193 | else:
194 | for k, candidate in enumerate(ob['candidate']):
195 | if candidate['viewpointId'] == ob['teacher']: # Next view point
196 | a[i] = k
197 | break
198 | else: # Stop here
199 | assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
200 | a[i] = len(ob['candidate'])
201 | return torch.from_numpy(a).cuda()
202 |
203 | def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
204 | """
205 | Interface between Panoramic view and Egocentric view
206 | It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
207 | """
208 | def take_action(i, idx, name):
209 | if type(name) is int: # Go to the next view
210 | self.env.env.sims[idx].makeAction(name, 0, 0)
211 | else: # Adjust
212 | self.env.env.sims[idx].makeAction(*self.env_actions[name])
213 |
214 | if perm_idx is None:
215 | perm_idx = range(len(perm_obs))
216 |
217 | for i, idx in enumerate(perm_idx):
218 | action = a_t[i]
219 | if action != -1: # -1 is the action
220 | select_candidate = perm_obs[i]['candidate'][action]
221 | src_point = perm_obs[i]['viewIndex']
222 | trg_point = select_candidate['pointId']
223 | src_level = (src_point ) // 12 # The point idx started from 0
224 | trg_level = (trg_point ) // 12
225 | while src_level < trg_level: # Tune up
226 | take_action(i, idx, 'up')
227 | src_level += 1
228 | while src_level > trg_level: # Tune down
229 | take_action(i, idx, 'down')
230 | src_level -= 1
231 | while self.env.env.sims[idx].getState().viewIndex != trg_point: # Turn right until the target
232 | take_action(i, idx, 'right')
233 | assert select_candidate['viewpointId'] == \
234 | self.env.env.sims[idx].getState().navigableLocations[select_candidate['idx']].viewpointId
235 | take_action(i, idx, select_candidate['idx'])
236 |
237 | state = self.env.env.sims[idx].getState()
238 | if traj is not None:
239 | traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
240 |
241 | def rollout(self, train_ml=None, train_rl=True, reset=True):
242 | """
243 | :param train_ml: The weight to train with maximum likelihood
244 | :param train_rl: whether use RL in training
245 | :param reset: Reset the environment
246 |
247 | :return:
248 | """
249 | if self.feedback == 'teacher' or self.feedback == 'argmax':
250 | train_rl = False
251 |
252 | if reset: # Reset env
253 | obs = np.array(self.env.reset())
254 | else:
255 | obs = np.array(self.env._get_obs())
256 |
257 | batch_size = len(obs)
258 |
259 | # Language input
260 | sentence, language_attention_mask, token_type_ids, \
261 | seq_lengths, perm_idx = self._sort_batch(obs)
262 | perm_obs = obs[perm_idx]
263 |
264 | if args.ADAPT:
265 | img_sub_prompt_set = torch.zeros((batch_size, args.prompt_set_size, 512)).cuda()
266 | txt_sub_prompt_set = torch.zeros((batch_size, args.prompt_set_size, 512)).cuda()
267 | prompt_mask_set = torch.zeros((batch_size, args.prompt_set_size))
268 |
269 | for bs_ind in range(batch_size):
270 | img_sub_prompt = perm_obs[bs_ind]['img_sub_prompt']
271 | txt_sub_prompt = perm_obs[bs_ind]['txt_sub_prompt']
272 | for sub_prompt_ind in range(len(img_sub_prompt)):
273 | scan_viewpoint_view = img_sub_prompt[sub_prompt_ind].split('_')
274 | scan_viewpoint = scan_viewpoint_view[0] + '_' + scan_viewpoint_view[1]
275 |
276 | img_sub_prompt_set[bs_ind, sub_prompt_ind, :] = torch.from_numpy(
277 | self.feat_dict[scan_viewpoint][int(scan_viewpoint_view[2])]).cuda().float()
278 | txt_sub_prompt_set[bs_ind, sub_prompt_ind, :] = torch.from_numpy(self.txt_subprompt_feat[
279 | txt_sub_prompt[
280 | sub_prompt_ind].split(
281 | '_')[-2]][int(txt_sub_prompt[sub_prompt_ind].split('_')[-1])]).cuda().float()
282 | prompt_num = perm_obs[bs_ind]['prompt_num']
283 | prompt_mask_set[bs_ind, :prompt_num] = 1
284 |
285 | ''' Language BERT '''
286 | language_inputs = {'mode': 'language',
287 | 'sentence': sentence,
288 | 'attention_mask': language_attention_mask,
289 | 'lang_mask': language_attention_mask,
290 | 'token_type_ids': token_type_ids}
291 | if args.vlnbert == 'oscar':
292 | language_features = self.vln_bert(**language_inputs)
293 | elif args.vlnbert == 'prevalent':
294 | h_t, language_features = self.vln_bert(**language_inputs)
295 |
296 | # Record starting point
297 | traj = [{
298 | 'instr_id': ob['instr_id'],
299 | 'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
300 | } for ob in perm_obs]
301 |
302 | # Init the reward shaping
303 | last_dist = np.zeros(batch_size, np.float32)
304 | last_ndtw = np.zeros(batch_size, np.float32)
305 | for i, ob in enumerate(perm_obs): # The init distance from the view point to the target
306 | last_dist[i] = ob['distance']
307 | path_act = [vp[0] for vp in traj[i]['path']]
308 | last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
309 |
310 | # Initialization the tracking state
311 | ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
312 |
313 | # Init the logs
314 | rewards = []
315 | hidden_states = []
316 | policy_log_probs = []
317 | masks = []
318 | entropys = []
319 | ml_loss = 0.
320 |
321 | if args.ADAPT:
322 | align_loss = 0
323 | consistency_loss = 0
324 |
325 | for t in range(self.episode_len):
326 |
327 | input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
328 |
329 | # the first [CLS] token, initialized by the language BERT, serves
330 | # as the agent's state passing through time steps
331 | if (t >= 1) or (args.vlnbert=='prevalent'):
332 | language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
333 |
334 | visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
335 | visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
336 |
337 | self.vln_bert.vln_bert.config.directions = max(candidate_leng)
338 |
339 | ''' Visual BERT '''
340 | visual_inputs = {'mode': 'visual',
341 | 'sentence': language_features,
342 | 'attention_mask': visual_attention_mask,
343 | 'lang_mask': language_attention_mask,
344 | 'vis_mask': visual_temp_mask,
345 | 'token_type_ids': token_type_ids,
346 | 'action_feats': input_a_t,
347 | # 'pano_feats': f_t,
348 | 'cand_feats': candidate_feat}
349 |
350 | if args.ADAPT:
351 | visual_inputs = {'mode': 'visual',
352 | 'sentence': language_features,
353 | 'attention_mask': visual_attention_mask,
354 | 'lang_mask': language_attention_mask,
355 | 'vis_mask': visual_temp_mask,
356 | 'token_type_ids': token_type_ids,
357 | 'action_feats': input_a_t,
358 | # 'pano_feats': f_t,
359 | 'cand_feats': candidate_feat,
360 | 'txt_sub_prompt_set': txt_sub_prompt_set,
361 | 'img_sub_prompt_set': img_sub_prompt_set,
362 | 'prompt_mask_set': prompt_mask_set
363 | }
364 |
365 | h_t, logit, attended_lan, attended_vis, attended_txt_pro, attended_img_pro, txt_pro, img_pro = self.vln_bert(
366 | **visual_inputs)
367 | # alignment loss
368 | if self.feedback == 'teacher':
369 | for loss_ind in range(args.prompt_set_size):
370 | sim_labels = np.arange(args.batchSize)
371 | sim_labels = torch.from_numpy(sim_labels).long().cuda()
372 | for bs_ind in range(batch_size):
373 | if loss_ind + 1 > perm_obs[bs_ind]['prompt_num']:
374 | sim_labels[bs_ind] = -1000
375 | sim_matrix = torch.mm(txt_pro[:, loss_ind, :], img_pro[:, loss_ind, :].transpose(0, 1))
376 | sim_matrix /= args.temperature
377 | loss_I = self.align_loss(sim_matrix, sim_labels)
378 | loss_V = self.align_loss(sim_matrix.transpose(0, 1), sim_labels)
379 | align_loss += (loss_I + loss_V) / 2
380 |
381 | # consistency loss
382 | consistency_loss += self.consistency_loss(attended_lan, attended_txt_pro)
383 | consistency_loss += self.consistency_loss(attended_vis, attended_img_pro)
384 | else:
385 | h_t, logit = self.vln_bert(**visual_inputs)
386 | hidden_states.append(h_t)
387 |
388 | # Mask outputs where agent can't move forward
389 | # Here the logit is [b, max_candidate]
390 | candidate_mask = utils.length2mask(candidate_leng)
391 | logit.masked_fill_(candidate_mask, -float('inf'))
392 |
393 | # Supervised training
394 | target = self._teacher_action(perm_obs, ended)
395 | ml_loss += self.criterion(logit, target)
396 |
397 | # Determine next model inputs
398 | if self.feedback == 'teacher':
399 | a_t = target # teacher forcing
400 | elif self.feedback == 'argmax':
401 | _, a_t = logit.max(1) # student forcing - argmax
402 | a_t = a_t.detach()
403 | log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
404 | policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
405 | elif self.feedback == 'sample':
406 | probs = F.softmax(logit, 1) # sampling an action from model
407 | c = torch.distributions.Categorical(probs)
408 | self.logs['entropy'].append(c.entropy().sum().item()) # For log
409 | entropys.append(c.entropy()) # For optimization
410 | a_t = c.sample().detach()
411 | policy_log_probs.append(c.log_prob(a_t))
412 | else:
413 | print(self.feedback)
414 | sys.exit('Invalid feedback option')
415 | # Prepare environment action
416 | # NOTE: Env action is in the perm_obs space
417 | cpu_a_t = a_t.cpu().numpy()
418 | for i, next_id in enumerate(cpu_a_t):
419 | if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is
420 | cpu_a_t[i] = -1 # Change the and ignore action to -1
421 |
422 | # Make action and get the new state
423 | self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
424 | obs = np.array(self.env._get_obs())
425 | perm_obs = obs[perm_idx] # Perm the obs for the resu
426 |
427 | if train_rl:
428 | # Calculate the mask and reward
429 | dist = np.zeros(batch_size, np.float32)
430 | ndtw_score = np.zeros(batch_size, np.float32)
431 | reward = np.zeros(batch_size, np.float32)
432 | mask = np.ones(batch_size, np.float32)
433 | for i, ob in enumerate(perm_obs):
434 | dist[i] = ob['distance']
435 | path_act = [vp[0] for vp in traj[i]['path']]
436 | ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
437 |
438 | if ended[i]:
439 | reward[i] = 0.0
440 | mask[i] = 0.0
441 | else:
442 | action_idx = cpu_a_t[i]
443 | # Target reward
444 | if action_idx == -1: # If the action now is end
445 | if dist[i] < 3.0: # Correct
446 | reward[i] = 2.0 + ndtw_score[i] * 2.0
447 | else: # Incorrect
448 | reward[i] = -2.0
449 | else: # The action is not end
450 | # Path fidelity rewards (distance & nDTW)
451 | reward[i] = - (dist[i] - last_dist[i])
452 | ndtw_reward = ndtw_score[i] - last_ndtw[i]
453 | if reward[i] > 0.0: # Quantification
454 | reward[i] = 1.0 + ndtw_reward
455 | elif reward[i] < 0.0:
456 | reward[i] = -1.0 + ndtw_reward
457 | else:
458 | raise NameError("The action doesn't change the move")
459 | # Miss the target penalty
460 | if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0):
461 | reward[i] -= (1.0 - last_dist[i]) * 2.0
462 | rewards.append(reward)
463 | masks.append(mask)
464 | last_dist[:] = dist
465 | last_ndtw[:] = ndtw_score
466 |
467 | # Update the finished actions
468 | # -1 means ended or ignored (already ended)
469 | ended[:] = np.logical_or(ended, (cpu_a_t == -1))
470 |
471 | # Early exit if all ended
472 | if ended.all():
473 | break
474 |
475 | if train_rl:
476 | # Last action in A2C
477 | input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
478 |
479 | language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
480 |
481 | visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
482 | visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
483 |
484 | self.vln_bert.vln_bert.config.directions = max(candidate_leng)
485 | ''' Visual BERT '''
486 | visual_inputs = {'mode': 'visual',
487 | 'sentence': language_features,
488 | 'attention_mask': visual_attention_mask,
489 | 'lang_mask': language_attention_mask,
490 | 'vis_mask': visual_temp_mask,
491 | 'token_type_ids': token_type_ids,
492 | 'action_feats': input_a_t,
493 | # 'pano_feats': f_t,
494 | 'cand_feats': candidate_feat}
495 |
496 | if args.ADAPT:
497 | visual_inputs = {'mode': 'visual',
498 | 'sentence': language_features,
499 | 'attention_mask': visual_attention_mask,
500 | 'lang_mask': language_attention_mask,
501 | 'vis_mask': visual_temp_mask,
502 | 'token_type_ids': token_type_ids,
503 | 'action_feats': input_a_t,
504 | # 'pano_feats': f_t,
505 | 'cand_feats': candidate_feat,
506 | 'txt_sub_prompt_set': txt_sub_prompt_set,
507 | 'img_sub_prompt_set': img_sub_prompt_set,
508 | 'prompt_mask_set': prompt_mask_set
509 | }
510 | last_h_, _, _, _, _, _, _, _ = self.vln_bert(**visual_inputs)
511 | else:
512 | last_h_, _ = self.vln_bert(**visual_inputs)
513 |
514 | rl_loss = 0.
515 |
516 | # NOW, A2C!!!
517 | # Calculate the final discounted reward
518 | last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety
519 | discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero
520 | for i in range(batch_size):
521 | if not ended[i]: # If the action is not ended, use the value function as the last reward
522 | discount_reward[i] = last_value__[i]
523 |
524 | length = len(rewards)
525 | total = 0
526 | for t in range(length-1, -1, -1):
527 | discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0
528 | mask_ = Variable(torch.from_numpy(masks[t]), requires_grad=False).cuda()
529 | clip_reward = discount_reward.copy()
530 | r_ = Variable(torch.from_numpy(clip_reward), requires_grad=False).cuda()
531 | v_ = self.critic(hidden_states[t])
532 | a_ = (r_ - v_).detach()
533 |
534 | rl_loss += (-policy_log_probs[t] * a_ * mask_).sum()
535 | rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss
536 | if self.feedback == 'sample':
537 | rl_loss += (- 0.01 * entropys[t] * mask_).sum()
538 | self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item())
539 |
540 | total = total + np.sum(masks[t])
541 | self.logs['total'].append(total)
542 |
543 | # Normalize the loss function
544 | if args.normalize_loss == 'total':
545 | rl_loss /= total
546 | elif args.normalize_loss == 'batch':
547 | rl_loss /= batch_size
548 | else:
549 | assert args.normalize_loss == 'none'
550 |
551 | self.loss += rl_loss
552 | self.logs['RL_loss'].append(rl_loss.item())
553 |
554 | if train_ml is not None:
555 | self.loss += ml_loss * train_ml / batch_size
556 | self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item())
557 |
558 | if args.ADAPT and self.feedback != 'argmax':
559 | self.loss += consistency_loss * args.consistency_loss_weight
560 | self.logs['consistency_loss'].append((consistency_loss * args.consistency_loss_weight).item())
561 |
562 | if self.feedback == 'teacher':
563 | self.loss += align_loss * args.align_loss_weight
564 | self.logs['alignment_loss'].append((align_loss * args.align_loss_weight).item())
565 |
566 |
567 | if type(self.loss) is int: # For safety, it will be activated if no losses are added
568 | self.losses.append(0.)
569 | else:
570 | self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
571 |
572 | return traj
573 |
574 | def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
575 | ''' Evaluate once on each instruction in the current environment '''
576 | self.feedback = feedback
577 | if use_dropout:
578 | self.vln_bert.train()
579 | self.critic.train()
580 | else:
581 | self.vln_bert.eval()
582 | self.critic.eval()
583 | super(Seq2SeqAgent, self).test(iters)
584 |
585 | def zero_grad(self):
586 | self.loss = 0.
587 | self.losses = []
588 | for model, optimizer in zip(self.models, self.optimizers):
589 | model.train()
590 | optimizer.zero_grad()
591 |
592 | def accumulate_gradient(self, feedback='teacher', **kwargs):
593 | if feedback == 'teacher':
594 | self.feedback = 'teacher'
595 | self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
596 | elif feedback == 'sample':
597 | self.feedback = 'teacher'
598 | self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
599 | self.feedback = 'sample'
600 | self.rollout(train_ml=None, train_rl=True, **kwargs)
601 | else:
602 | assert False
603 |
604 | def optim_step(self):
605 | self.loss.backward()
606 |
607 | torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
608 |
609 | self.vln_bert_optimizer.step()
610 | self.critic_optimizer.step()
611 |
612 | def train(self, n_iters, feedback='teacher', **kwargs):
613 | ''' Train for a given number of iterations '''
614 | self.feedback = feedback
615 |
616 | self.vln_bert.train()
617 | self.critic.train()
618 |
619 | self.losses = []
620 | for iter in range(1, n_iters + 1):
621 |
622 | self.vln_bert_optimizer.zero_grad()
623 | self.critic_optimizer.zero_grad()
624 |
625 | self.loss = 0
626 |
627 | if feedback == 'teacher':
628 | self.feedback = 'teacher'
629 | self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
630 | elif feedback == 'sample': # agents in IL and RL separately
631 | if args.ml_weight != 0:
632 | self.feedback = 'teacher'
633 | self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
634 | self.feedback = 'sample'
635 | self.rollout(train_ml=None, train_rl=True, **kwargs)
636 | else:
637 | assert False
638 |
639 | self.loss.backward()
640 |
641 | torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
642 |
643 | self.vln_bert_optimizer.step()
644 | self.critic_optimizer.step()
645 |
646 | if args.aug is None:
647 | print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50)
648 |
649 | def save(self, epoch, path):
650 | ''' Snapshot models '''
651 | the_dir, _ = os.path.split(path)
652 | os.makedirs(the_dir, exist_ok=True)
653 | states = {}
654 | def create_state(name, model, optimizer):
655 | states[name] = {
656 | 'epoch': epoch + 1,
657 | 'state_dict': model.state_dict(),
658 | 'optimizer': optimizer.state_dict(),
659 | }
660 | all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
661 | ("critic", self.critic, self.critic_optimizer)]
662 | for param in all_tuple:
663 | create_state(*param)
664 | torch.save(states, path)
665 |
666 | def load(self, path):
667 | ''' Loads parameters (but not training state) '''
668 | states = torch.load(path)
669 |
670 | def recover_state(name, model, optimizer):
671 | state = model.state_dict()
672 | model_keys = set(state.keys())
673 | load_keys = set(states[name]['state_dict'].keys())
674 | if model_keys != load_keys:
675 | print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
676 | state.update(states[name]['state_dict'])
677 | model.load_state_dict(state)
678 | if args.loadOptim:
679 | optimizer.load_state_dict(states[name]['optimizer'])
680 | all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
681 | ("critic", self.critic, self.critic_optimizer)]
682 | for param in all_tuple:
683 | recover_state(*param)
684 | return states['vln_bert']['epoch'] - 1
685 |
--------------------------------------------------------------------------------