├── .gitignore ├── LICENSE ├── README.md ├── assets ├── model.jpg ├── model_with_pretrain.log ├── model_without_pretrain.log └── overall_results.jpg ├── configs ├── ablation │ ├── cvdn.yaml │ ├── fgr2r.yaml │ ├── reverie.yaml │ ├── scanqa.yaml │ └── soon.yaml ├── held_out │ ├── held_out_cvdn.yaml │ ├── held_out_reverie.yaml │ └── held_out_soon.yaml └── multi.yaml ├── models ├── __init__.py ├── detr_transformer.py ├── graph_utils.py ├── image_embedding.py ├── modified_lm.py ├── nav_model.py ├── ops.py └── vln_bert.py ├── requirements.txt ├── scripts ├── ablation │ ├── from_scratch.sh │ └── single_task.sh ├── data_tools │ ├── extract_features_coco.py │ ├── extract_features_mp3d.py │ ├── extract_features_scanqa.py │ └── reformat_scanqa.py ├── evaluation │ ├── eval_cvdn.sh │ ├── eval_r2r.sh │ ├── eval_reverie.sh │ ├── eval_scanqa.sh │ └── eval_soon.sh ├── held_out │ ├── held_out_cvdn.sh │ ├── held_out_reverie.sh │ └── held_out_soon.sh ├── multi_w_pretrain.sh ├── multi_wo_pretrain.sh └── pretrain.sh ├── tasks ├── __init__.py ├── agents │ ├── __init__.py │ ├── base_agent.py │ ├── cvdn.py │ ├── eqa.py │ ├── llava.py │ ├── mp3d_agent.py │ ├── r2r.py │ ├── reverie.py │ ├── scanqa.py │ └── soon.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── coco_caption.py │ ├── cvdn.py │ ├── eqa.py │ ├── llava.py │ ├── mp3d_dataset.py │ ├── mp3d_envs.py │ ├── r2r.py │ ├── r2r_aug.py │ ├── reverie.py │ ├── reverie_aug.py │ ├── scanqa.py │ └── soon.py ├── feature_db.py └── loaders.py ├── tools ├── __init__.py ├── common_utils.py ├── distributed.py ├── evaluation │ ├── __init__.py │ ├── bleu │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── meteor │ │ ├── __init__.py │ │ ├── data │ │ │ └── paraphrase-en.gz │ │ ├── meteor-1.5.jar │ │ ├── meteor.py │ │ └── test_meteor.py │ ├── rouge │ │ ├── __init__.py │ │ └── rouge.py │ ├── stanford-corenlp-3.4.1.jar │ └── tokenizer.py ├── optims.py ├── parser.py └── trie.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **__pycache__** 2 | **output** 3 | data 4 | data/** 5 | **egg-info** 6 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 zd11024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/assets/model.jpg -------------------------------------------------------------------------------- /assets/overall_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/assets/overall_results.jpg -------------------------------------------------------------------------------- /configs/ablation/cvdn.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | Dataset: 20 | CVDN: 21 | DIR: "CVDN" 22 | SPLIT: { 23 | "train": "train.json", 24 | "val_seen": "val_seen.json", 25 | "val_unseen": "val_unseen.json", 26 | "test": "test_cleaned.json" 27 | } 28 | 29 | 30 | Multi: 31 | SOURCE: ['CVDN'] 32 | Ratio: [1] 33 | LOSS_COEF: { 34 | } 35 | 36 | 37 | Model: 38 | num_l_layers: 9 39 | num_pano_layers: 2 40 | num_x_layers: 4 41 | graph_sprels: True 42 | fusion: "dynamic" 43 | enc_full_graph: True 44 | expert_policy: "spl" 45 | 46 | Optim: 47 | val_max_action_len: { 48 | "R2R": 15, 49 | "REVERIE": 15, 50 | "CVDN": 30, # from VLN-SIG 51 | "SOON": 20, # from DUET 52 | "EQA": 15, 53 | } 54 | train_max_action_len: { 55 | "R2R": 15, 56 | "REVERIE": 15, 57 | "CVDN": 15, 58 | "SOON": 15, 59 | "EQA": 15, 60 | "R2R_AUG": 15, 61 | "REVERIE_AUG": 15 62 | } -------------------------------------------------------------------------------- /configs/ablation/fgr2r.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | R2R: 22 | DIR: "R2R" 23 | SPLIT: { 24 | "train": "FGR2R_train.json", 25 | "val_seen": "R2R_val_seen_enc.json", 26 | "val_unseen": "R2R_val_unseen_enc.json", 27 | "test": "R2R_test_enc.json" 28 | } 29 | 30 | 31 | Multi: 32 | SOURCE: ['R2R'] 33 | Ratio: [1] 34 | LOSS_COEF: { 35 | } 36 | 37 | 38 | Model: 39 | num_l_layers: 9 40 | num_pano_layers: 2 41 | num_x_layers: 4 42 | graph_sprels: True 43 | fusion: "dynamic" 44 | enc_full_graph: True 45 | expert_policy: "spl" 46 | 47 | Optim: 48 | val_max_action_len: { 49 | "R2R": 15, 50 | "REVERIE": 15, 51 | "CVDN": 30, # from VLN-SIG 52 | "SOON": 20, # from DUET 53 | "EQA": 15, 54 | } 55 | train_max_action_len: { 56 | "R2R": 15, 57 | "REVERIE": 15, 58 | "CVDN": 15, 59 | "SOON": 15, 60 | "EQA": 15, 61 | "R2R_AUG": 15, 62 | "REVERIE_AUG": 15 63 | } -------------------------------------------------------------------------------- /configs/ablation/reverie.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | REVERIE: 22 | DIR: "REVERIE" 23 | bbox_file: "BBoxes.json" 24 | SPLIT: { 25 | "train": "REVERIE_train_enc.json", 26 | "val_seen": "REVERIE_val_seen_enc.json", 27 | "val_unseen": "REVERIE_val_unseen_enc.json", 28 | "test": "REVERIE_test_enc.json" 29 | } 30 | 31 | 32 | Multi: 33 | SOURCE: ['REVERIE'] 34 | Ratio: [1] 35 | LOSS_COEF: { 36 | } 37 | 38 | 39 | Model: 40 | num_l_layers: 9 41 | num_pano_layers: 2 42 | num_x_layers: 4 43 | graph_sprels: True 44 | fusion: "dynamic" 45 | enc_full_graph: True 46 | expert_policy: "spl" 47 | 48 | Optim: 49 | val_max_action_len: { 50 | "R2R": 15, 51 | "REVERIE": 15, 52 | "CVDN": 30, # from VLN-SIG 53 | "SOON": 20, # from DUET 54 | "EQA": 15, 55 | } 56 | train_max_action_len: { 57 | "R2R": 15, 58 | "REVERIE": 15, 59 | "CVDN": 15, 60 | "SOON": 15, 61 | "EQA": 15, 62 | "R2R_AUG": 15, 63 | "REVERIE_AUG": 15 64 | } -------------------------------------------------------------------------------- /configs/ablation/scanqa.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | ScanQA: 22 | DIR: "ScanQA" 23 | SPLIT: { 24 | "train": "ScanQA_v1.0_train_reformat.json", 25 | "val_unseen": "ScanQA_v1.0_val_reformat.json", 26 | "test": "ScanQA_v1.0_test_wo_obj_reformat.json" 27 | } 28 | 29 | 30 | Multi: 31 | SOURCE: ['ScanQA'] 32 | Ratio: [1] 33 | LOSS_COEF: { 34 | } 35 | 36 | 37 | Model: 38 | num_l_layers: 9 39 | num_pano_layers: 2 40 | num_x_layers: 4 41 | graph_sprels: True 42 | fusion: "dynamic" 43 | enc_full_graph: True 44 | expert_policy: "spl" 45 | 46 | Optim: 47 | val_max_action_len: { 48 | "R2R": 15, 49 | "REVERIE": 15, 50 | "CVDN": 30, # from VLN-SIG 51 | "SOON": 20, # from DUET 52 | "EQA": 15, 53 | } 54 | train_max_action_len: { 55 | "R2R": 15, 56 | "REVERIE": 15, 57 | "CVDN": 15, 58 | "SOON": 15, 59 | "EQA": 15, 60 | "R2R_AUG": 15, 61 | "REVERIE_AUG": 15 62 | } -------------------------------------------------------------------------------- /configs/ablation/soon.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | SOON: 22 | DIR: "SOON" # from https://github.com/cshizhe/HM3DAutoVLN 23 | SPLIT: { 24 | "train": "train_enc_pseudo_obj_ade30k_label.jsonl", 25 | "val_seen": "val_unseen_instrs_enc_pseudo_obj_ade30k_label.jsonl", 26 | "val_unseen": "val_unseen_house_enc_pseudo_obj_ade30k_label.jsonl", 27 | "test": "test_v2_enc.jsonl" 28 | } 29 | 30 | Multi: 31 | SOURCE: ['SOON'] 32 | Ratio: [1] 33 | LOSS_COEF: { 34 | } 35 | 36 | 37 | Model: 38 | num_l_layers: 9 39 | num_pano_layers: 2 40 | num_x_layers: 4 41 | graph_sprels: True 42 | fusion: "dynamic" 43 | enc_full_graph: True 44 | expert_policy: "spl" 45 | 46 | Optim: 47 | val_max_action_len: { 48 | "R2R": 15, 49 | "REVERIE": 15, 50 | "CVDN": 30, # from VLN-SIG 51 | "SOON": 20, # from DUET 52 | "EQA": 15, 53 | } 54 | train_max_action_len: { 55 | "R2R": 15, 56 | "REVERIE": 15, 57 | "CVDN": 15, 58 | "SOON": 15, 59 | "EQA": 15, 60 | "R2R_AUG": 15, 61 | "REVERIE_AUG": 15 62 | } -------------------------------------------------------------------------------- /configs/held_out/held_out_cvdn.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | R2R: 22 | DIR: "R2R" 23 | SPLIT: { 24 | "train": "FGR2R_train.json", 25 | "val_seen": "R2R_val_seen_enc.json", 26 | "val_unseen": "R2R_val_unseen_enc.json", 27 | "test": "R2R_test_enc.json" 28 | } 29 | REVERIE: 30 | DIR: "REVERIE" 31 | bbox_file: "BBoxes.json" 32 | SPLIT: { 33 | "train": "REVERIE_train_enc.json", 34 | "val_seen": "REVERIE_val_seen_enc.json", 35 | "val_unseen": "REVERIE_val_unseen_enc.json", 36 | "test": "REVERIE_test_enc.json" 37 | } 38 | CVDN: 39 | DIR: "CVDN" 40 | SPLIT: { 41 | "train": "train.json", 42 | "val_seen": "val_seen.json", 43 | "val_unseen": "val_unseen.json", 44 | "test": "test_cleaned.json" 45 | } 46 | SOON: 47 | DIR: "SOON" # from https://github.com/cshizhe/HM3DAutoVLN 48 | SPLIT: { 49 | "train": "train_enc_pseudo_obj_ade30k_label.jsonl", 50 | "val_seen": "val_unseen_instrs_enc_pseudo_obj_ade30k_label.jsonl", 51 | "val_unseen": "val_unseen_house_enc_pseudo_obj_ade30k_label.jsonl", 52 | "test": "test_v2_enc.jsonl" 53 | } 54 | ScanQA: 55 | DIR: "ScanQA" 56 | SPLIT: { 57 | "train": "ScanQA_v1.0_train_reformat.json", 58 | "val_unseen": "ScanQA_v1.0_val_reformat.json", 59 | "test": "ScanQA_v1.0_test_wo_obj_reformat.json" 60 | } 61 | EQA: 62 | DIR: "EQA_MP3D" 63 | SPLIT: { 64 | "val_unseen": "eqa_val_enc.json" 65 | } 66 | ANSWER_VOCAB: "eqa_answer_vocab.json" 67 | 68 | R2R_AUG: 69 | DIR: "R2R" 70 | SPLIT: { 71 | "train": "R2R_prevalent_aug_train_enc.jsonl" 72 | } 73 | REVERIE_AUG: 74 | DIR: "REVERIE" 75 | bbox_file: "BBoxes.json" 76 | SPLIT: { 77 | "train": "REVERIE_speaker_aug_enc.jsonl" 78 | } 79 | LLaVA: 80 | DIR: "LLaVA" 81 | SPLIT: { 82 | "train": "detail_23k.json" 83 | } 84 | 85 | 86 | Multi: 87 | SOURCE: ['R2R','REVERIE','SOON','ScanQA'] 88 | Ratio: [2, 1, 1, 1] 89 | LOSS_COEF: { 90 | "R2R": 2, 91 | "REVERIE": 2 92 | } 93 | 94 | 95 | Model: 96 | num_l_layers: 9 97 | num_pano_layers: 2 98 | num_x_layers: 4 99 | graph_sprels: True 100 | fusion: "dynamic" 101 | enc_full_graph: True 102 | expert_policy: "spl" 103 | 104 | Optim: 105 | val_max_action_len: { 106 | "R2R": 15, 107 | "REVERIE": 15, 108 | "CVDN": 30, # from VLN-SIG 109 | "SOON": 20, # from DUET 110 | "EQA": 15, 111 | } 112 | train_max_action_len: { 113 | "R2R": 15, 114 | "REVERIE": 15, 115 | "CVDN": 15, 116 | "SOON": 15, 117 | "EQA": 15, 118 | "R2R_AUG": 15, 119 | "REVERIE_AUG": 15 120 | } -------------------------------------------------------------------------------- /configs/held_out/held_out_reverie.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | R2R: 22 | DIR: "R2R" 23 | SPLIT: { 24 | "train": "FGR2R_train.json", 25 | "val_seen": "R2R_val_seen_enc.json", 26 | "val_unseen": "R2R_val_unseen_enc.json", 27 | "test": "R2R_test_enc.json" 28 | } 29 | REVERIE: 30 | DIR: "REVERIE" 31 | bbox_file: "BBoxes.json" 32 | SPLIT: { 33 | "train": "REVERIE_train_enc.json", 34 | "val_seen": "REVERIE_val_seen_enc.json", 35 | "val_unseen": "REVERIE_val_unseen_enc.json", 36 | "test": "REVERIE_test_enc.json" 37 | } 38 | CVDN: 39 | DIR: "CVDN" 40 | SPLIT: { 41 | "train": "train.json", 42 | "val_seen": "val_seen.json", 43 | "val_unseen": "val_unseen.json", 44 | "test": "test_cleaned.json" 45 | } 46 | SOON: 47 | DIR: "SOON" # from https://github.com/cshizhe/HM3DAutoVLN 48 | SPLIT: { 49 | "train": "train_enc_pseudo_obj_ade30k_label.jsonl", 50 | "val_seen": "val_unseen_instrs_enc_pseudo_obj_ade30k_label.jsonl", 51 | "val_unseen": "val_unseen_house_enc_pseudo_obj_ade30k_label.jsonl", 52 | "test": "test_v2_enc.jsonl" 53 | } 54 | ScanQA: 55 | DIR: "ScanQA" 56 | SPLIT: { 57 | "train": "ScanQA_v1.0_train_reformat.json", 58 | "val_unseen": "ScanQA_v1.0_val_reformat.json", 59 | "test": "ScanQA_v1.0_test_wo_obj_reformat.json" 60 | } 61 | EQA: 62 | DIR: "EQA_MP3D" 63 | SPLIT: { 64 | "val_unseen": "eqa_val_enc.json" 65 | } 66 | ANSWER_VOCAB: "eqa_answer_vocab.json" 67 | 68 | R2R_AUG: 69 | DIR: "R2R" 70 | SPLIT: { 71 | "train": "R2R_prevalent_aug_train_enc.jsonl" 72 | } 73 | REVERIE_AUG: 74 | DIR: "REVERIE" 75 | bbox_file: "BBoxes.json" 76 | SPLIT: { 77 | "train": "REVERIE_speaker_aug_enc.jsonl" 78 | } 79 | LLaVA: 80 | DIR: "LLaVA" 81 | SPLIT: { 82 | "train": "detail_23k.json" 83 | } 84 | 85 | 86 | Multi: 87 | SOURCE: ['R2R','SOON','CVDN','ScanQA'] 88 | Ratio: [2, 1, 1, 1] 89 | LOSS_COEF: { 90 | "R2R": 2 91 | } 92 | 93 | 94 | Model: 95 | num_l_layers: 9 96 | num_pano_layers: 2 97 | num_x_layers: 4 98 | graph_sprels: True 99 | fusion: "dynamic" 100 | enc_full_graph: True 101 | expert_policy: "spl" 102 | 103 | Optim: 104 | val_max_action_len: { 105 | "R2R": 15, 106 | "REVERIE": 15, 107 | "CVDN": 30, # from VLN-SIG 108 | "SOON": 20, # from DUET 109 | "EQA": 15, 110 | } 111 | train_max_action_len: { 112 | "R2R": 15, 113 | "REVERIE": 15, 114 | "CVDN": 15, 115 | "SOON": 15, 116 | "EQA": 15, 117 | "R2R_AUG": 15, 118 | "REVERIE_AUG": 15 119 | } -------------------------------------------------------------------------------- /configs/held_out/held_out_soon.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | R2R: 22 | DIR: "R2R" 23 | SPLIT: { 24 | "train": "FGR2R_train.json", 25 | "val_seen": "R2R_val_seen_enc.json", 26 | "val_unseen": "R2R_val_unseen_enc.json", 27 | "test": "R2R_test_enc.json" 28 | } 29 | REVERIE: 30 | DIR: "REVERIE" 31 | bbox_file: "BBoxes.json" 32 | SPLIT: { 33 | "train": "REVERIE_train_enc.json", 34 | "val_seen": "REVERIE_val_seen_enc.json", 35 | "val_unseen": "REVERIE_val_unseen_enc.json", 36 | "test": "REVERIE_test_enc.json" 37 | } 38 | CVDN: 39 | DIR: "CVDN" 40 | SPLIT: { 41 | "train": "train.json", 42 | "val_seen": "val_seen.json", 43 | "val_unseen": "val_unseen.json", 44 | "test": "test_cleaned.json" 45 | } 46 | SOON: 47 | DIR: "SOON" # from https://github.com/cshizhe/HM3DAutoVLN 48 | SPLIT: { 49 | "train": "train_enc_pseudo_obj_ade30k_label.jsonl", 50 | "val_seen": "val_unseen_instrs_enc_pseudo_obj_ade30k_label.jsonl", 51 | "val_unseen": "val_unseen_house_enc_pseudo_obj_ade30k_label.jsonl", 52 | "test": "test_v2_enc.jsonl" 53 | } 54 | ScanQA: 55 | DIR: "ScanQA" 56 | SPLIT: { 57 | "train": "ScanQA_v1.0_train_reformat.json", 58 | "val_unseen": "ScanQA_v1.0_val_reformat.json", 59 | "test": "ScanQA_v1.0_test_wo_obj_reformat.json" 60 | } 61 | EQA: 62 | DIR: "EQA_MP3D" 63 | SPLIT: { 64 | "val_unseen": "eqa_val_enc.json" 65 | } 66 | ANSWER_VOCAB: "eqa_answer_vocab.json" 67 | 68 | R2R_AUG: 69 | DIR: "R2R" 70 | SPLIT: { 71 | "train": "R2R_prevalent_aug_train_enc.jsonl" 72 | } 73 | REVERIE_AUG: 74 | DIR: "REVERIE" 75 | bbox_file: "BBoxes.json" 76 | SPLIT: { 77 | "train": "REVERIE_speaker_aug_enc.jsonl" 78 | } 79 | LLaVA: 80 | DIR: "LLaVA" 81 | SPLIT: { 82 | "train": "detail_23k.json" 83 | } 84 | 85 | 86 | Multi: 87 | SOURCE: ['R2R','REVERIE','CVDN','ScanQA'] 88 | Ratio: [2, 1, 1, 1] 89 | LOSS_COEF: { 90 | "R2R": 2, 91 | "REVERIE": 2 92 | } 93 | 94 | 95 | Model: 96 | num_l_layers: 9 97 | num_pano_layers: 2 98 | num_x_layers: 4 99 | graph_sprels: True 100 | fusion: "dynamic" 101 | enc_full_graph: True 102 | expert_policy: "spl" 103 | 104 | Optim: 105 | val_max_action_len: { 106 | "R2R": 15, 107 | "REVERIE": 15, 108 | "CVDN": 30, # from VLN-SIG 109 | "SOON": 20, # from DUET 110 | "EQA": 15, 111 | } 112 | train_max_action_len: { 113 | "R2R": 15, 114 | "REVERIE": 15, 115 | "CVDN": 15, 116 | "SOON": 15, 117 | "EQA": 15, 118 | "R2R_AUG": 15, 119 | "REVERIE_AUG": 15 120 | } -------------------------------------------------------------------------------- /configs/multi.yaml: -------------------------------------------------------------------------------- 1 | Feature: 2 | # param 3 | object_feature_type: "" # "ade20k" 4 | angle_feat_size: 4 5 | max_objects: 70 6 | # feature 7 | image_feat_size: 1024 8 | feature_database: 9 | "mp3d": "eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5" 10 | "scan_qa": "eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5" 11 | "coco": "eva_features/coco_EVA02-CLIP-L-14-336.hdf5" 12 | 13 | # object 14 | obj_feat_size: 768 15 | object_database: 16 | "reverie": "obj_features/reverie_obj_feat" 17 | "soon": "obj_features/soon_obj_feat" 18 | 19 | # task 20 | Dataset: 21 | R2R: 22 | DIR: "R2R" 23 | SPLIT: { 24 | "train": "FGR2R_train.json", 25 | "val_seen": "R2R_val_seen_enc.json", 26 | "val_unseen": "R2R_val_unseen_enc.json", 27 | "test": "R2R_test_enc.json" 28 | } 29 | REVERIE: 30 | DIR: "REVERIE" 31 | bbox_file: "BBoxes.json" 32 | SPLIT: { 33 | "train": "REVERIE_train_enc.json", 34 | "val_seen": "REVERIE_val_seen_enc.json", 35 | "val_unseen": "REVERIE_val_unseen_enc.json", 36 | "test": "REVERIE_test_enc.json" 37 | } 38 | CVDN: 39 | DIR: "CVDN" 40 | SPLIT: { 41 | "train": "train.json", 42 | "val_seen": "val_seen.json", 43 | "val_unseen": "val_unseen.json", 44 | "test": "test_cleaned.json" 45 | } 46 | SOON: 47 | DIR: "SOON" # from https://github.com/cshizhe/HM3DAutoVLN 48 | SPLIT: { 49 | "train": "train_enc_pseudo_obj_ade30k_label.jsonl", 50 | "val_seen": "val_unseen_instrs_enc_pseudo_obj_ade30k_label.jsonl", 51 | "val_unseen": "val_unseen_house_enc_pseudo_obj_ade30k_label.jsonl", 52 | "test": "test_v2_enc.jsonl" 53 | } 54 | ScanQA: 55 | DIR: "ScanQA" 56 | SPLIT: { 57 | "train": "ScanQA_v1.0_train_reformat.json", 58 | "val_unseen": "ScanQA_v1.0_val_reformat.json", 59 | "test_wo_obj": "ScanQA_v1.0_test_wo_obj_reformat.json", 60 | "test_w_obj": "ScanQA_v1.0_test_w_obj_reformat.json" 61 | } 62 | EQA: 63 | DIR: "EQA_MP3D" 64 | SPLIT: { 65 | "val_unseen": "eqa_val_enc.json" 66 | } 67 | ANSWER_VOCAB: "eqa_answer_vocab.json" 68 | 69 | R2R_AUG: 70 | DIR: "R2R" 71 | SPLIT: { 72 | "train": "R2R_prevalent_aug_train_enc.jsonl" 73 | } 74 | REVERIE_AUG: 75 | DIR: "REVERIE" 76 | bbox_file: "BBoxes.json" 77 | SPLIT: { 78 | "train": "REVERIE_speaker_aug_enc.jsonl" 79 | } 80 | LLaVA: 81 | DIR: "LLaVA" 82 | SPLIT: { 83 | "train": "detail_23k.json" 84 | } 85 | 86 | # training 87 | Pretrain: 88 | SOURCE: ['R2R_AUG', 'REVERIE_AUG', 'R2R', 'REVERIE', 'SOON', 'CVDN', 'ScanQA'] 89 | Ratio: [20, 2, 1, 1, 1, 1, 1] 90 | LOSS_COEF: { 91 | "R2R_AUG": 1, 92 | "REVERIE_AUG": 1 93 | } 94 | 95 | 96 | Multi: 97 | SOURCE: ['R2R', 'REVERIE', 'CVDN','SOON', 'ScanQA', 'LLaVA'] 98 | Ratio: [20, 5, 1, 5, 5, 5] 99 | LOSS_COEF: { 100 | } 101 | 102 | 103 | Model: 104 | num_l_layers: 9 105 | num_pano_layers: 2 106 | num_x_layers: 4 107 | graph_sprels: True 108 | fusion: "dynamic" 109 | enc_full_graph: True 110 | expert_policy: "spl" 111 | 112 | Optim: 113 | val_max_action_len: { 114 | "R2R": 15, 115 | "REVERIE": 15, 116 | "CVDN": 30, # from VLN-SIG 117 | "SOON": 20, # from DUET 118 | "EQA": 15, 119 | } 120 | train_max_action_len: { 121 | "R2R": 15, 122 | "REVERIE": 15, 123 | "CVDN": 15, 124 | "SOON": 15, 125 | "EQA": 15, 126 | "R2R_AUG": 15, 127 | "REVERIE_AUG": 15 128 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/models/__init__.py -------------------------------------------------------------------------------- /models/graph_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | MAX_DIST = 30 6 | MAX_STEP = 10 7 | 8 | 9 | def calc_position_distance(a, b): 10 | # a, b: (x, y, z) 11 | dx = b[0] - a[0] 12 | dy = b[1] - a[1] 13 | dz = b[2] - a[2] 14 | dist = np.sqrt(dx ** 2 + dy ** 2 + dz ** 2) 15 | return dist 16 | 17 | 18 | def calculate_vp_rel_pos_fts(a, b, base_heading=0, base_elevation=0): 19 | # a, b: (x, y, z) 20 | dx = b[0] - a[0] 21 | dy = b[1] - a[1] 22 | dz = b[2] - a[2] 23 | xy_dist = max(np.sqrt(dx ** 2 + dy ** 2), 1e-8) 24 | xyz_dist = max(np.sqrt(dx ** 2 + dy ** 2 + dz ** 2), 1e-8) 25 | 26 | # the simulator's api is weired (x-y axis is transposed) 27 | heading = np.arcsin(dx / xy_dist) # [-pi/2, pi/2] 28 | if b[1] < a[1]: 29 | heading = np.pi - heading 30 | heading -= base_heading 31 | 32 | elevation = np.arcsin(dz / xyz_dist) # [-pi/2, pi/2] 33 | elevation -= base_elevation 34 | 35 | return heading, elevation, xyz_dist 36 | 37 | 38 | def get_angle_fts(headings, elevations, angle_feat_size): 39 | ang_fts = [np.sin(headings), np.cos(headings), np.sin(elevations), np.cos(elevations)] 40 | ang_fts = np.vstack(ang_fts).transpose().astype(np.float32) 41 | num_repeats = angle_feat_size // 4 42 | if num_repeats > 1: 43 | ang_fts = np.concatenate([ang_fts] * num_repeats, 1) 44 | return ang_fts 45 | 46 | 47 | class FloydGraph(object): 48 | def __init__(self): 49 | self._dis = defaultdict(lambda: defaultdict(lambda: 95959595)) 50 | self._point = defaultdict(lambda: defaultdict(lambda: "")) 51 | self._visited = set() 52 | 53 | def distance(self, x, y): 54 | if x == y: 55 | return 0 56 | else: 57 | return self._dis[x][y] 58 | 59 | def add_edge(self, x, y, dis): 60 | if dis < self._dis[x][y]: 61 | self._dis[x][y] = dis 62 | self._dis[y][x] = dis 63 | self._point[x][y] = "" 64 | self._point[y][x] = "" 65 | 66 | def update(self, k): 67 | for x in self._dis: 68 | for y in self._dis: 69 | if x != y: 70 | if self._dis[x][k] + self._dis[k][y] < self._dis[x][y]: 71 | self._dis[x][y] = self._dis[x][k] + self._dis[k][y] 72 | self._dis[y][x] = self._dis[x][y] 73 | self._point[x][y] = k 74 | self._point[y][x] = k 75 | self._visited.add(k) 76 | 77 | def visited(self, k): 78 | return (k in self._visited) 79 | 80 | def path(self, x, y): 81 | """ 82 | :param x: start 83 | :param y: end 84 | :return: the path from x to y [v1, v2, ..., v_n, y] 85 | """ 86 | if x == y: 87 | return [] 88 | if self._point[x][y] == "": # Direct edge 89 | return [y] 90 | else: 91 | k = self._point[x][y] 92 | # print(x, y, k) 93 | # for x1 in (x, k, y): 94 | # for x2 in (x, k, y): 95 | # print(x1, x2, "%.4f" % self._dis[x1][x2]) 96 | return self.path(x, k) + self.path(k, y) 97 | 98 | 99 | class GraphMap(object): 100 | def __init__(self, start_vp): 101 | self.start_vp = start_vp # start viewpoint 102 | 103 | self.node_positions = {} # viewpoint to position (x, y, z) 104 | self.graph = FloydGraph() # shortest path graph 105 | self.node_embeds = {} # {viewpoint: feature (sum feature, count)} 106 | self.node_stop_scores = {} # {viewpoint: prob} 107 | self.node_nav_scores = {} # {viewpoint: {t: prob}} 108 | self.node_step_ids = {} 109 | self.pooling_mode = 'mean' 110 | 111 | def update_graph(self, ob): 112 | self.node_positions[ob['viewpoint']] = ob['position'] 113 | for cc in ob['candidate']: 114 | self.node_positions[cc['viewpointId']] = cc['position'] 115 | dist = calc_position_distance(ob['position'], cc['position']) 116 | self.graph.add_edge(ob['viewpoint'], cc['viewpointId'], dist) 117 | self.graph.update(ob['viewpoint']) 118 | 119 | def update_node_embed(self, vp, embed, rewrite=False): 120 | if rewrite: 121 | self.node_embeds[vp] = [embed, 1] 122 | else: 123 | if vp in self.node_embeds: 124 | if self.pooling_mode == "max": 125 | pooling_features, _ = torch.max(torch.stack([self.node_embeds[vp][0], embed.clone()]), dim=0) 126 | self.node_embeds[vp][0] = pooling_features 127 | elif self.pooling_mode == "mean": 128 | self.node_embeds[vp][0] += embed.clone() 129 | else: 130 | raise NotImplementedError('`pooling_mode` Only support ["mean", "max"]') 131 | self.node_embeds[vp][1] += 1 132 | else: 133 | self.node_embeds[vp] = [embed, 1] 134 | 135 | 136 | def get_node_embed(self, vp): 137 | if self.pooling_mode == "max": 138 | return self.node_embeds[vp][0] 139 | elif self.pooling_mode == "mean": 140 | return self.node_embeds[vp][0] / self.node_embeds[vp][1] 141 | else: 142 | raise NotImplementedError('`pooling_mode` Only support ["mean", "max"]') 143 | 144 | def get_pos_fts(self, cur_vp, gmap_vpids, cur_heading, cur_elevation, angle_feat_size=4): 145 | # dim=7 (sin(heading), cos(heading), sin(elevation), cos(elevation), 146 | # line_dist, shortest_dist, shortest_step) 147 | rel_angles, rel_dists = [], [] 148 | for vp in gmap_vpids: 149 | if vp is None: 150 | rel_angles.append([0, 0]) 151 | rel_dists.append([0, 0, 0]) 152 | else: 153 | rel_heading, rel_elevation, rel_dist = calculate_vp_rel_pos_fts( 154 | self.node_positions[cur_vp], self.node_positions[vp], 155 | base_heading=cur_heading, base_elevation=cur_elevation, 156 | ) 157 | rel_angles.append([rel_heading, rel_elevation]) 158 | rel_dists.append( 159 | [rel_dist / MAX_DIST, self.graph.distance(cur_vp, vp) / MAX_DIST, \ 160 | len(self.graph.path(cur_vp, vp)) / MAX_STEP] 161 | ) 162 | rel_angles = np.array(rel_angles).astype(np.float32) 163 | rel_dists = np.array(rel_dists).astype(np.float32) 164 | rel_ang_fts = get_angle_fts(rel_angles[:, 0], rel_angles[:, 1], angle_feat_size) 165 | return np.concatenate([rel_ang_fts, rel_dists], 1) 166 | 167 | def save_to_json(self): 168 | nodes = {} 169 | for vp, pos in self.node_positions.items(): 170 | nodes[vp] = { 171 | 'location': pos, # (x, y, z) 172 | 'visited': self.graph.visited(vp), 173 | } 174 | if nodes[vp]['visited']: 175 | nodes[vp]['stop_prob'] = self.node_stop_scores[vp]['stop'] 176 | nodes[vp]['og_objid'] = self.node_stop_scores[vp]['og'] 177 | else: 178 | nodes[vp]['nav_prob'] = self.node_nav_scores[vp] 179 | 180 | edges = [] 181 | for k, v in self.graph._dis.items(): 182 | for kk in v.keys(): 183 | edges.append((k, kk)) 184 | 185 | return {'nodes': nodes, 'edges': edges} 186 | 187 | 188 | -------------------------------------------------------------------------------- /models/image_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .ops import ( 4 | create_transformer_encoder, 5 | gen_seq_masks, 6 | pad_tensors_wgrad 7 | ) 8 | 9 | 10 | class ImageEmbeddings(nn.Module): 11 | def __init__(self, config, use_obj: bool=False, fuse_obj: bool=False): 12 | super().__init__() 13 | 14 | self.img_linear = nn.Linear(config.image_feat_size, config.hidden_size) 15 | self.img_layer_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) 16 | self.loc_linear = nn.Linear(config.angle_feat_size + 3, config.hidden_size) 17 | self.loc_layer_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) 18 | 19 | # if config.obj_feat_size > 0 and config.obj_feat_size != config.image_feat_size: 20 | self.fuse_obj = fuse_obj 21 | if use_obj: 22 | if self.fuse_obj: 23 | self.obj_linear = nn.Sequential( 24 | nn.Linear(config.obj_feat_size, config.hidden_size), 25 | torch.nn.LayerNorm(config.hidden_size, eps=1e-12) 26 | ) 27 | self.obj_projector = nn.Sequential( 28 | nn.Linear(config.obj_feat_size, config.output_size), 29 | torch.nn.LayerNorm(config.output_size, eps=1e-12) 30 | ) 31 | else: 32 | self.obj_linear = self.obj_layer_norm = None 33 | self.fuse_layer = None 34 | 35 | # 0: non-navigable, 1: navigable, 2: object 36 | self.nav_type_embedding = nn.Embedding(3, config.hidden_size) 37 | 38 | # tf naming convention for layer norm 39 | self.layer_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12) 40 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 41 | 42 | if config.num_pano_layers > 0: 43 | self.pano_encoder = create_transformer_encoder( 44 | config, config.num_pano_layers, norm=True 45 | ) 46 | else: 47 | self.pano_encoder = None 48 | 49 | self.mapper = nn.Linear(config.hidden_size, config.output_size) 50 | 51 | def forward_panorama_per_step(self, 52 | view_img_fts, 53 | view_lens, 54 | loc_fts=None, 55 | nav_types=None, 56 | obj_img_fts=None, 57 | obj_lens=None, 58 | obj_loc_fts=None, 59 | ): 60 | ret = {} 61 | batch_size = view_img_fts.shape[0] 62 | pano_embeds = self.img_layer_norm( 63 | self.img_linear(view_img_fts) 64 | ) 65 | if loc_fts is None: 66 | loc_fts = torch.zeros(pano_embeds.shape[:2]+(7,), dtype=torch.float).to(pano_embeds.device) 67 | pano_embeds += self.loc_layer_norm(self.loc_linear(loc_fts)) 68 | 69 | if nav_types is None: 70 | nav_types = torch.ones(pano_embeds.shape[:2], dtype=torch.int).to(pano_embeds.device) 71 | pano_embeds += self.nav_type_embedding(nav_types) 72 | 73 | pano_embeds = self.layer_norm(pano_embeds) 74 | pano_embeds = self.dropout(pano_embeds) 75 | pano_masks = gen_seq_masks(view_lens) 76 | if self.pano_encoder is not None: 77 | 78 | if self.fuse_obj: 79 | obj_nav_types = torch.full(obj_img_fts.shape[:2], 2, dtype=torch.int).to(obj_img_fts.device) 80 | obj_embeds = self.obj_linear(obj_img_fts) + self.loc_layer_norm(self.loc_linear(obj_loc_fts)) + self.nav_type_embedding(obj_nav_types) 81 | fuse_embeds = [] 82 | for bn in range(batch_size): 83 | fuse_embeds.append( 84 | torch.cat([ 85 | pano_embeds[bn, :view_lens[bn]], obj_embeds[bn, :obj_lens[bn]] 86 | ], dim=0) 87 | ) 88 | fuse_embeds = pad_tensors_wgrad(fuse_embeds) 89 | fuse_masks = gen_seq_masks(view_lens+obj_lens) 90 | fuse_embeds = self.pano_encoder( 91 | fuse_embeds, src_key_padding_mask=fuse_masks.logical_not() 92 | ) 93 | pano_embeds = [fuse_embeds[bn, :view_lens[bn]] for bn in range(batch_size)] 94 | pano_embeds = pad_tensors_wgrad(pano_embeds) 95 | 96 | else: 97 | pano_embeds = self.pano_encoder( 98 | pano_embeds, src_key_padding_mask=pano_masks.logical_not() 99 | ) 100 | 101 | 102 | pano_embeds = self.mapper(pano_embeds) 103 | pano_embeds.masked_fill_(pano_masks.logical_not().unsqueeze(-1), 0) 104 | 105 | ret.update({ 106 | "pano_embeds": pano_embeds, 107 | "pano_masks": pano_masks 108 | }) 109 | 110 | # object feature 111 | if obj_img_fts is not None and obj_img_fts.shape[1] > 0: 112 | obj_embeds = self.obj_projector(obj_img_fts) 113 | obj_masks = gen_seq_masks(obj_lens) 114 | assert obj_embeds.shape[:2] == obj_loc_fts.shape[:2], f'shape of obj_embeds {obj_embeds.shape[:2]} must equal to shape of obj_loc_fts {obj_loc_fts.shape[:2]}' 115 | ret.update({ 116 | 'obj_embeds': obj_embeds, 117 | 'obj_loc_fts': obj_loc_fts, 118 | 'obj_masks': obj_masks 119 | }) 120 | 121 | return ret -------------------------------------------------------------------------------- /models/modified_lm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional, List, Union, Tuple 4 | from transformers import OPTForCausalLM, LlamaForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaConfig 5 | from transformers.modeling_outputs import CausalLMOutputWithPast 6 | from transformers.generation.logits_process import LogitsProcessor 7 | from tools.trie import Trie 8 | 9 | 10 | class TrieLogitsProcessor(LogitsProcessor): 11 | def __init__(self, trie: Trie): 12 | self.node_states = None 13 | self.trie = trie 14 | 15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 16 | batch_size = input_ids.shape[0] 17 | if self.node_states is None: 18 | self.node_states = [self.trie.root for bn in range(batch_size)] 19 | else: 20 | for bn in range(batch_size): 21 | w = input_ids[bn, -1].item() 22 | self.node_states[bn] = self.trie.get_next_node(self.node_states[bn], w) 23 | 24 | masks = torch.zeros_like(scores, dtype=torch.bool).to(scores.device) 25 | for bn in range(batch_size): 26 | next_layer = self.trie.get_child_index(self.node_states[bn]) 27 | masks[bn][next_layer] = True 28 | 29 | scores = scores.masked_fill(~masks, float('-inf')) 30 | return scores 31 | 32 | 33 | class ModifiedLM: 34 | """ 35 | This is base class for all ModifiedLM* 36 | """ 37 | 38 | def __init__(self, extra_config): 39 | 40 | if extra_config.precision == 'fp16': 41 | self.model_type = torch.float16 42 | elif 'bf16' in extra_config.precision or 'bfloat16' in extra_config.precision: 43 | self.model_type = torch.bfloat16 44 | else: 45 | self.model_type = torch.float32 46 | 47 | self.model = self.model.to(self.model_type) 48 | self.lm_head = self.lm_head.to(self.model_type) 49 | 50 | # print("************ Use dtype: {} ************\n".format(self.model_type)) 51 | 52 | # llama-7b dim=4096, bloom dim=1024, 53 | self.hidden_size = self.config.hidden_size 54 | 55 | 56 | def init_tokenizer(self, pretrained_model_name_or_path: str): 57 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left", truncation_side='left') if not isinstance(self.config, LlamaConfig) else LlamaTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left", truncation_side='left') 58 | 59 | self.cand_token = [''] 60 | self.hist_token = [''] 61 | self.obj_token = [''] 62 | self.cls_token = ['', ''] 63 | self.tokenizer.add_special_tokens( 64 | {"additional_special_tokens": self.cand_token + self.hist_token + self.obj_token + self.cls_token} 65 | ) 66 | if self.tokenizer.pad_token is None: 67 | self.tokenizer.add_special_tokens({"pad_token": ""}) 68 | 69 | self.cand_token_id = self.tokenizer.encode("".join(self.cand_token), add_special_tokens=False) 70 | self.hist_token_id = self.tokenizer.encode("".join(self.hist_token), add_special_tokens=False) 71 | self.obj_token_id = self.tokenizer.encode("".join(self.obj_token), add_special_tokens=False) 72 | self.cls_token_id = self.tokenizer.encode("".join(self.cls_token), add_special_tokens=False) 73 | self.special_token_ids = self.cand_token_id + self.hist_token_id + self.obj_token_id + self.cls_token_id 74 | 75 | self.resize_token_embeddings(len(self.tokenizer)) 76 | 77 | def tokenize(self, text: str, add_special_tokens: bool=True): 78 | batch_text = self.tokenizer( 79 | text, 80 | max_length=1024, 81 | padding=True, 82 | truncation=True, 83 | return_tensors="pt", 84 | add_special_tokens=add_special_tokens, 85 | return_token_type_ids=True 86 | ) 87 | return batch_text 88 | 89 | def forward( 90 | self, 91 | input_ids, 92 | attention_mask, 93 | labels=None, 94 | cand_vis=None, 95 | hist_vis=None, 96 | obj_vis=None, 97 | **kwargs 98 | ): 99 | 100 | hist_locations = (input_ids >= self.hist_token_id[0]) & (input_ids <= self.hist_token_id[-1]) 101 | cand_locations = (input_ids >= self.cand_token_id[0]) & (input_ids <= self.cand_token_id[-1]) 102 | obj_locations = (input_ids >= self.obj_token_id[0]) & (input_ids <= self.obj_token_id[-1]) 103 | 104 | inputs_embeds = self.get_input_embeddings()(input_ids) 105 | if cand_locations.sum() != 0: 106 | inputs_embeds[cand_locations] += cand_vis 107 | if hist_locations.sum() != 0: 108 | inputs_embeds[hist_locations] += hist_vis 109 | if obj_locations.sum() != 0: 110 | inputs_embeds[obj_locations] += obj_vis 111 | 112 | outputs = self.get_encoder()( 113 | attention_mask=attention_mask, 114 | inputs_embeds=inputs_embeds, 115 | **kwargs 116 | ) 117 | # outputs = self.model.transformer(*input, **kwargs) 118 | 119 | hidden_states = outputs[0] 120 | logits = self.lm_head(hidden_states) 121 | 122 | logits_mask = torch.ones_like(logits, dtype=torch.bool).to(logits.device) 123 | logits_mask[:, :, self.special_token_ids] = False 124 | logits = logits.masked_fill(~logits_mask, float('-inf')) 125 | 126 | loss = None 127 | if labels is not None: 128 | # Shift so that tokens < n predict n 129 | shift_logits = logits[..., :-1, :].contiguous() 130 | shift_labels = labels[..., 1:].contiguous() 131 | # Flatten the tokens 132 | loss_fct = nn.CrossEntropyLoss() 133 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 134 | shift_labels = shift_labels.view(-1) 135 | # Enable model parallelism 136 | shift_labels = shift_labels.to(shift_logits.device) 137 | loss = loss_fct(shift_logits, shift_labels) 138 | 139 | # logits = logits[cand_locations] 140 | return CausalLMOutputWithPast( 141 | loss=loss, 142 | logits=logits, 143 | past_key_values=outputs.past_key_values, 144 | hidden_states=hidden_states, # only store the last hidden states 145 | attentions=outputs.attentions, 146 | ) 147 | 148 | 149 | class ModifiedOPTForCasualLM(ModifiedLM, OPTForCausalLM): 150 | def __init__(self, config, extra_config): 151 | OPTForCausalLM.__init__(self, config) 152 | ModifiedLM.__init__(self, extra_config) 153 | 154 | def get_encoder(self): 155 | return self.model.decoder 156 | 157 | def prepare_inputs_for_generation( 158 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cand_vis=None, hist_vis=None, obj_vis=None, **kwargs 159 | ): 160 | model_inputs = OPTForCausalLM.prepare_inputs_for_generation( 161 | self, 162 | input_ids, 163 | past_key_values, 164 | attention_mask, 165 | inputs_embeds, 166 | **kwargs 167 | ) 168 | if not past_key_values: 169 | for k in ['cand_vis', 'hist_vis', 'obj_vis']: 170 | model_inputs[k] = eval(k) 171 | 172 | return model_inputs 173 | 174 | 175 | 176 | class ModifiedLlamaForCausalLM(ModifiedLM, LlamaForCausalLM): 177 | def __init__(self, config, extra_config): 178 | LlamaForCausalLM.__init__(self, config) 179 | ModifiedLM.__init__(self, extra_config) 180 | 181 | def get_encoder(self): 182 | return self.model 183 | 184 | def prepare_inputs_for_generation( 185 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cand_vis=None, hist_vis=None, obj_vis=None, **kwargs 186 | ): 187 | model_inputs = LlamaForCausalLM.prepare_inputs_for_generation( 188 | self, 189 | input_ids, 190 | past_key_values, 191 | attention_mask, 192 | inputs_embeds, 193 | **kwargs 194 | ) 195 | if not past_key_values: 196 | for k in ['cand_vis', 'hist_vis', 'obj_vis']: 197 | model_inputs[k] = eval(k) 198 | 199 | return model_inputs -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .detr_transformer import TransformerEncoder, TransformerEncoderLayer 3 | BertLayerNorm = torch.nn.LayerNorm 4 | 5 | 6 | def create_transformer_encoder(config, num_layers, norm=False): 7 | enc_layer = TransformerEncoderLayer( 8 | config.hidden_size, config.num_attention_heads, 9 | dim_feedforward=config.intermediate_size, 10 | dropout=config.hidden_dropout_prob, 11 | activation=config.hidden_act, 12 | normalize_before=True 13 | ) 14 | if norm: 15 | norm_layer = BertLayerNorm(config.hidden_size, eps=1e-12) 16 | else: 17 | norm_layer = None 18 | return TransformerEncoder(enc_layer, num_layers, norm=norm_layer, batch_first=True) 19 | 20 | 21 | def extend_neg_masks(masks, dtype=None): 22 | """ 23 | mask from (N, L) into (N, 1(H), 1(L), L) and make it negative 24 | """ 25 | if dtype is None: 26 | dtype = torch.float 27 | extended_masks = masks.unsqueeze(1).unsqueeze(2) 28 | extended_masks = extended_masks.to(dtype=dtype) 29 | extended_masks = (1.0 - extended_masks) * -10000.0 30 | return extended_masks 31 | 32 | 33 | def gen_seq_masks(seq_lens, max_len=None): 34 | if max_len is None: 35 | max_len = max(seq_lens) 36 | batch_size = len(seq_lens) 37 | device = seq_lens.device 38 | 39 | masks = torch.arange(max_len).unsqueeze(0).repeat(batch_size, 1).to(device) 40 | masks = masks < seq_lens.unsqueeze(1) 41 | return masks 42 | 43 | 44 | def pad_tensors_wgrad(tensors, lens=None): 45 | """B x [T, ...] torch tensors""" 46 | if lens is None: 47 | lens = [t.size(0) for t in tensors] 48 | max_len = max(lens) 49 | batch_size = len(tensors) 50 | hid = list(tensors[0].size()[1:]) 51 | 52 | device = tensors[0].device 53 | dtype = tensors[0].dtype 54 | 55 | output = [] 56 | for i in range(batch_size): 57 | if lens[i] < max_len: 58 | tmp = torch.cat( 59 | [tensors[i], torch.zeros([max_len-lens[i]]+hid, dtype=dtype).to(device)], 60 | dim=0 61 | ) 62 | else: 63 | tmp = tensors[i] 64 | output.append(tmp) 65 | output = torch.stack(output, 0) 66 | return output 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.10 2 | h5py==2.10.0 3 | jsonlines==2.0.0 4 | lmdb==1.4.1 5 | more_itertools==10.1.0 6 | msgpack_numpy==0.4.8 7 | msgpack_python==0.5.6 8 | networkx==2.5.1 9 | numpy==1.22.3 10 | opencv_python==4.7.0.72 11 | Pillow==10.1.0 12 | progressbar33==2.4 13 | psutil==5.9.4 14 | PyYAML==6.0.1 15 | ray==2.8.0 16 | requests==2.25.1 17 | shapely==2.0.1 18 | timm==0.9.2 19 | torch==1.10.0+cu113 20 | torchvision==0.11.0+cu113 21 | tqdm==4.64.1 22 | transformers==4.28.0 23 | sentencepiece==0.1.99 -------------------------------------------------------------------------------- /scripts/ablation/from_scratch.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 30 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 14 | --stage multi --cfg_file configs/multi.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 --from_scratch \ 16 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 30 \ 17 | --enable_og --enable_summarize --enable_fgr2r \ 18 | --test_datasets CVDN SOON R2R REVERIE ScanQA \ 19 | --max_saved_checkpoints 1 --output_dir build/ablation/from_scratch -------------------------------------------------------------------------------- /scripts/ablation/single_task.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 30 epochs on CVDN 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 43000 train.py \ 14 | --stage multi --cfg_file configs/ablation/cvdn.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --batch_size 1 --gradient_accumulation_step 8 --lr 3e-5 --seed 0 --num_epochs 20 \ 17 | --max_saved_checkpoints 1 --output_dir build/albation/cvdn -------------------------------------------------------------------------------- /scripts/data_tools/extract_features_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import collections 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import ray 10 | from ray.util.queue import Queue 11 | from torchvision import transforms 12 | from PIL import Image 13 | import math 14 | import h5py 15 | import argparse 16 | from more_itertools import batched 17 | import psutil 18 | 19 | 20 | @ray.remote(num_gpus=1) 21 | def process_features(proc_id, out_queue, scenevp_list, args): 22 | print(f"Start process {proc_id}, there are {len(scenevp_list)} datapoints") 23 | sys.path.append("EVA/EVA-CLIP/rei") 24 | from eva_clip import create_model_and_transforms 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # load visual encoder 29 | model, _, transform = create_model_and_transforms(args.model_name, args.pretrained, force_custom_clip=True) 30 | visual_encoder = model.visual.to(device) 31 | visual_encoder.eval() 32 | 33 | for i, batch in enumerate(batched(scenevp_list, args.batch_size)): 34 | images = [] 35 | for item in batch: 36 | image = Image.open(item["path"]) 37 | images.append(image) 38 | 39 | vision_x = [transform(image).unsqueeze(0).to(device) for image in images] 40 | vision_x = torch.cat(vision_x, dim=0) 41 | 42 | with torch.no_grad(), torch.cuda.amp.autocast(): 43 | outs = visual_encoder.forward_features(vision_x) 44 | outs = outs.data.cpu().numpy() 45 | 46 | for i, item in enumerate(batch): 47 | out_queue.put((item["image_id"], outs[i], [])) 48 | 49 | if i%1000==0: 50 | process = psutil.Process() 51 | memory_info = process.memory_info() 52 | print(f"Memory used by current process: {memory_info.rss / (1024 * 1024):.2f} MB") 53 | 54 | out_queue.put(None) 55 | 56 | @ray.remote 57 | def write_features(out_queue, total, num_workers, args): 58 | 59 | num_finished_workers = 0 60 | num_finished_vps = 0 61 | 62 | from progressbar import ProgressBar 63 | progress_bar = ProgressBar(total) 64 | progress_bar.start() 65 | 66 | with h5py.File(args.output_file, 'w') as outf: 67 | while num_finished_workers < num_workers: 68 | res = out_queue.get() 69 | if res is None: 70 | num_finished_workers += 1 71 | else: 72 | image_id, fts, logits = res 73 | key = image_id 74 | if False: 75 | data = np.hstack([fts, logits]) 76 | else: 77 | data = fts # shape=(36, 1408) 78 | outf.create_dataset(key, data.shape, dtype='float', compression='gzip') 79 | outf[key][...] = data 80 | outf[key].attrs['imageId'] = image_id 81 | 82 | num_finished_vps += 1 83 | if num_finished_vps % 20000 == 0: 84 | print("num_finished_vps: ", num_finished_vps) 85 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 86 | print("data shape: ", data.shape) 87 | progress_bar.update(num_finished_vps) 88 | 89 | progress_bar.finish() 90 | 91 | import time 92 | def main(args): 93 | 94 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 95 | 96 | 97 | image_list = [] 98 | for split in ["train2017", "val2017", "test2017"]: 99 | for filename in os.listdir(os.path.join(args.image_dir, split)): 100 | image_list.append({ 101 | "image_id": filename.split(".")[0], 102 | "path": os.path.join(args.image_dir, split, filename), 103 | }) 104 | print("Loaded %d viewpoints" % len(image_list)) 105 | print(image_list[0]) 106 | 107 | scenevp_list = image_list 108 | num_workers = min(args.num_workers, len(scenevp_list)) 109 | num_data_per_worker = len(scenevp_list) // num_workers 110 | 111 | ray.init() 112 | out_queue = Queue() 113 | processes = [] 114 | for proc_id in range(num_workers): 115 | sidx = proc_id * num_data_per_worker 116 | eidx = None if proc_id == num_workers - 1 else sidx + num_data_per_worker 117 | 118 | process = process_features.remote(proc_id, out_queue, scenevp_list[sidx: eidx], args) 119 | processes.append(process) 120 | 121 | process = write_features.remote(out_queue, len(scenevp_list), num_workers, args) 122 | processes.append(process) 123 | 124 | ray.get(processes) 125 | ray.shutdown() 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--model_name", type=str, default="EVA02-CLIP-L-14-336") 132 | parser.add_argument("--pretrained", type=str, default="data/models/EVA02_CLIP_L_336_psz14_s6B.pt", help="the path of EVA-CLIP") 133 | parser.add_argument('--batch_size', default=8, type=int) 134 | parser.add_argument('--num_workers', type=int, default=8) 135 | parser.add_argument('--image_dir', type=str, default="data/images/coco/", help="the path of coco 2017 dir") 136 | parser.add_argument("--output_file", type=str, default="data/eva_features/coco_EVA02-CLIP-L-14-336.hdf5", help="the path of output features") 137 | args = parser.parse_args() 138 | 139 | main(args) 140 | -------------------------------------------------------------------------------- /scripts/data_tools/extract_features_mp3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import collections 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import ray 10 | from ray.util.queue import Queue 11 | from torchvision import transforms 12 | from PIL import Image 13 | import math 14 | # sys.path.append(mp3d_path) # please add the simulator path to yout python path. 15 | import MatterSim 16 | import h5py 17 | import argparse 18 | 19 | 20 | def build_simulator(connectivity_dir, scan_dir): 21 | WIDTH = 640 22 | HEIGHT = 480 23 | VFOV = 60 24 | sim = MatterSim.Simulator() 25 | sim.setNavGraphPath(connectivity_dir) 26 | sim.setDatasetPath(scan_dir) 27 | sim.setCameraResolution(WIDTH, HEIGHT) 28 | sim.setCameraVFOV(math.radians(VFOV)) 29 | sim.setDiscretizedViewingAngles(True) 30 | sim.setDepthEnabled(False) 31 | sim.setPreloadingEnabled(False) 32 | sim.setBatchSize(1) 33 | sim.initialize() 34 | return sim 35 | 36 | @ray.remote(num_gpus=1) 37 | def process_features(proc_id, out_queue, scanvp_list, args): 38 | sys.path.append("EVA/EVA-CLIP/rei") 39 | from eva_clip import create_model_and_transforms 40 | 41 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 42 | print('start proc_id: %d' % proc_id) 43 | 44 | # Set up the simulator 45 | sim = build_simulator(args.connectivity_dir, args.scan_dir) 46 | 47 | # load visual encoder 48 | model, _, transform = create_model_and_transforms(args.model_name, args.pretrained, force_custom_clip=True) 49 | visual_encoder = model.visual.to(device) 50 | visual_encoder.eval() 51 | 52 | for scan_id, viewpoint_id in scanvp_list: 53 | # Loop all discretized views from this location 54 | images = [] 55 | for ix in range(36): 56 | if ix == 0: 57 | sim.newEpisode([scan_id], [viewpoint_id], [0], [math.radians(-30)]) 58 | elif ix % 12 == 0: 59 | sim.makeAction([0], [1.0], [1.0]) 60 | else: 61 | sim.makeAction([0], [1.0], [0]) 62 | state = sim.getState()[0] 63 | assert state.viewIndex == ix 64 | 65 | image = np.array(state.rgb, copy=True) # in BGR channel 66 | image = Image.fromarray(image[:, :, ::-1]) #cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 67 | images.append(image) 68 | 69 | vision_x = [transform(image).unsqueeze(0).to(device) for image in images] 70 | vision_x = torch.cat(vision_x, dim=0) 71 | 72 | fts = [] 73 | for k in range(0, len(images), args.batch_size): 74 | input_img = vision_x[k: k + args.batch_size] 75 | with torch.no_grad(), torch.cuda.amp.autocast(): 76 | outs = visual_encoder.forward_features(input_img) 77 | outs = outs.data.cpu().numpy() 78 | fts.append(outs) 79 | fts = np.concatenate(fts, 0) 80 | 81 | out_queue.put((scan_id, viewpoint_id, fts, [])) 82 | 83 | out_queue.put(None) 84 | 85 | @ray.remote 86 | def write_features(out_queue, total, num_workers, args): 87 | WIDTH = 640 88 | HEIGHT = 480 89 | VFOV = 60 90 | 91 | num_finished_workers = 0 92 | num_finished_vps = 0 93 | 94 | from progressbar import ProgressBar 95 | progress_bar = ProgressBar(total) 96 | progress_bar.start() 97 | 98 | with h5py.File(args.output_file, 'w') as outf: 99 | while num_finished_workers < num_workers: 100 | res = out_queue.get() 101 | if res is None: 102 | num_finished_workers += 1 103 | else: 104 | scan_id, viewpoint_id, fts, logits = res 105 | key = '%s_%s' % (scan_id, viewpoint_id) 106 | if False: 107 | data = np.hstack([fts, logits]) 108 | else: 109 | data = fts # shape=(36, 1408) 110 | outf.create_dataset(key, data.shape, dtype='float', compression='gzip') 111 | outf[key][...] = data 112 | outf[key].attrs['scanId'] = scan_id 113 | outf[key].attrs['viewpointId'] = viewpoint_id 114 | outf[key].attrs['image_w'] = WIDTH 115 | outf[key].attrs['image_h'] = HEIGHT 116 | outf[key].attrs['vfov'] = VFOV 117 | 118 | num_finished_vps += 1 119 | if num_finished_vps % 20 == 0: 120 | print("num_finished_vps: ",num_finished_vps) 121 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 122 | print("data shape: ", data.shape) 123 | progress_bar.update(num_finished_vps) 124 | 125 | progress_bar.finish() 126 | 127 | import time 128 | def main(args): 129 | 130 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 131 | 132 | viewpoint_ids = [] 133 | with open(os.path.join(connectivity_dir, 'scans.txt')) as f: 134 | scans = [x.strip() for x in f] 135 | for scan in scans: 136 | with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f: 137 | data = json.load(f) 138 | viewpoint_ids.extend([(scan, x['image_id']) for x in data if x['included']]) 139 | print('Loaded %d viewpoints' % len(viewpoint_ids)) 140 | scanvp_list = viewpoint_ids 141 | num_workers = min(args.num_workers, len(scanvp_list)) 142 | num_data_per_worker = len(scanvp_list) // num_workers 143 | 144 | ray.init() 145 | out_queue = Queue() 146 | processes = [] 147 | for proc_id in range(num_workers): 148 | sidx = proc_id * num_data_per_worker 149 | eidx = None if proc_id == num_workers - 1 else sidx + num_data_per_worker 150 | 151 | process = process_features.remote(proc_id, out_queue, scanvp_list[sidx: eidx], args) 152 | processes.append(process) 153 | 154 | process = write_features.remote(out_queue, len(scanvp_list), num_workers, args) 155 | processes.append(process) 156 | 157 | ray.get(processes) 158 | ray.shutdown() 159 | 160 | 161 | if __name__ == '__main__': 162 | 163 | scan_data_dir = '/mnt/petrelfs/zhaolin/vln/nav/features/mp3d' 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument("--model_name", type=str, default="EVA02-CLIP-L-14-336") 166 | parser.add_argument("--pretrained", type=str, default="data/models/EVA02_CLIP_L_336_psz14_s6B.pt", help='the path of pre-trained model') 167 | parser.add_argument('--connectivity_dir', default='data/connectivity', help='the path of connectivity') 168 | parser.add_argument('--scan_dir', default=scan_data_dir) 169 | parser.add_argument('--batch_size', default=16, type=int) 170 | parser.add_argument('--num_workers', type=int, default=8) 171 | parser.add_argument("--output_file", type=str, default="data/eva_features/mp3d_EVA02-CLIP-L-14-336.hdf5", help="the path of output features") 172 | args = parser.parse_args() 173 | 174 | main(args) 175 | 176 | -------------------------------------------------------------------------------- /scripts/data_tools/extract_features_scanqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import json 5 | import collections 6 | import cv2 7 | import torch 8 | import torch.nn as nn 9 | import ray 10 | from ray.util.queue import Queue 11 | from torchvision import transforms 12 | from PIL import Image 13 | import math 14 | import h5py 15 | import argparse 16 | from more_itertools import batched 17 | import psutil 18 | 19 | 20 | @ray.remote(num_gpus=1) 21 | def process_features(proc_id, out_queue, scenevp_list, args): 22 | print(f"Start process {proc_id}, there are {len(scenevp_list)} datapoints") 23 | sys.path.append("EVA/EVA-CLIP/rei") 24 | from eva_clip import create_model_and_transforms 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # load visual encoder 29 | model, _, transform = create_model_and_transforms(args.model_name, args.pretrained, force_custom_clip=True) 30 | visual_encoder = model.visual.to(device) 31 | visual_encoder.eval() 32 | 33 | # for scene_id, image_id in scenevp_list: 34 | for i, batch in enumerate(batched(scenevp_list, args.batch_size)): 35 | # Loop all discretized views from this location 36 | images = [] 37 | for item in batch: 38 | image = Image.open(item["path"]) 39 | images.append(image) 40 | 41 | vision_x = [transform(image).unsqueeze(0).to(device) for image in images] 42 | vision_x = torch.cat(vision_x, dim=0) 43 | 44 | with torch.no_grad(), torch.cuda.amp.autocast(): 45 | outs = visual_encoder.forward_features(vision_x) 46 | outs = outs.data.cpu().numpy() 47 | 48 | for i, item in enumerate(batch): 49 | out_queue.put((item["scene_id"], item["image_id"], outs[i], [])) 50 | 51 | if i%1000==0: 52 | process = psutil.Process() 53 | memory_info = process.memory_info() 54 | print(f"Memory used by current process: {memory_info.rss / (1024 * 1024):.2f} MB") 55 | 56 | out_queue.put(None) 57 | 58 | @ray.remote 59 | def write_features(out_queue, total, num_workers, args): 60 | 61 | num_finished_workers = 0 62 | num_finished_vps = 0 63 | 64 | from progressbar import ProgressBar 65 | progress_bar = ProgressBar(total) 66 | progress_bar.start() 67 | 68 | with h5py.File(args.output_file, 'w') as outf: 69 | while num_finished_workers < num_workers: 70 | res = out_queue.get() 71 | if res is None: 72 | num_finished_workers += 1 73 | else: 74 | scene_id, image_id, fts, logits = res 75 | key = '%s_%s' % (scene_id, image_id) 76 | if False: 77 | data = np.hstack([fts, logits]) 78 | else: 79 | data = fts # shape=(36, 1408) 80 | outf.create_dataset(key, data.shape, dtype='float', compression='gzip') 81 | outf[key][...] = data 82 | outf[key].attrs['sceneId'] = scene_id 83 | outf[key].attrs['imageId'] = image_id 84 | 85 | num_finished_vps += 1 86 | if num_finished_vps % 20000 == 0: 87 | print("num_finished_vps: ", num_finished_vps) 88 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 89 | print("data shape: ", data.shape) 90 | progress_bar.update(num_finished_vps) 91 | 92 | progress_bar.finish() 93 | 94 | import time 95 | def main(args): 96 | 97 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 98 | 99 | 100 | image_list = [] 101 | for scene_id in os.listdir(args.image_dir): 102 | if scene_id.endswith(".py") or scene_id.endswith(".txt"): 103 | continue 104 | for filename in os.listdir(os.path.join(args.image_dir, scene_id, "color")): 105 | image_list.append({ 106 | "path": os.path.join(args.image_dir, scene_id, "color", filename), 107 | "scene_id": scene_id, 108 | "image_id": filename.split('.')[0] 109 | }) 110 | print("Loaded %d viewpoints" % len(image_list)) 111 | print(image_list[0]) 112 | 113 | scenevp_list = image_list 114 | num_workers = min(args.num_workers, len(scenevp_list)) 115 | num_data_per_worker = len(scenevp_list) // num_workers 116 | 117 | ray.init() 118 | out_queue = Queue() 119 | processes = [] 120 | for proc_id in range(num_workers): 121 | sidx = proc_id * num_data_per_worker 122 | eidx = None if proc_id == num_workers - 1 else sidx + num_data_per_worker 123 | 124 | process = process_features.remote(proc_id, out_queue, scenevp_list[sidx: eidx], args) 125 | processes.append(process) 126 | 127 | process = write_features.remote(out_queue, len(scenevp_list), num_workers, args) 128 | processes.append(process) 129 | 130 | ray.get(processes) 131 | ray.shutdown() 132 | 133 | 134 | if __name__ == '__main__': 135 | 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument("--model_name", type=str, default="EVA02-CLIP-L-14-336") 138 | parser.add_argument("--pretrained", type=str, default="data/models/EVA02_CLIP_L_336_psz14_s6B.pt", help="the path of EVA-CLIP") 139 | parser.add_argument('--batch_size', default=8, type=int) 140 | parser.add_argument('--num_workers', type=int, default=8) 141 | parser.add_argument('--image_dir', type=str, default="data/ScanQA/frames_square/", help='the original ScanQA dataset with RGB frames') 142 | parser.add_argument("--output_file", type=str, default="data/eva_features/scanqa_EVA02-CLIP-L-14-336.hdf5", help="the path of output features") 143 | args = parser.parse_args() 144 | 145 | main(args) 146 | -------------------------------------------------------------------------------- /scripts/data_tools/reformat_scanqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | def get_image_metainfo(scene_id, args): 7 | path = None 8 | tmp_path = os.path.join(args.image_dir, scene_id) 9 | if os.path.exists(tmp_path): 10 | path = tmp_path 11 | # assert path is not None, f"{scene_id} cannot be None!" 12 | if path is None: 13 | raise ValueError(f"{scene_id} cannot be None!") 14 | 15 | image_info, object_info = [], [] 16 | 17 | def load_txt(filename): 18 | pose = [] 19 | with open(filename) as f: 20 | for line in f.readlines(): 21 | numbers = [float(s) for s in line.strip('\n').split(' ')] 22 | pose.append(numbers) 23 | return pose 24 | 25 | 26 | for filename in os.listdir(os.path.join(path, "color")): 27 | pose_file = os.path.join(path, "pose", filename.split('.')[0]+".txt") 28 | if not os.path.exists(pose_file): 29 | raise ValueError(f"{pose_file} not exist.") 30 | 31 | image_info.append({ 32 | "image_id": filename.split('.')[0], 33 | "pose": load_txt(pose_file) 34 | }) 35 | 36 | 37 | return image_info 38 | 39 | 40 | 41 | def main(args): 42 | for filename in ["ScanQA_v1.0_train.json", "ScanQA_v1.0_val.json", "ScanQA_v1.0_test_w_obj.json", "ScanQA_v1.0_test_wo_obj.json"]: 43 | with open(os.path.join(args.json_dir, filename)) as f: 44 | data = json.load(f) 45 | 46 | total_data = len(data) 47 | 48 | new_data = {} 49 | not_exist = 0 50 | not_exist_scene_id = {} 51 | for item in tqdm(data): 52 | scene_id = item["scene_id"] 53 | if scene_id in not_exist_scene_id: 54 | not_exist += 1 55 | continue 56 | 57 | try: 58 | if scene_id not in new_data: 59 | image_info = get_image_metainfo(scene_id, args) 60 | new_data[scene_id] = { 61 | "annotation": [], 62 | "image_info": image_info, 63 | } 64 | except Exception as e: 65 | print(f"{e} | SceneId: {scene_id}") 66 | not_exist += 1 67 | not_exist_scene_id[scene_id] = 1 68 | continue 69 | 70 | new_data[scene_id]["annotation"].append({ 71 | "question_id": item["question_id"], 72 | "question": item["question"], 73 | "answers": item.get("answers", []), 74 | "object_ids": item.get("object_ids", []), 75 | "object_names": item.get("object_names", []), 76 | }) 77 | 78 | 79 | data_list = [] 80 | for scene_id, item in new_data.items(): 81 | item["scene_id"] = scene_id 82 | item['image_info'] = sorted(item['image_info'], key=lambda x: int(x["image_id"])) 83 | data_list.append(item) 84 | 85 | os.makedirs(args.output_dir, exist_ok=True) 86 | with open(os.path.join(args.output_dir, f"{filename.replace('.json', '')}_reformat.json"), "w") as fout: 87 | json.dump(data_list, fout) 88 | 89 | print(f"Total data: {total_data}") 90 | print(f"Not exist: {not_exist}") 91 | 92 | 93 | if __name__=="__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--json_dir", type=str) 96 | parser.add_argument("--image_dir", type=str) 97 | parser.add_argument("--output_dir", type=str, default="data/ScanQA") 98 | args = parser.parse_args() 99 | main(args) -------------------------------------------------------------------------------- /scripts/evaluation/eval_cvdn.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 13 | --stage multi --mode test --data_dir data --cfg_file configs/multi.yaml \ 14 | --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 15 | --resume_from_checkpoint $model_path \ 16 | --test_datasets CVDN \ 17 | --batch_size 4 --output_dir build/eval --validation_split test --save_pred_results 18 | -------------------------------------------------------------------------------- /scripts/evaluation/eval_r2r.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 13 | --stage multi --mode test --data_dir data --cfg_file configs/multi.yaml \ 14 | --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 15 | --resume_from_checkpoint $model_path \ 16 | --test_datasets R2R \ 17 | --batch_size 4 --output_dir build/eval --validation_split test --save_pred_results 18 | -------------------------------------------------------------------------------- /scripts/evaluation/eval_reverie.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 13 | --stage multi --mode test --data_dir data --cfg_file configs/multi.yaml \ 14 | --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 15 | --resume_from_checkpoint $model_path \ 16 | --test_datasets REVERIE \ 17 | --batch_size 4 --output_dir build/eval --validation_split test --save_pred_results \ 18 | --do_sample --temperature 0.01 --enable_og 19 | -------------------------------------------------------------------------------- /scripts/evaluation/eval_scanqa.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 13 | --stage multi --mode test --data_dir data --cfg_file configs/multi.yaml \ 14 | --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 15 | --resume_from_checkpoint $model_path \ 16 | --test_datasets ScanQA \ 17 | --batch_size 4 --output_dir build/eval --validation_split test_w_obj --save_pred_results 18 | -------------------------------------------------------------------------------- /scripts/evaluation/eval_soon.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 13 | --stage multi --mode test --data_dir data --cfg_file configs/multi.yaml \ 14 | --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 15 | --resume_from_checkpoint $model_path \ 16 | --test_datasets SOON \ 17 | --batch_size 4 --output_dir build/eval --validation_split test --save_pred_results \ 18 | --do_sample --temperature 0.01 --enable_og 19 | -------------------------------------------------------------------------------- /scripts/held_out/held_out_cvdn.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 30 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 14 | --stage multi --cfg_file configs/held_out/held_out_cvdn.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 30 \ 17 | --enable_og --enable_summarize --enable_fgr2r \ 18 | --test_datasets CVDN SOON R2R REVERIE \ 19 | --max_saved_checkpoints 1 --output_dir output/held_out/held_out_cvdn -------------------------------------------------------------------------------- /scripts/held_out/held_out_reverie.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 30 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 14 | --stage multi --cfg_file configs/held_out/held_out_reverie.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 30 \ 17 | --enable_og --enable_summarize --enable_fgr2r \ 18 | --test_datasets CVDN SOON R2R REVERIE \ 19 | --max_saved_checkpoints 1 --output_dir output/held_out/held_out_reverie -------------------------------------------------------------------------------- /scripts/held_out/held_out_soon.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 30 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 14 | --stage multi --cfg_file configs/held_out/held_out_soon.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 30 \ 17 | --enable_og --enable_summarize --enable_fgr2r \ 18 | --test_datasets CVDN SOON R2R REVERIE \ 19 | --max_saved_checkpoints 1 --output_dir output/held_out/held_out_soon -------------------------------------------------------------------------------- /scripts/multi_w_pretrain.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 20 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 14 | --stage multi --cfg_file configs/multi.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --resume_from_checkpoint output/pretrain/pretrain_39.pt \ 17 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 20 \ 18 | --teacher_forcing_coef 1 --enable_og --enable_summarize --enable_fgr2r \ # setting teacher_forcing_coef=1 has less variance. 19 | --test_datasets CVDN SOON R2R REVERIE ScanQA \ 20 | --max_saved_checkpoints 1 --output_dir output/multi_w_pretrain \ -------------------------------------------------------------------------------- /scripts/multi_wo_pretrain.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 30 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 41000 train.py \ 14 | --stage multi --cfg_file configs/multi.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 30 \ 17 | --enable_og --enable_summarize --enable_fgr2r \ 18 | --test_datasets CVDN SOON R2R REVERIE ScanQA \ 19 | --max_saved_checkpoints 1 --output_dir output/multi_wo_pretrain \ -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | # set mp3d path 2 | # export PYTHONPATH=Matterport3DSimulator/build:$PYTHONPATH 3 | 4 | # set java path 5 | # export JAVA_HOME=$java_path 6 | # export PATH=$JAVA_HOME/bin:$PATH 7 | # export CLASSPATH=.:$JAVA_HOME/lib/dt.jar:$JAVA_HOME/lib/tools.jar 8 | 9 | # activate environment 10 | # conda activate navillm 11 | 12 | # training for 40 epochs 13 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 43000 train.py \ 14 | --stage pretrain --cfg_file configs/multi.yaml \ 15 | --data_dir data --pretrained_model_name_or_path data/models/Vicuna-7B --precision amp_bf16 \ 16 | --batch_size 1 --gradient_accumulation_step 8 --num_steps_per_epoch 2000 --lr 3e-5 --seed 0 --num_epochs 40 \ 17 | --enable_og --enable_summarize --enable_fgr2r \ 18 | --max_saved_checkpoints 1 --output_dir output/pretrain \ -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/tasks/__init__.py -------------------------------------------------------------------------------- /tasks/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_agent import MetaAgent 2 | 3 | # import the agent class here 4 | from .r2r import R2RAgent, R2RAugAgent 5 | from .reverie import REVERIEAgent, REVERIEAugAgent 6 | from .cvdn import CVDNAgent 7 | from .soon import SOONAgent 8 | from .scanqa import ScanQAAgent 9 | from .llava import LLaVAAgent 10 | 11 | 12 | def load_agent(name, *args, **kwargs): 13 | cls = MetaAgent.registry[name] 14 | return cls(*args, **kwargs) -------------------------------------------------------------------------------- /tasks/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | 2 | class MetaAgent(type): 3 | registry = {} 4 | 5 | def __init__(cls, name, bases, attrs): 6 | super().__init__(name, bases, attrs) 7 | if 'name' in attrs: 8 | MetaAgent.registry[attrs['name']] = cls 9 | 10 | class BaseAgent(metaclass=MetaAgent): 11 | def __init__(self, *args, **kwargs): 12 | pass 13 | 14 | def get_prompt(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | def prepare_prompts(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | def train(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | def validate(self, *args, **kwargs): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /tasks/agents/cvdn.py: -------------------------------------------------------------------------------- 1 | from .mp3d_agent import MP3DAgent 2 | 3 | class CVDNAgent(MP3DAgent): 4 | name = "cvdn" 5 | 6 | def get_prompt(self, task, *args, **kwargs): 7 | if task == 'navigation': 8 | return self.get_navigation_prompt(*args, **kwargs) 9 | else: 10 | raise NotImplementedError 11 | 12 | def get_navigation_prompt(self, instruction, hist_num, cand_num, cls_token): 13 | # Task 14 | prompt = '### Instruction: Find the described room according the given dialog. Target: {} \n'.format(instruction) 15 | # History 16 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 17 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 18 | prompt += '### History: {}\n'.format(hist_text) 19 | # Observation 20 | prompt += 'Following is the Candidate, which contains several directions you can go to at the current position, candidate (0) is stop.\n' 21 | obs_text = ' '.join(['({}) '.format(i) if i>0 else '(0) stop' for i in range(cand_num)]) 22 | prompt += '### Candidate: {}\n'.format(obs_text) 23 | # Output Hint 24 | prompt += 'Understand the dialog in the Instruction and infer the current progress based on the History and dialog. Then select the correct direction from the candidates to go to the target location.\n' 25 | prompt += '### Output: {}'.format(cls_token) 26 | 27 | return prompt 28 | -------------------------------------------------------------------------------- /tasks/agents/eqa.py: -------------------------------------------------------------------------------- 1 | from .mp3d_agent import MP3DAgent 2 | 3 | class EQAAgent(MP3DAgent): 4 | name = "eqa" 5 | 6 | def get_prompt(self, *args, **kwargs): 7 | if task == 'navigation': 8 | return self.get_navigation_prompt(*args, **kwargs) 9 | elif task == 'embodied_qa': 10 | return self.get_embodied_qa_prompt(*args, **kwargs) 11 | else: 12 | raise NotImplementedError 13 | 14 | def get_navigation_prompt(self, instruction, hist_num, cand_num, cls_token): 15 | 16 | # Task 17 | prompt = '### Instruction: Navigate following the instruction. Move to the object in "{}", and stop there. \n'.format(instruction.replace('?', '')) 18 | # History 19 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 20 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 21 | prompt += '### History: {}\n'.format(hist_text) 22 | # Observation 23 | prompt += 'Following is the Candidate, which contains several directions you can go to at the current position, candidate (0) is stop.\n' 24 | obs_text = ' '.join(['({}) '.format(i) if i>0 else '(0) stop' for i in range(cand_num)]) 25 | prompt += '### Candidate: {}\n'.format(obs_text) 26 | # Output Hint 27 | prompt += 'Compare the History and Instruction to infer your current progress, and then select the correct direction from the candidates to go to the target location.\n' 28 | prompt += '### Output: {}'.format(cls_token) 29 | 30 | return prompt 31 | 32 | def get_embodied_qa_prompt(self, instruction, hist_num, cand_num): 33 | # Task 34 | prompt = f'### Instruction: Answer the question according to the scene. \n' 35 | # History 36 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 37 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 38 | prompt += '### History: {}\n'.format(hist_text) 39 | # Observation 40 | if cand_num != 0: 41 | prompt += 'Following is the Observation, which contains panoramic views at your current location.\n' 42 | obs_text = ' '.join(['({}) '.format(i) for i in range(cand_num)]) 43 | prompt += '### Candidate: {}\n'.format(obs_text) 44 | # Output Hint 45 | prompt += '### Question: {}\n'.format(instruction) 46 | prompt += '### Answer: ' 47 | 48 | return prompt 49 | -------------------------------------------------------------------------------- /tasks/agents/llava.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from .base_agent import BaseAgent 3 | 4 | class LLaVAAgent(BaseAgent): 5 | name = "llava" 6 | 7 | def get_prompt(self, task, *args, **kwargs): 8 | if task == "3dqa": 9 | return self.get_3dqa_prompt(*args, **kwargs) 10 | else: 11 | raise NotImplementedError 12 | 13 | def get_3dqa_prompt(self, ques, cand_num): 14 | prompt = "### Image: \n" + \ 15 | "### Instruction: {}\n".format(ques) + \ 16 | "### Output: " 17 | return prompt 18 | 19 | def train( 20 | self, 21 | name, 22 | batch, 23 | args, 24 | config, 25 | model, 26 | **kwargs 27 | ): 28 | assert name in ["ScanQA", "LLaVA"], 'The task name must be in [ScanQA, LLaVA]' 29 | dataset_cfg = config.Pretrain if args.stage=='pretrain' else config.Multi 30 | loss_coef = dataset_cfg.LOSS_COEF.get(name, 1.) 31 | # construct prompt 32 | prompts = [] 33 | batch_size = len(batch["question"]) 34 | # update prompts 35 | batch["prompts"] = self.prepare_prompts(batch) 36 | 37 | # forward the model 38 | lm_loss = model("3dqa", batch).loss 39 | lm_loss *= loss_coef / args.gradient_accumulation_step 40 | lm_loss.backward() 41 | 42 | return lm_loss * args.gradient_accumulation_step 43 | 44 | 45 | def validate( 46 | self, 47 | name, 48 | args, 49 | config, 50 | model, 51 | loader, 52 | **kwargs, 53 | ): 54 | assert name in ["ScanQA"] 55 | preds = [] 56 | pbar = tqdm(loader, disable=args.rank!=0) 57 | for i, batch in enumerate(pbar): 58 | generation_kwargs = { 59 | "do_sample": args.do_sample, 60 | "temperature": args.temperature, 61 | "max_new_tokens": 20 62 | } 63 | batch["prompts"] = self.prepare_prompts(batch) 64 | outputs = model("3dqa", batch, training=False, **generation_kwargs) 65 | generated_sentences = outputs["generated_sentences"] 66 | for i in range(len(batch["question"])): 67 | preds.append({ 68 | "scene_id": batch["scene_id"][i], 69 | "question_id": batch["question_id"][i], 70 | "generated_sentences": [generated_sentences[i].lower().strip()] 71 | }) 72 | 73 | return preds 74 | 75 | def prepare_prompts(self, batch): 76 | prompts = [] 77 | for bn in range(len(batch["question"])): 78 | prompts.append( 79 | self.get_prompt( 80 | '3dqa', 81 | ques = batch["question"][bn], 82 | cand_num = batch["features"][bn].shape[0] 83 | ) 84 | ) 85 | return prompts -------------------------------------------------------------------------------- /tasks/agents/r2r.py: -------------------------------------------------------------------------------- 1 | from .mp3d_agent import MP3DAgent 2 | 3 | class R2RAgent(MP3DAgent): 4 | name = "r2r" 5 | 6 | def get_prompt(self, task, *args, **kwargs): 7 | if task == 'navigation': 8 | return self.get_navigation_prompt(*args, **kwargs) 9 | elif task == 'summarization': 10 | return self.get_summarization_prompt(*args, **kwargs) 11 | elif task == 'embodied_qa': 12 | return self.get_embodied_qa_prompt(*args, **kwargs) 13 | else: 14 | raise NotImplementedError 15 | 16 | def get_navigation_prompt(self, instruction, hist_num, cand_num, cls_token): 17 | # Task 18 | prompt = '### Instruction: Navigate following the instruction. {} \n'.format(instruction) 19 | # History 20 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 21 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 22 | prompt += '### History: {}\n'.format(hist_text) 23 | # Observation 24 | prompt += 'Following is the Candidate, which contains several directions you can go to at the current position, candidate (0) is stop.\n' 25 | obs_text = ' '.join(['({}) '.format(i) if i>0 else '(0) stop' for i in range(cand_num)]) 26 | prompt += '### Candidate: {}\n'.format(obs_text) 27 | # Output Hint 28 | prompt += 'Compare the History and Instruction to infer your current progress, and then select the correct direction from the candidates to go to the target location.\n' 29 | prompt += '### Output: {}'.format(cls_token) 30 | 31 | return prompt 32 | 33 | def get_summarization_prompt(self, instruction, hist_num, cand_num): 34 | # Task 35 | prompt = f'### Instruction: Predict the fine-grained instruction based on your previous history and current location. Fine-grained instructions contain commands for each individual step. \n' 36 | # History 37 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 38 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 39 | prompt += '### History: {}\n'.format(hist_text) 40 | # Observation 41 | if cand_num != 0: 42 | prompt += 'Following is the Observation, which contains panoramic views at your current location.\n' 43 | obs_text = ' '.join(['({}) '.format(i) for i in range(cand_num)]) 44 | prompt += '### Candidate: {}\n'.format(obs_text) 45 | # Output Hint 46 | prompt += 'Please generate the step-by-step instruction.\n' 47 | prompt += '### Answer: ' 48 | 49 | return prompt 50 | 51 | def get_embodied_qa_prompt(self, instruction, hist_num, cand_num): 52 | # Task 53 | prompt = f'### Instruction: answer the question. \n' 54 | # History 55 | if hist_num !=0: 56 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 57 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 58 | prompt += '### History: {}\n'.format(hist_text) 59 | # Observation 60 | if cand_num != 0: 61 | prompt += 'Following is the Observation, which contains panoramic views at your current location.\n' 62 | obs_text = ' '.join(['({}) '.format(i) for i in range(cand_num)]) 63 | prompt += '### Candidate: {}\n'.format(obs_text) 64 | # Output Hint 65 | prompt += '### Question: {}\n'.format(instruction) 66 | prompt += '### Answer: ' 67 | 68 | return prompt 69 | 70 | 71 | class R2RAugAgent(R2RAgent): 72 | name = "r2r_aug" -------------------------------------------------------------------------------- /tasks/agents/reverie.py: -------------------------------------------------------------------------------- 1 | from .mp3d_agent import MP3DAgent 2 | 3 | class REVERIEAgent(MP3DAgent): 4 | name = "reverie" 5 | 6 | def get_prompt(self, task, *args, **kwargs): 7 | if task == 'navigation': 8 | return self.get_navigation_prompt(*args, **kwargs) 9 | elif task == 'summarization': 10 | return self.get_summarization_prompt(*args, **kwargs) 11 | elif task == 'object_grounding': 12 | return self.get_object_grounding_prompt(*args, **kwargs) 13 | else: 14 | raise NotImplementedError 15 | 16 | def get_navigation_prompt(self, instruction, hist_num, cand_num, cls_token): 17 | # Task 18 | prompt = '### Instruction: Go to the location to complete the given task. Task: {} \n'.format(instruction) 19 | # History 20 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 21 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 22 | prompt += '### History: {}\n'.format(hist_text) 23 | # Observation 24 | prompt += 'Following is the Candidate, which contains several directions you can go to at the current position, candidate (0) is stop.\n' 25 | obs_text = ' '.join(['({}) '.format(i) if i>0 else '(0) stop' for i in range(cand_num)]) 26 | prompt += '### Candidate: {}\n'.format(obs_text) 27 | # Output Hint 28 | prompt += 'Explore the scene to find out the targeted room and object. Then select the correct direction from the candidates to go to the target location.\n' 29 | prompt += '### Output: {}'.format(cls_token) 30 | 31 | return prompt 32 | 33 | def get_summarization_prompt(self, instruction, hist_num, cand_num): 34 | # Task 35 | prompt = '### Instruction: Generate the task you need to complete based on your previous history and current location. \n' 36 | # History 37 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 38 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 39 | prompt += '### History: {}\n'.format(hist_text) 40 | # Observation 41 | if cand_num != 0: 42 | prompt += 'Following is the Observation, which contains panoramic views at your current location.\n' 43 | obs_text = ' '.join(['({}) '.format(i) for i in range(cand_num)]) 44 | prompt += '### Candidate: {}\n'.format(obs_text) 45 | # Output Hint 46 | prompt += 'Please predict the task you need to complete.\n' 47 | prompt += '### Answer: ' 48 | return prompt 49 | 50 | def get_object_grounding_prompt(self, instruction, hist_num, cand_num, cls_token): 51 | 52 | # Task 53 | prompt = "Select the target object from the candidate objects based on the instruction and history.\n" 54 | prompt += '### Instruction: Go to the location to complete the given task. Task: {} \n'.format(instruction) 55 | 56 | # History 57 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 58 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 59 | prompt += '### History: {}\n'.format(hist_text) 60 | 61 | # Observation 62 | prompt += 'Following is the Object, which contains several objects that you could see at the current viewpoint, option (0) indicates not exist.\n' 63 | cand_text = ' '.join(['({}) '.format(i) if i>0 else '(0) not exist' for i in range(cand_num)]) 64 | prompt += '### Object: {}\n'.format(cand_text) 65 | 66 | # Output Hint 67 | prompt += "Select the target object from the candidate objects according to the instruction.\n" 68 | prompt += '### Output: {}'.format(cls_token) 69 | 70 | return prompt 71 | 72 | class REVERIEAugAgent(REVERIEAgent): 73 | name = "reverie_aug" -------------------------------------------------------------------------------- /tasks/agents/scanqa.py: -------------------------------------------------------------------------------- 1 | from .llava import LLaVAAgent 2 | 3 | 4 | class ScanQAAgent(LLaVAAgent): 5 | name = "scanqa" 6 | 7 | def get_prompt(self, task, *args, **kwargs): 8 | if task == '3dqa': 9 | return self.get_3dqa_prompt(*args, **kwargs) 10 | else: 11 | raise NotImplementedError 12 | 13 | def get_3dqa_prompt(self, ques, cand_num): 14 | obs_text = ' '.join(["({}) ".format(i) for i in range(cand_num)]) 15 | prompt = "Please answer questions based on the observation.\n" + \ 16 | "The following is the Observation, which includes multiple images from different locations.\n" + \ 17 | "### Observation: {} \n".format(obs_text) + \ 18 | "### Question: {}\n".format(ques) + \ 19 | "### Answer: " 20 | return prompt -------------------------------------------------------------------------------- /tasks/agents/soon.py: -------------------------------------------------------------------------------- 1 | from .mp3d_agent import MP3DAgent 2 | 3 | class SOONAgent(MP3DAgent): 4 | name = "soon" 5 | 6 | def get_prompt(self, task, *args, **kwargs): 7 | if task == 'navigation': 8 | return self.get_navigation_prompt(*args, **kwargs) 9 | elif task == 'summarization': 10 | return self.get_summarization_prompt(*args, **kwargs) 11 | elif task == 'object_grounding': 12 | return self.get_object_grounding_prompt(*args, **kwargs) 13 | else: 14 | raise NotImplementedError 15 | 16 | def get_navigation_prompt(self, instruction, hist_num, cand_num, cls_token): 17 | 18 | # Task 19 | prompt = '### Instruction: Find the described target. Target: {} \n'.format(instruction) 20 | # History 21 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 22 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 23 | prompt += '### History: {}\n'.format(hist_text) 24 | # Observation 25 | prompt += 'Following is the Candidate, which contains several directions you can go to at the current position, candidate (0) is stop.\n' 26 | obs_text = ' '.join(['({}) '.format(i) if i>0 else '(0) stop' for i in range(cand_num)]) 27 | prompt += '### Candidate: {}\n'.format(obs_text) 28 | # Output Hint 29 | prompt += 'Nearby areas and objects can assist you in locating the desired room and object. Select the correct direction from the candidates to go to the target location.\n' 30 | prompt += '### Output: {}'.format(cls_token) 31 | 32 | return prompt 33 | 34 | def get_summarization_prompt(self, instruction, hist_num, cand_num): 35 | 36 | # Task 37 | prompt = '### Instruction: Generate the target you want to find based on your previous history and current location. Describe both the target and its surroundings. \n' 38 | # History 39 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 40 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 41 | prompt += '### History: {}\n'.format(hist_text) 42 | # Observation 43 | if cand_num != 0: 44 | prompt += 'Following is the Observation, which contains panoramic views at your current location.\n' 45 | obs_text = ' '.join(['({}) '.format(i) for i in range(cand_num)]) 46 | prompt += '### Candidate: {}\n'.format(obs_text) 47 | # Output Hint 48 | prompt += 'Please predict both the target you want to find and its surroundings.\n' 49 | prompt += '### Answer: ' 50 | 51 | return prompt 52 | 53 | def get_object_grounding_prompt(self, instruction, hist_num, cand_num, cls_token): 54 | 55 | # Task 56 | prompt = "Select the target object from the candidate objects based on the instruction and history.\n" 57 | prompt += '### Instruction: Find the described target. Target: {} \n'.format(instruction) 58 | 59 | # History 60 | prompt += 'Following is the History, which contains the visual information of your previous decisions.\n' 61 | hist_text = ' '.join(['({}) '.format(i) for i in range(hist_num)]) 62 | prompt += '### History: {}\n'.format(hist_text) 63 | 64 | # Observation 65 | prompt += 'Following is the Object, which contains several objects that you could see at the current viewpoint, option (0) indicates not exist.\n' 66 | cand_text = ' '.join(['({}) '.format(i) if i>0 else '(0) not exist' for i in range(cand_num)]) 67 | prompt += '### Object: {}\n'.format(cand_text) 68 | 69 | # Output Hint 70 | prompt += "Select the target object from the candidate objects according to the instruction.\n" 71 | prompt += '### Output: {}'.format(cls_token) 72 | 73 | return prompt 74 | 75 | -------------------------------------------------------------------------------- /tasks/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import MetaDataset 2 | 3 | # import the dataset class here 4 | from .r2r import R2RDataset 5 | from .cvdn import CVDNDataset 6 | from .soon import SOONDataset 7 | from .eqa import EQADataset 8 | from .reverie import REVERIEDataset 9 | from .r2r_aug import R2RAugDataset 10 | from .reverie_aug import REVERIEAugDataset 11 | from .llava import LLaVADataset 12 | from .scanqa import ScanQADataset 13 | 14 | def load_dataset(name, *args, **kwargs): 15 | cls = MetaDataset.registry[name] 16 | return cls(*args, **kwargs) -------------------------------------------------------------------------------- /tasks/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MetaDataset(type): 4 | registry = {} 5 | 6 | def __init__(cls, name, bases, attrs): 7 | super().__init__(name, bases, attrs) 8 | if 'name' in attrs: 9 | MetaDataset.registry[attrs['name']] = cls 10 | 11 | class BaseDataset(torch.utils.data.Dataset, metaclass=MetaDataset): 12 | pass 13 | -------------------------------------------------------------------------------- /tasks/datasets/coco_caption.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | from PIL import Image 5 | from typing import Dict, Any, List 6 | from .llava import LLaVADataset 7 | 8 | 9 | class COCOCaptionDataset(LLaVADataset): 10 | name = "coco" 11 | 12 | def _load_data(self, config: Dict, data_dir: str): 13 | path = os.path.join(data_dir, config.coco_caption.DIR, config.coco_caption.SPLIT[self.split]) 14 | 15 | self.alldata = [] 16 | with open(path) as f: 17 | data = json.load(f) 18 | 19 | for item in data: 20 | if self.training: 21 | for sent in item["sentences"]: 22 | self.alldata.append({ 23 | "sentid": sent["sentid"], 24 | "image": item["filename"].split("_")[-1], 25 | "input": "What is the caption of this image?", 26 | "label": sent["raw"]+"", 27 | }) 28 | else: 29 | self.alldata.append({ 30 | "imgid": item["imgid"], 31 | "image": item["filename"].split("_")[-1], 32 | "input": "What is the caption of this image?", 33 | "refs": [sent["raw"] for sent in item["sentences"]] 34 | }) 35 | 36 | if self.max_datapoints: 37 | self.alldata = self.alldata[:self.max_datapoints] 38 | self.logger.info(f"There are totally {len(self.alldata)} datapoints loaded.") 39 | 40 | def __getitem__(self, index:int) -> Dict[str, Any]: 41 | item = copy.deepcopy(self.alldata[index]) 42 | 43 | # load image 44 | image_path = os.path.join(self.config.coco_caption.IMAGE_DIR, 'train2017', item["image"]) 45 | if not os.path.exists(image_path): 46 | image_path = os.path.join(self.config.coco_caption.IMAGE_DIR, 'val2017', item["image"]) 47 | 48 | image = Image.open(image_path).convert('RGB') 49 | 50 | if self.training: 51 | data_dict = { 52 | "sentid": item["sentid"], 53 | "image": image, 54 | "input": item["input"], 55 | "label": item["label"] 56 | } 57 | else: 58 | data_dict = { 59 | "imgid": item["imgid"], 60 | "image": image, 61 | "input": item["input"], 62 | "refs": item["refs"] 63 | } 64 | 65 | return data_dict 66 | 67 | def eval_metrics(self, preds: List[Dict[str, Any]], logger, name: str) -> Dict[str, float]: 68 | refs = {} 69 | for item in self.alldata: 70 | refs[item["imgid"]] = item["refs"] 71 | 72 | gen = {item['imgid']:item['outputs'] for item in preds} 73 | 74 | from tools.evaluation.bleu import Bleu 75 | bleu_score = Bleu() 76 | 77 | score, scores = bleu_score.compute_score(refs, gen) 78 | 79 | ret = {} 80 | for i, s in enumerate(score): 81 | ret[f"bleu-{i+1}"] = s 82 | return ret -------------------------------------------------------------------------------- /tasks/datasets/eqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | from pathlib import Path 4 | import torch.utils.data as torch_data 5 | import torch 6 | import math 7 | from collections import defaultdict 8 | import numpy as np 9 | import networkx as nx 10 | from .mp3d_envs import ( 11 | EnvBatch, new_simulator, angle_feature, 12 | get_all_point_angle_feature, load_nav_graphs, 13 | ) 14 | ERROR_MARGIN = 3.0 15 | from tools.evaluation.bleu import Bleu 16 | from tools.evaluation.rouge import Rouge 17 | from tools.evaluation.cider import Cider 18 | from .mp3d_dataset import MP3DDataset, get_anno_file_path 19 | 20 | 21 | class EQADataset(MP3DDataset): 22 | name = "eqa" 23 | 24 | def __init__( 25 | self, 26 | args, 27 | config, 28 | training=False, 29 | logger=None, 30 | source=None, 31 | ): 32 | super().__init__(args, config, training, logger, source) 33 | # answer_vocab 34 | filename = get_anno_file_path(args.data_dir, config.EQA.DIR, config.EQA.ANSWER_VOCAB) 35 | with open(filename) as f: 36 | self.answer_vocab = json.load(f) 37 | 38 | 39 | def init_feat_db(self, feat_db, obj_feat_db=None): 40 | self.feat_db = feat_db 41 | self.obj_feat_db = obj_feat_db 42 | 43 | 44 | def load_data(self, anno_file, max_instr_len=200, split='train', debug=False): 45 | """ 46 | :param anno_file: 47 | :param max_instr_len: 48 | :param debug: 49 | :return: 50 | """ 51 | with open(str(anno_file), "r") as f: 52 | data = json.load(f) 53 | new_data = [] 54 | 55 | for i, item in enumerate(data): 56 | new_item = dict(item) 57 | new_item['raw_idx'] = item['sample_idx'] 58 | new_item['instr_id'] = 'eqa_{}_{}'.format(item['sample_idx'], i) 59 | new_item['path_id'] = item['sample_idx'] 60 | new_item['data_type'] = 'eqa' 61 | new_item['heading'] = 0.0 62 | new_data.append(new_item) 63 | 64 | if debug: 65 | new_data = new_data[:20] 66 | 67 | gt_trajs = { 68 | x['instr_id']: (x['scan'], x['path']) \ 69 | for x in new_data if len(x['path']) > 1 70 | } 71 | return new_data, gt_trajs 72 | 73 | 74 | def get_obs(self, items, env, data_type=None): 75 | obs = [] 76 | 77 | for i, (feature, state) in enumerate(env.getStates()): 78 | item = items[i] 79 | base_view_id = state.viewIndex 80 | 81 | if feature is None: 82 | feature = self.feat_db.get_image_feature(state.scanId, state.location.viewpointId) 83 | 84 | # Full features 85 | candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex) 86 | 87 | # [visual_feature, angle_feature] for views 88 | feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1) 89 | 90 | ob = { 91 | 'instr_id': item['instr_id'], 92 | 'scan': state.scanId, 93 | 'viewpoint': state.location.viewpointId, 94 | 'viewIndex': state.viewIndex, 95 | 'position': (state.location.x, state.location.y, state.location.z), 96 | 'heading': state.heading, 97 | 'elevation': state.elevation, 98 | 'feature': feature, 99 | 'candidate': candidate, 100 | 'navigableLocations': state.navigableLocations, 101 | 'instruction': item['question']['question_text'], 102 | 'answer': item['question']['answer_text'], 103 | # 'instr_encoding': item['instr_encoding'], 104 | 'gt_path': item['path'], 105 | 'path_id': item['path_id'], 106 | } 107 | if False: # ob['instr_id'] in self.gt_trajs: 108 | ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]] 109 | else: 110 | ob['distance'] = 0 111 | obs.append(ob) 112 | return obs 113 | 114 | ########################### Evalidation ########################### 115 | def eval_metrics(self, preds, logger, name): 116 | """ 117 | Evaluate each agent trajectory based on how close it got to the goal location 118 | the path contains [view_id, angle, vofv] 119 | :param preds: 120 | :param logger: 121 | :param name: 122 | :return: 123 | """ 124 | logger.info('eval %d predictions' % (len(preds))) 125 | metrics = defaultdict(list) 126 | all_pred_ans = {} 127 | all_gt_ans = {} 128 | for item in preds: 129 | instr_id = item['instr_id'] 130 | traj = item['trajectory'] 131 | pred_ans = item['pred_answer'] 132 | gt_ans = item['gt_answer'] 133 | all_pred_ans[instr_id] = pred_ans 134 | all_gt_ans[instr_id] = [gt_ans] 135 | 136 | if instr_id not in self.gt_trajs.keys(): 137 | print("instr_id {} not in self.gt_trajs".format(instr_id)) 138 | raise NotImplementedError 139 | 140 | scan, gt_traj = self.gt_trajs[instr_id] 141 | traj_scores = self.eval_dis_item(scan, traj, gt_traj) 142 | 143 | for k, v in traj_scores.items(): 144 | metrics[k].append(v) 145 | metrics['instr_id'].append(instr_id) 146 | 147 | avg_metrics = { 148 | 'action_steps': np.mean(metrics['action_steps']), 149 | 'steps': np.mean(metrics['trajectory_steps']), 150 | 'lengths': np.mean(metrics['trajectory_lengths']), 151 | 'nav_error': np.mean(metrics['nav_error']), 152 | 'oracle_error': np.mean(metrics['oracle_error']), 153 | 'sr': np.mean(metrics['success']) * 100, 154 | 'oracle_sr': np.mean(metrics['oracle_success']) * 100, 155 | 'spl': np.mean(metrics['spl']) * 100, 156 | } 157 | 158 | # bleu_score = Bleu() 159 | # score, scores = bleu_score.compute_score(all_gt_ans, all_pred_ans) 160 | # for i, s in enumerate(score): 161 | # avg_metrics[f"bleu-{i+1}"] = s * 100 162 | 163 | # rouge_score = Rouge() 164 | # score, compute_score = rouge_score.compute_score(all_gt_ans, all_pred_ans) 165 | # avg_metrics["rouge"] = score * 100 166 | 167 | # cider_score = Cider() 168 | # score, compute_score = cider_score.compute_score(all_gt_ans, all_pred_ans) 169 | # avg_metrics["cider"] = score * 100 170 | n_correct = 0 171 | for pred in preds: 172 | if pred['pred_answer'] in all_gt_ans[pred["instr_id"]]: 173 | n_correct += 1 174 | avg_metrics["exact_match"] = n_correct / len(preds) * 100 175 | 176 | n_oracle_correct = 0 177 | for pred in preds: 178 | if pred['oracle_pred_answer'] in all_gt_ans[pred['instr_id']]: 179 | n_oracle_correct += 1 180 | avg_metrics["oracle_exact_match"] = n_oracle_correct / len(preds) * 100 181 | 182 | return avg_metrics, metrics 183 | 184 | def eval_dis_item(self, scan, pred_path, gt_path): 185 | scores = {} 186 | 187 | shortest_distances = self.shortest_distances[scan] 188 | 189 | path = sum(pred_path, []) 190 | assert gt_path[0] == path[0], 'Result trajectories should include the start position' 191 | 192 | nearest_position = self.get_nearest(shortest_distances, gt_path[-1], path) 193 | 194 | scores['nav_error'] = shortest_distances[path[-1]][gt_path[-1]] 195 | scores['oracle_error'] = shortest_distances[nearest_position][gt_path[-1]] 196 | 197 | scores['action_steps'] = len(pred_path) - 1 198 | scores['trajectory_steps'] = len(path) - 1 199 | scores['trajectory_lengths'] = np.sum([shortest_distances[a][b] for a, b in zip(path[:-1], path[1:])]) 200 | 201 | gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])]) 202 | 203 | scores['success'] = float(scores['nav_error'] < ERROR_MARGIN) 204 | scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) 205 | scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN) 206 | 207 | return scores 208 | 209 | def save_json(self, results, path, item_metrics=None): 210 | if item_metrics is not None: 211 | for k in item_metrics: 212 | for item, v in zip(results, item_metrics[k]): 213 | item[k] = v 214 | 215 | with open(path, 'w') as fout: 216 | json.dump(results, fout) -------------------------------------------------------------------------------- /tasks/datasets/llava.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import torch 5 | import logging 6 | import numpy as np 7 | from collections import defaultdict 8 | from PIL import Image 9 | from typing import List, Dict, Any, Union 10 | from .base_dataset import BaseDataset 11 | 12 | 13 | class LLaVADataset(BaseDataset): 14 | name = 'llava' 15 | 16 | def __init__( 17 | self, 18 | args, 19 | config: Dict, 20 | training: bool=False, 21 | logger: logging.Logger=None, 22 | source: str=None, 23 | ): 24 | super().__init__() 25 | self.config = config 26 | self.training = training 27 | self.logger = logger 28 | self.source = source 29 | 30 | if training: 31 | self.split = 'train' 32 | else: 33 | self.split = args.validation_split 34 | 35 | self.batch_size = args.batch_size 36 | self.feat_db = None 37 | self.obj_feat_db = None 38 | self.max_datapoints = args.max_datapoints 39 | 40 | self._load_data(config, args.data_dir) 41 | 42 | 43 | def init_feat_db(self, feat_db, obj_feat_db=None): 44 | self.feat_db = feat_db 45 | self.obj_feat_db = obj_feat_db 46 | 47 | 48 | def _load_data(self, config: Dict, data_dir: str): 49 | 50 | if self.source == "LLaVA": 51 | role_mapping = { 52 | "human": "USER", 53 | "gpt": "ASSISTANT", 54 | } 55 | seps = [" ", ""] 56 | 57 | path = os.path.join(data_dir, config.LLaVA.DIR, config.LLaVA.SPLIT[self.split]) 58 | with open(path) as f: 59 | data = json.load(f) 60 | self.alldata = [] 61 | 62 | for item in data: 63 | image = item["image"] 64 | conversations = item["conversations"] 65 | prompt = "" 66 | assert len(conversations)==2, "The round of conversation must be 2." 67 | for i in range(0, len(conversations)-1, 2): 68 | assert conversations[i]["from"]=="human", f"The {i}-th utterance must come from human!" 69 | assert conversations[i+1]["from"]=="gpt", f"The {i+1}-th utterance must come from agent!" 70 | self.alldata.append({ 71 | "id": item["id"], 72 | "turn_id": i//2, 73 | "image_id": item["image"].split(".")[0], 74 | # "input": role_mapping[conversations[i]["from"]] + ": " + conversations[i]["value"] + seps[0] + role_mapping[conversations[i+1]["from"]] + ": ", 75 | # "label": conversations[i+1]["value"] + seps[1] # indicates the end of generation. 76 | "question": conversations[i]["value"].replace("", "").strip(), 77 | "answers": [conversations[i+1]["value"]] 78 | }) 79 | 80 | # prompt += role_mapping[conversations[i]["from"]] + ": " + conversations[i]["value"] + seps[0] + role_mapping[conversations[i+1]["from"]] + ": " + conversations[i+1]["value"] + seps[1] 81 | 82 | if self.max_datapoints: 83 | self.alldata = self.alldata[:self.max_datapoints] 84 | self.logger.info(f"There are totally {len(self.alldata)} datapoints loaded.") 85 | 86 | else: 87 | raise NotImplementedError 88 | 89 | 90 | def __len__(self) -> int: 91 | return len(self.alldata) 92 | 93 | 94 | def __getitem__(self, index:int) -> Dict[str, Any]: 95 | item = copy.deepcopy(self.alldata[index]) 96 | 97 | # load image 98 | features = self.feat_db.get_image_feature(item["image_id"]) 99 | features = torch.from_numpy(np.stack(features)).unsqueeze(0) 100 | 101 | data_dict = { 102 | "id": item["id"], 103 | "image_id": item["image_id"], 104 | "question": item["question"], 105 | "answers": item["answers"], 106 | "data_type": "llava", 107 | "features": features, 108 | } 109 | 110 | return data_dict 111 | 112 | 113 | @staticmethod 114 | def collate_batch(batch_list: List[Dict], _unused: bool=False) -> Dict[str, Union[List[Any], torch.Tensor]]: 115 | data_dict = defaultdict(list) 116 | for cur_sample in batch_list: 117 | for key, val in cur_sample.items(): 118 | data_dict[key].append(val) 119 | batch_size = len(batch_list) 120 | ret = {} 121 | for key, val in data_dict.items(): 122 | try: 123 | if key in ['NotImplemented']: 124 | ret[key] = torch.stack(val, 0) 125 | else: 126 | ret[key] = val 127 | except: 128 | print('Error in collate_batch: key=%s' % key) 129 | raise TypeError 130 | 131 | ret['batch_size'] = batch_size 132 | return ret 133 | -------------------------------------------------------------------------------- /tasks/datasets/mp3d_envs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import networkx as nx 4 | import json 5 | import os 6 | import math 7 | import msgpack 8 | import msgpack_numpy 9 | msgpack_numpy.patch() 10 | try: 11 | import MatterSim 12 | except Exception as e: 13 | print(e) 14 | raise NotImplementedError 15 | 16 | 17 | def new_simulator(connectivity_dir): 18 | # Simulator image parameters 19 | WIDTH = 640 20 | HEIGHT = 480 21 | VFOV = 60 22 | 23 | sim = MatterSim.Simulator() 24 | sim.setNavGraphPath(connectivity_dir) 25 | sim.setRenderingEnabled(False) 26 | sim.setCameraResolution(WIDTH, HEIGHT) 27 | sim.setCameraVFOV(math.radians(VFOV)) 28 | sim.setDiscretizedViewingAngles(True) 29 | sim.setBatchSize(1) 30 | sim.initialize() 31 | 32 | return sim 33 | 34 | 35 | def angle_feature(heading, elevation, angle_feat_size): 36 | return np.array( 37 | [math.sin(heading), math.cos(heading), 38 | math.sin(elevation), math.cos(elevation)] * (angle_feat_size // 4), 39 | dtype=np.float32) 40 | 41 | 42 | def get_point_angle_feature(sim, angle_feat_size, baseViewId=0): 43 | feature = np.empty((36, angle_feat_size), np.float32) 44 | base_heading = (baseViewId % 12) * math.radians(30) 45 | base_elevation = (baseViewId // 12 - 1) * math.radians(30) 46 | 47 | for ix in range(36): 48 | if ix == 0: 49 | sim.newEpisode(['ZMojNkEp431'], ['2f4d90acd4024c269fb0efe49a8ac540'], [0], [math.radians(-30)]) 50 | elif ix % 12 == 0: 51 | sim.makeAction([0], [1.0], [1.0]) 52 | else: 53 | sim.makeAction([0], [1.0], [0]) 54 | 55 | state = sim.getState()[0] 56 | assert state.viewIndex == ix 57 | 58 | heading = state.heading - base_heading 59 | elevation = state.elevation - base_elevation 60 | 61 | feature[ix, :] = angle_feature(heading, elevation, angle_feat_size) 62 | return feature 63 | 64 | 65 | def get_all_point_angle_feature(sim, angle_feat_size): 66 | return [get_point_angle_feature(sim, angle_feat_size, baseViewId) for baseViewId in range(36)] 67 | 68 | 69 | def load_nav_graphs(connectivity_dir, scans): 70 | ''' Load connectivity graph for each scan ''' 71 | 72 | def distance(pose1, pose2): 73 | ''' Euclidean distance between two graph poses ''' 74 | return ((pose1['pose'][3] - pose2['pose'][3]) ** 2 \ 75 | + (pose1['pose'][7] - pose2['pose'][7]) ** 2 \ 76 | + (pose1['pose'][11] - pose2['pose'][11]) ** 2) ** 0.5 77 | 78 | graphs = {} 79 | for scan in scans: 80 | with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f: 81 | G = nx.Graph() 82 | positions = {} 83 | data = json.load(f) 84 | for i, item in enumerate(data): 85 | if item['included']: 86 | for j, conn in enumerate(item['unobstructed']): 87 | if conn and data[j]['included']: 88 | positions[item['image_id']] = np.array([item['pose'][3], 89 | item['pose'][7], item['pose'][11]]); 90 | assert data[j]['unobstructed'][i], 'Graph should be undirected' 91 | G.add_edge(item['image_id'], data[j]['image_id'], weight=distance(item, data[j])) 92 | nx.set_node_attributes(G, values=positions, name='position') 93 | graphs[scan] = G 94 | return graphs 95 | 96 | 97 | def normalize_angle(x): 98 | '''convert radians into (-pi, pi]''' 99 | pi2 = 2 * math.pi 100 | x = x % pi2 # [0, 2pi] 101 | if x > math.pi: 102 | x = x - pi2 103 | return x 104 | 105 | 106 | def convert_heading(x): 107 | return x % (2 * math.pi) / (2 * math.pi) # [0, 2pi] -> [0, 1) 108 | 109 | 110 | def convert_elevation(x): 111 | return (normalize_angle(x) + math.pi) / (2 * math.pi) # [0, 2pi] -> [0, 1) 112 | 113 | 114 | class EnvBatch(object): 115 | def __init__(self, connectivity_dir, feat_db=None, batch_size=1): 116 | self.feat_db = feat_db 117 | self.image_w = 640 118 | self.image_h = 480 119 | self.vfov = 60 120 | self.sims = [] 121 | for i in range(batch_size): 122 | sim = MatterSim.Simulator() 123 | sim.setNavGraphPath(connectivity_dir) 124 | sim.setRenderingEnabled(False) 125 | sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians) 126 | sim.setCameraResolution(self.image_w, self.image_h) 127 | sim.setCameraVFOV(math.radians(self.vfov)) 128 | sim.setBatchSize(1) 129 | sim.initialize() 130 | self.sims.append(sim) 131 | 132 | def newEpisodes(self, scanIds, viewpointIds, headings): 133 | for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)): 134 | self.sims[i].newEpisode([scanId], [viewpointId], [heading], [0]) 135 | 136 | def getStates(self): 137 | """ 138 | Get list of states augmented with precomputed image features. rgb field will be empty. 139 | Agent's current view [0-35] (set only when viewing angles are discretized) 140 | [0-11] looking down, [12-23] looking at horizon, [24-35] looking up 141 | :return: [ ((36, 2048), sim_state) ] * batch_size 142 | """ 143 | feature_states = [] 144 | for i, sim in enumerate(self.sims): 145 | state = sim.getState()[0] 146 | 147 | if self.feat_db is None: 148 | feature = None 149 | else: 150 | feature = self.feat_db.get_image_feature(state.scanId, state.location.viewpointId) 151 | feature_states.append((feature, state)) 152 | return feature_states 153 | 154 | def makeActions(self, actions): 155 | ''' Take an action using the full state dependent action interface (with batched input). 156 | Every action element should be an (index, heading, elevation) tuple. ''' 157 | for i, (index, heading, elevation) in enumerate(actions): 158 | self.sims[i].makeAction([index], [heading], [elevation]) -------------------------------------------------------------------------------- /tasks/datasets/r2r.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from .mp3d_dataset import MP3DDataset 4 | from collections import defaultdict 5 | ERROR_MARGIN = 3.0 6 | 7 | class R2RDataset(MP3DDataset): 8 | name = "r2r" 9 | 10 | def load_data(self, anno_file, max_instr_len=200, debug=False): 11 | """ 12 | :param anno_file: 13 | :param max_instr_len: 14 | :param debug: 15 | :return: 16 | """ 17 | with open(str(anno_file), "r") as f: 18 | data = json.load(f) 19 | new_data = [] 20 | sample_index = 0 21 | 22 | for i, item in enumerate(data): 23 | # Split multiple instructions into separate entries 24 | for j, instr in enumerate(item['instructions']): 25 | new_item = dict(item) 26 | new_item['raw_idx'] = i 27 | new_item['sample_idx'] = sample_index 28 | new_item['instr_id'] = 'r2r_{}_{}'.format(item['path_id'], j) 29 | 30 | new_item['instruction'] = instr 31 | del new_item['instructions'] 32 | 33 | if 'instr_encodings' in new_item: 34 | new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len] 35 | del new_item['instr_encodings'] 36 | 37 | if 'new_instructions' in new_item and len(eval(item['new_instructions'])) > j: 38 | new_item['fg_instruction'] = eval(item['new_instructions'])[j] 39 | new_item['fg_instruction'] = [' '.join(instr) for instr in new_item['fg_instruction']] 40 | del new_item['new_instructions'] 41 | new_item['fg_view'] = item['chunk_view'][j] 42 | fg_view = [] 43 | for idx, index in enumerate(new_item['fg_view']): 44 | index_num = index[1] - index[0] 45 | fg_view += [idx] * index_num 46 | new_item['fg_view'] = fg_view 47 | del new_item['chunk_view'] 48 | 49 | new_item['data_type'] = 'r2r' 50 | new_data.append(new_item) 51 | sample_index += 1 52 | 53 | if debug: 54 | new_data = new_data[:20] 55 | 56 | gt_trajs = { 57 | x['instr_id']: (x['scan'], x['path']) \ 58 | for x in new_data if len(x['path']) > 1 59 | } 60 | return new_data, gt_trajs 61 | 62 | 63 | def eval_metrics(self, preds, logger, name): 64 | """ 65 | Evaluate each agent trajectory based on how close it got to the goal location 66 | the path contains [view_id, angle, vofv] 67 | :param preds: 68 | :param logger: 69 | :param name: 70 | :return: 71 | """ 72 | logger.info('eval %d predictions' % (len(preds))) 73 | metrics = defaultdict(list) 74 | 75 | for item in preds: 76 | instr_id = item['instr_id'] 77 | traj = item['trajectory'] 78 | 79 | if instr_id not in self.gt_trajs.keys(): 80 | print("instr_id {} not in self.gt_trajs".format(instr_id)) 81 | raise NotImplementedError 82 | 83 | if name == "R2R": 84 | scan, gt_traj = self.gt_trajs[instr_id] 85 | traj_scores = self.eval_dis_item(scan, traj, gt_traj) 86 | else: 87 | raise NotImplementedError 88 | 89 | for k, v in traj_scores.items(): 90 | metrics[k].append(v) 91 | metrics['instr_id'].append(instr_id) 92 | 93 | if name in ['R2R']: 94 | avg_metrics = { 95 | 'action_steps': np.mean(metrics['action_steps']), 96 | 'steps': np.mean(metrics['trajectory_steps']), 97 | 'lengths': np.mean(metrics['trajectory_lengths']), 98 | 'nav_error': np.mean(metrics['nav_error']), 99 | 'oracle_error': np.mean(metrics['oracle_error']), 100 | 'sr': np.mean(metrics['success']) * 100, 101 | 'oracle_sr': np.mean(metrics['oracle_success']) * 100, 102 | 'spl': np.mean(metrics['spl']) * 100, 103 | } 104 | else: 105 | raise NotImplementedError 106 | return avg_metrics, metrics 107 | 108 | def eval_dis_item(self, scan, pred_path, gt_path): 109 | scores = {} 110 | 111 | shortest_distances = self.shortest_distances[scan] 112 | 113 | path = sum(pred_path, []) 114 | assert gt_path[0] == path[0], 'Result trajectories should include the start position' 115 | 116 | nearest_position = self.get_nearest(shortest_distances, gt_path[-1], path) 117 | 118 | scores['nav_error'] = shortest_distances[path[-1]][gt_path[-1]] 119 | scores['oracle_error'] = shortest_distances[nearest_position][gt_path[-1]] 120 | 121 | scores['action_steps'] = len(pred_path) - 1 122 | scores['trajectory_steps'] = len(path) - 1 123 | scores['trajectory_lengths'] = np.sum([shortest_distances[a][b] for a, b in zip(path[:-1], path[1:])]) 124 | 125 | gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])]) 126 | 127 | scores['success'] = float(scores['nav_error'] < ERROR_MARGIN) 128 | scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) 129 | scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN) 130 | 131 | return scores 132 | 133 | def save_json(self, results, path, item_metrics=None): 134 | if item_metrics is not None: 135 | for k in item_metrics: 136 | for item, v in zip(results, item_metrics[k]): 137 | item[k] = v 138 | 139 | for item in results: 140 | item['instr_id'] = "_".join(item['instr_id'].split("_")[1:]) 141 | item['trajectory'] = [[y, 0, 0] for x in item['trajectory'] for y in x] 142 | 143 | with open(path, 'w') as fout: 144 | json.dump(results, fout) -------------------------------------------------------------------------------- /tasks/datasets/r2r_aug.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from .r2r import R2RDataset 4 | from collections import defaultdict 5 | from transformers import AutoTokenizer 6 | 7 | class R2RAugDataset(R2RDataset): 8 | name = "r2r_aug" 9 | 10 | def load_data(self, anno_file, max_instr_len=200, debug=False): 11 | """ 12 | :param anno_file: 13 | :param max_instr_len: 14 | :param debug: 15 | :return: 16 | """ 17 | if str(anno_file).endswith(".json"): 18 | return super().load_data(anno_file, max_instr_len=max_instr_len, debug=debug) 19 | 20 | with open(str(anno_file), "r") as f: 21 | data = [] 22 | for i, line in enumerate(f.readlines()): 23 | if debug and i==20: 24 | break 25 | data.append(json.loads(line.strip())) 26 | new_data = [] 27 | sample_idx = 0 28 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 29 | 30 | for i, item in enumerate(data): 31 | new_item = dict(item) 32 | new_item["raw_idx"] = i 33 | new_item["sample_idx"] = sample_idx 34 | new_item['data_type'] = 'r2r_aug' 35 | new_item["path_id"] = None 36 | new_item["heading"] = item.get("heading", 0) 37 | new_item["instruction"] = tokenizer.decode(new_item['instr_encoding'], skip_special_tokens=True) 38 | new_data.append(new_item) 39 | sample_idx += 1 40 | 41 | if debug: 42 | new_data = new_data[:20] 43 | 44 | gt_trajs = { 45 | x['instr_id']: (x['scan'], x['path']) \ 46 | for x in new_data if len(x['path']) > 1 47 | } 48 | return new_data, gt_trajs 49 | 50 | -------------------------------------------------------------------------------- /tasks/datasets/reverie.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from .mp3d_dataset import MP3DDataset 5 | from collections import defaultdict 6 | 7 | class REVERIEDataset(MP3DDataset): 8 | name = "reverie" 9 | 10 | def __init__( 11 | self, 12 | args, 13 | config, 14 | training=False, 15 | logger=None, 16 | source=None, 17 | ): 18 | super().__init__(args, config, training, logger, source) 19 | self.multi_startpoints = False 20 | self.multi_endpoints = args.multi_endpoints 21 | 22 | def preprocess_item(self, item): 23 | if self.split!="train" or "end_vps" not in item or (not self.multi_startpoints and not self.multi_endpoints): 24 | return item 25 | 26 | start_vp = item["path"][0] 27 | end_vp = item["path"][-1] 28 | 29 | if self.multi_startpoints: 30 | cand_vps = [] 31 | for cvp, cpath in self.shortest_paths[item['scan']][end_vps[i]].items(): 32 | if len(cpath) >= 4 and len(cpath) <= 7: 33 | cand_vps.append(cvp) 34 | if len(cand_vps) > 0: 35 | start_vp = cand_vps[np.random.randint(len(cand_vps))] 36 | 37 | if self.multi_endpoints: 38 | end_vp = item["end_vps"][np.random.randint(len(item["end_vps"]))] 39 | 40 | item = copy.deepcopy(item) 41 | item["path"] = self.shortest_paths[item["scan"]][start_vp][end_vp] 42 | return item 43 | 44 | def load_data(self, anno_file, obj2vps, debug=False): 45 | with open(str(anno_file), "r") as f: 46 | data = json.load(f) 47 | 48 | new_data = [] 49 | sample_index = 0 50 | for i, item in enumerate(data): 51 | # Split multiple instructions into separate entries 52 | for j, instr in enumerate(item['instructions']): 53 | new_item = dict(item) 54 | 55 | if 'objId' in item: 56 | new_item['instr_id'] = '%s_%s_%s_%d' % ('reverie', str(item['path_id']), str(item['objId']), j) 57 | else: 58 | new_item['path_id'] = item['id'] 59 | new_item['instr_id'] = '%s_%s_%d' % ('reverie', item['id'], j) 60 | new_item['objId'] = None 61 | 62 | new_item['sample_idx'] = sample_index 63 | new_item['instruction'] = instr 64 | del new_item['instructions'] 65 | new_item['data_type'] = 'reverie' 66 | 67 | new_item['raw_idx'] = None 68 | new_item['instr_encoding'] = None 69 | 70 | if 'objId' in item and item['objId'] is not None: 71 | new_item['end_vps'] = obj2vps['%s_%s'%(item['scan'], item['objId'])] 72 | 73 | new_data.append(new_item) 74 | sample_index += 1 75 | if debug: 76 | new_data = new_data[:20] 77 | 78 | gt_trajs = { 79 | x['instr_id']: (x['scan'], x['path'], x['objId']) \ 80 | for x in new_data if 'objId' in x and x['objId'] is not None 81 | } 82 | 83 | return new_data, gt_trajs 84 | 85 | 86 | def load_obj2vps(self, bbox_file): 87 | obj2vps = {} 88 | bbox_data = json.load(open(bbox_file)) 89 | for scanvp, value in bbox_data.items(): 90 | scan, vp = scanvp.split('_') 91 | # for all visible objects at that viewpoint 92 | for objid, objinfo in value.items(): 93 | if objinfo['visible_pos']: 94 | # if such object not already in the dict 95 | obj2vps.setdefault(scan+'_'+objid, []) 96 | obj2vps[scan+'_'+objid].append(vp) 97 | self.obj2vps = obj2vps 98 | return obj2vps 99 | 100 | def eval_metrics(self, preds, logger, name): 101 | """ 102 | Evaluate each agent trajectory based on how close it got to the goal location 103 | the path contains [view_id, angle, vofv] 104 | :param preds: 105 | :param logger: 106 | :param name: 107 | :return: 108 | """ 109 | logger.info('eval %d predictions' % (len(preds))) 110 | metrics = defaultdict(list) 111 | 112 | for item in preds: 113 | instr_id = item['instr_id'] 114 | traj = item['trajectory'] 115 | pred_objid = item.get('pred_objid', None) 116 | scan, gt_traj, gt_objid = self.gt_trajs[instr_id] 117 | traj_scores = self.eval_dis_item(scan, traj, pred_objid, gt_traj, gt_objid) 118 | 119 | for k, v in traj_scores.items(): 120 | metrics[k].append(v) 121 | metrics['instr_id'].append(instr_id) 122 | 123 | avg_metrics = { 124 | 'action_steps': np.mean(metrics['action_steps']), 125 | 'steps': np.mean(metrics['trajectory_steps']), 126 | 'lengths': np.mean(metrics['trajectory_lengths']), 127 | 'nav_error': np.mean(metrics['nav_error']), 128 | 'oracle_error': np.mean(metrics['oracle_error']), 129 | 'sr': np.mean(metrics['success']) * 100, 130 | 'oracle_sr': np.mean(metrics['oracle_success']) * 100, 131 | 'spl': np.mean(metrics['spl']) * 100, 132 | 'rgs': np.mean(metrics['rgs']) * 100, 133 | 'rgspl': np.mean(metrics['rgspl']) * 100 134 | } 135 | 136 | return avg_metrics, metrics 137 | 138 | def eval_dis_item(self, scan, pred_path, pred_objid, gt_path, gt_objid): 139 | scores = {} 140 | 141 | shortest_distances = self.shortest_distances[scan] 142 | 143 | path = sum(pred_path, []) 144 | assert gt_path[0] == path[0], 'Result trajectories should include the start position' 145 | 146 | nearest_position = self.get_nearest(shortest_distances, gt_path[-1], path) 147 | 148 | scores['nav_error'] = shortest_distances[path[-1]][gt_path[-1]] 149 | scores['oracle_error'] = shortest_distances[nearest_position][gt_path[-1]] 150 | 151 | scores['action_steps'] = len(pred_path) - 1 152 | scores['trajectory_steps'] = len(path) - 1 153 | scores['trajectory_lengths'] = np.sum([shortest_distances[a][b] for a, b in zip(path[:-1], path[1:])]) 154 | 155 | gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])]) 156 | 157 | # navigation: success is to arrive to a viewpoint where the object is visible 158 | goal_viewpoints = set(self.obj2vps['%s_%s'%(scan, str(gt_objid))]) 159 | assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid)) 160 | 161 | scores['success'] = float(path[-1] in goal_viewpoints) 162 | scores['oracle_success'] = float(any(x in goal_viewpoints for x in path)) 163 | scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) 164 | 165 | scores['rgs'] = str(pred_objid) == str(gt_objid) 166 | scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) 167 | 168 | return scores 169 | 170 | def get_object_info(self, item, state): 171 | # objects 172 | obj_img_fts, obj_ang_fts, obj_box_fts, obj_ids = \ 173 | self.obj_feat_db.get_object_feature( 174 | state.scanId, state.location.viewpointId, 175 | state.heading, state.elevation, self.angle_feat_size, 176 | max_objects=self.max_objects 177 | ) 178 | 179 | gt_end_vps = item.get('end_vps', []) 180 | 181 | gt_obj_id = None 182 | vp = state.location.viewpointId 183 | if vp in gt_end_vps: 184 | gt_obj_id = item['objId'] 185 | 186 | return { 187 | 'obj_img_fts': obj_img_fts, 188 | 'obj_ang_fts': obj_ang_fts, 189 | 'obj_box_fts': obj_box_fts, 190 | 'obj_ids': obj_ids, 191 | 'gt_end_vps': gt_end_vps, 192 | 'gt_obj_id': gt_obj_id, 193 | } 194 | 195 | def save_json(self, results, path, item_metrics=None): 196 | if item_metrics is not None: 197 | for k in item_metrics: 198 | for item, v in zip(results, item_metrics[k]): 199 | item[k] = v 200 | 201 | for item in results: 202 | item['instr_id'] = "_".join(item['instr_id'].split("_")[1:]) 203 | item['trajectory'] = [[y, 0, 0] for x in item['trajectory'] for y in x] 204 | item['predObjId'] = int(item['pred_objid']) if item['pred_objid'] is not None else 0 205 | 206 | with open(path, 'w') as f: 207 | json.dump(results, f) 208 | 209 | -------------------------------------------------------------------------------- /tasks/datasets/reverie_aug.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from .reverie import REVERIEDataset 4 | from collections import defaultdict 5 | from transformers import AutoTokenizer 6 | 7 | class REVERIEAugDataset(REVERIEDataset): 8 | name = "reverie_aug" 9 | 10 | def load_data(self, anno_file, obj2vps, debug=False): 11 | if str(anno_file).endswith("json"): 12 | return super().load_data(anno_file, obj2vps, debug=debug) 13 | 14 | with open(str(anno_file), "r") as f: 15 | data = [] 16 | for i, line in enumerate(f.readlines()): 17 | if debug and i==20: 18 | break 19 | data.append(json.loads(line.strip())) 20 | 21 | new_data = [] 22 | sample_idx = 0 23 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 24 | 25 | for i, item in enumerate(data): 26 | new_item = dict(item) 27 | new_item["raw_idx"] = i 28 | new_item["sample_idx"] = sample_idx 29 | new_item['data_type'] = 'reverie_aug' 30 | new_item["instruction"] = tokenizer.decode(new_item['instr_encoding'], skip_special_tokens=True) 31 | new_item['objId'] = None 32 | new_item["path_id"] = None 33 | new_item["heading"] = item.get("heading", 0) 34 | new_item['end_vps'] = item['pos_vps'] 35 | del new_item['pos_vps'] 36 | new_data.append(new_item) 37 | sample_idx += 1 38 | 39 | if debug: 40 | new_data = new_data[:20] 41 | 42 | gt_trajs = { 43 | x['instr_id']: (x['scan'], x['path'], x['objId']) \ 44 | for x in new_data if 'objId' in x and x['objId'] is not None 45 | } 46 | 47 | return new_data, gt_trajs -------------------------------------------------------------------------------- /tasks/datasets/scanqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | from PIL import Image 5 | from typing import Dict, Any, List, Tuple 6 | from .llava import LLaVADataset 7 | import random 8 | import torch 9 | import numpy as np 10 | from tools.evaluation.bleu import Bleu 11 | from tools.evaluation.rouge import Rouge 12 | from tools.evaluation.cider import Cider 13 | from tools.evaluation.meteor import Meteor 14 | 15 | class ScanQADataset(LLaVADataset): 16 | name = "scanqa" 17 | 18 | def _load_data(self, config: Dict, data_dir: str): 19 | if config.ScanQA.DIR.startswith("/"): 20 | path = os.path.join(config.ScanQA.DIR, config.ScanQA.SPLIT[self.split]) 21 | else: 22 | path = os.path.join(data_dir, config.ScanQA.DIR, config.ScanQA.SPLIT[self.split]) 23 | 24 | self.alldata = [] 25 | with open(path) as f: 26 | data = json.load(f) 27 | for item in data: 28 | for ann in item["annotation"]: 29 | self.alldata.append({ 30 | "question_id": ann["question_id"], 31 | "question": ann["question"], 32 | "answers": [ans.lower() for ans in ann["answers"]], 33 | "image_info": item["image_info"], 34 | "scene_id": item["scene_id"] 35 | }) 36 | 37 | if self.max_datapoints: 38 | self.alldata = self.alldata[:self.max_datapoints] 39 | self.logger.info(f"There are totally {len(self.alldata)} datapoints loaded.") 40 | 41 | def __getitem__(self, index:int) -> Dict[str, Any]: 42 | item = copy.deepcopy(self.alldata[index]) 43 | 44 | sampled_images = random.sample(item["image_info"], min(36, len(item["image_info"]))) 45 | features = [] 46 | for d in sampled_images: 47 | fts = self.feat_db.get_image_feature(item["scene_id"], d["image_id"]) 48 | features.append(fts) 49 | features = torch.from_numpy(np.stack(features)) 50 | 51 | data_dict = { 52 | "scene_id": item["scene_id"], 53 | "question_id": item["question_id"], 54 | "question": item["question"], 55 | "answers": item["answers"], 56 | "features": features, 57 | "data_type": "scan_qa" 58 | } 59 | return data_dict 60 | 61 | def eval_metrics(self, preds: List[Dict[str, Any]], logger, name: str) -> Tuple[Dict[str, float], Dict[str, List[float]]]: 62 | ret = {} 63 | if self.split=='test': 64 | return ret 65 | 66 | refs = {} 67 | for item in self.alldata: 68 | refs[item["question_id"]] = item["answers"] 69 | gen = {item['question_id']:item['generated_sentences'] for item in preds} 70 | 71 | bleu_score = Bleu() 72 | score, scores = bleu_score.compute_score(refs, gen) 73 | for i, s in enumerate(score): 74 | ret[f"bleu-{i+1}"] = s * 100 75 | 76 | rouge_score = Rouge() 77 | score, compute_score = rouge_score.compute_score(refs, gen) 78 | ret["rouge"] = score * 100 79 | 80 | cider_score = Cider() 81 | score, compute_score = cider_score.compute_score(refs, gen) 82 | ret["cider"] = score * 100 83 | 84 | meteor_score = Meteor() 85 | score, compute_score = meteor_score.compute_score(refs, gen) 86 | ret["meteor"] = score * 100 87 | 88 | n_correct = 0 89 | metrics = {"exact_match": []} 90 | for pred in preds: 91 | if pred['generated_sentences'][0] in refs[pred["question_id"]]: 92 | n_correct += 1 93 | metrics["exact_match"].append(1.) 94 | else: 95 | metrics["exact_match"].append(0.) 96 | ret["exact_match"] = n_correct / len(preds) * 100 97 | 98 | return ret, metrics 99 | 100 | def save_json(self, results, path, item_metrics=None): 101 | for item in results: 102 | item['answer_top10'] = item['generated_sentences'] 103 | item['pred_bbox'] = [] 104 | del item['generated_sentences'] 105 | 106 | with open(path, 'w') as f: 107 | json.dump(results, f) -------------------------------------------------------------------------------- /tasks/feature_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import lmdb 4 | from typing import Dict 5 | import numpy as np 6 | import msgpack 7 | import msgpack_numpy 8 | from .datasets.mp3d_envs import angle_feature, convert_elevation, convert_heading 9 | msgpack_numpy.patch() 10 | 11 | 12 | class ImageFeaturesDB(object): 13 | def __init__(self, img_ft_file: str, image_feat_size: str): 14 | self.image_feat_size = image_feat_size 15 | self.img_ft_file = img_ft_file 16 | self._feature_store = {} 17 | 18 | def get_image_feature(self, scan:str, viewpoint: str=None, load_in_memory: bool=False) -> np.ndarray: 19 | key = '%s_%s' % (scan, viewpoint) if viewpoint is not None else scan 20 | if key in self._feature_store: 21 | ft = self._feature_store[key] 22 | else: 23 | with h5py.File(self.img_ft_file, 'r') as f: 24 | ft = f[key] 25 | if len(ft.shape)==1: 26 | ft = ft[:self.image_feat_size].astype(np.float32) 27 | else: 28 | ft = ft[:, :self.image_feat_size].astype(np.float32) 29 | if load_in_memory: 30 | self._feature_store[key] = ft 31 | return ft 32 | 33 | 34 | def create_feature_db(config: Dict, image_feat_size: int, args) -> Dict[str, ImageFeaturesDB]: 35 | ret = {} 36 | for source in config: 37 | path = config[source] if config[source].startswith("/") else os.path.join(args.data_dir, config[source]) 38 | ret[source] = ImageFeaturesDB( 39 | path, 40 | image_feat_size 41 | ) 42 | return ret 43 | 44 | 45 | class REVERIEObjectFeatureDB(object): 46 | def __init__(self, obj_ft_file, obj_feat_size, im_width=640, im_height=480): 47 | self.obj_feat_size = obj_feat_size 48 | self.obj_ft_file = obj_ft_file 49 | self._feature_store = {} 50 | self.im_width = im_width 51 | self.im_height = im_height 52 | self.env = lmdb.open(self.obj_ft_file, readonly=True) 53 | 54 | def load_feature(self, scan, viewpoint, max_objects=None): 55 | key = '%s_%s' % (scan, viewpoint) 56 | if key in self._feature_store: 57 | obj_fts, obj_attrs = self._feature_store[key] 58 | else: 59 | with self.env.begin() as txn: 60 | obj_data = txn.get(key.encode('ascii')) 61 | if obj_data is not None: 62 | obj_data = msgpack.unpackb(obj_data) 63 | obj_fts = obj_data['fts'][:, :self.obj_feat_size].astype(np.float32) 64 | obj_attrs = {k: v for k, v in obj_data.items() if k != 'fts'} 65 | else: 66 | obj_fts = np.zeros((0, self.obj_feat_size), dtype=np.float32) 67 | obj_attrs = {} 68 | self._feature_store[key] = (obj_fts, obj_attrs) 69 | 70 | if max_objects is not None: 71 | obj_fts = obj_fts[:max_objects] 72 | obj_attrs = {k: v[:max_objects] for k, v in obj_attrs.items()} 73 | return obj_fts, obj_attrs 74 | 75 | def get_object_feature( 76 | self, scan, viewpoint, base_heading, base_elevation, angle_feat_size, 77 | max_objects=None 78 | ): 79 | obj_fts, obj_attrs = self.load_feature(scan, viewpoint, max_objects=max_objects) 80 | obj_ang_fts = np.zeros((len(obj_fts), angle_feat_size), dtype=np.float32) 81 | obj_box_fts = np.zeros((len(obj_fts), 3), dtype=np.float32) 82 | obj_ids = [] 83 | if len(obj_fts) > 0: 84 | for k, obj_ang in enumerate(obj_attrs['centers']): 85 | obj_ang_fts[k] = angle_feature( 86 | obj_ang[0] - base_heading, obj_ang[1] - base_elevation, angle_feat_size 87 | ) 88 | w, h = obj_attrs['bboxes'][k][2:] 89 | obj_box_fts[k, :2] = [h/self.im_height, w/self.im_width] 90 | obj_box_fts[k, 2] = obj_box_fts[k, 0] * obj_box_fts[k, 1] 91 | obj_ids = obj_attrs['obj_ids'] 92 | return obj_fts, obj_ang_fts, obj_box_fts, obj_ids 93 | 94 | 95 | class SOONObjectFeatureDB(object): 96 | # TODO: This class requires adapting to current modification. 97 | def __init__(self, obj_ft_file, obj_feat_size): 98 | self.obj_feat_size = obj_feat_size 99 | self.obj_ft_file = obj_ft_file 100 | self._feature_store = {} 101 | self.env = lmdb.open(obj_ft_file, readonly=True) 102 | 103 | def __del__(self): 104 | self.env.close() 105 | 106 | def load_feature(self, scan, viewpoint, max_objects=None): 107 | key = '%s_%s' % (scan, viewpoint) 108 | if key in self._feature_store: 109 | obj_fts, obj_attrs = self._feature_store[key] 110 | else: 111 | with self.env.begin() as txn: 112 | obj_data = txn.get(key.encode('ascii')) 113 | if obj_data is not None: 114 | obj_data = msgpack.unpackb(obj_data) 115 | obj_fts = obj_data['fts'][:, :self.obj_feat_size].astype(np.float32) 116 | obj_attrs = { 117 | 'directions': obj_data['2d_centers'], 118 | 'obj_ids': obj_data['obj_ids'], 119 | 'bboxes': np.array(obj_data['xyxy_bboxes']), 120 | } 121 | else: 122 | obj_attrs = {} 123 | obj_fts = np.zeros((0, self.obj_feat_size), dtype=np.float32) 124 | self._feature_store[key] = (obj_fts, obj_attrs) 125 | 126 | if max_objects is not None: 127 | obj_fts = obj_fts[:max_objects] 128 | obj_attrs = {k: v[:max_objects] for k, v in obj_attrs.items()} 129 | return obj_fts, obj_attrs 130 | 131 | def get_object_feature( 132 | self, scan, viewpoint, base_heading, base_elevation, angle_feat_size, 133 | max_objects=None 134 | ): 135 | obj_fts, obj_attrs = self.load_feature(scan, viewpoint, max_objects=max_objects) 136 | obj_ang_fts = np.zeros((len(obj_fts), angle_feat_size), dtype=np.float32) 137 | obj_loc_fts = np.zeros((len(obj_fts), 3), dtype=np.float32) 138 | obj_directions, obj_ids = [], [] 139 | if len(obj_fts) > 0: 140 | for k, obj_ang in enumerate(obj_attrs['directions']): 141 | obj_ang_fts[k] = angle_feature( 142 | obj_ang[0] - base_heading, obj_ang[1] - base_elevation, angle_feat_size 143 | ) 144 | x1, y1, x2, y2 = obj_attrs['bboxes'][k] 145 | h = y2 - y1 146 | w = x2 - x1 147 | obj_loc_fts[k, :2] = [h/224, w/224] 148 | obj_loc_fts[k, 2] = obj_loc_fts[k, 0] * obj_loc_fts[k, 1] 149 | obj_directions = [[convert_heading(x[0]), convert_elevation(x[1])] for x in obj_attrs['directions']] 150 | obj_ids = obj_attrs['obj_ids'] 151 | return obj_fts, obj_ang_fts, obj_loc_fts, obj_directions, obj_ids 152 | 153 | def create_object_feature_db(config: Dict, obj_feat_size: int, args): 154 | ret = {} 155 | for source in config: 156 | path = config[source] if config[source].startswith("/") else os.path.join(args.data_dir, config[source]) 157 | if source == 'reverie': 158 | ret[source] = REVERIEObjectFeatureDB( 159 | path, 160 | obj_feat_size 161 | ) 162 | elif source == 'soon': 163 | ret[source] = SOONObjectFeatureDB( 164 | path, 165 | obj_feat_size 166 | ) 167 | return ret -------------------------------------------------------------------------------- /tasks/loaders.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tools.common_utils import get_dist_info 3 | import torch 4 | import torch.distributed as dist 5 | from typing import List, Dict, Tuple, Union, Iterator 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 8 | from .datasets import load_dataset 9 | from .agents import load_agent 10 | 11 | 12 | def create_dataloaders(args, config, logger, training, device, feat_db=None, obj_feat_db=None, stage="multi"): 13 | if training==False and stage=='pretrain': 14 | return None, None 15 | 16 | dataset_cfg = copy.deepcopy(config.Dataset) 17 | dataset_cfg.update( 18 | config.Pretrain if stage=="pretrain" else config.Multi 19 | ) 20 | dataset_cfg.update(config.Feature) 21 | 22 | dataloaders = {} 23 | agents = {} 24 | if args.test_datasets is not None and not training: 25 | dataset_list = args.test_datasets 26 | else: 27 | dataset_list = copy.deepcopy(dataset_cfg.SOURCE) 28 | for k, task_name in enumerate(dataset_list): 29 | # load dataset by names 30 | dataset = load_dataset(task_name.lower(), args, dataset_cfg, training=training, logger=logger, source=task_name) 31 | 32 | # assign feature database 33 | if task_name in ["R2R", "REVERIE", "CVDN", "SOON", "EQA", "R2R_AUG", "REVERIE_AUG"]: 34 | task_feat_db = feat_db['mp3d'] 35 | elif task_name in ["ScanQA"]: 36 | task_feat_db = feat_db['scan_qa'] 37 | elif task_name in ["LLaVA"]: 38 | task_feat_db = feat_db["coco"] 39 | else: 40 | raise NotImplementedError 41 | 42 | # assign object database 43 | if args.enable_og: 44 | if task_name in ["REVERIE", "REVERIE_AUG"]: 45 | task_obj_feat_db = obj_feat_db['reverie'] 46 | elif task_name == "SOON": 47 | task_obj_feat_db = obj_feat_db['soon'] 48 | else: 49 | task_obj_feat_db = None 50 | else: 51 | task_obj_feat_db = None 52 | 53 | dataset.init_feat_db(feat_db=task_feat_db, obj_feat_db=task_obj_feat_db) 54 | 55 | 56 | logger.info(f"{task_name}: {len(dataset)} samples loaded") 57 | 58 | task_loader, pre_epoch = build_dataloader( 59 | dataset, distributed=args.distributed, 60 | training=training, batch_size=args.batch_size if training else args.val_batch_size, num_workers=args.workers 61 | ) 62 | 63 | if training: 64 | ratio = dataset_cfg.Ratio[k] 65 | dataloaders[task_name] = (task_loader, ratio, pre_epoch) 66 | else: 67 | dataloaders[task_name] = PrefetchLoader(task_loader, device=device) 68 | 69 | # load agents 70 | agents[task_name] = load_agent(task_name.lower(), args, getattr(dataset, "shortest_distances", None), getattr(dataset, "shortest_paths", None)) 71 | 72 | 73 | if training: 74 | meta_loader = MetaLoader( 75 | dataloaders, 76 | accum_steps=args.gradient_accumulation_step, 77 | distributed=args.distributed, 78 | device=device, 79 | off_batch_task=args.off_batch_task 80 | ) 81 | meta_loader = PrefetchLoader(meta_loader, device) 82 | 83 | if args.num_steps_per_epoch!=-1: 84 | meta_loader.num_batches = args.num_steps_per_epoch 85 | else: 86 | return dataloaders, agents 87 | return meta_loader, agents 88 | 89 | 90 | def build_dataloader(dataset, distributed, training, batch_size, num_workers): 91 | if distributed: 92 | size = dist.get_world_size() 93 | sampler = DistributedSampler( 94 | dataset, num_replicas=size, rank=dist.get_rank(), shuffle=training 95 | ) 96 | pre_epoch = sampler.set_epoch 97 | else: 98 | # not distributed 99 | if training: 100 | sampler: Union[ 101 | RandomSampler, SequentialSampler, DistributedSampler 102 | ] = RandomSampler(dataset) 103 | # sampler = SequentialSampler(dataset) # Debug Mode 104 | else: 105 | sampler = SequentialSampler(dataset) 106 | 107 | size = torch.cuda.device_count() if torch.cuda.is_available() else 1 108 | pre_epoch = lambda e: None 109 | 110 | # DataParallel: scale the batch size by the number of GPUs 111 | # if size > 1: 112 | # batch_size *= size 113 | 114 | loader = DataLoader( 115 | dataset, 116 | sampler=sampler, 117 | batch_size=batch_size, 118 | num_workers=num_workers, 119 | pin_memory=True, 120 | drop_last=False, 121 | collate_fn=dataset.collate_batch, 122 | ) 123 | loader.num_batches = len(loader) 124 | 125 | return loader, pre_epoch 126 | 127 | 128 | class MetaLoader: 129 | """wraps multiple data loaders""" 130 | 131 | def __init__( 132 | self, loaders, accum_steps: int = 1, distributed: bool = False, device=None, off_batch_task: bool = False, 133 | ): 134 | assert isinstance(loaders, dict) 135 | self.name2loader = {} 136 | self.name2iter = {} 137 | self.name2pre_epoch = {} 138 | self.names: List[str] = [] 139 | ratios: List[int] = [] 140 | 141 | self.num_batches = 0 142 | self.off_batch_task = off_batch_task 143 | 144 | for n, l in loaders.items(): 145 | if isinstance(l, tuple): 146 | l, r, p = l 147 | elif isinstance(l, DataLoader): 148 | r = 1 149 | p = lambda e: None 150 | else: 151 | raise ValueError() 152 | self.names.append(n) 153 | self.name2loader[n] = l 154 | self.name2iter[n] = iter(l) 155 | self.name2pre_epoch[n] = p 156 | ratios.append(r) 157 | 158 | self.num_batches += l.num_batches 159 | 160 | self.accum_steps = accum_steps 161 | self.device = device 162 | self.sampling_ratios = torch.tensor(ratios).float().to(self.device) 163 | self.distributed = distributed 164 | self.step = 0 165 | self.epoch_id = 0 166 | 167 | def get_dataset(self, name): 168 | return self.name2loader[name].dataset 169 | 170 | def __iter__(self) -> Iterator[Tuple]: 171 | """this iterator will run indefinitely""" 172 | task_id = None 173 | self.step = 0 174 | while True: 175 | # if self.step % self.accum_steps == 0: 176 | task_id = torch.multinomial(self.sampling_ratios, 1) 177 | if self.distributed and not self.off_batch_task: 178 | # make sure all process is training same task 179 | dist.broadcast(task_id, 0) 180 | self.step += 1 181 | task = self.names[task_id.cpu().item()] 182 | iter_ = self.name2iter[task] 183 | try: 184 | batch = next(iter_) 185 | except StopIteration: 186 | 187 | self.epoch_id += 1 188 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch 189 | # before creating the DataLoader iterator is necessary to make shuffling work properly 190 | # across multiple epochs. Otherwise, the same ordering will be always used. 191 | self.name2pre_epoch[task](self.epoch_id) 192 | iter_ = iter(self.name2loader[task]) 193 | batch = next(iter_) 194 | self.name2iter[task] = iter_ 195 | 196 | yield task, batch 197 | 198 | 199 | def move_to_cuda(batch: Union[List, Tuple, Dict, torch.Tensor], device: torch.device): 200 | if isinstance(batch, torch.Tensor): 201 | return batch.to(device, non_blocking=True) 202 | elif isinstance(batch, list): 203 | return [move_to_cuda(t, device) for t in batch] 204 | elif isinstance(batch, tuple): 205 | return tuple(move_to_cuda(t, device) for t in batch) 206 | elif isinstance(batch, dict): 207 | return {n: move_to_cuda(t, device) for n, t in batch.items()} 208 | return batch 209 | 210 | 211 | class PrefetchLoader(object): 212 | """ 213 | overlap compute and cuda data transfer 214 | """ 215 | def __init__(self, loader, device: torch.device): 216 | self.loader = loader 217 | self.device = device 218 | self.num_batches = self.loader.num_batches 219 | 220 | def get_dataset(self): 221 | return self.loader.dataset 222 | 223 | def __iter__(self): 224 | loader_it = iter(self.loader) 225 | self.preload(loader_it) 226 | batch = self.next(loader_it) 227 | while batch is not None: 228 | yield batch 229 | batch = self.next(loader_it) 230 | 231 | def __len__(self): 232 | return len(self.loader) 233 | 234 | def preload(self, it): 235 | try: 236 | self.batch = next(it) 237 | except StopIteration: 238 | self.batch = None 239 | return 240 | self.batch = move_to_cuda(self.batch, self.device) 241 | 242 | def next(self, it): 243 | batch = self.batch 244 | self.preload(it) 245 | return batch 246 | 247 | def __getattr__(self, name): 248 | method = self.loader.__getattribute__(name) 249 | return method 250 | 251 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/tools/__init__.py -------------------------------------------------------------------------------- /tools/common_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import time 5 | import torch 6 | import torch.distributed as dist 7 | import pickle 8 | import random 9 | import shutil 10 | from easydict import EasyDict 11 | import numpy as np 12 | 13 | 14 | def worker_init_fn(worker_id, seed=666): 15 | if seed is not None: 16 | random.seed(seed + worker_id) 17 | np.random.seed(seed + worker_id) 18 | torch.manual_seed(seed + worker_id) 19 | torch.cuda.manual_seed(seed + worker_id) 20 | torch.cuda.manual_seed_all(seed + worker_id) 21 | 22 | 23 | def get_dist_info(return_gpu_per_machine=False): 24 | if torch.__version__ < '1.0': 25 | initialized = dist._initialized 26 | else: 27 | if dist.is_available(): 28 | initialized = dist.is_initialized() 29 | else: 30 | initialized = False 31 | if initialized: 32 | rank = dist.get_rank() 33 | world_size = dist.get_world_size() 34 | else: 35 | rank = 0 36 | world_size = 1 37 | 38 | if return_gpu_per_machine: 39 | gpu_per_machine = torch.cuda.device_count() 40 | return rank, world_size, gpu_per_machine 41 | 42 | return rank, world_size 43 | 44 | 45 | def create_logger(log_file=None, rank=0, log_level=logging.INFO): 46 | logger = logging.getLogger(__name__) 47 | logger.setLevel(log_level if rank == 0 else 'ERROR') 48 | formatter = logging.Formatter('%(asctime)s %(levelname)5s %(message)s') 49 | console = logging.StreamHandler() 50 | console.setLevel(log_level if rank == 0 else 'ERROR') 51 | console.setFormatter(formatter) 52 | logger.addHandler(console) 53 | if log_file is not None: 54 | file_handler = logging.FileHandler(filename=log_file) 55 | file_handler.setLevel(log_level if rank == 0 else 'ERROR') 56 | file_handler.setFormatter(formatter) 57 | logger.addHandler(file_handler) 58 | logger.propagate = False 59 | return logger 60 | 61 | 62 | def log_config_to_file(cfg, pre='cfg', logger=None): 63 | for key, val in cfg.items(): 64 | if isinstance(cfg[key], EasyDict): 65 | logger.info('----------- %s -----------' % (key)) 66 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 67 | continue 68 | logger.info('%s.%s: %s' % (pre, key, val)) 69 | 70 | 71 | def summary_model(model,level=2): 72 | message = "" 73 | if level < 1: 74 | return message 75 | for name1, module1 in model.named_children(): 76 | message += "[1] {}\n".format(name1) 77 | if level > 1: 78 | for name2, module2 in module1.named_children(): 79 | message += "- [2] {}\n".format(name2) 80 | if level > 2: 81 | for name3, module3 in module2.named_children(): 82 | message += " +++ [3] {}\n".format(name3) 83 | if level > 3: 84 | for name4, module4 in module3.named_children(): 85 | message += " +++++ [4] {}\n".format(name4) 86 | return message 87 | 88 | 89 | def get_world_size(): 90 | if not dist.is_available(): 91 | return 1 92 | if not dist.is_initialized(): 93 | return 1 94 | return dist.get_world_size() 95 | 96 | 97 | def get_rank(): 98 | if not dist.is_available(): 99 | return 0 100 | if not dist.is_initialized(): 101 | return 0 102 | return dist.get_rank() 103 | 104 | 105 | def all_gather(data): 106 | """ 107 | Run all_gather on arbitrary picklable data (not necessarily tensors) 108 | Args: 109 | data: any picklable object 110 | Returns: 111 | list[data]: list of data gathered from each rank 112 | """ 113 | world_size = get_world_size() 114 | if world_size == 1: 115 | return [data] 116 | 117 | # serialized to a Tensor 118 | origin_size = None 119 | if not isinstance(data, torch.Tensor): 120 | buffer = pickle.dumps(data) 121 | storage = torch.ByteStorage.from_buffer(buffer) 122 | tensor = torch.ByteTensor(storage).to("cuda") 123 | else: 124 | origin_size = data.size() 125 | tensor = data.reshape(-1) 126 | 127 | tensor_type = tensor.dtype 128 | 129 | # obtain Tensor size of each rank 130 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 131 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 132 | dist.all_gather(size_list, local_size) 133 | size_list = [int(size.item()) for size in size_list] 134 | max_size = max(size_list) 135 | 136 | # receiving Tensor from all ranks 137 | # we pad the tensor because torch all_gather does not support 138 | # gathering tensors of different shapes 139 | tensor_list = [] 140 | for _ in size_list: 141 | tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type)) 142 | if local_size != max_size: 143 | padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type) 144 | tensor = torch.cat((tensor, padding), dim=0) 145 | dist.all_gather(tensor_list, tensor) 146 | 147 | data_list = [] 148 | for size, tensor in zip(size_list, tensor_list): 149 | if origin_size is None: 150 | buffer = tensor.cpu().numpy().tobytes()[:size] 151 | data_list.append(pickle.loads(buffer)) 152 | else: 153 | buffer = tensor[:size] 154 | data_list.append(buffer) 155 | 156 | if origin_size is not None: 157 | new_shape = [-1] + list(origin_size[1:]) 158 | resized_list = [] 159 | for data in data_list: 160 | # suppose the difference of tensor size exist in first dimension 161 | data = data.reshape(new_shape) 162 | resized_list.append(data) 163 | 164 | return resized_list 165 | else: 166 | return data_list 167 | -------------------------------------------------------------------------------- /tools/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def is_global_master(args): 6 | return args.rank == 0 7 | 8 | 9 | def is_local_master(args): 10 | return args.local_rank == 0 11 | 12 | 13 | def is_master(args, local=False): 14 | return is_local_master(args) if local else is_global_master(args) 15 | 16 | 17 | def is_using_horovod(): 18 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 19 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 20 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 21 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 22 | if all([var in os.environ for var in ompi_vars]) or all( 23 | [var in os.environ for var in pmi_vars] 24 | ): 25 | return True 26 | else: 27 | return False 28 | 29 | 30 | def is_using_distributed(): 31 | if "WORLD_SIZE" in os.environ: 32 | return int(os.environ["WORLD_SIZE"]) > 1 33 | if "SLURM_NTASKS" in os.environ: 34 | return int(os.environ["SLURM_NTASKS"]) > 1 35 | return False 36 | 37 | 38 | def world_info_from_env(): 39 | local_rank = 0 40 | for v in ( 41 | "LOCAL_RANK", 42 | "MPI_LOCALRANKID", 43 | "SLURM_LOCALID", 44 | "OMPI_COMM_WORLD_LOCAL_RANK", 45 | ): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | from torch import distributed as torch_dist 63 | import subprocess 64 | def _init_dist_slurm(backend, port=None) -> None: 65 | """Initialize slurm distributed training environment. 66 | 67 | If argument ``port`` is not specified, then the master port will be system 68 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 69 | environment variable, then a default port ``29500`` will be used. 70 | 71 | Args: 72 | backend (str): Backend of torch.distributed. 73 | port (int, optional): Master port. Defaults to None. 74 | """ 75 | proc_id = int(os.environ['SLURM_PROCID']) 76 | ntasks = int(os.environ['SLURM_NTASKS']) 77 | node_list = os.environ['SLURM_NODELIST'] 78 | # Not sure when this environment variable could be None, so use a fallback 79 | local_rank_env = os.environ.get('SLURM_LOCALID', None) 80 | if local_rank_env is not None: 81 | local_rank = int(local_rank_env) 82 | else: 83 | num_gpus = torch.cuda.device_count() 84 | local_rank = proc_id % num_gpus 85 | torch.cuda.set_device(local_rank) 86 | addr = subprocess.getoutput( 87 | f'scontrol show hostname {node_list} | head -n1') 88 | # specify master port 89 | if port is not None: 90 | os.environ['MASTER_PORT'] = str(port) 91 | elif 'MASTER_PORT' in os.environ: 92 | pass # use MASTER_PORT in the environment variable 93 | else: 94 | # 29500 is torch.distributed default port 95 | os.environ['MASTER_PORT'] = '29500' 96 | # use MASTER_ADDR in the environment variable if it already exists 97 | if 'MASTER_ADDR' not in os.environ: 98 | os.environ['MASTER_ADDR'] = addr 99 | os.environ['WORLD_SIZE'] = str(ntasks) 100 | os.environ['LOCAL_RANK'] = str(local_rank) 101 | os.environ['RANK'] = str(proc_id) 102 | torch_dist.init_process_group(backend=backend) 103 | 104 | 105 | def init_distributed_device(args): 106 | # Distributed training = training on more than one GPU. 107 | # Works in both single and multi-node scenarios. 108 | args.distributed = False 109 | args.world_size = 1 110 | args.rank = 0 # global rank 111 | args.local_rank = 0 112 | # if args.horovod: 113 | # assert hvd is not None, "Horovod is not installed" 114 | # hvd.init() 115 | # args.local_rank = int(hvd.local_rank()) 116 | # args.rank = hvd.rank() 117 | # args.world_size = hvd.size() 118 | # args.distributed = True 119 | # os.environ["LOCAL_RANK"] = str(args.local_rank) 120 | # os.environ["RANK"] = str(args.rank) 121 | # os.environ["WORLD_SIZE"] = str(args.world_size) 122 | # elif is_using_distributed(): 123 | if is_using_distributed(): 124 | if "SLURM_PROCID" in os.environ: 125 | # DDP via SLURM 126 | args.local_rank, args.rank, args.world_size = world_info_from_env() 127 | # SLURM var -> torch.distributed vars in case needed 128 | os.environ["LOCAL_RANK"] = str(args.local_rank) 129 | os.environ["RANK"] = str(args.rank) 130 | os.environ["WORLD_SIZE"] = str(args.world_size) 131 | torch.distributed.init_process_group( 132 | backend=args.dist_backend, 133 | init_method=args.dist_url, 134 | world_size=args.world_size, 135 | rank=args.rank, 136 | ) 137 | else: 138 | # DDP via torchrun, torch.distributed.launch 139 | args.local_rank, _, _ = world_info_from_env() 140 | torch.distributed.init_process_group( 141 | backend=args.dist_backend, init_method=args.dist_url 142 | ) 143 | args.world_size = torch.distributed.get_world_size() 144 | args.rank = torch.distributed.get_rank() 145 | args.distributed = True 146 | else: 147 | DistSingleGPU = False 148 | if DistSingleGPU: 149 | # TODO in S2: "torchrun --nnodes=1 --nproc_per_node=1" Bug 150 | # DistSingleGPU = False 151 | os.environ['MASTER_ADDR'] = '127.0.0.1' 152 | from socket import socket 153 | with socket() as s: 154 | s.bind(('', 0)) 155 | free_port = str(s.getsockname()[1]) 156 | os.environ['MASTER_PORT'] = free_port 157 | # TODO 注释此处 for debug 158 | # needed to run on single gpu 159 | torch.distributed.init_process_group( 160 | backend=args.dist_backend, 161 | init_method=args.dist_url, 162 | world_size=1, 163 | rank=0, 164 | ) 165 | print('[INFO] single gpu run') 166 | # args.distributed = False 167 | else: 168 | args.distributed = False 169 | print('[INFO] single gpu: Not distributed') 170 | 171 | if torch.cuda.is_available(): 172 | if args.distributed and not args.no_set_device_rank: 173 | device = "cuda:%d" % args.local_rank 174 | else: 175 | device = "cuda:0" 176 | torch.cuda.set_device(device) 177 | else: 178 | device = "cpu" 179 | args.device = device 180 | device = torch.device(device) 181 | if args.distributed: 182 | print('[INFO] distributed: True') 183 | return device 184 | -------------------------------------------------------------------------------- /tools/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/tools/evaluation/__init__.py -------------------------------------------------------------------------------- /tools/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /tools/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /tools/evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /tools/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /tools/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | 12 | class Cider: 13 | """ 14 | Main Class to compute the CIDEr metric 15 | 16 | """ 17 | def __init__(self, gts=None, n=4, sigma=6.0): 18 | # set cider to sum over 1 to 4-grams 19 | self._n = n 20 | # set the standard deviation parameter for gaussian penalty 21 | self._sigma = sigma 22 | self.doc_frequency = None 23 | self.ref_len = None 24 | if gts is not None: 25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 26 | self.doc_frequency = tmp_cider.doc_frequency 27 | self.ref_len = tmp_cider.ref_len 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | Main function to compute CIDEr score 32 | :param gts (dict) : dictionary with key and value 33 | res (dict) : dictionary with key and value 34 | :return: cider (float) : computed CIDEr score for the corpus 35 | """ 36 | assert(gts.keys() == res.keys()) 37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 38 | ref_len=self.ref_len) 39 | return cider_scorer.compute_score() 40 | 41 | def __str__(self): 42 | return 'CIDEr' 43 | -------------------------------------------------------------------------------- /tools/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /tools/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /tools/evaluation/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/tools/evaluation/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /tools/evaluation/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/tools/evaluation/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /tools/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | import requests 9 | 10 | def download_from_url(url, path): 11 | """Download file, with logic (from tensor2tensor) for Google Drive""" 12 | if 'drive.google.com' not in url: 13 | print('Downloading %s; may take a few minutes' % url) 14 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 15 | with open(path, "wb") as file: 16 | file.write(r.content) 17 | return 18 | print('Downloading from Google Drive; may take a few minutes') 19 | confirm_token = None 20 | session = requests.Session() 21 | response = session.get(url, stream=True) 22 | for k, v in response.cookies.items(): 23 | if k.startswith("download_warning"): 24 | confirm_token = v 25 | 26 | if confirm_token: 27 | url = url + "&confirm=" + confirm_token 28 | response = session.get(url, stream=True) 29 | 30 | chunk_size = 16 * 1024 31 | with open(path, "wb") as f: 32 | for chunk in response.iter_content(chunk_size): 33 | if chunk: 34 | f.write(chunk) 35 | 36 | 37 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 38 | METEOR_JAR = 'meteor-1.5.jar' 39 | 40 | class Meteor: 41 | def __init__(self): 42 | base_path = os.path.dirname(os.path.abspath(__file__)) 43 | jar_path = os.path.join(base_path, METEOR_JAR) 44 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 45 | if not os.path.isfile(jar_path): 46 | if not os.path.isfile(gz_path): 47 | download_from_url(METEOR_GZ_URL, gz_path) 48 | tar = tarfile.open(gz_path, "r") 49 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 50 | tar.close() 51 | os.remove(gz_path) 52 | 53 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 54 | '-', '-', '-stdio', '-l', 'en', '-norm'] 55 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 56 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 57 | stdin=subprocess.PIPE, \ 58 | stdout=subprocess.PIPE, \ 59 | stderr=subprocess.PIPE) 60 | # Used to guarantee thread safety 61 | self.lock = threading.Lock() 62 | 63 | def compute_score(self, gts, res): 64 | assert(gts.keys() == res.keys()) 65 | imgIds = gts.keys() 66 | scores = [] 67 | 68 | eval_line = 'EVAL' 69 | self.lock.acquire() 70 | for i in imgIds: 71 | assert(len(res[i]) == 1) 72 | stat = self._stat(res[i][0], gts[i]) 73 | eval_line += ' ||| {}'.format(stat) 74 | 75 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 76 | self.meteor_p.stdin.flush() 77 | for i in range(0,len(imgIds)): 78 | scores.append(float(self.meteor_p.stdout.readline().strip())) 79 | score = float(self.meteor_p.stdout.readline().strip()) 80 | self.lock.release() 81 | 82 | return score, scores 83 | 84 | def _stat(self, hypothesis_str, reference_list): 85 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 86 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 87 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 88 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 89 | self.meteor_p.stdin.flush() 90 | raw = self.meteor_p.stdout.readline().decode().strip() 91 | numbers = [str(int(float(n))) for n in raw.split()] 92 | return ' '.join(numbers) 93 | 94 | def __del__(self): 95 | self.lock.acquire() 96 | self.meteor_p.stdin.close() 97 | self.meteor_p.kill() 98 | self.meteor_p.wait() 99 | self.lock.release() 100 | 101 | def __str__(self): 102 | return 'METEOR' 103 | -------------------------------------------------------------------------------- /tools/evaluation/meteor/test_meteor.py: -------------------------------------------------------------------------------- 1 | from meteor import Meteor 2 | 3 | 4 | meteor_eval = Meteor() 5 | ref = {0: [u'a shoe rack with some shoes and a dog sleeping on them .', 6 | u'a small dog is curled up on top of the shoes .', 7 | u'various slides and other footwear rest in a metal basket outdoors .', 8 | u'a dog sleeping on a show rack in the shoes .', 9 | u'this wire metal rack holds several pairs of shoes and sandals .'], } 10 | hypo = {0: [u'a large white plate with a white plate with a white plate .'],} 11 | result = meteor_eval.compute_score(ref, hypo) 12 | print(result) -------------------------------------------------------------------------------- /tools/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /tools/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /tools/evaluation/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaVi-Lab/NaviLLM/e069f46ed98affb221d58715a785613622e11145/tools/evaluation/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /tools/evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /tools/optims.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import glob 4 | from transformers import get_constant_schedule_with_warmup 5 | 6 | 7 | def check_checkpoint(args, model, optimizer, lr_scheduler, logger) -> int: 8 | resume_from_epoch = 0 9 | if args.resume_from_checkpoint is not None: 10 | if args.rank == 0: 11 | logger.info(f"Loading checkpoint from {args.resume_from_checkpoint}") 12 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 13 | model_state_dict = model.state_dict() 14 | state_disk = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()} 15 | update_model_state = {} 16 | for key, val in state_disk.items(): 17 | if key in model_state_dict and model_state_dict[key].shape == val.shape: 18 | update_model_state[key] = val 19 | else: 20 | logger.info( 21 | 'Ignore weight %s: %s' % (key, str(val.shape)) 22 | ) 23 | msg = model.load_state_dict(update_model_state, strict=False) 24 | logger.info(msg) 25 | 26 | if 'epoch' in checkpoint: 27 | resume_from_epoch = checkpoint['epoch'] + 1 28 | logger.info("Resume from Epoch {}".format(resume_from_epoch)) 29 | optimizer.load_state_dict(checkpoint['optimizer']) 30 | 31 | 32 | return resume_from_epoch 33 | 34 | 35 | def dist_models(args, model, logger): 36 | logger.info("*************** init model *************** ") 37 | # args.rank: global rank. 38 | total_gpus = torch.cuda.device_count() 39 | device_id = args.rank % total_gpus 40 | 41 | model.to(device_id) 42 | 43 | optimizer = torch.optim.AdamW([p for n, p in model.named_parameters() if p.requires_grad], lr=args.lr) 44 | 45 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.num_warmup_steps) 46 | 47 | resume_from_epoch = check_checkpoint( 48 | args, model, optimizer, lr_scheduler, logger, 49 | ) 50 | param_sums = sum(p.numel() for p in model.parameters() if p.requires_grad) 51 | logger.info("model initialized with {:.2f} M trainable parameters".format(param_sums/1000**2)) 52 | if args.distributed: 53 | from torch.nn.parallel import DistributedDataParallel as DDP 54 | model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 55 | 56 | # args.batch_size: BATCH_SIZE_PER_GPU 57 | logger.info('Training in distributed mode : total_batch_size: %d' % (total_gpus * args.batch_size)) 58 | else: 59 | total_gpus = 1 60 | logger.info('Training with a single process') 61 | 62 | return model, optimizer, resume_from_epoch, lr_scheduler 63 | 64 | 65 | def save_checkpoint(model, model_path, optimizer=None, epoch: int=0, save_states: bool=False): 66 | if hasattr(model, 'module'): 67 | model = model.module 68 | 69 | state_dict = { 70 | "model_state_dict": model.state_dict() 71 | } 72 | if save_states: 73 | state_dict.update({ 74 | "optimizer": optimizer.state_dict(), 75 | "epoch": epoch, 76 | }) 77 | 78 | torch.save(state_dict, model_path) -------------------------------------------------------------------------------- /tools/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | import torch 5 | import os 6 | import datetime 7 | import yaml 8 | from easydict import EasyDict 9 | from .distributed import world_info_from_env, init_distributed_device 10 | from .common_utils import create_logger, log_config_to_file 11 | from pathlib import Path 12 | 13 | 14 | def random_seed(seed=0, rank=0): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | 21 | 22 | def read_args(): 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--data_dir', type=str, default='data', help="dataset root path") 26 | parser.add_argument('--cfg_file', type=str, default=None, help='dataset configs', required=True) 27 | parser.add_argument('--pretrained_model_name_or_path', default=None, type=str, required=True, help="path to tokenizer") 28 | 29 | # local fusion 30 | parser.add_argument('--off_batch_task', action='store_true', default=False, help="whether all process is training same task") 31 | parser.add_argument('--debug', action="store_true", help="debug mode") 32 | parser.add_argument('--seed', type=int, default=0) 33 | 34 | parser.add_argument("--num_epochs", type=int, default=30) 35 | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="path to ckpt to resume from") 36 | parser.add_argument("--from_scratch", action="store_true") 37 | parser.add_argument("--batch_size", type=int, default=1) 38 | parser.add_argument("--val_batch_size", type=int, default=2) 39 | parser.add_argument("--lr", default=1e-5, type=float) 40 | parser.add_argument("--feat_dropout", type=float, default=0.4) 41 | parser.add_argument("--num_warmup_steps", type=int, default=0) 42 | parser.add_argument("--num_steps_per_epoch", type=int, default=-1) 43 | parser.add_argument("--gradient_accumulation_step", type=int, default=2) 44 | parser.add_argument( 45 | "--precision", 46 | choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], 47 | default="fp32", 48 | help="Floating point precision.", 49 | ) 50 | parser.add_argument("--workers", type=int, default=0) 51 | 52 | # distributed training args 53 | parser.add_argument('--world_size', type=int, default=0, help='number of gpus') 54 | parser.add_argument('--local_rank', type=int, default=-1) 55 | parser.add_argument( 56 | "--dist-url", 57 | default="env://", 58 | type=str, 59 | help="url used to set up distributed training", 60 | ) 61 | parser.add_argument( 62 | "--dist-backend", default="nccl", type=str, help="distributed backend" 63 | ) 64 | parser.add_argument( 65 | "--horovod", 66 | default=False, 67 | action="store_true", 68 | help="Use horovod for distributed training.", 69 | ) 70 | parser.add_argument( 71 | "--no-set-device-rank", 72 | default=False, 73 | action="store_true", 74 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", 75 | ) 76 | 77 | # Save checkpoints 78 | parser.add_argument('--output_dir', type=str, default=None, required=True, help="output logs and ckpts") 79 | parser.add_argument("--max_saved_checkpoints", type=int, default=0) 80 | parser.add_argument("--save_ckpt_per_epochs", type=int, default=10) 81 | parser.add_argument("--save_latest_states", action='store_true') 82 | parser.add_argument("--save_pred_results", action="store_true") 83 | parser.add_argument("--save_detail_results", action="store_true") 84 | 85 | # training 86 | parser.add_argument('--mode', type=str, default="train", choices=["train", "test"]) 87 | parser.add_argument("--stage", type=str, required=True, choices=["pretrain", "multi"]) 88 | parser.add_argument('--ignoreid', default=-100, type=int, help="criterion: ignore label") 89 | parser.add_argument('--enable_og', action='store_true', default=False, help="object grounding task") 90 | parser.add_argument("--enable_summarize", action="store_true", help="perform EQA or generate instructions") 91 | parser.add_argument("--enable_fgr2r", action="store_true", help="perform fgr2r for R2R") 92 | parser.add_argument("--gen_loss_coef", type=float, default=1.) 93 | parser.add_argument("--obj_loss_coef", type=float, default=1.) 94 | parser.add_argument("--teacher_forcing_coef", type=float, default=1.) 95 | parser.add_argument("--fuse_obj", action="store_true", help="whether fuse object features for REVERIE and SOON") 96 | 97 | # datasets 98 | parser.add_argument("--multi_endpoints", type=int, default=1) 99 | parser.add_argument("--path_type", type=str, default="trusted_path", choices=["planner_path", "trusted_path"]) 100 | 101 | # evaluation 102 | parser.add_argument('--test_datasets', type=str, default=None, nargs='+') 103 | parser.add_argument('--validation_split', type=str, default="val_unseen", help="validation split: val_seen, val_unseen, test") 104 | parser.add_argument("--do_sample", action="store_true", help="do_sample in evaluation") 105 | parser.add_argument("--temperature", type=float, default=1.) 106 | 107 | 108 | # others 109 | parser.add_argument( 110 | "--max_datapoints", 111 | default=None, 112 | type=int, 113 | help="The number of datapoints used for debug." 114 | ) 115 | 116 | args = parser.parse_args() 117 | 118 | args.local_rank, args.rank, args.world_size = world_info_from_env() 119 | 120 | ###################### configurations ######################### 121 | # single-gpu or multi-gpu 122 | device_id = init_distributed_device(args) 123 | global_cfg = EasyDict(yaml.safe_load(open(str(Path(args.cfg_file).resolve())))) 124 | 125 | args.data_dir = Path(args.data_dir).resolve() 126 | 127 | # off-line image features from Matterport3D 128 | args.image_feat_size = global_cfg.Feature.image_feat_size 129 | args.obj_feat_size = global_cfg.Feature.obj_feat_size 130 | 131 | ############# Configurations ############### 132 | args.angle_feat_size = global_cfg.Feature.angle_feat_size 133 | args.enc_full_graph = global_cfg.Model.enc_full_graph 134 | args.expert_policy = global_cfg.Model.expert_policy 135 | args.num_pano_layers = global_cfg.Model.num_pano_layers 136 | 137 | os.makedirs(args.output_dir, exist_ok=True) 138 | log_file = Path(args.output_dir) / 'log.txt' 139 | 140 | logger = create_logger(log_file, rank=args.rank) 141 | logger.info('**********************Start logging**********************') 142 | gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL' 143 | logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list) 144 | for key, val in vars(args).items(): 145 | logger.info('{:16} {}'.format(key, val)) 146 | log_config_to_file(global_cfg, logger=logger) 147 | 148 | print(" + rank: {}, + device_id: {}".format(args.local_rank, device_id)) 149 | print(f"Start running training on rank {args.rank}.") 150 | 151 | if os.path.exists(os.path.join(args.output_dir, "latest_states.pt")): 152 | state_path = os.path.join(args.output_dir, "latest_states.pt") 153 | logger.info("Resume checkponit from {}".format(state_path)) 154 | args.resume_from_checkpoint = state_path 155 | 156 | return args, global_cfg, logger, device_id 157 | -------------------------------------------------------------------------------- /tools/trie.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The OFA-Sys Team. 2 | # All rights reserved. 3 | # This source code is licensed under the Apache 2.0 license 4 | # found in the LICENSE file in the root directory. 5 | 6 | from collections import defaultdict 7 | from typing import List 8 | 9 | 10 | class TreeNode(): 11 | def __init__(self): 12 | self.child = defaultdict(TreeNode) 13 | 14 | class Trie: 15 | 16 | def __init__(self, bos, eos): 17 | self.root = TreeNode() 18 | self.bos = bos 19 | self.eos = eos 20 | 21 | def insert(self, word: List[int]): 22 | cur = self.root 23 | for c in word: 24 | cur = cur.child[c] 25 | 26 | def get_child_index(self, cur: TreeNode) -> List[int]: 27 | if len(cur.child)==0: 28 | return [self.eos] 29 | return list(cur.child.keys()) 30 | 31 | def get_next_node(self, cur: TreeNode, w: int) -> TreeNode: 32 | if len(cur.child)==0: 33 | return cur 34 | return cur.child[w] --------------------------------------------------------------------------------