├── .gitignore ├── LICENSE ├── README.md ├── assets └── ga.png ├── dataset ├── __init__.py ├── base_dataset.py ├── dataloader.py ├── dataset_train.py └── dataset_val.py ├── demo ├── demo.py └── run_demo.sh ├── models ├── __init__.py ├── configuration_llama.py ├── graph3dllm.py ├── helpers.py ├── load_llama.py ├── modeling_llama.py ├── moe │ ├── __init__.py │ ├── layer.py │ └── moe_lora.py ├── position_embedding.py └── transformer_vanilla │ ├── __init__.py │ ├── mhsa.py │ ├── self_attention.py │ └── transformer_block.py ├── others ├── analyze_grounding_results.py ├── calc_scanrefer_grounding_acc.py ├── eval_offline.py ├── extract_target_noun.py ├── gpt_generate.py ├── ground_visualize.py ├── llama_tmp.py ├── modify.py ├── prepare_anno_stage1.py ├── prepare_captions_noun.py ├── prepare_describe.py ├── prepare_eval_json.py ├── prepare_identifier_rich.py ├── prepare_multi3dref.py ├── prepare_obj_align_data.py ├── prepare_obj_caption.py ├── prepare_ref_captions.py ├── prepare_referit_anno_stage1.py ├── prepare_scanqa.py ├── prepare_scanrefer_grounding_train.py ├── prepare_scene_align_data.py ├── prepare_scene_level_dataset.py ├── prepare_sqa3d.py ├── prepare_train_stage3.py ├── process_annos.py ├── process_preds.py ├── process_vil3dref_multichoice.py ├── process_vil3dref_results.py ├── run_chat.sh ├── run_eval.sh ├── run_generate.sh ├── tmp.py ├── tmp2.py ├── visualize.py └── viz.py ├── preprocess ├── README.md ├── prepare_filtered_mask3d_gnn_data.py ├── prepare_gnn_data.py ├── prepare_mask3d_data.py ├── prepare_mask3d_img_feat.py ├── prepare_multi3dref_annos.py ├── prepare_multi3dref_location_annos.py ├── prepare_nr3d_annos.py ├── prepare_nr3dcaption_annos.py ├── prepare_objalign_annos.py ├── prepare_scan2cap_annos.py ├── prepare_scan2cap_location_annos.py ├── prepare_scannet_attributes.py ├── prepare_scannet_attributes_clasp.py ├── prepare_scannet_caption_annos.py ├── prepare_scannet_mask3d_attributes.py ├── prepare_scannet_region_caption_annos.py ├── prepare_scanqa_annos.py ├── prepare_scanrefer_annos.py ├── prepare_scanrefer_location_annos.py ├── prepare_sqa3d_annos.py ├── prepare_sr3d_annos.py ├── process_scannet_data.py └── run_prepare.sh ├── prompts ├── concise_description.txt ├── concise_description_objxx.txt ├── conv_description.txt ├── dataset_generation │ ├── conversation.txt │ ├── detail.txt │ └── textualize_obj.txt ├── detailed_description.txt ├── grounding_answer_templates.txt ├── grounding_prompts.txt ├── instruction.txt ├── nr3d_caption_templates.txt ├── object_caption_templates.txt ├── prompts.py ├── scanrefer_caption_templates.txt ├── scene_align_template.txt ├── score_template.txt ├── score_template_old.txt ├── system.txt └── system_backup.txt ├── requirements.txt ├── scripts ├── config-gt-pretrain.py ├── config.py ├── config_3dgraphllm.py ├── run.sh └── run_gt_pretrain.sh ├── tasks ├── shared_utils.py └── train.py └── utils ├── __init__.py ├── basic_utils.py ├── box_utils.py ├── config.py ├── config_utils.py ├── distributed.py ├── easydict.py ├── eval.py ├── eval_tmp.py ├── helper.py ├── logger.py ├── optimizer.py ├── pc_util.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | outputs/ 164 | annotations/ 165 | wandb/ 166 | vicuna-7b-v1.5/ 167 | utils/capeval/ 168 | annotations 169 | Meta-Llama-3-8B-Instruct 170 | *.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Chat-Scene 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DGraphLLM 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2412.18450-b31b1b.svg)](https://arxiv.org/abs/2412.18450) 4 | [![Huggingace](https://img.shields.io/badge/Weights-3DGraphLLM-blue?logo=HuggingFace)](https://huggingface.co/wingrune/3DGraphLLM) 5 | 6 | In this work, we propose 3DGraphLLM, a method for constructing a learnable representation of a 3D scene graph, which serves as input for LLMs to perform 3D vision-language tasks. 7 | 8 |

9 | 10 |

11 | 12 | ## News 13 | 14 | [2024.12] We release 3DGraphLLM pre-training on GT instance segmentation scene graphs 15 | 16 | [2024.12] We release 3DGraphLLM [paper](https://arxiv.org/abs/2412.18450) [code](https://github.com/CognitiveAISystems/3DGraphLLM) 17 | 18 | ### 🔥 Semantic relations boost LLM performance on 3D Referred Object Grounding and Dense Scene Captioning tasks 19 | 20 | 21 | 22 | | | [ScanRefer](https://github.com/daveredrum/ScanRefer) | | [Multi3dRefer](https://github.com/3dlg-hcvc/M3DRef-CLIP) | | [Scan2Cap](https://github.com/daveredrum/Scan2Cap) | | [ScanQA](https://github.com/ATR-DBI/ScanQA) | | [SQA3D](https://github.com/SilongYong/SQA3D) | 23 | |:----: |:---------: |:-------: |:------: |:------: |:---------: |:----------: |:------------: |:------: |:-----: | 24 | | | Acc@0.25 | Acc@0.5 | F1@0.25 | F1@0.5 | CIDEr@0.5 | B-4@0.5 | CIDEr | B-4 | EM | 25 | | [Chat-Scene](https://github.com/ZzZZCHS/Chat-Scene/tree/dev) | 55.5 | 50.2 | 57.1 | 52.3 | 77.1 | 36.3 | **87.7** | **14.3** | 54.6 | 26 | | 3DGraphLLM Vicuna-1.5 | 57.0 | 51.3 | 60.1 | 55.4 | 81.2 | 36.3 | 87.6 | 12.1 | 53.1 | 27 | **3DGraphLLM LLAMA3-8B** | **60.2** | **54.6** | **63.0** | **58.2** | **82.9** | **37.8** | 83.1 | 12.5 | **55.2** | 28 | 29 | 30 | ## 🔨 Preparation 31 | 32 | - Prepare the environment: 33 | 34 | ```shell 35 | conda create -n 3dgraphllm python=3.9.17 36 | conda activate 3dgraphllm 37 | conda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia 38 | pip install -r requirements.txt 39 | ``` 40 | - If you don't have root permissions to install java (needed for pycocoeval scripts for metrics such as BLEU and CIDER), install it with conda: 41 | 42 | ``` 43 | conda install -c conda-forge openjdk 44 | ``` 45 | 46 | 47 | - Download LLM backbone: 48 | - We use LLAMA3-8B-Instruct in our experiments, which can be downloaded from [Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct). 49 | 50 | - Change the `llama_model_path` in [config.py](./scripts/config.py) to the path of `LLAMA3-8B-Instruct`. 51 | 52 | 53 | - Annotations and extracted features: 54 | 55 | Please follow the instructions in [preprocess](preprocess/). 56 | 57 | 58 | ## 🤖 Training and Inference 59 | 60 | - Pre-training on GT instance segmentation scene graphs. 61 | - Modify [run_gt_pretrain.sh](scripts/run_gt_pretrain.sh): 62 | ```python 63 | train_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref#nr3d_caption#obj_align" 64 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref" 65 | evaluate=False 66 | ``` 67 | 68 |
69 | Explanation of "train_tag" and "val_tag" 70 | 71 | - Use `#` to seperate different datasets 72 | 73 | - Datasets: 74 | - `scanrefer`: [ScanRefer](https://github.com/daveredrum/ScanRefer) Dataset 75 | - `scan2cap`: [Scan2Cap](https://github.com/daveredrum/Scan2Cap) Dataset 76 | - `scanqa`: [ScanQA](https://github.com/ATR-DBI/ScanQA) Dataset 77 | - `sqa3d`: [SQA3D](https://github.com/SilongYong/SQA3D) Dataset 78 | - `multi3dref`: [Multi3dRefer](https://github.com/3dlg-hcvc/M3DRef-CLIP) Dataset 79 | - `nr3d_caption`: A captioning dataset originated from [Nr3D](https://github.com/referit3d/referit3d). 80 | - `obj_align`: A dataset originated from ScanRefer to align the object identifiers with object tokens. 81 | 82 |
83 | - Run: `bash scripts/run_gt_pretrain.sh` 84 | 85 | - Training 86 | - Modify [run.sh](scripts/run.sh): 87 | ```python 88 | train_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref#nr3d_caption#obj_align" 89 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref" 90 | evaluate=False 91 | pretrained_path="outputs/llama3-8b-gt-pretrain-2/ckpt_00_28927.pth" 92 | ``` 93 | - Run: `bash scripts/run.sh` 94 | 95 | 96 | - Inference 97 | 98 | - Modify [run.sh](scripts/run.sh): 99 | 100 | ```python 101 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref" 102 | evaluate=True 103 | pretrained_path="/path/to/pretrained_model.pth" 104 | ``` 105 | 106 | - Run: `bash scripts/run.sh` 107 | 108 | ## 🚀 Demo 109 | 110 | - Run: `bash demo/run_demo.sh`. You will be prompted to ask different queries about Scene 435 of ScanNet. 111 | 112 | 113 | ## 📪 Contact 114 | 115 | If you have any questions about the project, please open an issue in this repository or send an email to [Tatiana Zemskova](zemskova@airi.net). 116 | 117 | ## 📑 Citation 118 | 119 | If you find this work helpful, please consider citing our work as: 120 | 121 | ``` 122 | @misc{zemskova20243dgraphllm, 123 | title={3DGraphLLM: Combining Semantic Graphs and Large Language Models for 3D Scene Understanding}, 124 | author={Tatiana Zemskova and Dmitry Yudin}, 125 | year={2024}, 126 | eprint={2412.18450}, 127 | archivePrefix={arXiv}, 128 | primaryClass={cs.CV}, 129 | url={https://arxiv.org/abs/2412.18450}, 130 | } 131 | ``` 132 | 133 | 134 | ## 😊 Acknowledgement 135 | 136 | Thanks to the open source of the following projects: 137 | 138 | [Chat-Scene](https://github.com/ZzZZCHS/Chat-Scene/tree/dev) 139 | -------------------------------------------------------------------------------- /assets/ga.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/assets/ga.png -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import ConcatDataset, DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms import InterpolationMode 5 | 6 | from dataset.dataloader import MetaLoader 7 | from dataset.dataset_train import TrainDataset 8 | from dataset.dataset_val import ValDataset 9 | 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def create_dataset(config): 15 | if config.evaluate: 16 | train_datasets = [] 17 | else: 18 | train_files = [] 19 | for train_name in config.train_tag.split('#'): 20 | if train_name not in config.train_file_dict: 21 | raise NotImplementedError 22 | train_files.append(config.train_file_dict[train_name]) 23 | 24 | train_datasets = [] 25 | datasets = [] 26 | for train_file in train_files: 27 | datasets.append(TrainDataset(ann_list=train_file, config=config)) 28 | dataset = ConcatDataset(datasets) 29 | train_datasets.append(dataset) 30 | 31 | val_files = {} 32 | for val_name in config.val_tag.split('#'): 33 | if val_name not in config.val_file_dict: 34 | raise NotImplementedError 35 | val_files[val_name] = config.val_file_dict[val_name] 36 | 37 | val_datasets = [] 38 | for k, v in val_files.items(): 39 | datasets = [] 40 | if type(v[0]) != list: 41 | v = [v] 42 | for val_file in v: 43 | datasets.append(ValDataset(ann_list=val_file, dataset_name=k, config=config)) 44 | dataset = ConcatDataset(datasets) 45 | val_datasets.append(dataset) 46 | 47 | return train_datasets, val_datasets 48 | 49 | 50 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 51 | samplers = [] 52 | for dataset, shuffle in zip(datasets, shuffles): 53 | sampler = torch.utils.data.DistributedSampler( 54 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle 55 | ) 56 | samplers.append(sampler) 57 | return samplers 58 | 59 | 60 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 61 | loaders = [] 62 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( 63 | datasets, samplers, batch_size, num_workers, is_trains, collate_fns 64 | ): 65 | if is_train: 66 | shuffle = sampler is None 67 | drop_last = True 68 | else: 69 | shuffle = False 70 | drop_last = False 71 | loader = DataLoader( 72 | dataset, 73 | batch_size=bs, 74 | num_workers=n_worker, 75 | pin_memory=False, 76 | sampler=sampler, 77 | shuffle=shuffle, 78 | collate_fn=collate_fn, 79 | drop_last=drop_last, 80 | persistent_workers=True if n_worker > 0 else False, 81 | ) 82 | loaders.append(loader) 83 | return loaders 84 | 85 | 86 | def iterate_dataloaders(dataloaders): 87 | """Alternatively generate data from multiple dataloaders, 88 | since we use `zip` to concat multiple dataloaders, 89 | the loop will end when the smaller dataloader runs out. 90 | 91 | Args: 92 | dataloaders List(DataLoader): can be a single or multiple dataloaders 93 | """ 94 | for data_tuples in zip(*dataloaders): 95 | for idx, data in enumerate(data_tuples): 96 | yield dataloaders[idx].dataset.media_type, data 97 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class MetaLoader(object): 11 | """ wraps multiple data loader """ 12 | def __init__(self, name2loader): 13 | """Iterates over multiple dataloaders, it ensures all processes 14 | work on data from the same dataloader. This loader will end when 15 | the shorter dataloader raises StopIteration exception. 16 | 17 | loaders: Dict, {name: dataloader} 18 | """ 19 | self.name2loader = name2loader 20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()} 21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} 22 | index2name = {v: k for k, v in name2index.items()} 23 | 24 | iter_order = [] 25 | for n, l in name2loader.items(): 26 | iter_order.extend([name2index[n]]*len(l)) 27 | 28 | random.shuffle(iter_order) 29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 30 | 31 | # sync 32 | if is_dist_avail_and_initialized(): 33 | # make sure all processes have the same order so that 34 | # each step they will have data from the same loader 35 | dist.broadcast(iter_order, src=0) 36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] 37 | 38 | logger.info(str(self)) 39 | 40 | def __str__(self): 41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] 42 | for idx, (name, loader) in enumerate(self.name2loader.items()): 43 | output.append( 44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " 45 | ) 46 | return "\n".join(output) 47 | 48 | def __len__(self): 49 | return len(self.iter_order) 50 | 51 | def __iter__(self): 52 | """ this iterator will run indefinitely """ 53 | for name in self.iter_order: 54 | _iter = self.name2iter[name] 55 | batch = next(_iter) 56 | yield name, batch 57 | -------------------------------------------------------------------------------- /dataset/dataset_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from dataset.base_dataset import BaseDataset, update_caption 9 | import glob 10 | import random 11 | from prompts.prompts import obj_caption_wid_prompt 12 | from torch.nn.utils.rnn import pad_sequence 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | 18 | class TrainDataset(BaseDataset): 19 | 20 | cached_feats = {} 21 | 22 | def __init__(self, ann_list, config, **kwargs): 23 | super().__init__() 24 | self.feat_dim = config.model.input_dim 25 | self.img_feat_dim = config.model.img_input_dim 26 | self.max_obj_num = config.model.max_obj_num 27 | self.knn = config.model.knn 28 | self.point_cloud_type = ann_list[5] 29 | 30 | feat_file, img_feat_file, attribute_file, anno_file, feats_gnn_file = ann_list[:5] 31 | self.attributes = torch.load(attribute_file, map_location='cpu') if attribute_file is not None else None 32 | self.anno = json.load(open(anno_file, 'r')) 33 | 34 | 35 | if len(ann_list) > 6: 36 | sample_ratio = ann_list[-1] 37 | if sample_ratio < 1: 38 | self.anno = random.sample(self.anno, int(sample_ratio * len(self.anno))) 39 | 40 | if feat_file in TrainDataset.cached_feats and img_feat_file in TrainDataset.cached_feats and feats_gnn_file in TrainDataset.cached_feats: 41 | self.scene_feats, self.scene_masks = TrainDataset.cached_feats[feat_file] 42 | self.scene_img_feats = TrainDataset.cached_feats[img_feat_file] 43 | (self.scene_gnn_feats, self.scene_foreground_ids) = TrainDataset.cached_feats[feats_gnn_file] 44 | else: 45 | if feat_file is not None and os.path.exists(feat_file): 46 | self.feats = torch.load(feat_file, map_location='cpu') 47 | else: 48 | self.feats = None 49 | if img_feat_file is not None and os.path.exists(img_feat_file): 50 | self.img_feats = torch.load(img_feat_file, map_location='cpu') 51 | else: 52 | self.img_feats = None 53 | if feats_gnn_file is not None and os.path.exists(feats_gnn_file): 54 | self.feats_edge = torch.load(feats_gnn_file, map_location='cpu') 55 | else: 56 | self.feats_edge = None 57 | 58 | if self.attributes is None: 59 | self.scene_feats = self.feats 60 | self.scene_img_feats = self.scene_masks = None 61 | else: 62 | self.scene_feats, self.scene_img_feats, self.scene_masks, self.scene_gnn_feats, self.scene_foreground_ids = self.prepare_scene_features() 63 | TrainDataset.cached_feats[feat_file] = (self.scene_feats, self.scene_masks) 64 | TrainDataset.cached_feats[img_feat_file] = self.scene_img_feats 65 | TrainDataset.cached_feats[feats_gnn_file] = (self.scene_gnn_feats, self.scene_foreground_ids) 66 | 67 | 68 | def __len__(self): 69 | return len(self.anno) 70 | 71 | def __getitem__(self, index): 72 | if self.attributes is not None and self.anno[index]['scene_id'] not in self.attributes: 73 | # print(f"{self.anno[index]['scene_id']} not in attribute file!") 74 | return self.__getitem__(random.randint(0, len(self.anno)-1)) 75 | if "obj_id" in self.anno[index]: 76 | obj_id = int(self.anno[index]["obj_id"]) 77 | else: 78 | obj_id = random.randint(0, self.max_obj_num - 1) 79 | if 'prompt' not in self.anno[index]: 80 | question = random.choice(obj_caption_wid_prompt).replace('', f"") 81 | else: 82 | question = self.anno[index]["prompt"] 83 | caption = self.anno[index]["caption"] 84 | scene_id, scene_feat, scene_img_feat, scene_mask, scene_locs, assigned_ids, scene_gnn_feats, scene_foreground_ids = self.get_anno(index) 85 | caption = update_caption(caption, assigned_ids) 86 | question = update_caption(question, assigned_ids) 87 | return scene_feat, scene_img_feat, scene_mask, scene_locs, obj_id, assigned_ids, scene_gnn_feats, scene_foreground_ids, caption, question 88 | 89 | 90 | def train_collate_fn(batch): 91 | scene_feats, scene_img_feats, scene_masks, scene_locs, obj_ids, assigned_ids, scene_gnn_feats, scene_foreground_ids, captions, questions = zip(*batch) 92 | batch_scene_feat = pad_sequence(scene_feats, batch_first=True) 93 | batch_scene_img_feat = pad_sequence(scene_img_feats, batch_first=True) 94 | batch_scene_mask = pad_sequence(scene_masks, batch_first=True).to(torch.bool) 95 | batch_scene_locs = pad_sequence(scene_locs, batch_first=True) 96 | batch_assigned_ids = pad_sequence(assigned_ids, batch_first=True) 97 | batch_scene_gnn_feats = pad_sequence(scene_gnn_feats, batch_first=True) 98 | # batch_detach_mask = torch.ones_like(batch_scene_mask, dtype=torch.bool) 99 | # for i in range(batch_detach_mask.shape[0]): 100 | # batch_detach_mask[i][:detach_masks[i].shape[0]] = detach_masks[i] 101 | obj_ids = torch.tensor(obj_ids) 102 | foreground_ids = scene_foreground_ids 103 | return { 104 | "scene_feat": batch_scene_feat, 105 | "scene_img_feat": batch_scene_img_feat, 106 | "scene_locs": batch_scene_locs, 107 | "scene_mask": batch_scene_mask, 108 | "assigned_ids": batch_assigned_ids, 109 | "scene_gnn_feats": batch_scene_gnn_feats, 110 | "foreground_ids": foreground_ids, 111 | # "detach_mask": batch_detach_mask, 112 | "obj_ids": obj_ids, 113 | "answers": captions, 114 | "questions": questions 115 | # "ref_captions": ref_captions, 116 | # "ids": index 117 | } 118 | -------------------------------------------------------------------------------- /dataset/dataset_val.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from dataset.base_dataset import BaseDataset, update_caption 9 | import glob 10 | import random 11 | from prompts.prompts import obj_caption_wid_prompt 12 | from torch.nn.utils.rnn import pad_sequence 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ValDataset(BaseDataset): 18 | 19 | cached_feats = {} 20 | 21 | def __init__(self, ann_list, dataset_name, config, **kwargs): 22 | super().__init__() 23 | self.dataset_name = dataset_name 24 | self.feat_dim = config.model.input_dim 25 | self.img_feat_dim = config.model.img_input_dim 26 | self.max_obj_num = config.model.max_obj_num 27 | self.knn = config.model.knn 28 | self.point_cloud_type = ann_list[5] 29 | 30 | feat_file, img_feat_file, attribute_file, anno_file, feats_gnn_file = ann_list[:5] 31 | self.attributes = torch.load(attribute_file, map_location='cpu') if attribute_file is not None else None 32 | self.anno = json.load(open(anno_file, 'r')) 33 | 34 | if feat_file in ValDataset.cached_feats and img_feat_file in ValDataset.cached_feats and feats_gnn_file in ValDataset.cached_feats: 35 | self.scene_feats, self.scene_masks = ValDataset.cached_feats[feat_file] 36 | self.scene_img_feats = ValDataset.cached_feats[img_feat_file] 37 | (self.scene_gnn_feats, self.scene_foreground_ids) = ValDataset.cached_feats[feats_gnn_file] 38 | else: 39 | if feat_file is not None and os.path.exists(feat_file): 40 | self.feats = torch.load(feat_file, map_location='cpu') 41 | else: 42 | self.feats = None 43 | if img_feat_file is not None and os.path.exists(img_feat_file): 44 | self.img_feats = torch.load(img_feat_file, map_location='cpu') 45 | else: 46 | self.img_feats = None 47 | if feats_gnn_file is not None and os.path.exists(feats_gnn_file): 48 | self.feats_edge = torch.load(feats_gnn_file, map_location='cpu') 49 | else: 50 | self.feats_edge = None 51 | 52 | if self.attributes is None: 53 | self.scene_feats = self.feats 54 | self.scene_img_feats = self.scene_masks = None 55 | else: 56 | self.scene_feats, self.scene_img_feats, self.scene_masks, self.scene_gnn_feats, self.scene_foreground_ids = self.prepare_scene_features() 57 | ValDataset.cached_feats[feat_file] = (self.scene_feats, self.scene_masks) 58 | ValDataset.cached_feats[img_feat_file] = self.scene_img_feats 59 | ValDataset.cached_feats[feats_gnn_file] = (self.scene_gnn_feats, self.scene_foreground_ids) 60 | 61 | def __len__(self): 62 | return len(self.anno) 63 | 64 | def __getitem__(self, index): 65 | scene_id, scene_feat, scene_img_feat, scene_mask, scene_locs, assigned_ids, scene_gnn_feats, scene_foreground_ids = self.get_anno(index) 66 | obj_id = int(self.anno[index].get('obj_id', 0)) 67 | pred_id = int(self.anno[index].get('pred_id', 0)) 68 | type_info = int(self.anno[index].get('sqa_type', 0)) 69 | if 'sqa_type' in self.anno[index]: 70 | type_info = self.anno[index]['sqa_type'] 71 | elif 'eval_type' in self.anno[index]: 72 | type_info = self.anno[index]['eval_type'] 73 | elif 'type_info' in self.anno[index]: 74 | type_info = self.anno[index]['type_info'] 75 | if 'prompt' not in self.anno[index]: 76 | prompt = random.choice(obj_caption_wid_prompt).replace('', f"") 77 | else: 78 | prompt = self.anno[index]["prompt"] 79 | ref_captions = self.anno[index]["ref_captions"].copy() if "ref_captions" in self.anno[index] else [] 80 | qid = self.anno[index]["qid"] if "qid" in self.anno[index] else 0 81 | return scene_feat, scene_img_feat, scene_mask, scene_locs, obj_id, assigned_ids, scene_gnn_feats, scene_foreground_ids, prompt, ref_captions, scene_id, qid, pred_id, type_info 82 | 83 | 84 | def val_collate_fn(batch): 85 | scene_feats, scene_img_feats, scene_masks, scene_locs, obj_ids, assigned_ids, scene_gnn_feats, scene_foreground_ids, prompts, ref_captions, scene_ids, qids, pred_ids, type_infos = zip(*batch) 86 | batch_scene_feat = pad_sequence(scene_feats, batch_first=True) 87 | batch_scene_img_feat = pad_sequence(scene_img_feats, batch_first=True) 88 | batch_scene_mask = pad_sequence(scene_masks, batch_first=True).to(torch.bool) 89 | batch_scene_locs = pad_sequence(scene_locs, batch_first=True) 90 | batch_assigned_ids = pad_sequence(assigned_ids, batch_first=True) 91 | batch_scene_gnn_feats = pad_sequence(scene_gnn_feats, batch_first=True) 92 | obj_ids = torch.tensor(obj_ids) 93 | pred_ids = torch.tensor(pred_ids) 94 | foreground_ids = scene_foreground_ids 95 | 96 | 97 | return { 98 | "scene_feat": batch_scene_feat, 99 | "scene_img_feat": batch_scene_img_feat, 100 | "scene_locs": batch_scene_locs, 101 | "scene_mask": batch_scene_mask, 102 | "assigned_ids": batch_assigned_ids, 103 | "scene_gnn_feats": batch_scene_gnn_feats, 104 | "foreground_ids": foreground_ids, 105 | "obj_ids": obj_ids, 106 | "custom_prompt": prompts, 107 | "ref_captions": ref_captions, 108 | "scene_id": scene_ids, 109 | "qid": qids, 110 | "pred_ids": pred_ids, 111 | "type_infos": type_infos 112 | # "ids": index 113 | } 114 | 115 | -------------------------------------------------------------------------------- /demo/run_demo.sh: -------------------------------------------------------------------------------- 1 | which_python=$(which python) 2 | export PYTHONPATH=${PYTHONPATH}:${which_python}:. 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | 5 | python demo/demo.py scripts/config_3dgraphllm.py -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/models/__init__.py -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | from functools import partial 4 | import copy 5 | 6 | 7 | class BatchNormDim1Swap(nn.BatchNorm1d): 8 | """ 9 | Used for nn.Transformer that uses a HW x N x C rep 10 | """ 11 | 12 | def forward(self, x): 13 | """ 14 | x: HW x N x C 15 | permute to N x C x HW 16 | Apply BN on C 17 | permute back 18 | """ 19 | hw, n, c = x.shape 20 | x = x.permute(1, 2, 0) 21 | x = super(BatchNormDim1Swap, self).forward(x) 22 | # x: n x c x hw -> hw x n x c 23 | x = x.permute(2, 0, 1) 24 | return x 25 | 26 | 27 | NORM_DICT = { 28 | "bn": BatchNormDim1Swap, 29 | "bn1d": nn.BatchNorm1d, 30 | "id": nn.Identity, 31 | "ln": nn.LayerNorm, 32 | } 33 | 34 | ACTIVATION_DICT = { 35 | "relu": nn.ReLU, 36 | "gelu": nn.GELU, 37 | "silu": nn.SiLU, 38 | "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1), 39 | } 40 | 41 | WEIGHT_INIT_DICT = { 42 | "xavier_uniform": nn.init.xavier_uniform_, 43 | } 44 | 45 | 46 | class GenericMLP(nn.Module): 47 | def __init__( 48 | self, 49 | input_dim, 50 | hidden_dims, 51 | output_dim, 52 | norm_fn_name=None, 53 | activation="silu", 54 | use_conv=False, 55 | dropout=None, 56 | hidden_use_bias=False, 57 | output_use_bias=True, 58 | output_use_activation=False, 59 | output_use_norm=False, 60 | weight_init_name=None, 61 | weight_init_std=0.02 62 | ): 63 | super().__init__() 64 | activation = ACTIVATION_DICT[activation] 65 | norm = None 66 | if norm_fn_name is not None: 67 | norm = NORM_DICT[norm_fn_name] 68 | if norm_fn_name == "ln" and use_conv: 69 | norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm 70 | 71 | if dropout is not None: 72 | if not isinstance(dropout, list): 73 | dropout = [dropout for _ in range(len(hidden_dims))] 74 | 75 | layers = [] 76 | prev_dim = input_dim 77 | for idx, x in enumerate(hidden_dims): 78 | if use_conv: 79 | layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias) 80 | else: 81 | layer = nn.Linear(prev_dim, x, bias=hidden_use_bias) 82 | layers.append(layer) 83 | if norm: 84 | layers.append(norm(x)) 85 | layers.append(activation()) 86 | if dropout is not None: 87 | layers.append(nn.Dropout(p=dropout[idx])) 88 | prev_dim = x 89 | if use_conv: 90 | layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias) 91 | else: 92 | layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias) 93 | layers.append(layer) 94 | 95 | if output_use_norm: 96 | layers.append(norm(output_dim)) 97 | 98 | if output_use_activation: 99 | layers.append(activation()) 100 | 101 | self.layers = nn.Sequential(*layers) 102 | # self.weight_init_std = weight_init_std 103 | # self.apply(self._init_weights) 104 | 105 | def _init_weights(self, module): 106 | std = self.weight_init_std 107 | if isinstance(module, nn.Linear): 108 | module.weight.data.normal_(mean=0.0, std=std) 109 | if module.bias is not None: 110 | module.bias.data.zero_() 111 | 112 | # if weight_init_name is not None: 113 | # self.do_weight_init(weight_init_name) 114 | # 115 | # def do_weight_init(self, weight_init_name): 116 | # func = WEIGHT_INIT_DICT[weight_init_name] 117 | # for (_, param) in self.named_parameters(): 118 | # if param.dim() > 1: # skips batchnorm/layernorm 119 | # func(param) 120 | 121 | def forward(self, x): 122 | output = self.layers(x) 123 | return output 124 | 125 | 126 | def get_clones(module, N): 127 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 128 | -------------------------------------------------------------------------------- /models/moe/__init__.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import peft 3 | from peft import PEFT_TYPE_TO_CONFIG_MAPPING 4 | from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING 5 | 6 | 7 | # register MoE LoRA 8 | class PeftType(str, enum.Enum): 9 | PROMPT_TUNING = "PROMPT_TUNING" 10 | P_TUNING = "P_TUNING" 11 | PREFIX_TUNING = "PREFIX_TUNING" 12 | LORA = "LORA" 13 | ADALORA = "ADALORA" 14 | ADAPTION_PROMPT = "ADAPTION_PROMPT" 15 | IA3 = "IA3" 16 | MOE_LORA = 'MOE_LORA' 17 | 18 | peft.PeftType = PeftType 19 | 20 | from .moe_lora import MoeLoraConfig, MoeLoraModel 21 | PEFT_TYPE_TO_CONFIG_MAPPING[peft.PeftType.MOE_LORA] = MoeLoraConfig 22 | PEFT_TYPE_TO_MODEL_MAPPING[peft.PeftType.MOE_LORA] = MoeLoraModel 23 | 24 | 25 | __all__ = [ 26 | 'MoeLoraConfig', 27 | ] 28 | -------------------------------------------------------------------------------- /models/moe/moe_lora.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from peft import LoraConfig, LoraModel, PeftType 3 | from peft.utils import _get_submodules 4 | from peft.tuners.lora import Embedding as LoraEmbedding 5 | import re 6 | import torch.nn as nn 7 | import warnings 8 | from typing import Optional 9 | 10 | from .layer import MoeLoraLayer, MoeLinear 11 | 12 | 13 | @dataclass 14 | class MoeLoraConfig(LoraConfig): 15 | 16 | num_experts: int = field( 17 | default=16, 18 | metadata={'help': 'number of experts in MoE Lora Layer'}) 19 | 20 | gate_mode: str = field( 21 | default='top2_gate', 22 | metadata={'help': 'choice: [top2_gate, dual_gate]'} 23 | ) 24 | 25 | def __post_init__(self): 26 | self.peft_type = PeftType.MOE_LORA 27 | 28 | 29 | class MoeLoraModel(LoraModel): 30 | 31 | def __init__(self, model, config, adapter_name): 32 | super().__init__(model, config, adapter_name) 33 | 34 | def _find_and_replace(self, adapter_name): 35 | lora_config = self.peft_config[adapter_name] 36 | loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) 37 | if loaded_in_8bit: 38 | raise NotImplementedError 39 | 40 | is_target_modules_in_base_model = False 41 | kwargs = { 42 | "r": lora_config.r, 43 | "num_experts": lora_config.num_experts, 44 | "gate_mode": lora_config.gate_mode, 45 | "lora_alpha": lora_config.lora_alpha, 46 | "lora_dropout": lora_config.lora_dropout, 47 | "fan_in_fan_out": lora_config.fan_in_fan_out, 48 | "init_lora_weights": lora_config.init_lora_weights, 49 | } 50 | 51 | key_list = [key for key, _ in self.model.named_modules()] 52 | for key in key_list: 53 | if isinstance(lora_config.target_modules, str): 54 | target_module_found = re.fullmatch(lora_config.target_modules, key) 55 | else: 56 | target_module_found = any( 57 | key.endswith(target_key) for target_key in lora_config.target_modules) 58 | 59 | if target_module_found: 60 | if not is_target_modules_in_base_model: 61 | is_target_modules_in_base_model = True 62 | parent, target, target_name = _get_submodules(self.model, key) 63 | 64 | if hasattr(target, "bias"): 65 | bias = target.bias is not None 66 | else: 67 | bias = False 68 | 69 | if isinstance(target, MoeLoraLayer) and isinstance(target, nn.Linear): 70 | target.update_moe_layer( 71 | adapter_name, 72 | lora_config.r, 73 | lora_config.num_experts, 74 | lora_config.lora_alpha, 75 | lora_config.lora_dropout, 76 | lora_config.init_lora_weights) 77 | 78 | elif isinstance(target, nn.Linear): 79 | in_features, out_features = target.in_features, target.out_features 80 | if kwargs["fan_in_fan_out"]: 81 | warnings.warn( 82 | "fan_in_fan_out is set to True but the target module is " 83 | "`torch.nn.Linear`. Setting fan_in_fan_out to False." 84 | ) 85 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False 86 | new_module = MoeLinear( 87 | adapter_name, in_features, out_features, bias=bias, **kwargs) 88 | 89 | elif isinstance(target, nn.Conv1D): 90 | in_features, out_features = target.weight.ds_shape \ 91 | if hasattr(target.weight, "ds_shape") else target.weight.shape 92 | if not kwargs["fan_in_fan_out"]: 93 | warnings.warn( 94 | "fan_in_fan_out is set to False but the target module is " 95 | "`torch.nn.Conv1D`. Setting fan_in_fan_out to True." 96 | ) 97 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True 98 | new_module = MoeLinear( 99 | adapter_name, in_features, out_features, bias=bias, **kwargs) 100 | 101 | else: 102 | raise RuntimeError( 103 | f"Target module {target} is not supported. Currently, only " 104 | f"``torch.nn.Linear`, torch.nn.Conv1D` and `torch.nn.Embedding` " 105 | f"are supported.") 106 | 107 | self._replace_module(parent, target_name, new_module, target) 108 | 109 | if not is_target_modules_in_base_model: 110 | raise ValueError( 111 | f"Target modules {lora_config.target_modules} not found in the base model. " 112 | f"Please check the target modules and try again." 113 | ) 114 | -------------------------------------------------------------------------------- /models/position_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | from utils.pc_util import shift_scale_points 10 | 11 | 12 | class PositionalEmbedding(nn.Module): 13 | def __init__(self, sigma=1, dim=4096): 14 | super().__init__() 15 | self.sigma = sigma 16 | self.dim = dim // 2 17 | self.w = torch.randn((self.dim, 3)) * sigma 18 | self.w = nn.Parameter(self.w, requires_grad=True) 19 | 20 | def forward(self, x): 21 | bs, obj_num, _ = x.shape 22 | x = x.reshape(-1, 3) 23 | v = torch.cat([torch.sin(self.w.detach() @ x.T), torch.cos(self.w.detach() @ x.T)]) 24 | v = v.T.reshape(bs, obj_num, -1) 25 | v_norm = v / v.norm(dim=-1).unsqueeze(-1) 26 | return v_norm 27 | 28 | 29 | class PositionEmbeddingCoordsSine(nn.Module): 30 | def __init__( 31 | self, 32 | temperature=10000, 33 | normalize=True, 34 | scale=None, 35 | pos_type="fourier", 36 | d_pos=None, 37 | d_in=3, 38 | gauss_scale=1.0, 39 | ): 40 | super().__init__() 41 | self.temperature = temperature 42 | self.normalize = normalize 43 | if scale is not None and normalize is False: 44 | raise ValueError("normalize should be True if scale is passed") 45 | if scale is None: 46 | scale = 2 * math.pi 47 | assert pos_type in ["sine", "fourier"] 48 | self.pos_type = pos_type 49 | self.scale = scale 50 | if pos_type == "fourier": 51 | assert d_pos is not None 52 | assert d_pos % 2 == 0 53 | # define a gaussian matrix input_ch -> output_ch 54 | B = torch.empty((d_in, d_pos // 2)).normal_() 55 | B *= gauss_scale 56 | self.register_buffer("gauss_B", B) 57 | self.d_pos = d_pos 58 | 59 | def get_sine_embeddings(self, xyz, num_channels, input_range): 60 | # clone coords so that shift/scale operations do not affect original tensor 61 | orig_xyz = xyz 62 | xyz = orig_xyz.clone() 63 | 64 | ncoords = xyz.shape[1] 65 | if self.normalize: 66 | xyz = shift_scale_points(xyz, src_range=input_range) 67 | 68 | ndim = num_channels // xyz.shape[2] 69 | if ndim % 2 != 0: 70 | ndim -= 1 71 | # automatically handle remainder by assiging it to the first dim 72 | rems = num_channels - (ndim * xyz.shape[2]) 73 | 74 | assert ( 75 | ndim % 2 == 0 76 | ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}" 77 | 78 | final_embeds = [] 79 | prev_dim = 0 80 | 81 | for d in range(xyz.shape[2]): 82 | cdim = ndim 83 | if rems > 0: 84 | # add remainder in increments of two to maintain even size 85 | cdim += 2 86 | rems -= 2 87 | 88 | if cdim != prev_dim: 89 | dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device) 90 | dim_t = self.temperature ** (2 * (dim_t // 2) / cdim) 91 | 92 | # create batch x cdim x nccords embedding 93 | raw_pos = xyz[:, :, d] 94 | if self.scale: 95 | raw_pos *= self.scale 96 | pos = raw_pos[:, :, None] / dim_t 97 | pos = torch.stack( 98 | (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3 99 | ).flatten(2) 100 | final_embeds.append(pos) 101 | prev_dim = cdim 102 | 103 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 104 | return final_embeds 105 | 106 | def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None): 107 | # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 108 | 109 | if num_channels is None: 110 | num_channels = self.gauss_B.shape[1] * 2 111 | 112 | bsize, npoints = xyz.shape[0], xyz.shape[1] 113 | assert num_channels > 0 and num_channels % 2 == 0 114 | d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1] 115 | d_out = num_channels // 2 116 | assert d_out <= max_d_out 117 | assert d_in == xyz.shape[-1] 118 | 119 | # clone coords so that shift/scale operations do not affect original tensor 120 | orig_xyz = xyz 121 | xyz = orig_xyz.clone() 122 | 123 | ncoords = xyz.shape[1] 124 | if self.normalize: 125 | xyz = shift_scale_points(xyz, src_range=input_range) 126 | 127 | xyz *= 2 * np.pi 128 | xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view( 129 | bsize, npoints, d_out 130 | ) 131 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()] 132 | 133 | # return batch x d_pos x npoints embedding 134 | final_embeds = torch.cat(final_embeds, dim=2) 135 | return final_embeds 136 | 137 | def forward(self, xyz, num_channels=None, input_range=None): 138 | assert isinstance(xyz, torch.Tensor) 139 | assert xyz.ndim == 3 140 | # xyz is batch x npoints x 3 141 | if self.pos_type == "sine": 142 | with torch.no_grad(): 143 | return self.get_sine_embeddings(xyz, num_channels, input_range) 144 | elif self.pos_type == "fourier": 145 | with torch.no_grad(): 146 | return self.get_fourier_embeddings(xyz, num_channels, input_range) 147 | else: 148 | raise ValueError(f"Unknown {self.pos_type}") 149 | 150 | def extra_repr(self): 151 | st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}" 152 | if hasattr(self, "gauss_B"): 153 | st += ( 154 | f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}" 155 | ) 156 | return st 157 | -------------------------------------------------------------------------------- /models/transformer_vanilla/__init__.py: -------------------------------------------------------------------------------- 1 | from .mhsa import MultiHeadSelfAttention 2 | from .self_attention import SelfAttention 3 | from .transformer_block import TransformerEncoder, CMT 4 | -------------------------------------------------------------------------------- /models/transformer_vanilla/mhsa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | def compute_mhsa(q, k, v, scale_factor=1, mask=None): 8 | # resulted shape will be: [batch, heads, tokens, tokens] 9 | scaled_dot_prod = torch.einsum('... i d , ... j d -> ... i j', q, k) * scale_factor 10 | 11 | if mask is not None: 12 | breakpoint() 13 | # assert mask.shape == scaled_dot_prod.shape[2:] 14 | scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf) 15 | 16 | attention = torch.softmax(scaled_dot_prod, dim=-1) 17 | # calc result per head 18 | return torch.einsum('... i j , ... j d -> ... i d', attention, v) 19 | 20 | 21 | class MultiHeadSelfAttention(nn.Module): 22 | def __init__(self, dim, heads=8, dim_head=None): 23 | """ 24 | Implementation of multi-head attention layer of the original transformer model. 25 | einsum and einops.rearrange is used whenever possible 26 | Args: 27 | dim: token's dimension, i.e. word embedding vector size 28 | heads: the number of distinct representations to learn 29 | dim_head: the dim of the head. In general dim_head k b h t d ', k=3, h=self.heads)) 47 | 48 | out = compute_mhsa(q, k, v, mask=mask, scale_factor=self.scale_factor) 49 | 50 | # re-compose: merge heads with dim_head 51 | out = rearrange(out, "b h t d -> b t (h d)") 52 | # Apply final linear transformation layer 53 | return self.W_0(out) 54 | -------------------------------------------------------------------------------- /models/transformer_vanilla/self_attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class SelfAttention(nn.Module): 8 | """ 9 | Implementation of plain self attention mechanism with einsum operations 10 | Paper: https://arxiv.org/abs/1706.03762 11 | Blog: https://theaisummer.com/transformer/ 12 | """ 13 | 14 | def __init__(self, dim): 15 | """ 16 | Args: 17 | dim: for NLP it is the dimension of the embedding vector 18 | the last dimension size that will be provided in forward(x) 19 | where x is a 3D tensor 20 | """ 21 | super().__init__() 22 | self.to_qvk = nn.Linear(dim, dim * 3, bias=False) 23 | self.scale_factor = dim ** -0.5 # 1/np.sqrt(dim) 24 | 25 | def forward(self, x, mask=None): 26 | assert x.dim() == 3, '3D tensor must be provided' 27 | qkv = self.to_qvk(x) # [batch, tokens, dim*3 ] 28 | 29 | # decomposition to q,v,k 30 | # rearrange tensor to [3, batch, tokens, dim] and cast to tuple 31 | q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d ', k=3)) 32 | 33 | # Resulting shape: [batch, tokens, tokens] 34 | scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) * self.scale_factor 35 | 36 | if mask is not None: 37 | assert mask.shape == scaled_dot_prod.shape[1:] 38 | scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf) 39 | 40 | attention = torch.softmax(scaled_dot_prod, dim=-1) 41 | return torch.einsum('b i j , b j d -> b i d', attention, v) 42 | -------------------------------------------------------------------------------- /others/analyze_grounding_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import jsonlines 4 | 5 | def is_explicitly_view_dependent(s): 6 | target_words = {'front', 'behind', 'back', 'right', 'left', 'facing', 'leftmost', 'rightmost', 7 | 'looking', 'across'} 8 | for word in target_words: 9 | if word in s: 10 | return True 11 | return False 12 | 13 | 14 | attrs = torch.load("annotations/scannet_val_attributes.pt") 15 | 16 | 17 | # val_file = "/root/scene-LLaMA/datasets/exprs_neurips22/gtlabelpcd_mix/nr3d/preds/val_outs.json" 18 | val_file = "outputs/2023-11-17-230123_dp0.1_lr2e-4_sta2_ep3_objscale200_scenescale50_bs1_cosine_objalign_scenealign/preds_epoch-1_step0.json" 19 | nr3d_anno_file = "/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/sr3d.jsonl" 20 | 21 | anno_root = "annotations" # annotation dir 22 | val_results = json.load(open(val_file)) 23 | # val_results_dict = {} 24 | # for val_item in val_results: 25 | # val_results_dict[val_item["qid"]] = val_item 26 | nr3d_anno = {} 27 | with jsonlines.open(nr3d_anno_file, "r") as reader: 28 | for l in reader: 29 | nr3d_anno[l["item_id"]] = l 30 | 31 | val_num = len(val_results) 32 | 33 | easy_acc, hard_acc = 0, 0 34 | view_indep_acc, view_dep_acc = 0, 0 35 | acc = 0 36 | easy_num, hard_num = 0, 0 37 | view_indep_num, view_dep_num = 0, 0 38 | 39 | 40 | # from nltk.tokenize import sent_tokenize 41 | 42 | # for v in val_results: 43 | # 44 | # pred = v["pred"] 45 | # target = v["ref_captions"][0] 46 | # scene_id = v["scene_id"] 47 | # obj_id = v["obj_id"] 48 | # object_labels = attrs[scene_id]["objects"] 49 | # hardness = object_labels.count(object_labels[obj_id]) 50 | # # print(object_labels) 51 | # # breakpoint() 52 | # caption = v["prompt"].split("the given description, \"")[1].split(",\" please provide the")[0] 53 | # tokens = sent_tokenize(caption)[0].split() 54 | # print(tokens) 55 | # # print(caption) 56 | # flag = pred == target 57 | # acc += flag 58 | # if is_explicitly_view_dependent(tokens): 59 | # view_dep_acc += flag 60 | # view_dep_num += 1 61 | # else: 62 | # view_indep_acc += flag 63 | # view_indep_num += 1 64 | # if hardness > 2: 65 | # hard_acc += flag 66 | # hard_num += 1 67 | # else: 68 | # easy_acc += flag 69 | # easy_num += 1 70 | 71 | dataset = "sr3d" 72 | val_file = f"/root/scene-LLaMA/datasets/exprs_neurips22/gtlabelpcd_mix/{dataset}/preds/val_outs.json" 73 | nr3d_anno_file = f"/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/{dataset}.jsonl" 74 | 75 | val_results = json.load(open(val_file)) 76 | 77 | val_num = len(val_results) 78 | 79 | nr3d_anno = {} 80 | with jsonlines.open(nr3d_anno_file, "r") as reader: 81 | for l in reader: 82 | nr3d_anno[l["item_id"]] = l 83 | 84 | for k, v in val_results.items(): 85 | obj_ids = v["obj_ids"] 86 | obj_logits = v["obj_logits"] 87 | obj_logits = (torch.tensor(obj_logits)).softmax(dim=-1).tolist() 88 | scene_id = nr3d_anno[k]["scan_id"] 89 | caption = nr3d_anno[k]["utterance"] 90 | target_id = nr3d_anno[k]["target_id"] 91 | instance_type = nr3d_anno[k]["instance_type"] 92 | tokens = nr3d_anno[k]["tokens"] 93 | logit_ids = zip(obj_logits, obj_ids) 94 | logit_ids = sorted(logit_ids, reverse=True) 95 | logits, ids = zip(*logit_ids) 96 | object_labels = attrs[scene_id]["objects"] 97 | hardness = object_labels.count(instance_type) 98 | # if k in val_results_dict: 99 | # pred = val_results_dict[k]["pred"] 100 | # ref_captions = val_results_dict[k]["ref_captions"] 101 | # flag = pred == ref_captions[0] 102 | # flag = 0.5 103 | if logits[1] > 0.01 and logits[2] <= 0.01: 104 | flag = 0.6 105 | else: 106 | flag = ids[0] == target_id 107 | acc += flag 108 | if is_explicitly_view_dependent(tokens): 109 | view_dep_acc += flag 110 | view_dep_num += 1 111 | else: 112 | view_indep_acc += flag 113 | view_indep_num += 1 114 | if hardness > 2: 115 | hard_acc += flag 116 | hard_num += 1 117 | else: 118 | easy_acc += flag 119 | easy_num += 1 120 | 121 | print(f"Acc: {float(acc) / val_num} {acc} {val_num}") 122 | print(f"Easy-Acc: {float(easy_acc) / easy_num} {easy_acc} {easy_num}") 123 | print(f"Hard-Acc: {float(hard_acc) / hard_num} {hard_acc} {hard_num}") 124 | 125 | print(f"View-Dep-Acc: {float(view_dep_acc) / view_dep_num}") 126 | print(f"View-Indep-Acc: {float(view_indep_acc) / view_indep_num}") 127 | -------------------------------------------------------------------------------- /others/eval_offline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | sys.path.append('.') 5 | 6 | from utils.eval import calc_scanrefer_score, clean_answer, calc_scan2cap_score, calc_scanqa_score, calc_sqa3d_score, calc_multi3dref_score 7 | 8 | from pycocoevalcap.bleu.bleu import Bleu 9 | #from pycocoevalcap.meteor.meteor import Meteor 10 | from pycocoevalcap.rouge.rouge import Rouge 11 | from pycocoevalcap.cider.cider import Cider 12 | #from pycocoevalcap.spice.spice import Spice 13 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 14 | 15 | output_dir = 'outputs/3dgraphllm_2e-5_ep6' 16 | 17 | tokenizer = PTBTokenizer() 18 | scorers = [ 19 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 20 | #(Meteor(), "METEOR"), 21 | (Rouge(), "ROUGE_L"), 22 | (Cider(), "CIDEr"), 23 | #(Spice(), "SPICE") 24 | ] 25 | 26 | 27 | prefix = 'preds_epoch4_step0' 28 | 29 | all_val_scores = {} 30 | 31 | #for task in ['scanqa', 'scanrefer', 'scan2cap', 'sqa3d', 'multi3dref']: 32 | for task in ['scanrefer']: 33 | save_preds = [] 34 | for filename in os.listdir(output_dir): 35 | if filename.startswith(prefix) and task in filename: 36 | preds = json.load(open(os.path.join(output_dir, filename))) 37 | save_preds += preds 38 | print(len(save_preds)) 39 | val_scores = {} 40 | if task == 'scanqa': 41 | val_scores = calc_scanqa_score(save_preds, tokenizer, scorers) 42 | if task == 'scanrefer': 43 | val_scores = calc_scanrefer_score(save_preds) 44 | if task == 'multi3dref': 45 | val_scores = calc_multi3dref_score(save_preds) 46 | if task == 'scan2cap': 47 | val_scores = calc_scan2cap_score(save_preds, tokenizer, scorers) 48 | if task == 'sqa3d': 49 | val_scores = calc_sqa3d_score(save_preds, tokenizer, scorers) 50 | 51 | all_val_scores = {**all_val_scores, **val_scores} 52 | 53 | print(json.dumps(all_val_scores, indent=4)) -------------------------------------------------------------------------------- /others/extract_target_noun.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import sys 4 | 5 | from models.modeling_llama import LlamaForCausalLM 6 | from transformers import LlamaTokenizer, LlamaConfig 7 | from collections import defaultdict 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | llama_model_path = "model/vicuna-7b-v0" 13 | 14 | print("Loading LLaMA") 15 | llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) 16 | model = LlamaForCausalLM.from_pretrained( 17 | llama_model_path, 18 | torch_dtype=torch.float16 19 | ).cuda() 20 | for p in model.parameters(): 21 | p.requires_grad = False 22 | print("Loading LLaMA Done") 23 | 24 | model.eval() 25 | 26 | 27 | dataset_name = "sr3d_merge" 28 | anno_file = f"anno/{dataset_name}_captions.json" 29 | output_anno_file = f"anno/{dataset_name}_captions_noun.json" 30 | annos = json.load(open(anno_file, "r")) 31 | output_annos = defaultdict(list) 32 | 33 | # nr3d 34 | # prompt_head = "###System: Given a sentence that asks for an object in a scene. Extract the primary subject from each sentence and include any accompanying adjectives, if present. " \ 35 | # "###Human: When facing the bookcases, choose the plant directly on the right, next to the right most bookcase. " \ 36 | # "###Assistant: plant. " \ 37 | # "###Human: The big black box between the door and the couch. " \ 38 | # "###Assistant: big black box. " \ 39 | # "###Human: pick the white pillow that has a pillow above and under it. " \ 40 | # "###Assistant: white pillow. " 41 | 42 | # sr3d 43 | prompt_head = "###System: Given a sentence that asks for an object in a scene. Extract the primary subject from each sentence and include any accompanying adjectives, if present. " \ 44 | "###Human: find the office chair that is near the copier. " \ 45 | "###Assistant: office chair. " \ 46 | "###Human: select the trash can that is near the printer. " \ 47 | "###Assistant: trash can. " \ 48 | "###Human: the monitor that is near the door. " \ 49 | "###Assistant: monitor. " 50 | 51 | for i, k in tqdm(enumerate(annos.keys())): 52 | print(f"{i} / {len(annos)}") 53 | captions = annos[k] 54 | for caption in captions: 55 | end_1 = (caption.find(".") + len(caption)) % len(caption) 56 | # end_2 = (caption.find(",") + len(caption)) % len(caption) 57 | bk = end_1 58 | sen_caption = caption[:bk] + "." 59 | prompt = prompt_head + "###Human: " + sen_caption + " ###" 60 | input_token = llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to("cuda:0") 61 | input_embed = model.model.embed_tokens(input_token.input_ids) 62 | outputs = model.generate( 63 | inputs_embeds=input_embed, 64 | max_new_tokens=16, 65 | num_beams=1, 66 | do_sample=True, 67 | min_length=1, 68 | top_p=0.9, 69 | repetition_penalty=1.0, 70 | length_penalty=1, 71 | temperature=1.0 72 | ) 73 | output = outputs[0] 74 | if output[0] == 0: 75 | output = output[1:] 76 | if output[0] == 1: 77 | output = output[1:] 78 | output_text = llama_tokenizer.decode(output, add_special_tokens=False) 79 | print("INPUT:", sen_caption) 80 | print("OUTPUT:", output_text) 81 | try: 82 | output_annos[k].append(output_text.split("Assistant:")[1].split(".")[0].strip()) 83 | print("EX OUTPUT:", output_text.split("Assistant:")[1].split(".")[0].strip()) 84 | except Exception: 85 | print("Fail:") 86 | output_annos[k].append(caption[:bk]) 87 | 88 | with open(output_anno_file, "w") as f: 89 | json.dump(output_annos, f, indent=4) 90 | -------------------------------------------------------------------------------- /others/gpt_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | sleep_time = NUM_SECONDS_TO_SLEEP 13 | while True: 14 | try: 15 | response = openai.ChatCompletion.create( 16 | model='gpt-4', 17 | # model="gpt-3.5-turbo-0613", 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except Exception as e: 30 | print(e) 31 | print(f"!!sleep for {sleep_time}") 32 | time.sleep(sleep_time) 33 | sleep_time *= 2 34 | 35 | return response['choices'][0]['message']['content'] 36 | 37 | 38 | def parse_score(review): 39 | try: 40 | score_pair = review.split('\n')[0] 41 | score_pair = score_pair.replace(',', ' ') 42 | sp = score_pair.split(' ') 43 | if len(sp) == 2: 44 | return [float(sp[0]), float(sp[1])] 45 | else: 46 | print('error', review) 47 | return [-1, -1] 48 | except Exception as e: 49 | print(e) 50 | print('error', review) 51 | return [-1, -1] 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser(description='ChatGPT-based dataset generation.') 56 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 57 | args = parser.parse_args() 58 | 59 | 60 | idx = 0 61 | prompt_head = "You are a 3D scene understanding expert specializing in 3D visual assistance. I will provide you with a catalog of objects within a 3D scene, with each object\'s information presented in the following format: object\'s ID, object\'s class name, object\'s 3D coordinates, and a concise description of the object. The object list is as follows:\n" 62 | prompt_end = "Your task is to generate a comprehensive description of the entire scene. This description should encompass an analysis of the scene's functionality, an examination of the key objects within the scene, their spatial relationships with the surrounding objects, an assessment of the arrangement of the objects within the scene, and other relevant insights. An important guideline to follow is that when referring to an object, you must explicitly include its object ID. The description should be more than 200 words and less than 300 words." 63 | import glob 64 | for split in ["train", "val"]: 65 | for file_path in glob.glob(f"annotations/scene_dataset/obj_info_list/{split}/*.json"): 66 | scene_id = file_path.split("/")[-1][:-5] 67 | print("-" * 20) 68 | print(scene_id) 69 | output_path = f"annotations/scene_dataset/gpt_generation/{split}/{scene_id}.json" 70 | if os.path.exists(output_path): 71 | print("skip") 72 | continue 73 | obj_infos = json.load(open(file_path, "r")) 74 | prompt = prompt_head + obj_infos + prompt_end 75 | print(prompt) 76 | answer = get_eval(prompt, args.max_tokens) 77 | print(answer) 78 | with open(output_path, "w") as f: 79 | json.dump(answer, f) 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /others/llama_tmp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import sys 4 | sys.path.append(".") 5 | from models.modeling_llama import LlamaForCausalLM 6 | from transformers import LlamaTokenizer, LlamaConfig 7 | from collections import defaultdict 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | llama_model_path = "model/vicuna-7b-v0" 13 | 14 | print("Loading LLaMA") 15 | llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) 16 | model = LlamaForCausalLM.from_pretrained( 17 | llama_model_path, 18 | torch_dtype=torch.float16 19 | ) 20 | # model = model.to("cuda") 21 | # print(torch.cuda.memory_allocated(device="cuda:0")/1e9) 22 | # exit() 23 | print("is training:", model.training) 24 | # for p in model.parameters(): 25 | # p.requires_grad = False 26 | print("Loading LLaMA Done") 27 | 28 | llama_tokenizer.add_tokens(["", ""], special_tokens=False) 29 | 30 | model.resize_token_embeddings(len(llama_tokenizer)) 31 | 32 | 33 | def get_text_len(text): 34 | return llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.shape[1] 35 | 36 | 37 | def get_ids(text): 38 | return llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids 39 | 40 | 41 | def get_emb(text, is_eval=False): 42 | input_ids = llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids 43 | if is_eval: 44 | model.eval() 45 | else: 46 | model.train() 47 | return model.model.embed_tokens(input_ids) 48 | 49 | 50 | breakpoint() 51 | -------------------------------------------------------------------------------- /others/modify.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | # split = 'val' 5 | # ori_annos = json.load(open(f'annotations/scanrefer_{split}_stage2_objxx.json')) 6 | 7 | # templates = [line.rstrip() for line in open('prompts/scanrefer_caption_templates.txt')] 8 | 9 | # new_annos = [] 10 | # for anno in ori_annos: 11 | # obj_id = anno['obj_id'] 12 | # anno['prompt'] = random.choice(templates).replace('', f"") 13 | # new_annos.append(anno) 14 | 15 | # with open(f'annotations/scanrefer_{split}_stage2_caption_OBJ.json', 'w') as f: 16 | # json.dump(new_annos, f, indent=4) 17 | 18 | # ori_annos = json.load(open(f'annotations/nr3d_{split}_stage2_objxx.json')) 19 | 20 | # templates = [line.rstrip() for line in open('prompts/nr3d_caption_templates.txt')] 21 | 22 | # new_annos = [] 23 | # for anno in ori_annos: 24 | # obj_id = anno['obj_id'] 25 | # anno['prompt'] = random.choice(templates).replace('', f"") 26 | # new_annos.append(anno) 27 | 28 | # with open(f'annotations/nr3d_{split}_stage2_caption_OBJ.json', 'w') as f: 29 | # json.dump(new_annos, f, indent=4) 30 | 31 | 32 | # for dataset in ["scanrefer", "nr3d"]: 33 | # x = json.load(open(f'annotations/{dataset}_val_stage2_caption_OBJ.json')) 34 | # x = random.sample(x, 100) 35 | # with open(f'annotations/{dataset}_val_stage2_caption100_OBJ.json', 'w') as f: 36 | # json.dump(x, f, indent=4) 37 | 38 | # x = json.load(open('annotations/scanqa_val_stage2_objxx.json')) 39 | 40 | # x = random.sample(x, 100) 41 | 42 | # with open('annotations/scanqa_val_stage2_objxx100.json', 'w') as f: 43 | # json.dump(x, f, indent=4) 44 | 45 | 46 | # split = 'val' 47 | # iou = '25' 48 | 49 | # ori_annos = json.load(open(f'annotations/scanrefer_pointgroup_{split}_stage2_caption_iou{iou}.json')) 50 | 51 | # templates = [line.rstrip() for line in open('prompts/scanrefer_caption_templates.txt')] 52 | 53 | # new_annos = [] 54 | # for anno in ori_annos: 55 | # obj_id = anno['obj_id'] 56 | # anno['prompt'] = random.choice(templates).replace('', f"{obj_id:02}") 57 | # new_annos.append(anno) 58 | 59 | # with open(f'annotations/scanrefer_pointgroup_{split}_stage2_caption_iou{iou}.json', 'w') as f: 60 | # json.dump(new_annos, f, indent=4) 61 | 62 | 63 | old_annos = json.load(open('/mnt/petrelfs/huanghaifeng/share/Chat-3D-v2/annotations/scanrefer_mask3d_val_stage2_caption_iou0.json')) 64 | 65 | print(len(old_annos)) -------------------------------------------------------------------------------- /others/prepare_anno_stage1.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from sklearn.model_selection import train_test_split 4 | from collections import defaultdict 5 | 6 | dataset_name = "nr3d" 7 | 8 | dataset_dir = f"/root/autodl-tmp/scene-LLaMA/datasets/{dataset_name}" 9 | if dataset_name == "scanrefer": 10 | dataset_train = json.load(open(os.path.join(dataset_dir, "ScanRefer_filtered_train.json"), "r")) 11 | dataset_val = json.load(open(os.path.join(dataset_dir, "ScanRefer_filtered_val.json"), "r")) 12 | else: 13 | dataset_train = json.load(open(os.path.join(dataset_dir, "train.json"), "r")) 14 | dataset_val = json.load(open(os.path.join(dataset_dir, "val.json"), "r")) 15 | captions = defaultdict(list) 16 | 17 | 18 | def process(dataset_data): 19 | new_list = [] 20 | for data in dataset_data: 21 | scene_id = data["scene_id"] if dataset_name == "scanrefer" else data["scan_id"] 22 | obj_id = int(data["object_id"]) if dataset_name == "scanrefer" else int(data["tgt_idx"]) 23 | feat_path = f"{scene_id}/{obj_id:03}.pt" 24 | caption = data["description"] if dataset_name == "scanrefer" else data["query"] 25 | new_data = { 26 | "pc_feat_path": feat_path, 27 | "caption": caption, 28 | "scene_id": scene_id, 29 | "obj_id": obj_id 30 | } 31 | captions[f"{scene_id}_{obj_id}"].append(caption) 32 | new_list.append(new_data) 33 | return new_list 34 | 35 | 36 | output_train = process(dataset_train) 37 | output_val = process(dataset_val) 38 | 39 | 40 | with open(f"anno/{dataset_name}_train_stage2.json", "w") as f: 41 | json.dump(output_train, f) 42 | 43 | with open(f"anno/{dataset_name}_val_stage2.json", "w") as f: 44 | json.dump(output_val, f) 45 | 46 | with open(f"anno/{dataset_name}_captions.json", "w") as f: 47 | json.dump(captions, f) 48 | -------------------------------------------------------------------------------- /others/prepare_captions_noun.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | dataset = "sr3d_merge" 5 | val_split_path = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_val.txt" 6 | train_split_path = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_train.txt" 7 | 8 | caption_path = f"anno/{dataset}_captions_noun.json" 9 | 10 | train_scene_ids = [] 11 | val_scene_ids = [] 12 | 13 | with open(train_split_path, "r") as f: 14 | for line in f.readlines(): 15 | train_scene_ids.append(line.strip()) 16 | 17 | with open(val_split_path, "r") as f: 18 | for line in f.readlines(): 19 | val_scene_ids.append(line.strip()) 20 | 21 | captions = json.load(open(caption_path, "r")) 22 | 23 | train_captions = {} 24 | val_captions = {} 25 | 26 | for k, v in captions.items(): 27 | scene_id = "_".join(k.split("_")[:-1]) 28 | if scene_id in train_scene_ids: 29 | train_captions[k] = v 30 | if scene_id in val_scene_ids: 31 | val_captions[k] = v 32 | 33 | output_train_path = f"anno/{dataset}_train_captions_noun.json" 34 | output_val_path = f"anno/{dataset}_val_captions_noun.json" 35 | 36 | with open(output_train_path, "w") as f: 37 | json.dump(train_captions, f) 38 | 39 | with open(output_val_path, "w") as f: 40 | json.dump(val_captions, f) 41 | -------------------------------------------------------------------------------- /others/prepare_describe.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | anno_root = "anno" 6 | anno_file = os.path.join(anno_root, "scanrefer_val_conversation.json") 7 | 8 | annos = json.load(open(anno_file, "r")) 9 | 10 | # questions = [] 11 | # with open("prompts/detailed_description.txt", "r") as f: 12 | # for q in f.readlines(): 13 | # questions.append(q.strip()) 14 | 15 | # for k, a in annos.items(): 16 | # q = random.choice(questions) 17 | # annos[k] = [{ 18 | # "Question": q, 19 | # "Answer": a 20 | # }] 21 | output_anno = [] 22 | for k, v in annos.items(): 23 | if len(v) == 0: 24 | continue 25 | scene_id = "_".join(k.split("_")[:-1]) 26 | obj_id = int(k.split("_")[-1]) 27 | for i in range(len(v)): 28 | q = v[i]["Question"] 29 | a = v[i]["Answer"] 30 | output_anno.append({ 31 | "scene_id": scene_id, 32 | "obj_id": obj_id, 33 | "qid": i, 34 | "prompt": q, 35 | "ref": a 36 | }) 37 | # if len(output_anno) >= 500: 38 | # break 39 | 40 | output_anno = sorted(output_anno, key=lambda x: f"{x['scene_id']}_{x['obj_id']:03}_{x['qid']:2}") 41 | print(len(output_anno)) 42 | 43 | with open("anno/scanrefer_val_conv.json", "w") as f: 44 | json.dump(output_anno, f, indent=4) 45 | -------------------------------------------------------------------------------- /others/prepare_eval_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | anno_root = "anno" 6 | detail_refs = json.load(open(os.path.join(anno_root, "scanrefer_val_describe.json"), "r")) 7 | conv_refs = json.load(open(os.path.join(anno_root, "scanrefer_val_conversation.json"), "r")) 8 | target_infos = json.load(open(os.path.join(anno_root, "scanrefer_val_content.json"), "r")) 9 | 10 | # detail_items = set(detail_refs.keys()) 11 | # conv_items = set(conv_refs.keys()) 12 | # item_list = list(detail_items & conv_items) 13 | # print(len(item_list)) 14 | # 15 | # sampled_items = sorted(random.sample(item_list, 100)) 16 | # 17 | # with open("scripts/eval/item_list.json", "w") as f: 18 | # json.dump(sampled_items, f) 19 | 20 | item_list = json.load(open("eval/item_list.json", "r")) 21 | 22 | # detail_samples = [] 23 | # conv_samples = [] 24 | # qid = 0 25 | # answers = [] 26 | # questions = [] 27 | # 28 | # for item in item_list: 29 | # scene_id = "_".join(item.split("_")[:-1]) 30 | # obj_id = int(item.split("_")[-1]) 31 | # detail_samples.append({ 32 | # "scene_id": scene_id, 33 | # "obj_id": obj_id, 34 | # "prompt": detail_refs[item][0]["Question"] 35 | # }) 36 | # answers.append({ 37 | # "question_id": qid, 38 | # "text": detail_refs[item][0]["Answer"], 39 | # "category": "detail" 40 | # }) 41 | # questions.append({ 42 | # "question_id": qid, 43 | # "text": detail_refs[item][0]["Question"], 44 | # "category": "detail", 45 | # "item_id": item 46 | # }) 47 | # qid += 1 48 | # conv = random.choice(conv_refs[item]) 49 | # while len(conv["Answer"]) == 0: 50 | # print(f"empty answer in {item}...") 51 | # conv = random.choice(conv_refs[item]) 52 | # conv_samples.append({ 53 | # "scene_id": scene_id, 54 | # "obj_id": obj_id, 55 | # "prompt": conv["Question"] 56 | # }) 57 | # answers.append({ 58 | # "question_id": qid, 59 | # "text": conv["Answer"], 60 | # "category": "conv" 61 | # }) 62 | # questions.append({ 63 | # "question_id": qid, 64 | # "text": conv["Question"], 65 | # "category": "conv", 66 | # "item_id": item 67 | # }) 68 | # qid += 1 69 | # 70 | # with open(os.path.join(anno_root, "scanrefer_val_describe100.json"), "w") as f: 71 | # json.dump(detail_samples, f) 72 | # with open(os.path.join(anno_root, "scanrefer_val_conversation100.json"), "w") as f: 73 | # json.dump(conv_samples, f) 74 | # 75 | # with open("eval/qa200_gpt4_answer.json", "w") as f: 76 | # json.dump(answers, f, indent=4) 77 | # with open("eval/qa200_questions.json", "w") as f: 78 | # json.dump(questions, f, indent=4) 79 | 80 | old_answers = json.load(open("eval/qa200_gpt4_answer.json", "r")) 81 | old_questions = json.load(open("eval/qa200_questions.json", "r")) 82 | 83 | new_list = [item_list[0]] 84 | for i in range(1, len(item_list)): 85 | if item_list[i][:11] != item_list[i-1][:11]: 86 | new_list.append(item_list[i]) 87 | new_list = random.sample(new_list, 30) 88 | 89 | new_answers = [] 90 | new_questions = [] 91 | qid = 0 92 | 93 | for i in range(len(old_questions)): 94 | ques = old_questions[i] 95 | ans = old_answers[i] 96 | if ques["item_id"] in new_list: 97 | ques["question_id"] = ans["answer_id"] = qid 98 | qid += 1 99 | new_questions.append(ques) 100 | new_answers.append(ans) 101 | 102 | with open("eval/qa60_gpt4_answer.json", "w") as f: 103 | json.dump(new_answers, f, indent=4) 104 | with open("eval/qa60_questions.json", "w") as f: 105 | json.dump(new_questions, f, indent=4) 106 | -------------------------------------------------------------------------------- /others/prepare_identifier_rich.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | 4 | split = "train" 5 | 6 | annos = [] 7 | for file_name in glob.glob(f"annotations/scene_dataset/gpt_generation/{split}/*.json"): 8 | scene_id = file_name.split("/")[-1][:-5] 9 | obj_id = 0 10 | prompt = "Provide a comprehensive description of the entire scene." 11 | caption = json.load(open(file_name, "r")) 12 | annos.append({ 13 | "scene_id": scene_id, 14 | "obj_id": obj_id, 15 | "prompt": prompt, 16 | "caption": caption 17 | }) 18 | 19 | with open(f"annotations/scene_dataset_{split}_stage2.json", "w") as f: 20 | json.dump(annos, f, indent=4) 21 | -------------------------------------------------------------------------------- /others/prepare_multi3dref.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | split = 'train' 5 | ori_anno = json.load(open(os.path.join('annotations/multi3drefer_train_val', f"multi3drefer_{split}.json"))) 6 | 7 | new_anno = [] -------------------------------------------------------------------------------- /others/prepare_obj_align_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | split = "val" 5 | with open(f"annotations/scanrefer_{split}_stage1.json", "r") as f: 6 | annos = json.load(f) 7 | 8 | if split == "train": 9 | with open(f"annotations/scannet_{split}_stage1.json", "r") as f: 10 | annos.extend(json.load(f)) 11 | 12 | print(len(annos)) 13 | 14 | new_annos = [] 15 | unwanted_words = ["wall", "ceiling", "floor", "object", "item"] 16 | 17 | with open("prompts/obj_align_template.txt", 'r') as f: 18 | answer_templates = f.read().splitlines() 19 | 20 | for anno in annos: 21 | scene_id = anno["scene_id"] 22 | obj_id = anno["obj_id"] 23 | caption = anno["captions"][0] 24 | prompt = f"What is the ?" 25 | if any(x in caption for x in unwanted_words): 26 | continue 27 | if split == "train": 28 | answer_template = random.choice(answer_templates) 29 | if answer_template.count("{}") == 2: 30 | answer = answer_template.format(f"", caption) 31 | else: 32 | answer = answer_template.format(caption) 33 | new_annos.append({ 34 | "scene_id": scene_id, 35 | "obj_id": obj_id, 36 | "prompt": prompt, 37 | "caption": answer 38 | }) 39 | else: 40 | answers = [] 41 | for answer_template in answer_templates: 42 | if answer_template.count("{}") == 2: 43 | answer = answer_template.format(f"", caption) 44 | else: 45 | answer = answer_template.format(caption) 46 | answers.append(answer) 47 | new_annos.append({ 48 | "scene_id": scene_id, 49 | "obj_id": obj_id, 50 | "prompt": prompt, 51 | "ref_captions": answers 52 | }) 53 | 54 | print(len(new_annos)) 55 | 56 | with open(f"annotations/obj_align_{split}_OBJ.json", "w") as f: 57 | json.dump(new_annos, f, indent=4) 58 | -------------------------------------------------------------------------------- /others/prepare_obj_caption.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | split = 'val' 6 | all_objects = json.load(open('annotations/all_objects_with_full_desc_0318.json')) 7 | templates = [line.rstrip() for line in open('prompts/object_caption_templates.txt')] 8 | scan_list = [line.rstrip() for line in open(f"annotations/scannet_{split}.txt")] 9 | 10 | new_annos = [] 11 | for k, v in all_objects.items(): 12 | if k not in scan_list: 13 | continue 14 | for o in v['object_list']: 15 | if o['description'] is not None: 16 | new_annos.append({ 17 | 'scene_id': k, 18 | 'obj_id': o['object_id'], 19 | 'prompt': random.choice(templates).replace('', f""), 20 | 'ref_captions': [o['description'].capitalize()+'.'], 21 | 'related_ids': [o['object_id']] 22 | }) 23 | print(len(new_annos)) 24 | 25 | with open(f'annotations/scannet_{split}_stage2_caption_OBJ.json', 'w') as f: 26 | json.dump(new_annos, f, indent=4) 27 | -------------------------------------------------------------------------------- /others/prepare_ref_captions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | val_anno = json.load(open("anno/scanrefer_val_stage2.json", "r")) 6 | ref_captions = json.load(open("anno/scanrefer_captions.json", "r")) 7 | train_convs = json.load(open("anno/scanrefer_train_conversation.json", "r")) 8 | output_anno = [] 9 | 10 | for k, v in train_convs.items(): 11 | scene_id = "_".join(k.split("_")[:-1]) 12 | obj_id = int(k.split("_")[-1]) 13 | if scene_id[:7] != "scene00" or len(v) == 0: 14 | continue 15 | qa = random.choice(v) 16 | output_anno.append({ 17 | "scene_id": scene_id, 18 | "obj_id": obj_id, 19 | "prompt": qa["Question"], 20 | "ref_captions": [] 21 | }) 22 | 23 | 24 | # cap_list = [] 25 | # for k, v in ref_captions.items(): 26 | # if k[:9] == "scene0000": 27 | # cap_list.append({ 28 | # "scene_obj": k, 29 | # "captions": v 30 | # }) 31 | # cap_list = sorted(cap_list, key=lambda x: x["scene_obj"]) 32 | # with open("anno/tmp.json", "w") as f: 33 | # json.dump(cap_list, f) 34 | # 35 | # exit() 36 | 37 | # prompt_list = [] 38 | # with open("prompts/conv_description.txt", "r") as f: 39 | # for line in f.readlines(): 40 | # prompt = line.strip().split(" ")[-1] 41 | # prompt_list.append(prompt) 42 | # 43 | # id_set = set() 44 | # 45 | # output_anno = [] 46 | # 47 | # for anno in val_anno: 48 | # scene_id = anno["scene_id"] 49 | # obj_id = anno["obj_id"] 50 | # item_id = f"{scene_id}_{obj_id}" 51 | # if scene_id[:7] == "scene00" and item_id not in id_set: 52 | # id_set.add(item_id) 53 | # prompt = random.choice(prompt_list) 54 | # output_anno.append({ 55 | # "scene_id": scene_id, 56 | # "obj_id": obj_id, 57 | # "prompt": prompt, 58 | # "ref_captions": [] 59 | # }) 60 | 61 | print(len(output_anno)) 62 | output_anno = sorted(output_anno, key=lambda x: f"{x['scene_id']}_{x['obj_id']:03}") 63 | output_anno = output_anno[:200] 64 | 65 | with open("anno/scanrefer_val_convs.json", "w") as f: 66 | json.dump(output_anno, f, indent=4) 67 | -------------------------------------------------------------------------------- /others/prepare_referit_anno_stage1.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from sklearn.model_selection import train_test_split 4 | from collections import defaultdict 5 | 6 | dataset_name = "sr3d_merge" 7 | referit_anno_root = "/root/scene-LLaMA/datasets/referit3d/annotations" 8 | 9 | dataset_dir = os.path.join(referit_anno_root, "bert_tokenized") 10 | train_split_file = os.path.join(referit_anno_root, "splits/scannetv2_train.txt") 11 | val_split_file = os.path.join(referit_anno_root, "splits/scannetv2_val.txt") 12 | # dataset_train = json.load(open(os.path.join(dataset_dir, "train.json"), "r")) 13 | # dataset_val = json.load(open(os.path.join(dataset_dir, "val.json"), "r")) 14 | 15 | sr3d_data = [] 16 | with open(os.path.join(dataset_dir, f"sr3d.jsonl"), "r") as f: 17 | for line in f: 18 | data = json.loads(line) 19 | for k, v in data.items(): 20 | if type(v) == list: 21 | data[k] = frozenset(v) 22 | sr3d_data.append(data) 23 | 24 | sr3d_plus_data = [] 25 | with open(os.path.join(dataset_dir, f"sr3d+.jsonl"), "r") as f: 26 | for line in f: 27 | data = json.loads(line) 28 | for k, v in data.items(): 29 | if type(v) == list: 30 | data[k] = frozenset(v) 31 | sr3d_plus_data.append(data) 32 | 33 | sr3d_data = {frozenset(d.items()) for d in sr3d_data} 34 | sr3d_plus_data = {frozenset(d.items()) for d in sr3d_plus_data} 35 | sr3d_merged_data = sr3d_data | sr3d_plus_data 36 | print(len(sr3d_data), len(sr3d_plus_data), len(sr3d_merged_data)) 37 | # 38 | # exit() 39 | 40 | train_scenes = [] 41 | with open(train_split_file, "r") as f: 42 | for line in f.readlines(): 43 | train_scenes.append(line.strip()) 44 | 45 | val_scenes = [] 46 | with open(val_split_file, "r") as f: 47 | for line in f.readlines(): 48 | val_scenes.append(line.strip()) 49 | 50 | print(len(train_scenes), len(val_scenes)) 51 | 52 | 53 | train_data = [] 54 | val_data = [] 55 | 56 | correct_false = 0 57 | other_data = 0 58 | 59 | # with open(os.path.join(dataset_dir, f"{dataset_name}.jsonl"), "r") as f: 60 | # for line in f: 61 | # tmp_data = json.loads(line) 62 | # if not tmp_data["correct_guess"]: 63 | # correct_false += 1 64 | # continue 65 | # if tmp_data["scan_id"] in train_scenes: 66 | # train_data.append(tmp_data) 67 | # elif tmp_data["scan_id"] in val_scenes: 68 | # val_data.append(tmp_data) 69 | # else: 70 | # # print(tmp_data["scan_id"]) 71 | # other_data += 1 72 | 73 | for tmp_data in sr3d_merged_data: 74 | scan_id, target_id, utterance = None, None, None 75 | for k, v in tmp_data: 76 | if k == "scan_id": 77 | scan_id = v 78 | if k == "target_id": 79 | target_id = v 80 | if k == "utterance": 81 | utterance = v 82 | tmp_data = { 83 | "scan_id": scan_id, 84 | "target_id": target_id, 85 | "utterance": utterance 86 | } 87 | if tmp_data["scan_id"] in train_scenes: 88 | train_data.append(tmp_data) 89 | elif tmp_data["scan_id"] in val_scenes: 90 | val_data.append(tmp_data) 91 | else: 92 | # print(tmp_data["scan_id"]) 93 | other_data += 1 94 | 95 | print(len(train_data), len(val_data), correct_false, other_data) 96 | 97 | captions = defaultdict(list) 98 | 99 | 100 | def process(dataset_data): 101 | new_list = [] 102 | for data in dataset_data: 103 | scene_id = data["scan_id"] 104 | obj_id = int(data["target_id"]) 105 | feat_path = f"{scene_id}/{obj_id:03}.pt" 106 | caption = data["utterance"] 107 | new_data = { 108 | "pc_feat_path": feat_path, 109 | "caption": caption, 110 | "scene_id": scene_id, 111 | "obj_id": obj_id 112 | } 113 | captions[f"{scene_id}_{obj_id}"].append(caption) 114 | new_list.append(new_data) 115 | return new_list 116 | 117 | 118 | output_train = process(train_data) 119 | output_val = process(val_data) 120 | 121 | 122 | with open(f"anno/{dataset_name}_train_stage1.json", "w") as f: 123 | json.dump(output_train, f) 124 | 125 | with open(f"anno/{dataset_name}_val_stage1.json", "w") as f: 126 | json.dump(output_val, f) 127 | 128 | with open(f"anno/{dataset_name}_captions.json", "w") as f: 129 | json.dump(captions, f) 130 | -------------------------------------------------------------------------------- /others/prepare_scanqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | replace_list = [["10", "ten"], ["12", "twelve"], ["1", "one"], ["2", "two"], ["3", "three"], ["4", "four"], ["5", "five"], 4 | ["6", "six"], ["7", "seven"], ["8", "eight"], ["9", "nine"]] 5 | 6 | split = "train" 7 | with open(f"annotations/scanqa/ScanQA_v1.0_{split}.json", "r") as f: 8 | annos = json.load(f) 9 | print(len(annos)) 10 | new_annos = [] 11 | for anno in annos: 12 | scene_id = anno["scene_id"] 13 | obj_ids = anno["object_ids"] if "object_ids" in anno else [] 14 | question = anno["question"] 15 | # for (a, b) in replace_list: 16 | # question = question.replace(a, b) 17 | 18 | # prompt = f"Pay attention to obj{obj_ids[0]:02}" 19 | # if len(obj_ids) == 2: 20 | # prompt += f" and obj{obj_ids[1]:02}" 21 | # elif len(obj_ids) > 2: 22 | # for i in range(1, len(obj_ids)-1): 23 | # prompt += f", obj{obj_ids[i]:02}" 24 | # prompt += f", and obj{obj_ids[-1]:02}" 25 | # if len(obj_ids) > 0: 26 | # related_prompt = f"The relevant object{'s are' if len(obj_ids) > 1 else ' is'} obj{obj_ids[0]:02}" 27 | # if len(obj_ids) == 2: 28 | # related_prompt += f" and obj{obj_ids[1]:02}" 29 | # elif len(obj_ids) > 2: 30 | # for i in range(1, len(obj_ids)-1): 31 | # related_prompt += f", obj{obj_ids[i]:02}" 32 | # related_prompt += f", and obj{obj_ids[-1]:02}" 33 | # related_prompt += "." 34 | # prompt += ". " + question + " Answer the question using a single word or phrase." 35 | prompt = question + " Answer the question using a single word or phrase." 36 | # prompt = question + " The answer should be a phrase or a single word." 37 | 38 | answers = anno["answers"] 39 | if split == "val": 40 | # for i in range(len(answers)): 41 | # for (a, b) in replace_list: 42 | # answers[i] = answers[i].replace(a, b) 43 | new_annos.append({ 44 | "scene_id": scene_id, 45 | "obj_id": obj_ids[0], 46 | "prompt": prompt, 47 | "ref_captions": answers 48 | }) 49 | elif split == "train": 50 | for i in range(len(answers)): 51 | if i > 0 and answers[i] == answers[i-1]: 52 | continue 53 | answer = answers[i] 54 | # for (a, b) in replace_list: 55 | # answer = answer.replace(a, b) 56 | answer = answer.capitalize() 57 | if answer[-1] != ".": 58 | answer += "." 59 | # answer = "The answer is " + answer + "." 60 | new_annos.append({ 61 | "scene_id": scene_id, 62 | "obj_id": obj_ids[0], 63 | "prompt": prompt, 64 | "caption": answer, 65 | # "related_ids": obj_ids 66 | }) 67 | print(len(new_annos)) 68 | 69 | with open(f"annotations/scanqa_{split}.json", "w") as f: 70 | json.dump(new_annos, f, indent=4) 71 | -------------------------------------------------------------------------------- /others/prepare_scene_align_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import torch 4 | import random 5 | 6 | split = "train" 7 | 8 | annos = [] 9 | 10 | if split == "val": 11 | with open(f"annotations/scanrefer_{split}_stage1.json", "r") as f: 12 | annos.extend(json.load(f)) 13 | 14 | if split == "train": 15 | with open(f"annotations/scannet_{split}_stage1.json", "r") as f: 16 | annos.extend(json.load(f)) 17 | 18 | scene_attrs = json.load(open("annotations/scannet_attributes.json", "r")) 19 | 20 | print(len(annos)) 21 | 22 | new_annos = [] 23 | 24 | for anno in annos: 25 | scene_id = anno["scene_id"] 26 | obj_id = int(anno["obj_id"]) 27 | # caption = anno["captions"][0] 28 | 29 | locs = torch.tensor(scene_attrs[scene_id]["locs"]) 30 | obj_num = locs.shape[0] 31 | if obj_num <= 6: 32 | continue 33 | centers = locs[:, :3] 34 | dis = ((centers - centers[obj_id])**2).sum(dim=-1) 35 | center_diff = (centers - centers[obj_id]).abs() 36 | 37 | mode = random.randint(0, 3) 38 | if mode == 0: 39 | if random.randint(0, 1) == 0: 40 | prompt = f"Which object is closest to obj{obj_id:02}?" 41 | answer_id = int(dis.topk(k=2, largest=False)[1][1]) 42 | answer = f"The closest object to obj{obj_id:02} is obj{answer_id:02}." 43 | else: 44 | prompt = f"Which object is farthest from obj{obj_id:02}?" 45 | answer_id = int(dis.topk(k=2, largest=True)[1][0]) 46 | answer = f"The farthest object from obj{obj_id:02} is obj{answer_id:02}." 47 | if mode == 1: 48 | a = random.randint(0, obj_num-1) 49 | b = random.randint(0, obj_num-1) 50 | while a == obj_id or b == obj_id or a == b: 51 | a = random.randint(0, obj_num - 1) 52 | b = random.randint(0, obj_num - 1) 53 | if random.randint(0, 1): 54 | prompt = f"Which object is closer to obj{obj_id:02}, obj{a:02} or obj{b:02}?" 55 | answer_id = a if dis[a] < dis[b] else b 56 | answer = f"The closer object to obj{obj_id:02} is obj{answer_id:02}." 57 | else: 58 | prompt = f"Which object is farther from obj{obj_id:02}, obj{a:02} or obj{b:02}?" 59 | answer_id = a if dis[a] > dis[b] else b 60 | answer = f"The farther object from obj{obj_id:02} is obj{answer_id:02}." 61 | if mode == 2: 62 | a = random.randint(0, obj_num - 1) 63 | b = random.randint(0, obj_num - 1) 64 | while a == obj_id or b == obj_id or a == b: 65 | a = random.randint(0, obj_num - 1) 66 | b = random.randint(0, obj_num - 1) 67 | z_a = locs[a][2] - locs[a][5] 68 | z_b = locs[b][2] - locs[b][5] 69 | if random.randint(0, 1): 70 | prompt = f"Which object is located at the higher position, obj{a:02} or obj{b:02}?" 71 | if z_a < z_b: 72 | a, b = b, a 73 | answer = f"Obj{a:02} is located at a higher position compared to obj{b:02}." 74 | else: 75 | prompt = f"Which object is located at the lower position, obj{a:02} or obj{b:02}?" 76 | if z_a > z_b: 77 | a, b = b, a 78 | answer = f"Obj{a:02} is located at a lower position compared to obj{b:02}." 79 | if mode == 3: 80 | prompt = f"List the five closest objects to obj{obj_id:02} in ascending order of their object IDs." 81 | answer_ids = dis.topk(k=6, largest=False)[1][1:6].tolist() 82 | answer_ids.sort() 83 | answer = f"The five closest objects to obj{obj_id:02} in ascending order are: obj{answer_ids[0]:02}, obj{answer_ids[1]:02}, obj{answer_ids[2]:02}, obj{answer_ids[3]:02}, and obj{answer_ids[4]:02}." 84 | # nearest 85 | if split == "train": 86 | new_annos.append({ 87 | "scene_id": scene_id, 88 | "obj_id": obj_id, 89 | "prompt": prompt, 90 | "caption": answer 91 | }) 92 | else: 93 | new_annos.append({ 94 | "scene_id": scene_id, 95 | "obj_id": obj_id, 96 | "prompt": prompt, 97 | "ref_captions": [answer] 98 | }) 99 | 100 | with open(f"annotations/scene_align_{split}.json", "w") as f: 101 | json.dump(new_annos, f, indent=4) 102 | -------------------------------------------------------------------------------- /others/prepare_scene_level_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from collections import defaultdict 4 | 5 | split = "val" 6 | 7 | annos = json.load(open(f"annotations/scanrefer_{split}_stage2_objxx.json", "r")) 8 | attrs = torch.load(f"annotations/scannet_{split}_attributes.pt") 9 | 10 | new_annos = defaultdict(dict) 11 | 12 | for anno in annos: 13 | scene_id = anno["scene_id"] 14 | obj_id = anno["obj_id"] 15 | caption = anno["caption"] if "caption" in anno else anno["ref_captions"][0] 16 | new_annos[scene_id][f"{obj_id:03}"] = { 17 | "loc": attrs[scene_id]["locs"][obj_id][:3].tolist(), 18 | "caption": caption, 19 | "class_name": attrs[scene_id]["objects"][obj_id] 20 | } 21 | 22 | 23 | print(len(new_annos)) 24 | 25 | for scene_id in new_annos.keys(): 26 | message = "" 27 | for i in range(200): 28 | obj_id = f"{i:03}" 29 | if obj_id not in new_annos[scene_id]: 30 | continue 31 | obj_anno = new_annos[scene_id][obj_id] 32 | class_name = obj_anno["class_name"] 33 | loc = obj_anno["loc"] 34 | for j in range(len(loc)): 35 | loc[j] = round(loc[j], 2) 36 | caption = obj_anno["caption"] 37 | message += f"obj{i:02}: {class_name}; {loc}; {caption}\n" 38 | with open(f"annotations/scene_dataset/obj_info_list/{split}/{scene_id}.json", "w") as f: 39 | json.dump(message, f) 40 | 41 | -------------------------------------------------------------------------------- /others/prepare_sqa3d.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import nltk 5 | import random 6 | from tqdm import tqdm 7 | 8 | anno_dir = 'annotations/sqa3d' 9 | 10 | # for filename in os.listdir(anno_dir): 11 | # x = json.load(open(os.path.join(anno_dir, filename))) 12 | # with open(os.path.join(anno_dir, filename), 'w') as f: 13 | # json.dump(x, f, indent=4) 14 | 15 | 16 | def convert_person_view(sentence): 17 | # first-person view to second-person view 18 | forms = {'i': 'you', 'me': 'you', 'my': 'your', 'mine': 'yours', 'am': 'are'} 19 | def translate(word): 20 | if word.lower() in forms: 21 | return forms[word.lower()] 22 | return word 23 | result = ' '.join([translate(word) for word in nltk.wordpunct_tokenize(sentence)]) 24 | return result.capitalize() 25 | 26 | 27 | def get_sqa_question_type(question): 28 | question = question.lstrip() 29 | if question[:4].lower() == 'what': 30 | return 0 31 | elif question[:2].lower() == 'is': 32 | return 1 33 | elif question[:3].lower() == 'how': 34 | return 2 35 | elif question[:3].lower() == 'can': 36 | return 3 37 | elif question[:5].lower() == 'which': 38 | return 4 39 | else: 40 | return 5 # others 41 | 42 | 43 | for split in ['train', 'val']: 44 | scan_ids = [] 45 | sqa_annos = [] 46 | question_file = os.path.join(anno_dir, f'v1_balanced_questions_{split}_scannetv2.json') 47 | with open(question_file, 'r', encoding='utf-8') as f: 48 | question_data = json.load(f)['questions'] 49 | question_map = {} 50 | for item in question_data: 51 | question_map[item['question_id']] = { 52 | 's': [item['situation']] + item['alternative_situation'], # list of str 53 | 'q': item['question'], # str 54 | } 55 | 56 | anno_file = os.path.join(anno_dir, f'v1_balanced_sqa_annotations_{split}_scannetv2.json') 57 | with open(anno_file, 'r', encoding='utf-8') as f: 58 | anno_data = json.load(f)['annotations'] 59 | for item in tqdm(anno_data): 60 | scan_ids.append(item['scene_id']) 61 | # sqa_annos.append({ 62 | # 's': question_map[item['question_id']]['s'], # list of str 63 | # 'q': question_map[item['question_id']]['q'], # str 64 | # 'a': [meta['answer'] for meta in item['answers']], # list of str 65 | # 'pos': np.array(list(item['position'].values())), # array (3,) 66 | # 'rot': np.array(list(item['rotation'].values())), # array (4,) 67 | # }) 68 | scene_id = item['scene_id'] 69 | obj_id = 0 70 | situation = random.choice(question_map[item['question_id']]['s']) 71 | question = question_map[item['question_id']]['q'] 72 | question_type = get_sqa_question_type(question) 73 | prompt = situation + ' ' + question + " Answer the question using a single word or phrase." 74 | answers = [meta['answer'] for meta in item['answers']] 75 | if split == 'train': 76 | answer = random.choice(answers) 77 | answer = answer.capitalize() 78 | if answer[-1] != ".": 79 | answer += "." 80 | sqa_annos.append({ 81 | 'scene_id': scene_id, 82 | 'obj_id': obj_id, 83 | 'prompt': prompt, 84 | 'caption': answer, 85 | 'sqa_type': question_type 86 | }) 87 | else: 88 | sqa_annos.append({ 89 | 'scene_id': scene_id, 90 | 'obj_id': obj_id, 91 | 'prompt': prompt, 92 | 'ref_captions': answers, 93 | 'sqa_type': question_type 94 | }) 95 | # print(sqa_annos[-1]) 96 | # breakpoint() 97 | with open(f"annotations/sqa3d_{split}.json", "w") as f: 98 | json.dump(sqa_annos, f, indent=4) 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /others/prepare_train_stage3.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | 4 | train_split_file = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_train.txt" 5 | train_list = [] 6 | with open(train_split_file, "r") as f: 7 | for line in f.readlines(): 8 | train_list.append(line.strip()) 9 | 10 | nr3d_anno_file = "/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/nr3d.jsonl" 11 | nr3d_anno = [] 12 | with jsonlines.open(nr3d_anno_file, "r") as reader: 13 | for l in reader: 14 | nr3d_anno.append(l) 15 | 16 | anno_root = "annotations" # annotation dir 17 | attribute_file = f"{anno_root}/scannet_attributes_old.json" 18 | attributes = json.load(open(attribute_file, 'r')) 19 | 20 | q_template = "Evaluate the request below and determine whether it accurately identifies the target object enclosed within the tags '' and '': {} Please respond with the answer in the following format: 'The answer is: True' if the request correctly localizes the target object, or 'The answer is: False' if it does not." 21 | a_template = "The answer is: {}." 22 | 23 | from tqdm import tqdm 24 | import random 25 | from collections import defaultdict 26 | utters = defaultdict(list) 27 | for item in nr3d_anno: 28 | scene_id = item["scan_id"] 29 | if scene_id not in train_list: 30 | continue 31 | utter = item["utterance"] 32 | utters[scene_id].append(utter) 33 | # target_id = item["target_id"] 34 | 35 | new_anno = [] 36 | 37 | for item in tqdm(nr3d_anno): 38 | scene_id = item["scan_id"] 39 | if scene_id not in train_list: 40 | continue 41 | utter = item["utterance"] 42 | target_id = item["target_id"] 43 | new_anno.append({ 44 | "scene_id": scene_id, 45 | "obj_id": target_id, 46 | "QA": [{ 47 | "Question": q_template.format(utter), 48 | "Answer": a_template.format("True") 49 | }] 50 | }) 51 | 52 | utter = random.choice(utters[scene_id]) 53 | while utter == item["utterance"]: 54 | utter = random.choice(utters[scene_id]) 55 | new_anno.append({ 56 | "scene_id": scene_id, 57 | "obj_id": target_id, 58 | "QA": [{ 59 | "Question": q_template.format(utter), 60 | "Answer": a_template.format("False") 61 | }] 62 | }) 63 | 64 | with open("annotations/nr3d_train_tf.json", "w") as f: 65 | json.dump(new_anno, f) 66 | -------------------------------------------------------------------------------- /others/process_preds.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | model = "2023-09-17-144020_dp0.1_lr5e-5_sta3_ep3" 5 | conv_pred_file = f"eval/conv100_{model}.json" 6 | detail_pred_file = f"eval/detail100_{model}.json" 7 | 8 | qa_name = "qa60" 9 | conv_preds_ = json.load(open(conv_pred_file, "r")) 10 | detail_preds_ = json.load(open(detail_pred_file, "r")) 11 | ques = json.load(open(f"eval/{qa_name}_questions.json", "r")) 12 | 13 | conv_preds = {} 14 | detail_preds = {} 15 | 16 | for pred in conv_preds_: 17 | item_id = f"{pred['scene_id']}_{pred['obj_id']}" 18 | conv_preds[item_id] = pred["pred"] 19 | for pred in detail_preds_: 20 | item_id = f"{pred['scene_id']}_{pred['obj_id']}" 21 | detail_preds[item_id] = pred["pred"] 22 | 23 | answers = [] 24 | 25 | for que in ques: 26 | item_id = que["item_id"] 27 | answer = conv_preds[item_id] if que["category"] == "conv" else detail_preds[item_id] 28 | if ": " in answer: 29 | answer = ": ".join(answer.split(": ")[1:]) 30 | answers.append({ 31 | "answer_id": que["question_id"], 32 | "text": answer 33 | }) 34 | 35 | with open(f"eval/{qa_name}_{model}_answer.json", "w") as f: 36 | json.dump(answers, f, indent=4) 37 | -------------------------------------------------------------------------------- /others/process_vil3dref_multichoice.py: -------------------------------------------------------------------------------- 1 | """ 2 | loss/og3d: 2.9594, loss/obj3d_clf: 3.3753, loss/obj3d_clf_pre: 2.0714, loss/txt_clf: 0.6708, loss/total: 10.2789, loss/cross_attn_0: 0.0032, loss/cross_attn_1: 0.0011, loss/cross_attn_2: 0.0011, loss/cross_attn_3: 0.0012, loss/self_attn_0: 0.1595, loss/self_attn_1: 0.0425, loss/self_attn_2: 0.0541, loss/self_attn_3: 0.1030, loss/hidden_state_0: 0.3919, loss/hidden_state_1: 0.0765, loss/hidden_state_2: 0.1033, loss/hidden_state_3: 0.1308, loss/hidden_state_4: 0.1337, acc/og3d: 0.6373, acc/og3d_class: 0.8903, acc/obj3d_clf: 0.6828, acc/obj3d_clf_pre: 0.6131, acc/txt_clf: 0.9281 3 | """ 4 | 5 | import json 6 | import jsonlines 7 | import math 8 | import torch 9 | from random import shuffle 10 | import random 11 | split = "val" 12 | val_file = f"/root/scene-LLaMA/datasets/exprs_neurips22/gtlabelpcd_mix/nr3d/preds/{split}_outs.json" 13 | nr3d_anno_file = "/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/nr3d.jsonl" 14 | 15 | val_results = json.load(open(val_file)) 16 | 17 | nr3d_anno = {} 18 | with jsonlines.open(nr3d_anno_file, "r") as reader: 19 | for l in reader: 20 | nr3d_anno[l["item_id"]] = l 21 | 22 | val_split_path = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_val.txt" 23 | scene_ids = [] 24 | with open(val_split_path, "r") as f: 25 | for line in f.readlines(): 26 | scene_ids.append(line.strip()) 27 | 28 | shuffle(scene_ids) 29 | scene_num = len(scene_ids) 30 | train_scene_num = int(scene_num * 0.8) 31 | train_scene_ids, val_scene_ids = scene_ids[:train_scene_num], scene_ids[train_scene_num:] 32 | 33 | 34 | 35 | multi_choice_template = "Given the description of \",\" I have received a list of possible objects from a robust 3D localization model: []. These objects are considered potential matches for the given description. Kindly provide the object ID that you believe is the closest match to the description. If you believe none of the listed objects are a correct match, please specify an alternative object ID." 36 | 37 | item_list = [] 38 | train_output_annos = [] 39 | val_output_annos = [] 40 | 41 | acc = 0 42 | random_acc = 0 43 | origin_acc = 0 44 | tot_len = 0 45 | max_len = 0 46 | thres = 1e-2 47 | tot_num = 0 48 | for k, v in val_results.items(): 49 | obj_ids = v["obj_ids"] 50 | obj_logits = v["obj_logits"] 51 | obj_logits = (torch.tensor(obj_logits)).softmax(dim=-1).tolist() 52 | scan_id = nr3d_anno[k]["scan_id"] 53 | utter = nr3d_anno[k]["utterance"] 54 | target_id = nr3d_anno[k]["target_id"] 55 | logit_ids = zip(obj_logits, obj_ids) 56 | logit_ids = sorted(logit_ids, reverse=True) 57 | logits, ids = zip(*logit_ids) 58 | # print(logits) 59 | # print(ids) 60 | # print(target_id) 61 | # breakpoint() 62 | can_ids = [] 63 | if split == "val": 64 | for i in range(min(len(logits), 5)): 65 | if logits[i] > thres: 66 | can_ids.append(ids[i]) 67 | else: 68 | can_num = random.randint(1, 5) 69 | can_ids = ids[:can_num] 70 | if len(can_ids) == 1: 71 | continue 72 | # can_ids = sorted(can_ids) 73 | id_list = "" 74 | for i in range(len(can_ids)): 75 | if i > 0: 76 | id_list += ", " 77 | id_list += f"obj{can_ids[i]:02}" 78 | if utter[-1] == ".": 79 | utter = utter[:-1] 80 | prompt = multi_choice_template.replace("", utter).replace("", id_list) 81 | answer = f"obj{target_id:02}.".capitalize() 82 | # logits = (torch.tensor(logits[:5]) / 5.).softmax(dim=-1).tolist() 83 | # print(logits) 84 | # if ids[0] == target_id: 85 | # acc += 1 86 | 87 | # item_list.append({ 88 | # "can_ids": ids[:5], 89 | # "can_preds": logits[:5], 90 | # "utter": utter, 91 | # "target_id": target_id, 92 | # "scan_id": scan_id 93 | # }) 94 | if scan_id in train_scene_ids: 95 | train_output_annos.append({ 96 | "scene_id": scan_id, 97 | "obj_id": target_id, 98 | "prompt": prompt, 99 | "caption": answer, 100 | }) 101 | else: 102 | val_output_annos.append({ 103 | "scene_id": scan_id, 104 | "obj_id": target_id, 105 | "prompt": prompt, 106 | "ref_captions": [answer], 107 | "qid": k 108 | }) 109 | if target_id in can_ids: 110 | acc += 1 111 | if random.choice(can_ids) == target_id: 112 | random_acc += 1 113 | if ids[0] == target_id: 114 | origin_acc += 1 115 | tot_len += len(can_ids) 116 | max_len = len(can_ids) if len(can_ids) > max_len else max_len 117 | tot_num += 1 118 | 119 | # if split == "val": 120 | # with open(f"annotations/nr3d_{split}_stage2_multichoice{str(thres)}.json", "w") as f: 121 | # json.dump(train_output_annos, f, indent=4) 122 | # else: 123 | # with open(f"annotations/nr3d_{split}_stage2_multichoice.json", "w") as f: 124 | # json.dump(val_output_annos, f, indent=4) 125 | 126 | with open(f"annotations/nr3d_train_stage2_multichoice{str(thres)}.json", "w") as f: 127 | json.dump(train_output_annos, f, indent=4) 128 | with open(f"annotations/nr3d_val_stage2_multichoice{str(thres)}.json", "w") as f: 129 | json.dump(val_output_annos, f, indent=4) 130 | 131 | print(tot_num) 132 | print("Origin Acc:", float(origin_acc) / tot_num) 133 | print("Upper Acc:", float(acc) / tot_num) 134 | print("Random Acc:", float(random_acc) / tot_num) 135 | print("mean len:", float(tot_len) / tot_num) 136 | print("max len:", max_len) 137 | # print(len(item_list)) 138 | # print(item_list[:5]) 139 | # exit() 140 | -------------------------------------------------------------------------------- /others/run_chat.sh: -------------------------------------------------------------------------------- 1 | 2 | pretrained_path=outputs/2023-09-07-205742_dp0.1_lr5e-5_sta3_ep3/ckpt_02.pth 3 | 4 | CUDA_VISIBLE_DEVICES=2 python others/process_vil3dref_results.py \ 5 | scripts/config_old.py \ 6 | pretrained_path "$pretrained_path" \ 7 | model.max_txt_len 20 8 | -------------------------------------------------------------------------------- /others/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model=2023-09-17-144020_dp0.1_lr5e-5_sta3_ep3 4 | qa_name=qa60 5 | OPENAI_API_KEY="" python eval/eval.py \ 6 | --question eval/"$qa_name"_questions.json \ 7 | --context anno/scanrefer_val_content.json \ 8 | --answer-list \ 9 | eval/"$qa_name"_gpt4_answer.json \ 10 | eval/"$qa_name"_"$model"_answer.json \ 11 | --rule eval/rule.json \ 12 | --output eval/review_"$qa_name"_"$model".jsonl -------------------------------------------------------------------------------- /others/run_generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OPENAI_API_KEY="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" python others/gpt_generate.py -------------------------------------------------------------------------------- /others/tmp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | from collections import defaultdict 5 | 6 | # annos = json.load(open("annotations/obj_align_train_OBJ.json")) 7 | 8 | # new_annos = [] 9 | # obj_ids = set() 10 | 11 | # for anno in annos: 12 | # if anno['scene_id'] == "scene0000_00": 13 | # tmp_id = int(anno["caption"].split("OBJ")[1][:3]) 14 | # if tmp_id in obj_ids: 15 | # continue 16 | # obj_ids.add(tmp_id) 17 | # if anno["caption"].startswith("") 44 | # caption = '' 45 | # for idx in range(len(clean_text)): 46 | # if idx in p: 47 | # if p[idx][0] == '[': 48 | # caption += '[' 49 | # else: 50 | # caption += ' ' + ', '.join(p[idx]) + ']' 51 | # caption += clean_text[idx] 52 | # return caption 53 | 54 | 55 | # scannet_root = '/mnt/petrelfs/share_data/huanghaifeng/data/processed/scannet' 56 | # x = json.load(open('data/step2_captions_by_scene_v2.json')) 57 | 58 | # train_annos = [] 59 | # val_annos = [] 60 | 61 | # train_scans = [line.rstrip() for line in open(os.path.join(scannet_root, f'train.txt'))] 62 | # val_scans = [line.rstrip() for line in open(os.path.join(scannet_root, f'val.txt'))] 63 | 64 | 65 | # import pandas as pd 66 | # from tqdm import tqdm 67 | 68 | # obj_csv = pd.read_csv('annotations/Cap3D_automated_Objaverse_no3Dword.csv', header=None) 69 | # obj_ids = [] 70 | # obj_cap_dict = {} 71 | # feats = torch.load('annotations/objaverse_uni3d_feature.pt') 72 | 73 | # for obj_id, cap in tqdm(zip(obj_csv[0].values, obj_csv[1].values)): 74 | # # remove redundant quotation marks, here we do not directly strip because the mark may appear only at one side 75 | # if obj_id not in feats: 76 | # continue 77 | # if cap.startswith('"') and cap.endswith('"'): 78 | # cap = cap.strip('"') 79 | # elif cap.startswith("'") and cap.endswith("'"): 80 | # cap = cap.strip("'") 81 | # cap = cap.capitalize() 82 | # obj_ids.append(obj_id) 83 | # obj_cap_dict[obj_id] = cap 84 | 85 | # train_annos = [] 86 | # val_annos = [] 87 | # train_obj_ids = obj_ids[:-1000] 88 | # val_obj_ids = obj_ids[-1000:] 89 | 90 | 91 | # for obj_id in train_obj_ids: 92 | # train_annos.append({ 93 | # 'scene_id': obj_id, 94 | # 'caption': obj_cap_dict[obj_id] 95 | # }) 96 | 97 | # for obj_id in val_obj_ids: 98 | # val_annos.append({ 99 | # 'scene_id': obj_id, 100 | # 'ref_captions': [obj_cap_dict[obj_id]] 101 | # }) 102 | 103 | # print(len(train_annos)) 104 | # print(len(val_annos)) 105 | 106 | 107 | # with open('annotations/objaverse_caption_train.json', 'w') as f: 108 | # json.dump(train_annos, f, indent=4) 109 | 110 | # with open('annotations/objaverse_caption_val.json', 'w') as f: 111 | # json.dump(val_annos, f, indent=4) 112 | 113 | 114 | # train_feats = {} 115 | # val_feats = {} 116 | 117 | # for obj_id in train_obj_ids: 118 | # train_feats[obj_id] = feats[obj_id] 119 | # for obj_id in val_obj_ids: 120 | # val_feats[obj_id] = feats[obj_id] 121 | 122 | # torch.save(train_feats, 'annotations/objaverse_uni3d_feature_train.pt') 123 | # torch.save(val_feats, 'annotations/objaverse_uni3d_feature_val.pt') 124 | 125 | 126 | # import os 127 | # import gzip 128 | # import numpy as np 129 | # from tqdm import tqdm 130 | 131 | # folder_path = '/mnt/petrelfs/huanghaifeng/share/data/cap3d/8192_npy' 132 | 133 | # for filename in tqdm(os.listdir(folder_path)): 134 | # if filename.endswith('.npy'): 135 | # obj_id = filename.split('_8192')[0] 136 | # data = np.load(os.path.join(folder_path, filename)) 137 | # with gzip.open(os.path.join(folder_path, obj_id + '.gz'), 'wb') as f: 138 | # np.save(f, data) 139 | # os.remove(os.path.join(folder_path, filename)) 140 | 141 | import csv 142 | id2class = {} 143 | labels = set() 144 | class_label_file = "annotations/scannet/scannetv2-labels.combined.tsv" 145 | with open(class_label_file, "r") as f: 146 | csvreader = csv.reader(f, delimiter='\t') 147 | csvreader.__next__() 148 | for line in csvreader: 149 | id2class[line[0]] = line[1] 150 | labels.add(line[2]) 151 | print(len(labels)) 152 | 153 | -------------------------------------------------------------------------------- /others/tmp2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | split = 'val' 5 | 6 | scanqa_anno = json.load(open(f"annotations/scanqa_{split}_stage2_objxx.json")) 7 | scanrefer_anno = json.load(open(f"annotations/scanrefer_{split}_stage2_objxx.json")) 8 | 9 | scanqa_scan_list = np.unique([x['scene_id'] for x in scanqa_anno]) 10 | scanrefer_scan_list = np.unique([x['scene_id'] for x in scanrefer_anno]) 11 | 12 | print(len(set(scanrefer_scan_list) - set(scanqa_scan_list))) 13 | print(len(scanqa_scan_list)) 14 | print(len(scanrefer_scan_list)) 15 | -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | ## Skip the data preparation 2 | 3 | - Chat-Scene has provided all the prepared data in [Google Drive](https://drive.google.com/drive/folders/1iwVFUkvveehvwGcAnJK3EwLxBt5ggR2c?usp=sharing). Simply download the files and place them in the annotations/ directory. You’ll then be ready to run and test the code. 4 | 5 | - We've provided preprocessed VL-SAT features for semantic relations between objects as well as additional text annotations in [Yandex Disk](https://disk.yandex.ru/d/LpPJgHg8Qg6BpA) 6 | 7 | - We've provided VL-SAT features for fully-connected graphs with semantic relations between objects in [Yandex Disk](https://disk.yandex.ru/d/LpPJgHg8Qg6BpA) (output_vlsat.zip) 8 | 9 | ## Prepare data 10 | 11 | - Download the ScanNet dataset by following the [ScanNet instructions](https://github.com/ScanNet/ScanNet). 12 | 13 | - Extract object masks using a pretrained 3D detector: 14 | - Use [Mask3D](https://github.com/JonasSchult/Mask3D) for instance segmentation. We used the [checkpoint](https://omnomnom.vision.rwth-aachen.de/data/mask3d/checkpoints/scannet200/scannet200_val.ckpt) pretrained on ScanNet200. 15 | - The complete predicted results (especially the masks) for the train/validation sets are too large to share (~40GB). We’ve shared the post-processed [results](https://drive.google.com/file/d/1jwQYJvkWwRmawZvNOSy6U0lnqnEiasNX/view?usp=sharing): 16 | - Unzip the `mask3d_inst_seg.tar.gz` file. 17 | - Each file under `mask3d_inst_seg` contains the predicted results for a single scene, including a list of segmented instances with their labels and segmented indices. 18 | 19 | - Process object masks and prepare annotations: 20 | - If you use Mask3D for instance segmentation, set the `segment_result_dir` in [run_prepare.sh](run_prepare.sh) to the output directory of Mask3D. 21 | - If you use the downloaded `mask3d_inst_seg` directly, set `segment_result_dir` to None and set `inst_seg_dir` to the path of `mask3d_inst_seg`. 22 | - Run: `bash preprocess/run_prepare.sh` 23 | 24 | - Extract 3D features using a pretrained 3D encoder: 25 | - Follow [Uni3D](https://github.com/baaivision/Uni3D?tab=readme-ov-file) to extract 3D features for each instance. We used the pretrained model [uni3d-g](https://huggingface.co/BAAI/Uni3D/blob/main/modelzoo/uni3d-g/model.pt). 26 | - We've also provided modified code for feature extraction in this forked [repository](https://github.com/ZzZZCHS/Uni3D). Set the `data_dir` [here](https://github.com/ZzZZCHS/Uni3D/blob/main/main.py#L620) to the path to `${processed_data_dir}/pcd_all` (`processed_data_dir` is an intermediate directory set in `run_prepare.sh`). After preparing the environment, run `bash scripts/inference.sh`. 27 | 28 | - Extract 2D features using a pretrained 2D encoder: 29 | 30 | - We followed [OpenScene](https://github.com/pengsongyou/openscene)'s code to calculate the mapping between 3D points and 2D image pixels. This allows each object to be projected onto multi-view images. Based on the projected masks on the images, we extract and merge DINOv2 features from multi-view images for each object. 31 | 32 | - [TODO] Detailed implementation will be released. 33 | 34 | - Obtain connections based on the N nearest neighbors for each object, filter the fully connected graphs with VLSAT features for Mask3D segmentation. To achieve this, run the ```prepare_filtered_mask3d_gnn_data.py``` script after updating the paths to the directories containing the fully connected graphs for each scene, the object attributes, and the ScanNet splits. The number of nearest neighbors can be adjusted by modifying the ```KNN``` parameter at the beginning of the ```prepare_filtered_mask3d_gnn_data.py``` script. 35 | 36 | - Obtain connections based on the N nearest neighbors for each object, filter the fully connected graphs with VLSAT features for GT segmentation. To achieve this, run the ```prepare_gnn_data.py``` script after updating the paths to the directories containing the fully connected graphs for each scene, the object attributes, and the ScanNet splits. The number of nearest neighbors can be adjusted by modifying the ```KNN``` parameter at the beginning of the ```prepare_gnn_data.py``` script. -------------------------------------------------------------------------------- /preprocess/prepare_mask3d_img_feat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--segmentor', required=True, type=str) 14 | parser.add_argument('--version', type=str, default='') 15 | args = parser.parse_args() 16 | 17 | 18 | segmentor = args.segmentor 19 | version = args.version 20 | # annos = json.load(open(f"annotations/scanrefer_{split}_stage2_grounding_OBJ.json", "r")) 21 | feats = torch.load(f'annotations/scannet_img_dinov2_features.pt') 22 | new_feats = {} 23 | item2iou = {} 24 | iou_thres = 0.5 25 | 26 | for split in ['train', 'val']: 27 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 28 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 29 | instance_attrs = torch.load(instance_attribute_file) 30 | scannet_attrs = torch.load(scannet_attribute_file) 31 | for k, v in tqdm(feats.items()): 32 | scene_id = '_'.join(k.split('_')[:2]) 33 | if scene_id not in instance_attrs: 34 | continue 35 | obj_id = int(k.split('_')[-1]) 36 | instance_locs = instance_attrs[scene_id]["locs"] 37 | scannet_locs = scannet_attrs[scene_id]["locs"] 38 | instance_num = instance_locs.shape[0] 39 | max_iou, max_id = -1, -1 40 | for pred_id in range(instance_num): 41 | pred_locs = instance_locs[pred_id].tolist() 42 | try: 43 | gt_locs = scannet_locs[obj_id].tolist() 44 | except: 45 | gt_locs = scannet_locs[obj_id-31].tolist() 46 | # print(k) 47 | # breakpoint() 48 | # break 49 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 50 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 51 | iou = box3d_iou(pred_corners, gt_corners) 52 | if iou > max_iou: 53 | max_iou = iou 54 | max_id = pred_id 55 | item_id = f"{scene_id}_{max_id:02}" 56 | if max_iou > iou_thres and (item_id not in new_feats or item2iou[item_id] < max_iou): 57 | new_feats[item_id] = v 58 | item2iou[item_id] = max_iou 59 | 60 | torch.save(new_feats, f'annotations/scannet_img_mask3d_dinov2_features{version}.pt') -------------------------------------------------------------------------------- /preprocess/prepare_multi3dref_annos.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import sys 5 | sys.path.append('.') 6 | from tqdm import tqdm 7 | import argparse 8 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 9 | from prompts.prompts import multi3dref_prompt, ID_format 10 | import random 11 | from collections import defaultdict 12 | import string 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--segmentor', required=True, type=str) 16 | parser.add_argument('--version', type=str, default='') 17 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 18 | args = parser.parse_args() 19 | 20 | segmentor = args.segmentor 21 | version = args.version 22 | 23 | for split in ['train', 'val']: 24 | annos = json.load(open(f"annotations/multi3drefer/multi3drefer_{split}.json")) 25 | new_annos = [] 26 | 27 | count_all = defaultdict(int) 28 | count_used = defaultdict(int) 29 | if segmentor == 'deva': 30 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu') 31 | else: 32 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 33 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 34 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 35 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 36 | 37 | for i, anno in tqdm(enumerate(annos)): 38 | scene_id = anno['scene_id'] 39 | count_all[anno['eval_type']] += 1 40 | if segmentor == 'deva': 41 | if scene_id not in seg_gt_ious: 42 | continue 43 | seg_gt_iou = seg_gt_ious[scene_id] 44 | else: 45 | if scene_id not in instance_attrs: 46 | continue 47 | instance_locs = instance_attrs[scene_id]["locs"] 48 | scannet_locs = scannet_attrs[scene_id]["locs"] 49 | instance_num = instance_locs.shape[0] 50 | gt_ids = anno['object_ids'] 51 | caption = anno['description'] 52 | if caption[-1] in string.punctuation: 53 | caption = caption[:-1] 54 | prompt = random.choice(multi3dref_prompt).replace("", caption) 55 | pred_ids = [] 56 | flag = 1 57 | for gt_id in gt_ids: 58 | if segmentor == 'deva': 59 | if gt_id >= seg_gt_iou.shape[1]: 60 | flag = 0 61 | break 62 | max_iou, max_id = seg_gt_iou[:, gt_id].max(0) 63 | max_iou = float(max_iou) 64 | max_id = int(max_id) 65 | else: 66 | max_iou, max_id = -1, -1 67 | for pred_id in range(instance_num): 68 | pred_locs = instance_locs[pred_id].tolist() 69 | gt_locs = scannet_locs[gt_id].tolist() 70 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 71 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 72 | iou = box3d_iou(pred_corners, gt_corners) 73 | if iou > max_iou: 74 | max_iou = iou 75 | max_id = pred_id 76 | if split == 'train' and (max_iou < args.train_iou_thres or max_id in pred_ids): 77 | flag = 0 78 | break 79 | pred_ids.append(max_id) 80 | if flag == 0: 81 | continue 82 | count_used[anno['eval_type']] += 1 83 | pred_ids = sorted(pred_ids) 84 | pred_id_strs = [ID_format.format(pred_id) for pred_id in pred_ids] 85 | if len(pred_ids) == 0: 86 | answer = "No." 87 | elif len(pred_ids) == 1: 88 | answer = f"Yes. {pred_id_strs[0]}." 89 | elif len(pred_ids) == 2: 90 | answer = f"Yes. {pred_id_strs[0]} and {pred_id_strs[1]}." 91 | else: 92 | answer = f"Yes. {', '.join(pred_id_strs[:-1])}, and {pred_id_strs[-1]}." 93 | if split == 'train': 94 | new_annos.append({ 95 | 'scene_id': scene_id, 96 | 'obj_id': 0, 97 | 'prompt': prompt, 98 | 'caption': answer, 99 | 'eval_type': anno['eval_type'] 100 | }) 101 | else: 102 | new_annos.append({ 103 | 'scene_id': scene_id, 104 | 'obj_id': 0, 105 | 'prompt': prompt, 106 | 'ref_captions': gt_ids, 107 | 'eval_type': anno['eval_type'] 108 | }) 109 | 110 | print(f"Split: {split}") 111 | print(f"Count all: {len(annos)}", count_all) 112 | print(f"Count used: {len(new_annos)}", count_used) 113 | 114 | with open(f"annotations/multi3dref_{segmentor}_{split}{version}.json", "w") as f: 115 | json.dump(new_annos, f, indent=4) 116 | -------------------------------------------------------------------------------- /preprocess/prepare_multi3dref_location_annos.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import sys 5 | sys.path.append('.') 6 | from tqdm import tqdm 7 | import argparse 8 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 9 | from prompts.prompts import multi3dref_location_prompt, ID_format 10 | import random 11 | from collections import defaultdict 12 | import string 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--segmentor', required=True, type=str) 16 | parser.add_argument('--version', type=str, default='') 17 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 18 | args = parser.parse_args() 19 | 20 | segmentor = args.segmentor 21 | version = args.version 22 | 23 | def num_to_location_token(ori_num): 24 | ori_num = int(ori_num * 100) + 500 25 | if ori_num < 0: 26 | ori_num = 0 27 | if ori_num > 999: 28 | ori_num = 999 29 | return f"" 30 | 31 | for split in ['train', 'val']: 32 | annos = json.load(open(f"annotations/multi3drefer/multi3drefer_{split}.json")) 33 | new_annos = [] 34 | 35 | count_all = defaultdict(int) 36 | count_used = defaultdict(int) 37 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 38 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 39 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 40 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 41 | 42 | for i, anno in tqdm(enumerate(annos)): 43 | scene_id = anno['scene_id'] 44 | count_all[anno['eval_type']] += 1 45 | if scene_id not in instance_attrs: 46 | continue 47 | instance_locs = instance_attrs[scene_id]["locs"] 48 | scannet_locs = scannet_attrs[scene_id]["locs"] 49 | instance_num = instance_locs.shape[0] 50 | gt_ids = anno['object_ids'] 51 | caption = anno['description'] 52 | if caption[-1] in string.punctuation: 53 | caption = caption[:-1] 54 | prompt = random.choice(multi3dref_location_prompt).replace("", caption) 55 | locs_caption_list = [] 56 | flag = 1 57 | for gt_id in gt_ids: 58 | max_iou, max_id = -1, -1 59 | # for pred_id in range(instance_num): 60 | # pred_locs = instance_locs[pred_id].tolist() 61 | # gt_locs = scannet_locs[gt_id].tolist() 62 | # pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 63 | # gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 64 | # iou = box3d_iou(pred_corners, gt_corners) 65 | # if iou > max_iou: 66 | # max_iou = iou 67 | # max_id = pred_id 68 | gt_locs = scannet_locs[gt_id].tolist() 69 | gt_loc_tokens = [num_to_location_token(x) for x in gt_locs] 70 | tmp_loc_caption = " " + " ".join(gt_loc_tokens) + " " 71 | locs_caption_list.append(tmp_loc_caption) 72 | if flag == 0: 73 | continue 74 | count_used[anno['eval_type']] += 1 75 | locs_caption_list = sorted(locs_caption_list) 76 | # pred_ids = sorted(pred_ids) 77 | # pred_id_strs = [ID_format.format(pred_id) for pred_id in pred_ids] 78 | if len(locs_caption_list) == 0: 79 | answer = "No." 80 | elif len(locs_caption_list) == 1: 81 | answer = f"Yes. {locs_caption_list[0]}." 82 | elif len(locs_caption_list) == 2: 83 | answer = f"Yes. {locs_caption_list[0]} and {locs_caption_list[1]}." 84 | else: 85 | answer = f"Yes. {', '.join(locs_caption_list[:-1])}, and {locs_caption_list[-1]}." 86 | if split == 'train': 87 | new_annos.append({ 88 | 'scene_id': scene_id, 89 | 'obj_id': 0, 90 | 'prompt': prompt, 91 | 'caption': answer, 92 | 'eval_type': anno['eval_type'] 93 | }) 94 | else: 95 | new_annos.append({ 96 | 'scene_id': scene_id, 97 | 'obj_id': 0, 98 | 'prompt': prompt, 99 | 'ref_captions': gt_ids, 100 | 'eval_type': anno['eval_type'] 101 | }) 102 | 103 | print(f"Split: {split}") 104 | print(f"Count all: {len(annos)}", count_all) 105 | print(f"Count used: {len(new_annos)}", count_used) 106 | 107 | with open(f"annotations/multi3dref_{segmentor}_{split}_location{version}.json", "w") as f: 108 | json.dump(new_annos, f, indent=4) 109 | -------------------------------------------------------------------------------- /preprocess/prepare_nr3dcaption_annos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import argparse 10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 11 | from prompts.prompts import nr3d_caption_prompt 12 | import csv 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--segmentor', required=True, type=str) 18 | parser.add_argument('--version', type=str, default='') 19 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 20 | args = parser.parse_args() 21 | 22 | segmentor = args.segmentor 23 | version = args.version 24 | 25 | train_scenes = [x.strip() for x in open('annotations/scannet/scannetv2_train.txt').readlines()] 26 | val_scenes = [x.strip() for x in open('annotations/scannet/scannetv2_val.txt').readlines()] 27 | scene_lists = { 28 | 'train': train_scenes, 29 | 'val': val_scenes 30 | } 31 | 32 | raw_annos = [] 33 | with open('annotations/referit3d/nr3d.csv') as csvfile: 34 | reader = csv.DictReader(csvfile) 35 | for row in reader: 36 | raw_annos.append({ 37 | 'scene_id': row['scan_id'], 38 | 'obj_id': int(row['target_id']), 39 | 'caption': row['utterance'] 40 | }) 41 | 42 | for split in ["train"]: 43 | annos = [anno for anno in raw_annos if anno['scene_id'] in scene_lists[split]] 44 | new_annos = [] 45 | 46 | if segmentor == 'deva': 47 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu') 48 | else: 49 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 50 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 51 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 52 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 53 | 54 | for anno in tqdm(annos): 55 | scene_id = anno['scene_id'] 56 | obj_id = anno['obj_id'] 57 | if segmentor == 'deva': 58 | if scene_id not in seg_gt_ious: 59 | continue 60 | seg_gt_iou = seg_gt_ious[scene_id] 61 | if obj_id >= seg_gt_iou.shape[1]: 62 | continue 63 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0) 64 | max_iou = float(max_iou) 65 | max_id = int(max_id) 66 | else: 67 | if scene_id not in instance_attrs: 68 | continue 69 | instance_locs = instance_attrs[scene_id]['locs'] 70 | scannet_locs = scannet_attrs[scene_id]['locs'] 71 | max_iou, max_id = -1, -1 72 | for pred_id in range(instance_locs.shape[0]): 73 | pred_locs = instance_locs[pred_id].tolist() 74 | gt_locs = scannet_locs[obj_id].tolist() 75 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 76 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 77 | iou = box3d_iou(pred_corners, gt_corners) 78 | if iou > max_iou: 79 | max_iou = iou 80 | max_id = pred_id 81 | prompt = random.choice(nr3d_caption_prompt).replace('', f"") 82 | if split == 'train': 83 | if max_iou >= args.train_iou_thres: 84 | new_annos.append({ 85 | 'scene_id': scene_id, 86 | 'obj_id': obj_id, 87 | 'prompt': prompt, 88 | 'caption': anno['caption'] 89 | }) 90 | else: 91 | new_annos.append({ 92 | 'scene_id': scene_id, 93 | 'obj_id': obj_id, 94 | 'prompt': prompt, 95 | 'ref_captions': [anno['caption']] 96 | }) 97 | print(len(new_annos)) 98 | 99 | with open(f"annotations/nr3d_caption_{segmentor}_{split}{version}.json", 'w') as f: 100 | json.dump(new_annos, f, indent=4) 101 | -------------------------------------------------------------------------------- /preprocess/prepare_objalign_annos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import argparse 10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--segmentor', required=True, type=str) 16 | parser.add_argument('--version', type=str, default='') 17 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 18 | args = parser.parse_args() 19 | 20 | unwanted_words = ["wall", "ceiling", "floor", "object", "item"] 21 | 22 | segmentor = args.segmentor 23 | version = args.version 24 | 25 | for split in ["train", "val"]: 26 | new_annos = [] 27 | 28 | if segmentor == 'deva': 29 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu') 30 | else: 31 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 32 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 33 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 34 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 35 | 36 | for scene_id in tqdm(scannet_attrs.keys()): 37 | if segmentor == 'deva': 38 | if scene_id not in seg_gt_ious: 39 | continue 40 | seg_gt_iou = seg_gt_ious[scene_id] 41 | segmented_num = seg_gt_iou.shape[0] 42 | gt_num = seg_gt_iou.shape[1] 43 | else: 44 | if scene_id not in instance_attrs: 45 | continue 46 | instance_locs = instance_attrs[scene_id]['locs'] 47 | scannet_locs = scannet_attrs[scene_id]['locs'] 48 | segmented_num = len(instance_locs) 49 | gt_num = len(scannet_locs) 50 | scannet_class_labels = scannet_attrs[scene_id]['objects'] 51 | for obj_id in range(gt_num): 52 | class_label = scannet_class_labels[obj_id] 53 | if any(x in class_label for x in unwanted_words): 54 | continue 55 | if segmentor == 'deva': 56 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0) 57 | max_iou = float(max_iou) 58 | max_id = int(max_id) 59 | else: 60 | max_iou, max_id = -1, -1 61 | for pred_id in range(instance_locs.shape[0]): 62 | pred_locs = instance_locs[pred_id].tolist() 63 | gt_locs = scannet_locs[obj_id].tolist() 64 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 65 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 66 | iou = box3d_iou(pred_corners, gt_corners) 67 | if iou > max_iou: 68 | max_iou = iou 69 | max_id = pred_id 70 | prompt = f"What is the ?" 71 | caption = f" is a {class_label}." 72 | if split == 'train': 73 | if max_iou >= args.train_iou_thres: 74 | new_annos.append({ 75 | 'scene_id': scene_id, 76 | 'obj_id': obj_id, 77 | 'prompt': prompt, 78 | 'caption': caption 79 | }) 80 | else: 81 | new_annos.append({ 82 | 'scene_id': scene_id, 83 | 'obj_id': obj_id, 84 | 'prompt': prompt, 85 | 'ref_captions': [caption] 86 | }) 87 | 88 | print(len(new_annos)) 89 | with open(f"annotations/obj_align_{segmentor}_{split}{version}.json", 'w') as f: 90 | json.dump(new_annos, f, indent=4) -------------------------------------------------------------------------------- /preprocess/prepare_scan2cap_location_annos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import argparse 10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 11 | from prompts.prompts import scan2cap_location_prompt 12 | import nltk 13 | 14 | 15 | def capitalize_sentences(text): 16 | sentences = nltk.sent_tokenize(text) 17 | capitalized_sentences = [sentence.capitalize() for sentence in sentences] 18 | result = ' '.join(capitalized_sentences) 19 | return result 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--segmentor', required=True, type=str) 25 | parser.add_argument('--version', type=str, default='') 26 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 27 | parser.add_argument('--max_obj_num', type=int, default=150) 28 | args = parser.parse_args() 29 | 30 | 31 | def num_to_location_token(ori_num): 32 | ori_num = int(ori_num * 100) + 500 33 | if ori_num < 0: 34 | ori_num = 0 35 | if ori_num > 999: 36 | ori_num = 999 37 | return f"" 38 | 39 | 40 | for split in ["train", "val"]: 41 | segmentor = args.segmentor 42 | version = args.version 43 | annos = json.load(open(f"annotations/scanrefer/ScanRefer_filtered_{split}.json", "r")) 44 | new_annos = [] 45 | 46 | print(len(annos)) 47 | scene_ids = set() 48 | corpus = defaultdict(list) 49 | for anno in annos: 50 | gt_key = f"{anno['scene_id']}|{anno['object_id']}" 51 | description = capitalize_sentences(anno['description']) 52 | corpus[gt_key].append(description) 53 | scene_ids.add(anno['scene_id']) 54 | scene_ids = list(scene_ids) 55 | 56 | count = [0] * args.max_obj_num 57 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 58 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 59 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 60 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 61 | 62 | 63 | covered25_num, covered50_num = 0, 0 64 | count_all = 0 65 | for scene_id in tqdm(scene_ids): 66 | if scene_id not in instance_attrs: 67 | continue 68 | instance_locs = instance_attrs[scene_id]["locs"] 69 | scannet_locs = scannet_attrs[scene_id]["locs"] 70 | segmented_num = len(instance_locs) 71 | gt_num = len(scannet_locs) 72 | gt_match_id = [-1] * gt_num 73 | gt_match_iou = [-1] * gt_num 74 | for pred_id in range(segmented_num): 75 | pred_locs = instance_locs[pred_id].tolist() 76 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 77 | max_id = max_iou = -1 78 | for gt_id in range(len(scannet_locs)): 79 | if f"{scene_id}|{gt_id}" not in corpus: 80 | continue 81 | gt_locs = scannet_locs[gt_id].tolist() 82 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 83 | iou = box3d_iou(pred_corners, gt_corners) 84 | if iou > max_iou: 85 | max_iou = iou 86 | max_id = gt_id 87 | if f"{scene_id}|{max_id}" not in corpus: 88 | continue 89 | if max_iou > gt_match_iou[max_id]: 90 | gt_match_iou[max_id] = max_iou 91 | gt_match_id[max_id] = pred_id 92 | for gt_id, pred_id in enumerate(gt_match_id): 93 | if f"{scene_id}|{gt_id}" in corpus: 94 | count_all += len(corpus[f"{scene_id}|{gt_id}"]) 95 | if pred_id == -1: 96 | continue 97 | if split == 'train' and gt_match_iou[gt_id] < args.train_iou_thres: 98 | continue 99 | if gt_match_iou[gt_id] >= 0.25: 100 | covered25_num += len(corpus[f"{scene_id}|{gt_id}"]) 101 | if gt_match_iou[gt_id] >= 0.5: 102 | covered50_num += len(corpus[f"{scene_id}|{gt_id}"]) 103 | count[pred_id] += 1 104 | pred_locs = instance_locs[pred_id].tolist() 105 | loc_tokens = [num_to_location_token(x) for x in pred_locs] 106 | loc_caption = " " + " ".join(loc_tokens) + " " 107 | if split == 'train': 108 | for caption in corpus[f"{scene_id}|{gt_id}"]: 109 | new_annos.append({ 110 | 'scene_id': scene_id, 111 | 'obj_id': gt_id, 112 | 'pred_id': pred_id, 113 | 'prompt': random.choice(scan2cap_location_prompt).replace(f"", loc_caption), 114 | "caption": caption, 115 | "iou": gt_match_iou[gt_id] 116 | }) 117 | else: 118 | new_annos.append({ 119 | 'scene_id': scene_id, 120 | 'obj_id': gt_id, 121 | 'pred_id': pred_id, 122 | 'prompt': random.choice(scan2cap_location_prompt).replace(f"", loc_caption), 123 | "ref_captions": corpus[f"{scene_id}|{gt_id}"], 124 | "iou": gt_match_iou[gt_id] 125 | }) 126 | 127 | print(len(new_annos)) 128 | print(covered25_num, covered50_num) 129 | print(count_all) 130 | # print(count) 131 | 132 | with open(f"annotations/scan2cap_{segmentor}_{split}_location{version}.json", "w") as f: 133 | json.dump(new_annos, f, indent=4) -------------------------------------------------------------------------------- /preprocess/prepare_scannet_attributes.py: -------------------------------------------------------------------------------- 1 | from plyfile import PlyData 2 | import numpy as np 3 | import os 4 | import json 5 | import torch 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--scannet_dir', required=True, type=str, 13 | help='the path of the directory to scannet scans') 14 | args = parser.parse_args() 15 | 16 | raw_data_dir = os.path.join(args.scannet_dir, 'scans') 17 | 18 | 19 | for split in ["train", "val"]: 20 | split_file = f"annotations/scannet/scannetv2_{split}.txt" 21 | scan_names = [line.rstrip() for line in open(split_file)] 22 | print(f'{split} split scans: {len(scan_names)}') 23 | outputs = {} 24 | for scan_id in tqdm(scan_names): 25 | aggregation_path = os.path.join(raw_data_dir, scan_id, scan_id + '.aggregation.json') 26 | segs_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.0.010000.segs.json') 27 | scan_ply_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.labels.ply') 28 | 29 | data = PlyData.read(scan_ply_path) 30 | x = np.asarray(data.elements[0].data['x']).astype(np.float32) 31 | y = np.asarray(data.elements[0].data['y']).astype(np.float32) 32 | z = np.asarray(data.elements[0].data['z']).astype(np.float32) 33 | pc = np.stack([x, y, z], axis=1) 34 | 35 | align_matrix = np.eye(4) 36 | with open(os.path.join(raw_data_dir, scan_id, '%s.txt'%(scan_id)), 'r') as f: 37 | for line in f: 38 | if line.startswith('axisAlignment'): 39 | align_matrix = np.array([float(x) for x in line.strip().split()[-16:]]).astype(np.float32).reshape(4, 4) 40 | break 41 | 42 | pts = np.ones((pc.shape[0], 4), dtype=pc.dtype) 43 | pts[:, 0:3] = pc 44 | pc = np.dot(pts, align_matrix.transpose())[:, :3] 45 | 46 | scan_aggregation = json.load(open(aggregation_path)) 47 | segments_info = json.load(open(segs_path)) 48 | segment_indices = segments_info["segIndices"] 49 | segment_indices_dict = defaultdict(list) 50 | for i, s in enumerate(segment_indices): 51 | segment_indices_dict[s].append(i) 52 | 53 | pc_segment_label = [''] * pc.shape[0] 54 | 55 | instance_labels = [] 56 | inst_locs = [] 57 | for idx, object_info in enumerate(scan_aggregation['segGroups']): 58 | object_instance_label = object_info['label'] 59 | object_id = object_info['objectId'] 60 | segments = object_info["segments"] 61 | pc_ids = [] 62 | for s in segments: 63 | pc_ids.extend(segment_indices_dict[s]) 64 | object_pc = pc[pc_ids] 65 | object_center = (np.max(object_pc, axis=0) + np.min(object_pc, axis=0)) / 2.0 66 | object_size = np.max(object_pc, axis=0) - np.min(object_pc, axis=0) 67 | object_bbox = torch.from_numpy(np.concatenate([object_center, object_size], axis=0)) 68 | inst_locs.append(object_bbox) 69 | instance_labels.append(object_instance_label) 70 | inst_locs = torch.stack(inst_locs, dim=0) 71 | outputs[scan_id] = { 72 | 'objects': instance_labels, 73 | 'locs': inst_locs 74 | } 75 | 76 | torch.save(outputs, f"annotations/scannet_{split}_attributes.pt") 77 | 78 | -------------------------------------------------------------------------------- /preprocess/prepare_scannet_attributes_clasp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import glob 5 | import numpy as np 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--scan_dir', required=True, type=str, 12 | help='the path of the directory to be saved preprocessed scans') 13 | parser.add_argument('--segmentor', required=True, type=str) 14 | parser.add_argument('--max_inst_num', required=True, type=int) 15 | parser.add_argument('--version', type=str, default='') 16 | args = parser.parse_args() 17 | 18 | 19 | save_feats = {} 20 | for split in ["val"]: 21 | scan_dir = args.scan_dir 22 | output_dir = "annotations" 23 | split_path = f"annotations/scannet/scannetv2_{split}.txt" 24 | 25 | scan_ids = [line.strip() for line in open(split_path).readlines()] 26 | 27 | scan_ids = sorted(scan_ids) 28 | # print(scan_ids) 29 | 30 | scans = {} 31 | for scan_id in scan_ids: 32 | pcd_path = os.path.join(scan_dir, f"{scan_id}.pth") 33 | if not os.path.exists(pcd_path): 34 | # print('skip', scan_id) 35 | continue 36 | pred_results = torch.load(pcd_path, map_location='cpu') 37 | inst_locs = [] 38 | pred_boxes = pred_results['pred_boxes'] 39 | num_insts = pred_boxes.shape[0] 40 | for i in range(min(num_insts, args.max_inst_num)): 41 | center = pred_boxes[i].mean(dim=0) 42 | size = pred_boxes[i][1] - pred_boxes[i][0] 43 | inst_locs.append(torch.cat([center, size], 0)) 44 | save_feats[f"{scan_id}_{i:02}"] = pred_results['queries'][0][i] 45 | inst_locs = torch.stack(inst_locs, dim=0).to(torch.float32) 46 | scans[scan_id] = { 47 | # 'objects': instance_class_labels, # (n_obj, ) 48 | 'locs': inst_locs, # (n_obj, 6) center xyz, whl 49 | } 50 | print(f"{split}: {len(scans)}") 51 | 52 | torch.save(scans, os.path.join(output_dir, f"scannet_{args.segmentor}_{split}_attributes{args.version}.pt")) 53 | 54 | # torch.save(save_feats, os.path.join(output_dir, f"scannet_{args.segmentor}_{args.segmentor}_feats.pt")) -------------------------------------------------------------------------------- /preprocess/prepare_scannet_caption_annos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import argparse 10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 11 | from prompts.prompts import obj_caption_wid_prompt 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--segmentor', required=True, type=str) 16 | parser.add_argument('--version', type=str, default='') 17 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 18 | args = parser.parse_args() 19 | 20 | segmentor = args.segmentor 21 | version = args.version 22 | 23 | for split in ['train']: 24 | annos = json.load(open(f"annotations/scannet_{split}_caption.json")) 25 | new_annos = [] 26 | 27 | if segmentor == 'deva': 28 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu') 29 | else: 30 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 31 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 32 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 33 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 34 | 35 | for i, anno in tqdm(enumerate(annos)): 36 | scene_id = anno['scene_id'] 37 | obj_id = anno['obj_id'] 38 | if segmentor == 'deva': 39 | if scene_id not in seg_gt_ious: 40 | continue 41 | seg_gt_iou = seg_gt_ious[scene_id] 42 | if obj_id >= seg_gt_iou.shape[1]: 43 | continue 44 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0) 45 | max_iou = float(max_iou) 46 | max_id = int(max_id) 47 | else: 48 | if scene_id not in instance_attrs: 49 | continue 50 | instance_locs = instance_attrs[scene_id]["locs"] 51 | scannet_locs = scannet_attrs[scene_id]["locs"] 52 | instance_num = instance_locs.shape[0] 53 | max_iou, max_id = -1, -1 54 | for pred_id in range(instance_num): 55 | pred_locs = instance_locs[pred_id].tolist() 56 | gt_locs = scannet_locs[obj_id].tolist() 57 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 58 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 59 | iou = box3d_iou(pred_corners, gt_corners) 60 | if iou > max_iou: 61 | max_iou = iou 62 | max_id = pred_id 63 | if split == 'train': 64 | if max_iou > args.train_iou_thres: 65 | new_annos.append({ 66 | 'scene_id': scene_id, 67 | 'obj_id': obj_id, 68 | 'prompt': random.choice(obj_caption_wid_prompt).replace('', f""), 69 | 'caption': anno['caption'] 70 | }) 71 | else: 72 | new_annos.append({ 73 | 'scene_id': scene_id, 74 | 'obj_id': obj_id, 75 | 'prompt': random.choice(obj_caption_wid_prompt).replace('', f""), 76 | 'ref_captions': anno['ref_captions'] 77 | }) 78 | 79 | print(f"Split: {split}") 80 | print(f"{len(annos)} -> {len(new_annos)}") 81 | 82 | with open(f"annotations/scannet_caption_{segmentor}_{split}{version}.json", "w") as f: 83 | json.dump(new_annos, f, indent=4) 84 | -------------------------------------------------------------------------------- /preprocess/prepare_scannet_mask3d_attributes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import glob 5 | import numpy as np 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--scan_dir', required=True, type=str, 12 | help='the path of the directory to be saved preprocessed scans') 13 | parser.add_argument('--segmentor', required=True, type=str) 14 | parser.add_argument('--max_inst_num', required=True, type=int) 15 | parser.add_argument('--version', type=str, default='') 16 | args = parser.parse_args() 17 | 18 | 19 | 20 | for split in ["train", "val"]: 21 | scan_dir = os.path.join(args.scan_dir, 'pcd_all') 22 | output_dir = "annotations" 23 | split_path = f"annotations/scannet/scannetv2_{split}.txt" 24 | 25 | scan_ids = [line.strip() for line in open(split_path).readlines()] 26 | 27 | scan_ids = sorted(scan_ids) 28 | # print(scan_ids) 29 | 30 | scans = {} 31 | for scan_id in tqdm(scan_ids): 32 | pcd_path = os.path.join(scan_dir, f"{scan_id}.pth") 33 | if not os.path.exists(pcd_path): 34 | print('skip', scan_id) 35 | continue 36 | points, colors, instance_class_labels, instance_segids = torch.load(pcd_path) 37 | inst_locs = [] 38 | num_insts = len(instance_class_labels) 39 | for i in range(min(num_insts, args.max_inst_num)): 40 | inst_mask = instance_segids[i] 41 | pc = points[inst_mask] 42 | if len(pc) == 0: 43 | print(scan_id, i, 'empty bbox') 44 | inst_locs.append(np.zeros(6, ).astype(np.float32)) 45 | continue 46 | size = pc.max(0) - pc.min(0) 47 | center = (pc.max(0) + pc.min(0)) / 2 48 | inst_locs.append(np.concatenate([center, size], 0)) 49 | inst_locs = torch.tensor(np.stack(inst_locs, 0), dtype=torch.float32) 50 | scans[scan_id] = { 51 | 'objects': instance_class_labels, # (n_obj, ) 52 | 'locs': inst_locs, # (n_obj, 6) center xyz, whl 53 | } 54 | 55 | torch.save(scans, os.path.join(output_dir, f"scannet_{args.segmentor}_{split}_attributes{args.version}.pt")) -------------------------------------------------------------------------------- /preprocess/prepare_scanqa_annos.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | for split in ['train', 'val']: 4 | if split == 'test': 5 | with open(f"annotations/scanqa/ScanQA_v1.0_test_w_obj.json", "r") as f: 6 | annos = json.load(f) 7 | with open(f"annotations/scanqa/ScanQA_v1.0_test_wo_obj.json", "r") as f: 8 | annos.extend(json.load(f)) 9 | else: 10 | with open(f"annotations/scanqa/ScanQA_v1.0_{split}.json", "r") as f: 11 | annos = json.load(f) 12 | print(len(annos)) 13 | new_annos = [] 14 | for anno in annos: 15 | scene_id = anno["scene_id"] 16 | obj_ids = anno["object_ids"] if "object_ids" in anno else [0] 17 | question = anno["question"] 18 | 19 | prompt = question + " Answer the question using a single word or phrase." 20 | 21 | answers = anno["answers"] if "answers" in anno else [] 22 | if split == "train": 23 | for i in range(len(answers)): 24 | if i > 0 and answers[i] == answers[i-1]: 25 | continue 26 | answer = answers[i] 27 | answer = answer.capitalize() 28 | if answer[-1] != ".": 29 | answer += "." 30 | new_annos.append({ 31 | "scene_id": scene_id, 32 | "obj_id": obj_ids[0], 33 | "prompt": prompt, 34 | "caption": answer, 35 | }) 36 | else: 37 | new_annos.append({ 38 | "scene_id": scene_id, 39 | "obj_id": obj_ids[0], 40 | "prompt": prompt, 41 | "ref_captions": answers 42 | }) 43 | print(len(new_annos)) 44 | 45 | with open(f"annotations/scanqa_{split}.json", "w") as f: 46 | json.dump(new_annos, f, indent=4) 47 | -------------------------------------------------------------------------------- /preprocess/prepare_scanrefer_annos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import argparse 10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 11 | from prompts.prompts import grounding_prompt 12 | import string 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--segmentor', required=True, type=str) 18 | parser.add_argument('--version', type=str, default='') 19 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 20 | parser.add_argument('--max_obj_num', type=int, default=150) 21 | args = parser.parse_args() 22 | 23 | segmentor = args.segmentor 24 | version = args.version 25 | 26 | for split in ["train", "val"]: 27 | count = [0] * args.max_obj_num 28 | annos = json.load(open(f"annotations/scanrefer/ScanRefer_filtered_{split}.json", "r")) 29 | annos = sorted(annos, key=lambda p: f"{p['scene_id']}_{int(p['object_id']):03}") 30 | new_annos = [] 31 | 32 | if segmentor == 'deva': 33 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu') 34 | else: 35 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 36 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 37 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 38 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 39 | 40 | iou25_count = 0 41 | iou50_count = 0 42 | # maxiou_count = 0 43 | # valid_count = 0 44 | for i, anno in tqdm(enumerate(annos)): 45 | scene_id = anno['scene_id'] 46 | obj_id = int(anno['object_id']) 47 | desc = anno['description'] 48 | if desc[-1] in string.punctuation: 49 | desc = desc[:-1] 50 | prompt = random.choice(grounding_prompt).replace('', desc) 51 | 52 | if segmentor == 'deva': 53 | if scene_id not in seg_gt_ious: 54 | continue 55 | seg_gt_iou = seg_gt_ious[scene_id] 56 | if obj_id >= seg_gt_iou.shape[1]: 57 | continue 58 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0) 59 | max_iou = float(max_iou) 60 | max_id = int(max_id) 61 | else: 62 | if scene_id not in instance_attrs: 63 | continue 64 | instance_locs = instance_attrs[scene_id]["locs"] 65 | scannet_locs = scannet_attrs[scene_id]["locs"] 66 | instance_num = instance_locs.shape[0] 67 | max_iou, max_id = -1, -1 68 | for pred_id in range(instance_num): 69 | pred_locs = instance_locs[pred_id].tolist() 70 | gt_locs = scannet_locs[obj_id].tolist() 71 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:]) 72 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:]) 73 | iou = box3d_iou(pred_corners, gt_corners) 74 | if iou > max_iou: 75 | max_iou = iou 76 | max_id = pred_id 77 | # maxiou_count += max_iou 78 | # valid_count += 1 79 | if max_iou >= 0.25: 80 | iou25_count += 1 81 | if max_iou >= 0.5: 82 | iou50_count += 1 83 | count[max_id] += 1 84 | if split == "train": 85 | if max_iou >= args.train_iou_thres: 86 | new_annos.append({ 87 | "scene_id": scene_id, 88 | "obj_id": max_id, 89 | "caption": f".", 90 | "prompt": prompt 91 | }) 92 | else: 93 | new_annos.append({ 94 | "scene_id": scene_id, 95 | "obj_id": obj_id, 96 | "ref_captions": [f"."], 97 | "prompt": prompt 98 | }) 99 | 100 | print(len(new_annos)) 101 | print(count) 102 | # print(maxiou_count / valid_count) 103 | # print(f"max iou@0.25: {iou25_count / len(new_annos)}") 104 | # print(f"max iou@0.5: {iou50_count / len(new_annos)}") 105 | 106 | with open(f"annotations/scanrefer_{segmentor}_{split}{version}.json", "w") as f: 107 | json.dump(new_annos, f, indent=4) -------------------------------------------------------------------------------- /preprocess/prepare_scanrefer_location_annos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import sys 4 | sys.path.append('.') 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import argparse 10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners 11 | from prompts.prompts import grounding_location_prompt 12 | import string 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--segmentor', required=True, type=str) 18 | parser.add_argument('--version', type=str, default='') 19 | parser.add_argument('--train_iou_thres', type=float, default=0.75) 20 | parser.add_argument('--max_obj_num', type=int, default=150) 21 | args = parser.parse_args() 22 | 23 | segmentor = args.segmentor 24 | version = args.version 25 | 26 | def num_to_location_token(ori_num): 27 | ori_num = int(ori_num * 100) + 500 28 | if ori_num < 0: 29 | ori_num = 0 30 | if ori_num > 999: 31 | ori_num = 999 32 | return f"" 33 | 34 | for split in ["train", "val"]: 35 | count = [0] * args.max_obj_num 36 | annos = json.load(open(f"annotations/scanrefer/ScanRefer_filtered_{split}.json", "r")) 37 | annos = sorted(annos, key=lambda p: f"{p['scene_id']}_{int(p['object_id']):03}") 38 | new_annos = [] 39 | 40 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt" 41 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt" 42 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu') 43 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu') 44 | 45 | iou25_count = 0 46 | iou50_count = 0 47 | # maxiou_count = 0 48 | # valid_count = 0 49 | for i, anno in tqdm(enumerate(annos)): 50 | scene_id = anno['scene_id'] 51 | obj_id = int(anno['object_id']) 52 | desc = anno['description'] 53 | if desc[-1] in string.punctuation: 54 | desc = desc[:-1] 55 | prompt = random.choice(grounding_location_prompt).replace('', desc) 56 | 57 | if scene_id not in instance_attrs: 58 | continue 59 | scannet_locs = scannet_attrs[scene_id]["locs"] 60 | gt_locs = scannet_locs[obj_id].tolist() 61 | 62 | gt_loc_tokens = [num_to_location_token(x) for x in gt_locs] 63 | caption = " " + " ".join(gt_loc_tokens) + " " 64 | 65 | if split == "train": 66 | new_annos.append({ 67 | "scene_id": scene_id, 68 | "obj_id": obj_id, 69 | "caption": caption, 70 | "prompt": prompt 71 | }) 72 | else: 73 | new_annos.append({ 74 | "scene_id": scene_id, 75 | "obj_id": obj_id, 76 | "ref_captions": [caption], 77 | "prompt": prompt 78 | }) 79 | 80 | print(len(new_annos)) 81 | print(count) 82 | # print(maxiou_count / valid_count) 83 | # print(f"max iou@0.25: {iou25_count / len(new_annos)}") 84 | # print(f"max iou@0.5: {iou50_count / len(new_annos)}") 85 | 86 | with open(f"annotations/scanrefer_{segmentor}_{split}_location{version}.json", "w") as f: 87 | json.dump(new_annos, f, indent=4) -------------------------------------------------------------------------------- /preprocess/prepare_sqa3d_annos.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import nltk 5 | import random 6 | from tqdm import tqdm 7 | 8 | anno_dir = 'annotations/sqa3d' 9 | 10 | 11 | def convert_person_view(sentence): 12 | # first-person view to second-person view 13 | forms = {'i': 'you', 'me': 'you', 'my': 'your', 'mine': 'yours', 'am': 'are'} 14 | def translate(word): 15 | if word.lower() in forms: 16 | return forms[word.lower()] 17 | return word 18 | result = ' '.join([translate(word) for word in nltk.wordpunct_tokenize(sentence)]) 19 | return result.capitalize() 20 | 21 | 22 | def get_sqa_question_type(question): 23 | question = question.lstrip() 24 | if question[:4].lower() == 'what': 25 | return 0 26 | elif question[:2].lower() == 'is': 27 | return 1 28 | elif question[:3].lower() == 'how': 29 | return 2 30 | elif question[:3].lower() == 'can': 31 | return 3 32 | elif question[:5].lower() == 'which': 33 | return 4 34 | else: 35 | return 5 # others 36 | 37 | 38 | for split in ['train', 'val']: 39 | scan_ids = [] 40 | sqa_annos = [] 41 | question_file = os.path.join(anno_dir, f'v1_balanced_questions_{split}_scannetv2.json') 42 | with open(question_file, 'r', encoding='utf-8') as f: 43 | question_data = json.load(f)['questions'] 44 | question_map = {} 45 | for item in question_data: 46 | question_map[item['question_id']] = { 47 | 's': [item['situation']] + item['alternative_situation'], # list of str 48 | 'q': item['question'], # str 49 | } 50 | 51 | anno_file = os.path.join(anno_dir, f'v1_balanced_sqa_annotations_{split}_scannetv2.json') 52 | with open(anno_file, 'r', encoding='utf-8') as f: 53 | anno_data = json.load(f)['annotations'] 54 | for item in tqdm(anno_data): 55 | scan_ids.append(item['scene_id']) 56 | scene_id = item['scene_id'] 57 | obj_id = 0 58 | situation = random.choice(question_map[item['question_id']]['s']) 59 | question = question_map[item['question_id']]['q'] 60 | question_type = get_sqa_question_type(question) 61 | prompt = situation + ' ' + question + " Answer the question using a single word or phrase." 62 | answers = [meta['answer'] for meta in item['answers']] 63 | if split == 'train': 64 | answer = random.choice(answers) 65 | answer = answer.capitalize() 66 | if answer[-1] != ".": 67 | answer += "." 68 | sqa_annos.append({ 69 | 'scene_id': scene_id, 70 | 'obj_id': obj_id, 71 | 'prompt': prompt, 72 | 'caption': answer, 73 | 'sqa_type': question_type 74 | }) 75 | else: 76 | sqa_annos.append({ 77 | 'scene_id': scene_id, 78 | 'obj_id': obj_id, 79 | 'prompt': prompt, 80 | 'ref_captions': answers, 81 | 'sqa_type': question_type 82 | }) 83 | with open(f"annotations/sqa3d_{split}.json", "w") as f: 84 | json.dump(sqa_annos, f, indent=4) 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /preprocess/process_scannet_data.py: -------------------------------------------------------------------------------- 1 | from plyfile import PlyData 2 | import numpy as np 3 | import os 4 | import json 5 | from pytorch3d.io import load_obj 6 | import torch 7 | from collections import defaultdict 8 | from tqdm import tqdm 9 | import mmengine 10 | 11 | data_root = '/mnt/petrelfs/share_data/huanghaifeng/maoxiaohan/ScanNet_v2' 12 | raw_data_dir = os.path.join(data_root, 'scans') 13 | meta_data_dir = os.path.join(data_root, 'meta_data') 14 | output_dir = '/mnt/petrelfs/share_data/huanghaifeng/data/processed/scannet' 15 | idx2class = json.load(open('/mnt/petrelfs/share_data/huanghaifeng/referit3d/referit3d/data/mappings/scannet_idx_to_semantic_class.json')) 16 | idx2class = {int(k): v for k, v in idx2class.items()} 17 | class2idx = {v: k for k, v in idx2class.items()} 18 | scan2axis_align = json.load(open('/mnt/petrelfs/share_data/huanghaifeng/data/processed/scannet/scans_axis_alignment_matrices.json')) 19 | 20 | 21 | def process_one_scan(scan_id): 22 | save_dir = os.path.join(output_dir, 'scans', scan_id) 23 | if os.path.exists(os.path.join(save_dir, 'object_infos.json')): 24 | return 25 | # label_path = os.path.join(raw_data_dir, scan_id, "labels.instances.annotated.v2.ply") 26 | aggregation_path = os.path.join(raw_data_dir, scan_id, scan_id + '.aggregation.json') 27 | segs_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.0.010000.segs.json') 28 | scan_ply_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.labels.ply') 29 | 30 | data = PlyData.read(scan_ply_path) 31 | x = np.asarray(data.elements[0].data['x']).astype(np.float32) 32 | y = np.asarray(data.elements[0].data['y']).astype(np.float32) 33 | z = np.asarray(data.elements[0].data['z']).astype(np.float32) 34 | pc = np.stack([x, y, z], axis=1) 35 | 36 | axis_align_matrix = np.array(scan2axis_align[scan_id], dtype=np.float32).reshape(4, 4) 37 | pts = np.ones((pc.shape[0], 4), dtype=pc.dtype) 38 | pts[:, :3] = pc 39 | pc = np.dot(pts, axis_align_matrix.transpose())[:, :3] 40 | 41 | scan_aggregation = json.load(open(aggregation_path)) 42 | segments_info = json.load(open(segs_path)) 43 | segment_indices = segments_info["segIndices"] 44 | segment_indices_dict = defaultdict(list) 45 | for i, s in enumerate(segment_indices): 46 | segment_indices_dict[s].append(i) 47 | 48 | pc_instance_id = np.zeros(pc.shape[0]).astype(np.int32) * -1 49 | # pc_semantic_label_id = np.zeros(pc.shape[0]).astype(np.int32) * -1 50 | pc_segment_label = [''] * pc.shape[0] 51 | 52 | valid_ids = [] 53 | all_objects = [] 54 | for idx, object_info in enumerate(scan_aggregation['segGroups']): 55 | object_instance_label = object_info['label'] 56 | object_id = object_info['objectId'] 57 | segments = object_info["segments"] 58 | valid_ids.append(idx) 59 | pc_ids = [] 60 | for s in segments: 61 | pc_ids.extend(segment_indices_dict[s]) 62 | pc_instance_id[pc_ids] = object_id 63 | object_pc = pc[pc_ids] 64 | object_center = (np.max(object_pc, axis=0) + np.min(object_pc, axis=0)) / 2.0 65 | object_size = np.max(object_pc, axis=0) - np.min(object_pc, axis=0) 66 | object_bbox = np.concatenate([object_center, object_size], axis=0) 67 | all_objects.append({ 68 | 'bbox': object_bbox.tolist(), 69 | 'label': object_instance_label 70 | }) 71 | object_infos = { 72 | 'valid_ids': valid_ids, 73 | 'object_list': all_objects 74 | } 75 | 76 | save_dir = os.path.join(output_dir, 'scans', scan_id) 77 | if not os.path.exists(save_dir): 78 | os.makedirs(save_dir) 79 | with open(os.path.join(save_dir, 'object_infos.json'), 'w') as f: 80 | json.dump(object_infos, f, indent=4) 81 | # np.save(os.path.join(save_dir, 'object_infos.npy'), object_infos) 82 | # np.save(os.path.join(save_dir, 'axis_align_matrix.npy'), axis_align_matrix) 83 | 84 | 85 | def process_split(split): 86 | assert split in ['train', 'val', 'test'] 87 | split_file = os.path.join(meta_data_dir, f'scannetv2_{split}.txt') 88 | scan_names = [line.rstrip() for line in open(split_file)] 89 | print(f'{split} split scans: {len(scan_names)}') 90 | # new_split_file = os.path.join(output_dir, f'{split}.txt') 91 | # valid_names = [] 92 | # for scan_name in tqdm(scan_names): 93 | # if scan_name in mapping: 94 | # new_scan_name = mapping[scan_name] 95 | # process_one_scan(scan_name, new_scan_name) 96 | # valid_names.append(scan_name) 97 | # params = [] 98 | # for scan_name in scan_names: 99 | # if scan_name in mapping: 100 | # new_scan_name = mapping[scan_name] 101 | # params.append((scan_name, new_scan_name)) 102 | 103 | parallel = True 104 | 105 | if parallel: 106 | mmengine.utils.track_parallel_progress(process_one_scan, scan_names, 8) 107 | else: 108 | for scan_id in tqdm(scan_names): 109 | process_one_scan(scan_id) 110 | 111 | # if not os.path.exists(new_split_file): 112 | # with open(new_split_file, 'w') as f: 113 | # for t in valid_names: 114 | # f.write(f'{t}\n') 115 | 116 | 117 | for s in ['train', 'val']: 118 | process_split(s) 119 | 120 | -------------------------------------------------------------------------------- /preprocess/run_prepare.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # scannet_dir="/mnt/petrelfs/share_data/maoxiaohan/ScanNet_v2" 4 | scannet_dir="/mnt/hwfile/OpenRobotLab/huanghaifeng/data/scannet" 5 | version="" 6 | segment_result_dir="/mnt/hwfile/OpenRobotLab/huanghaifeng/data/processed/scannet/Mask3DInst" 7 | # segment_result_dir="/mnt/petrelfs/share_data/chenyilun/haifeng/Mask3DInst" 8 | inst_seg_dir="" 9 | class_label_file="annotations/scannet/scannetv2-labels.combined.tsv" 10 | max_obj_num=100 11 | 12 | train_iou_thres=0.5 13 | 14 | processed_data_dir="/mnt/hwfile/OpenRobotLab/huanghaifeng/data/processed/scannet/mask3d_ins_data" 15 | # processed_data_dir="/mnt/petrelfs/share_data/chenyilun/haifeng/mask3d_ins_data" 16 | segmentor="mask3d" 17 | # processed_data_dir="/mnt/petrelfs/share_data/chenyilun/share/mask3d/proposals" 18 | # segmentor="clasp" 19 | # segmentor="deva" 20 | 21 | python preprocess/prepare_mask3d_data.py \ 22 | --scannet_dir "$scannet_dir" \ 23 | --output_dir "$processed_data_dir" \ 24 | --segment_dir "$segment_result_dir" \ 25 | --inst_seg_dir "$inst_seg_dir" \ 26 | --class_label_file "$class_label_file" \ 27 | --apply_global_alignment \ 28 | --num_workers 16 \ 29 | --parallel 30 | 31 | python preprocess/prepare_scannet_mask3d_attributes.py \ 32 | --scan_dir "$processed_data_dir" \ 33 | --segmentor "$segmentor" \ 34 | --max_inst_num "$max_obj_num" \ 35 | --version "$version" 36 | 37 | # python preprocess/prepare_scannet_attributes.py \ 38 | # --scannet_dir "$scannet_dir" 39 | 40 | # python preprocess/prepare_scannet_attributes_clasp.py \ 41 | # --scan_dir "$processed_data_dir" \ 42 | # --segmentor "$segmentor" \ 43 | # --version "$version" \ 44 | # --max_inst_num "$max_obj_num" 45 | 46 | python preprocess/prepare_scanrefer_annos.py \ 47 | --segmentor "$segmentor" \ 48 | --version "$version" \ 49 | --train_iou_thres "$train_iou_thres" \ 50 | --max_obj_num "$max_obj_num" 51 | 52 | python preprocess/prepare_scan2cap_annos.py \ 53 | --segmentor "$segmentor" \ 54 | --version "$version" \ 55 | --train_iou_thres "$train_iou_thres" \ 56 | --max_obj_num "$max_obj_num" 57 | 58 | python preprocess/prepare_objalign_annos.py \ 59 | --segmentor "$segmentor" \ 60 | --version "$version" \ 61 | --train_iou_thres "$train_iou_thres" 62 | 63 | python preprocess/prepare_nr3dcaption_annos.py \ 64 | --segmentor "$segmentor" \ 65 | --version "$version" \ 66 | --train_iou_thres "$train_iou_thres" 67 | 68 | python preprocess/prepare_multi3dref_annos.py \ 69 | --segmentor "$segmentor" \ 70 | --version "$version" \ 71 | --train_iou_thres "$train_iou_thres" 72 | 73 | python preprocess/prepare_scanqa_annos.py 74 | 75 | python preprocess/prepare_sqa3d_annos.py 76 | 77 | # python preprocess/prepare_nr3d_annos.py \ 78 | # --segmentor "$segmentor" \ 79 | # --version "$version" \ 80 | # --train_iou_thres "$train_iou_thres" \ 81 | # --max_obj_num "$max_obj_num" 82 | 83 | # python preprocess/prepare_sr3d_annos.py \ 84 | # --segmentor "$segmentor" \ 85 | # --version "$version" \ 86 | # --train_iou_thres "$train_iou_thres" \ 87 | # --max_obj_num "$max_obj_num" 88 | 89 | 90 | 91 | # python preprocess/prepare_scannet_caption_annos.py \ 92 | # --segmentor "$segmentor" \ 93 | # --version "$version" \ 94 | # --train_iou_thres "$train_iou_thres" 95 | 96 | # python preprocess/prepare_scannet_region_caption_annos.py \ 97 | # --segmentor "$segmentor" \ 98 | # --version "$version" \ 99 | # --train_iou_thres "$train_iou_thres" 100 | 101 | # python preprocess/prepare_scanrefer_location_annos.py \ 102 | # --segmentor "$segmentor" \ 103 | # --version "$version" \ 104 | # --train_iou_thres "$train_iou_thres" \ 105 | # --max_obj_num "$max_obj_num" 106 | 107 | # python preprocess/prepare_multi3dref_location_annos.py \ 108 | # --segmentor "$segmentor" \ 109 | # --version "$version" \ 110 | # --train_iou_thres "$train_iou_thres" 111 | 112 | # python preprocess/prepare_scan2cap_location_annos.py \ 113 | # --segmentor "$segmentor" \ 114 | # --version "$version" \ 115 | # --train_iou_thres "$train_iou_thres" \ 116 | # --max_obj_num "$max_obj_num" -------------------------------------------------------------------------------- /prompts/concise_description.txt: -------------------------------------------------------------------------------- 1 | Describe the target object in the 3D scene concisely. 2 | Provide a brief description of the given target object in the 3D scene. 3 | Offer a succinct explanation of the target object in the 3D scene presented. 4 | Summarize the visual content of the target object in the 3D scene. 5 | Give a short and clear explanation of the previous target object in the 3D scene. 6 | Share a concise interpretation of the target object in the 3D scene provided. 7 | Present a compact description of the the target object's key features in the 3D scene. 8 | Relay a brief, clear account of the target object shown in the 3D scene. 9 | Render a clear and concise summary of the target object in the 3D scene. 10 | Write a terse but informative summary of the target object in the 3D scene. 11 | Create a compact narrative representing the target object in the 3D scene presented. -------------------------------------------------------------------------------- /prompts/concise_description_objxx.txt: -------------------------------------------------------------------------------- 1 | Provide a brief description of obj within the 3D scene. 2 | Offer a succinct overview of obj situated in the 3D environment. 3 | Summarize the key details of obj in the 3D setting. 4 | Give a concise account of obj in the 3D scene. 5 | Deliver a compact portrayal of obj within the 3D context. 6 | Outline obj in the 3D scene with brevity. 7 | Present a pithy depiction of obj in the 3D environment. 8 | Provide a concise report on obj in the 3D scene. 9 | Describe obj within the 3D scene in a succinct manner. 10 | Offer a brief 3D scene description of obj. -------------------------------------------------------------------------------- /prompts/conv_description.txt: -------------------------------------------------------------------------------- 1 | Where is the target object. -------------------------------------------------------------------------------- /prompts/dataset_generation/conversation.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant, and you are seeing an object in a 3D scene. What you see is provided with several sentences, describing the same object you are looking at, and the position of surrounding objects in the 3D scene to represent the content of the 3D scene. Based on these descriptions of this object and the location of surrounding objects in the 3D scene, answer all the questions as if you are in the 3D scene. 2 | Design a conversation between you and a person asking about this object in the 3D scene. The answers should be in a tone that a visual AI assistant is in the 3D scene and answering the question. Ask diverse questions and give corresponding answers. 3 | Include questions asking about the visual content of this object, including the object types, object shape, object attribute, object functions, object locations, relative positions between objects, etc. Only include questions that have definite answers: 4 | (1) Questions whose contents can be confidently observed and answered based on the 3D scene. 5 | (2) Questions whose absence from the 3D scene can be confidently determined. 6 | Also include complex questions that are relevant to the object, because you are seeing a 3D scene, complex problems should focus more on discussing spatial reasoning, human interaction in the scene, and the rationality, potential risk and purpose of the position of objects in the room. (For example, assuming a person is interacting with this object, ask about the relative position of other objects to the person and the person's path to access those objects. It is valuable to ask about the specific function of an object, its current location, and suggestions for placement based on background knowledge. etc.) Again, provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary. 7 | Notice! Do not ask about uncertain details, do not ask unsure questions that can be answered by the given information, do not mention the information of descriptions in the generated answer, and don't need to specify the class or properties of the described object in the question. -------------------------------------------------------------------------------- /prompts/dataset_generation/detail.txt: -------------------------------------------------------------------------------- 1 | You are an AI 3D visual assistant, and you are seeing an object in a 3D scene. What you see is provided with several sentences, describing the same object you are looking at, and the position of surrounding objects in the 3D scene to represent the content of the 3D scene. Based on these descriptions of this object and the location of surrounding objects in the 3D scene, summary and describe the placement, function of this object, and how a person can access this object in detail as if you are in the 3D scene. 2 | Importantly, do not mention any specific spatial coordinate values. The description should be more than 150 words and less than 200 words. -------------------------------------------------------------------------------- /prompts/dataset_generation/textualize_obj.txt: -------------------------------------------------------------------------------- 1 | Descriptions: ["There is a single white armchair. placed next to the window of the room.", "The sofa chair is the corner chair. lying parallel to the wall. a small table with the lamp is present beside the chair.", "This is a white sofa chair. it is under a window.", "This is a white armchair. is next to a lamp.", "This is the corner sofa chair. a small table with a lamp can be seen near this chair."] 2 | Described object: {sofa chair:[-1.31, 3.15, 0.59]}; Neighbor objects: {window:[-1.12, 4.12, 1.59], table:[0.86, 1.61, 0.38], doorframe:[-2.25, 0.67, 1.27], windowsill:[0.88, 3.97, 0.98], windowsill:[-1.32, 3.93, 0.91], sofa chair:[0.98, 3.35, 0.71], window:[1.16, 4.18, 1.73], pillow:[1.35, 0.29, 0.46], table:[-0.15, -2.66, 0.26], tv:[-2.2, -0.55, 1.52]} -------------------------------------------------------------------------------- /prompts/detailed_description.txt: -------------------------------------------------------------------------------- 1 | Describe the target object in detail. 2 | Provide a detailed description of the given target object in the 3D scene. 3 | Give an elaborate explanation of the target object you see in the 3D scene. 4 | Share a comprehensive rundown of the presented target object in the 3D scene. 5 | Offer a thorough analysis of the target object in the 3D scene. 6 | Explain the various aspects of the target object before you. 7 | Clarify the contents of the displayed target object with great detail. 8 | Characterize the target object using a well-detailed description. 9 | Break down the elements of the target object in a detailed manner. 10 | Walk through the important details of the target object in the 3D scene. 11 | Portray the target object with a rich, descriptive narrative. 12 | Narrate the contents of the target object with precision. 13 | Analyze the target object in a comprehensive and detailed manner. 14 | Illustrate the target object through a descriptive explanation. 15 | Examine the target object closely and share its details. 16 | Write an exhaustive depiction of the given target object in the 3D scene. -------------------------------------------------------------------------------- /prompts/grounding_answer_templates.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/prompts/grounding_answer_templates.txt -------------------------------------------------------------------------------- /prompts/grounding_prompts.txt: -------------------------------------------------------------------------------- 1 | Here is a description: "" Which object aligns most closely with the provided description? Please respond with the object ID only. 2 | Here is a description: "" Can you identify the object that best fits the provided description? Please respond with the object ID only. 3 | Here is a description: "" Identify the object that is the most suitable match for the given description. Please respond with the object ID only. 4 | Here is a description: "" Which object corresponds most accurately to the provided description? Please respond with the object ID only. 5 | Here is a description: "" Choose the object that best meets the features outlined in the description. Please respond with the object ID only. 6 | Here is a description: "" Which object is the most appropriate match for the description provided? Please respond with the object ID only. 7 | Here is a description: "" Select the object that most closely matches the given description. Please respond with the object ID only. 8 | Here is a description: "" Determine which object is the optimal match for the description provided. Please respond with the object ID only. 9 | Here is a description: "" Which object aligns best with the description given? Please respond with the object ID only. 10 | Here is a description: "" Pick the object that best suits the description given. Please respond with the object ID only. -------------------------------------------------------------------------------- /prompts/instruction.txt: -------------------------------------------------------------------------------- 1 | The conversation centers around an indoor scene. Object information: []. -------------------------------------------------------------------------------- /prompts/nr3d_caption_templates.txt: -------------------------------------------------------------------------------- 1 | Detail the spatial positioning of the amidst surrounding elements. 2 | Illustrate the 's placement relative to its environment. 3 | Explain the 's location in correlation with nearby items. 4 | Elaborate on the 's spatial context within the scene. 5 | Describe how the is situated in relation to other elements present. 6 | Provide insight into the 's positioning among its surroundings. 7 | Discuss the relative placement of the compared to its surrounding context. 8 | Offer a depiction of the 's spatial orientation within the scene. 9 | Interpret the 's location within the broader context of the scene. 10 | Present the 's spatial relationship with other entities within the scene. -------------------------------------------------------------------------------- /prompts/object_caption_templates.txt: -------------------------------------------------------------------------------- 1 | Portray the visual characteristics of the . 2 | Detail the outward presentation of the . 3 | Provide a depiction of the 's appearance. 4 | Illustrate how the looks. 5 | Describe the visual aspects of the . 6 | Convey the physical attributes of the . 7 | Outline the external features of the . 8 | Render the appearance of the in words. 9 | Depict the outward form of the . 10 | Elaborate on the visual representation of the . -------------------------------------------------------------------------------- /prompts/scanrefer_caption_templates.txt: -------------------------------------------------------------------------------- 1 | Begin by detailing the visual aspects of the before delving into its spatial context among other elements within the scene. 2 | First, depict the physical characteristics of the , followed by its placement and interactions within the surrounding environment. 3 | Describe the appearance of the , then elaborate on its positioning relative to other objects in the scene. 4 | Paint a picture of the visual attributes of , then explore how it relates spatially to other elements in the scene. 5 | Start by articulating the outward features of the , then transition into its spatial alignment within the broader scene. 6 | Provide a detailed description of the appearance of before analyzing its spatial connections with other elements in the scene. 7 | Capture the essence of the appearance of , then analyze its spatial relationships within the scene's context. 8 | Detail the physical characteristics of the and subsequently examine its spatial dynamics amidst other objects in the scene. 9 | Describe the visual traits of first, then elucidate its spatial arrangements in relation to neighboring elements. 10 | Begin by outlining the appearance of , then proceed to illustrate its spatial orientation within the scene alongside other objects. -------------------------------------------------------------------------------- /prompts/scene_align_template.txt: -------------------------------------------------------------------------------- 1 | Obj12 is the nearest object to obj00. -------------------------------------------------------------------------------- /prompts/score_template.txt: -------------------------------------------------------------------------------- 1 | ###Human: 2 | 3 | Evaluate the request below and determine whether it accurately identifies the target object enclosed within the tags '' and '': 4 | 5 | 6 | {} 7 | 8 | 9 | For reference, I have asked for a powerful 3D discriminator machine, it indicates that the request has a {}% likelihood of being accurate in identifying the target object. However, the machine is not always correct, so you need to provide the final determination. 10 | 11 | Please respond with the answer in the following format: 'The answer is: True' if the request correctly localizes the target object, or 'The answer is: False' if it does not. 12 | 13 | ### -------------------------------------------------------------------------------- /prompts/score_template_old.txt: -------------------------------------------------------------------------------- 1 | ###Human: 2 | 3 | {}. 4 | 5 | Please evaluate the relevance of the sentence above, which describes an object in the scene, to the target object (given between the tags "" and ""). Assign a score on a scale of 1 to 5, where higher scores indicate greater relevance. Use the following rating system: 1 for totally irrelevant, 2 for irrelevant, 3 for somewhat relevant or somewhat irrelevant, 4 for relevant, and 5 for highly relevant. 6 | 7 | Begin your response by stating the rating as "Rating: x/5.", where 'x' represents the assigned score. After assigning the rating, provide explanations for your rating. 8 | 9 | ### -------------------------------------------------------------------------------- /prompts/system.txt: -------------------------------------------------------------------------------- 1 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. -------------------------------------------------------------------------------- /prompts/system_backup.txt: -------------------------------------------------------------------------------- 1 | ###System: In this task, you will be provided with extensive information regarding objects within a 3D scene. Our conversation will focus primarily on a specific target object, and I will also supply information about all other objects present in the scene. 2 | Target object: . 3 | All the objects in the scene: []. 4 | 5 | 6 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The conversation focuses on a 3D indoor scene, which contains dozens of 3D objects. Here is a list of object information: []. "objxxx" refers to the object ID, which will be used to identify a specific object in the following conversation. 7 | 8 | A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. 9 | The conversation focuses on a 3D indoor scene, which contains dozens of 3D objects. Here is a list of object information: []. "objxx" refers to the object ID, which will be used to identify a specific object in the following conversation. 10 | 11 | A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. The conversation centers around a 3D indoor scene that encompasses numerous 3D objects. Here is a list of object information: []. Objects are separated by "," and each object is identified by an ID in the format "objxx". 12 | 13 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.5.3 2 | transformers==4.39.3 3 | einops==0.6.1 4 | plyfile==1.0.1 5 | trimesh==3.23.1 6 | peft==0.9.0 7 | termcolor==2.3.0 8 | scipy==1.12.0 9 | pycocoevalcap==1.2 10 | sentencepiece==0.2.0 11 | protobuf==4.25.3 12 | packaging==24.0 13 | flash_attn==2.5.6 14 | mmengine==0.10.3 -------------------------------------------------------------------------------- /scripts/config_3dgraphllm.py: -------------------------------------------------------------------------------- 1 | # ========================= data ========================== 2 | anno_root = "annotations" # annotation dir 3 | pc_encoder = "uni3d" 4 | segmentor = "mask3d" 5 | version = "" 6 | 7 | seg_feat_file = f"{anno_root}/scannet_{segmentor}_{pc_encoder}_feats.pt" 8 | seg_img_feat_file = f"{anno_root}/scannet_{segmentor}_videofeats.pt" 9 | seg_val_attr_file = f"{anno_root}/scannet_{segmentor}_val_attributes.pt" 10 | seg_val_gnn_file = f"{anno_root}/scannet_{segmentor}_val_gnn_feats_2{version}.pt" 11 | 12 | val_file_dict = { 13 | 'demo': [ 14 | seg_feat_file, 15 | seg_img_feat_file, 16 | seg_val_attr_file, 17 | seg_val_gnn_file, 18 | f"{segmentor}", 19 | ], 20 | } 21 | 22 | # ========================= model ========================== 23 | model = dict( 24 | llama_model_path="./Meta-Llama-3-8B-Instruct", 25 | input_dim=1024, 26 | img_input_dim=1024, 27 | attr_dim=512, 28 | scene_dim=256, 29 | pos_dim=128, 30 | encoder_num_layers=3, 31 | low_resource=False, 32 | system_path="./prompts/system.txt", 33 | instruction_path="./prompts/instruction.txt", 34 | max_txt_len=64, 35 | end_sym="", 36 | role=("USER", "ASSISTANT"), 37 | add_scene_token=False, 38 | add_img_token=True, 39 | use_lora=True, 40 | train_emb=True, 41 | train_img_proj=True, 42 | no_obj=False, 43 | max_obj_num=150, 44 | bidirection=False, 45 | add_pos_emb=False, 46 | feat_fusion=False, 47 | fuse_with_id=False, 48 | use_objid=True, 49 | use_location_token=False, 50 | knn=2, 51 | bbox_embed=False, 52 | gt_pretrain=False, 53 | nms=True, 54 | nn_distance=True, 55 | max_knn=2 56 | ) 57 | 58 | lora = dict( 59 | lora_target_modules=[ 60 | "q_proj", 61 | "v_proj", 62 | "k_proj", 63 | "o_proj", 64 | "gate_proj", 65 | "up_proj", 66 | "down_proj" 67 | ], 68 | lora_r=16, 69 | lora_alpha=16, 70 | lora_dropout=0.05 71 | ) 72 | 73 | optimizer = dict( 74 | opt="adamW", 75 | lr=5e-6, 76 | opt_betas=[0.9, 0.999], # default 77 | weight_decay=0.02, 78 | scaler_enable=False, 79 | max_grad_norm=0.01, # requires a positive float, use -1 to disable 80 | # use a different lr for some modules, e.g., larger lr for new modules 81 | different_lr=dict( 82 | enable=False, 83 | module_names=["model.embed_tokens"], 84 | lr=[5e-4], 85 | wd=[0.02] 86 | ), 87 | ) 88 | 89 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.1) 90 | 91 | evaluate = False 92 | 93 | # ========================= wandb ========================== 94 | wandb = dict( 95 | enable=False, 96 | entity="anonym", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 97 | project="3DGraphLLM", 98 | ) 99 | dist_url = "env://" 100 | device = "cuda" 101 | 102 | # ========================= others ========================== 103 | output_dir = "./llama3-8b-gt-pretrain-2-3rscan" # output dir 104 | resume = False # if True, load optimizer and scheduler states as well 105 | debug = False 106 | log_freq = 20 107 | # eval_freq = 500 108 | seed = 42 109 | 110 | save_latest = False 111 | do_save = True 112 | auto_resume = True 113 | pretrained_path = "" 114 | img_projector_path = "" 115 | 116 | debug=False -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | which_python=$(which python) 2 | export PYTHONPATH=${PYTHONPATH}:${which_python}:. 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | 5 | export MASTER_PORT=$((54000 + $RANDOM % 10000)) 6 | export MASTER_ADDR=localhost 7 | 8 | epoch=3 9 | batch_size=8 10 | lr=2e-5 11 | train_emb=True 12 | train_img_proj=True 13 | add_img_token=True 14 | add_scene_token=False 15 | no_obj=False 16 | input_dim=1024 # 1024 17 | bidirection=False 18 | different_lr=False 19 | max_obj_num=150 20 | lora_r=16 21 | lora_alpha=16 22 | add_pos_emb=False 23 | feat_fusion=False 24 | fuse_with_id=False 25 | config="" 26 | max_grad_norm=0.01 27 | seed=42 28 | use_location_token=False 29 | 30 | #llama_model_path="./vicuna-7b-v1.5" 31 | llama_model_path="./Meta-Llama-3-8B-Instruct" 32 | 33 | train_tag="scanrefer#obj_align#nr3d_caption#scan2cap#scanqa#sqa3d#multi3dref" 34 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref" 35 | 36 | evaluate=True 37 | debug=False 38 | if [ $evaluate = "True" ]; then 39 | enable_wandb=False 40 | gpu_num=1 41 | do_save=True 42 | other_info="evaluation" 43 | else 44 | enable_wandb=True 45 | gpu_num=1 46 | do_save=True 47 | other_info="chatscene" 48 | fi 49 | 50 | tag="${train_tag}__${val_tag}__${other_info}" 51 | 52 | pretrained_path="./demo/3dgraphllm.pth" 53 | 54 | OUTPUT_DIR=outputs/3dgraphllm_2e-5_ep6 55 | mkdir -p ${OUTPUT_DIR} 56 | 57 | python tasks/train.py \ 58 | "$(dirname $0)/${config}config.py" \ 59 | output_dir "$OUTPUT_DIR" \ 60 | scheduler.epochs "$epoch" \ 61 | optimizer.lr "$lr" \ 62 | model.add_scene_token "$add_scene_token" \ 63 | model.add_img_token "$add_img_token" \ 64 | pretrained_path "$pretrained_path" \ 65 | evaluate "$evaluate" \ 66 | wandb.enable "$enable_wandb" \ 67 | gpu_num "$gpu_num" \ 68 | do_save "$do_save" \ 69 | batch_size "$batch_size" \ 70 | model.train_emb "$train_emb" \ 71 | model.train_img_proj "$train_img_proj" \ 72 | train_tag "$train_tag" \ 73 | val_tag "$val_tag" \ 74 | model.no_obj "$no_obj" \ 75 | segmentor "$segmentor" \ 76 | pc_encoder "$pc_encoder" \ 77 | model.input_dim "$input_dim" \ 78 | model.bidirection "$bidirection" \ 79 | optimizer.different_lr.enable "$different_lr" \ 80 | model.max_obj_num "$max_obj_num" \ 81 | lora.lora_r "$lora_r" \ 82 | lora.lora_alpha "$lora_alpha" \ 83 | model.add_pos_emb "$add_pos_emb" \ 84 | model.feat_fusion "$feat_fusion" \ 85 | optimizer.max_grad_norm "$max_grad_norm" \ 86 | seed "$seed" \ 87 | model.fuse_with_id "$fuse_with_id" \ 88 | model.llama_model_path "$llama_model_path" \ 89 | model.use_location_token "$use_location_token" \ 90 | model.gt_pretrain False 91 | 92 | -------------------------------------------------------------------------------- /scripts/run_gt_pretrain.sh: -------------------------------------------------------------------------------- 1 | which_python=$(which python) 2 | export PYTHONPATH=${PYTHONPATH}:${which_python}:. 3 | echo "PYTHONPATH: ${PYTHONPATH}" 4 | 5 | export MASTER_PORT=$((54000 + $RANDOM % 10000)) 6 | export MASTER_ADDR=localhost 7 | 8 | epoch=3 9 | batch_size=8 10 | lr=2e-5 11 | train_emb=True 12 | train_img_proj=True 13 | add_img_token=True 14 | add_scene_token=False 15 | no_obj=False 16 | input_dim=1024 # 1024 17 | bidirection=False 18 | different_lr=False 19 | max_obj_num=150 20 | lora_r=16 21 | lora_alpha=16 22 | add_pos_emb=False 23 | feat_fusion=False 24 | fuse_with_id=False 25 | config="" 26 | max_grad_norm=0.01 27 | seed=42 28 | use_location_token=False 29 | 30 | #llama_model_path="./vicuna-7b-v1.5" 31 | llama_model_path="./Meta-Llama-3-8B-Instruct" 32 | 33 | train_tag="scanrefer#obj_align#nr3d_caption#scan2cap#scanqa#sqa3d#multi3dref" 34 | val_tag="scanrefer#scan2cap#multi3dref#sqa3d#scanqa" 35 | 36 | evaluate=False 37 | debug=False 38 | resume=False 39 | if [ $evaluate = "True" ]; then 40 | enable_wandb=False 41 | gpu_num=1 42 | do_save=True 43 | other_info="evaluation" 44 | else 45 | enable_wandb=False 46 | gpu_num=1 47 | do_save=True 48 | other_info="chatscene" 49 | fi 50 | 51 | tag="${train_tag}__${val_tag}__${other_info}" 52 | 53 | pretrained_path="" 54 | 55 | OUTPUT_DIR=outputs/llama3-8b-gt-pretrain-2 56 | 57 | python tasks/train.py \ 58 | "$(dirname $0)/${config}config-gt-pretrain.py" \ 59 | output_dir "$OUTPUT_DIR" \ 60 | scheduler.epochs "$epoch" \ 61 | optimizer.lr "$lr" \ 62 | model.add_scene_token "$add_scene_token" \ 63 | model.add_img_token "$add_img_token" \ 64 | pretrained_path "$pretrained_path" \ 65 | evaluate "$evaluate" \ 66 | wandb.enable "$enable_wandb" \ 67 | gpu_num "$gpu_num" \ 68 | do_save "$do_save" \ 69 | batch_size "$batch_size" \ 70 | model.train_emb "$train_emb" \ 71 | model.train_img_proj "$train_img_proj" \ 72 | train_tag "$train_tag" \ 73 | val_tag "$val_tag" \ 74 | model.no_obj "$no_obj" \ 75 | segmentor "$segmentor" \ 76 | pc_encoder "$pc_encoder" \ 77 | model.input_dim "$input_dim" \ 78 | model.bidirection "$bidirection" \ 79 | optimizer.different_lr.enable "$different_lr" \ 80 | model.max_obj_num "$max_obj_num" \ 81 | lora.lora_r "$lora_r" \ 82 | lora.lora_alpha "$lora_alpha" \ 83 | model.add_pos_emb "$add_pos_emb" \ 84 | model.feat_fusion "$feat_fusion" \ 85 | optimizer.max_grad_norm "$max_grad_norm" \ 86 | seed "$seed" \ 87 | model.fuse_with_id "$fuse_with_id" \ 88 | model.llama_model_path "$llama_model_path" \ 89 | model.use_location_token "$use_location_token" \ 90 | model.knn 2 \ 91 | model.gt_pretrain True 92 | -------------------------------------------------------------------------------- /tasks/shared_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import os.path as osp 5 | from os.path import join 6 | 7 | import torch 8 | from torch.utils.data import ConcatDataset, DataLoader 9 | 10 | from utils.optimizer import create_optimizer 11 | from utils.scheduler import create_scheduler 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_media_types(datasources): 17 | """get the media types for for all the dataloaders. 18 | 19 | Args: 20 | datasources (List): List of dataloaders or datasets. 21 | 22 | Returns: List. The media_types. 23 | 24 | """ 25 | if isinstance(datasources[0], DataLoader): 26 | datasets = [dataloader.dataset for dataloader in datasources] 27 | else: 28 | datasets = datasources 29 | media_types = [ 30 | dataset.datasets[0].media_type 31 | if isinstance(dataset, ConcatDataset) 32 | else dataset.media_type 33 | for dataset in datasets 34 | ] 35 | 36 | return media_types 37 | 38 | 39 | def setup_model( 40 | config, model_cls, find_unused_parameters=False 41 | ): 42 | logger.info("Creating model") 43 | config = copy.deepcopy(config) 44 | 45 | model = model_cls(config=config) 46 | 47 | model = model.to(torch.device(config.device)) 48 | model_without_ddp = model 49 | if config.distributed: 50 | model = torch.nn.parallel.DistributedDataParallel( 51 | model, 52 | device_ids=[config.gpu], 53 | find_unused_parameters=find_unused_parameters, # `False` for image-only task 54 | gradient_as_bucket_view=True 55 | ) 56 | optimizer = create_optimizer(config.optimizer, model, config) 57 | scheduler = create_scheduler(config.scheduler, optimizer) 58 | scaler = torch.cuda.amp.GradScaler(enabled=config.optimizer.scaler_enable, growth_interval=100) 59 | 60 | start_epoch = 0 61 | global_step = 0 62 | 63 | # auto resume the latest checkpoint 64 | if config.get("auto_resume", False): 65 | logger.info("Auto resuming") 66 | model_latest = join(config.output_dir, "ckpt_latest.pth") 67 | model_best = join(config.output_dir, "ckpt_best.pth") 68 | large_num = -1 69 | for p in os.listdir(config.output_dir): 70 | if 'ckpt' in p: 71 | num = p.split('_')[1].split('.')[0] 72 | if str.isnumeric(num): 73 | if int(num) > large_num: 74 | large_num = int(num) 75 | if large_num != -1: 76 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") 77 | if osp.isfile(model_latest) and not config.pretrained_path: 78 | config.pretrained_path = model_latest 79 | config.resume = True 80 | elif osp.isfile(model_best) and not config.pretrained_path: 81 | config.pretrained_path = model_best 82 | config.resume = True 83 | else: 84 | logger.info(f"Not found checkpoint in {config.output_dir}") 85 | 86 | if osp.isfile(config.img_projector_path): 87 | img_projector_sd = torch.load(config.img_projector_path, map_location="cpu") 88 | msg = model_without_ddp.object_img_proj.load_state_dict(img_projector_sd) 89 | logger.info(f"Loaded pretrained image projector from {config.img_projector_path}.") 90 | 91 | if osp.isfile(config.pretrained_path): 92 | checkpoint = torch.load(config.pretrained_path, map_location="cpu") 93 | state_dict = checkpoint["model"] 94 | 95 | if config.resume: 96 | optimizer.load_state_dict(checkpoint["optimizer"]) 97 | scheduler.load_state_dict(checkpoint["scheduler"]) 98 | scaler.load_state_dict(checkpoint["scaler"]) 99 | start_epoch = checkpoint["epoch"] + 1 100 | global_step = checkpoint["global_step"] 101 | keys_to_delete = [] 102 | for name, param in state_dict.items(): 103 | if name not in model_without_ddp.state_dict(): 104 | continue 105 | if param.size() != model_without_ddp.state_dict()[name].size(): 106 | keys_to_delete.append(name) 107 | for key in keys_to_delete: 108 | del state_dict[key] 109 | msg = model_without_ddp.load_state_dict(state_dict, strict=False) 110 | logger.info(msg) 111 | logger.info(f"Loaded checkpoint from {config.pretrained_path}.") 112 | else: 113 | logger.warning("No pretrained checkpoint provided, training from scratch.") 114 | 115 | return ( 116 | model, 117 | model_without_ddp, 118 | optimizer, 119 | scheduler, 120 | scaler, 121 | start_epoch, 122 | global_step, 123 | ) 124 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/utils/__init__.py -------------------------------------------------------------------------------- /utils/box_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_box3d_min_max(corner): 5 | ''' Compute min and max coordinates for 3D bounding box 6 | Note: only for axis-aligned bounding boxes 7 | 8 | Input: 9 | corners: numpy array (8,3), assume up direction is Z (batch of N samples) 10 | Output: 11 | box_min_max: an array for min and max coordinates of 3D bounding box IoU 12 | 13 | ''' 14 | 15 | min_coord = corner.min(axis=0) 16 | max_coord = corner.max(axis=0) 17 | x_min, x_max = min_coord[0], max_coord[0] 18 | y_min, y_max = min_coord[1], max_coord[1] 19 | z_min, z_max = min_coord[2], max_coord[2] 20 | 21 | return x_min, x_max, y_min, y_max, z_min, z_max 22 | 23 | 24 | def box3d_iou(corners1, corners2): 25 | ''' Compute 3D bounding box IoU. 26 | 27 | Input: 28 | corners1: numpy array (8,3), assume up direction is Z 29 | corners2: numpy array (8,3), assume up direction is Z 30 | Output: 31 | iou: 3D bounding box IoU 32 | 33 | ''' 34 | 35 | x_min_1, x_max_1, y_min_1, y_max_1, z_min_1, z_max_1 = get_box3d_min_max(corners1) 36 | x_min_2, x_max_2, y_min_2, y_max_2, z_min_2, z_max_2 = get_box3d_min_max(corners2) 37 | xA = np.maximum(x_min_1, x_min_2) 38 | yA = np.maximum(y_min_1, y_min_2) 39 | zA = np.maximum(z_min_1, z_min_2) 40 | xB = np.minimum(x_max_1, x_max_2) 41 | yB = np.minimum(y_max_1, y_max_2) 42 | zB = np.minimum(z_max_1, z_max_2) 43 | inter_vol = np.maximum((xB - xA), 0) * np.maximum((yB - yA), 0) * np.maximum((zB - zA), 0) 44 | box_vol_1 = (x_max_1 - x_min_1) * (y_max_1 - y_min_1) * (z_max_1 - z_min_1) 45 | box_vol_2 = (x_max_2 - x_min_2) * (y_max_2 - y_min_2) * (z_max_2 - z_min_2) 46 | iou = inter_vol / (box_vol_1 + box_vol_2 - inter_vol + 1e-8) 47 | 48 | return iou 49 | 50 | 51 | def construct_bbox_corners(center, box_size): 52 | sx, sy, sz = box_size 53 | x_corners = [sx / 2, sx / 2, -sx / 2, -sx / 2, sx / 2, sx / 2, -sx / 2, -sx / 2] 54 | y_corners = [sy / 2, -sy / 2, -sy / 2, sy / 2, sy / 2, -sy / 2, -sy / 2, sy / 2] 55 | z_corners = [sz / 2, sz / 2, sz / 2, sz / 2, -sz / 2, -sz / 2, -sz / 2, -sz / 2] 56 | corners_3d = np.vstack([x_corners, y_corners, z_corners]) 57 | corners_3d[0, :] = corners_3d[0, :] + center[0] 58 | corners_3d[1, :] = corners_3d[1, :] + center[1] 59 | corners_3d[2, :] = corners_3d[2, :] + center[2] 60 | corners_3d = np.transpose(corners_3d) 61 | 62 | return corners_3d 63 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from os.path import dirname, join 5 | 6 | from utils.config import Config 7 | from utils.distributed import init_distributed_mode, is_main_process 8 | from utils.logger import setup_logger 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_config(): 14 | """Conbine yaml config and command line config with OmegaConf. 15 | Also converts types, e.g., `'None'` (str) --> `None` (None) 16 | """ 17 | config = Config.get_config() 18 | if config.debug: 19 | config.wandb.enable = False 20 | return config 21 | 22 | 23 | def setup_evaluate_config(config): 24 | """setup evaluation default settings, e.g., disable wandb""" 25 | assert config.evaluate 26 | config.wandb.enable = False 27 | if config.output_dir is None: 28 | config.output_dir = join(dirname(config.pretrained_path), "eval") 29 | return config 30 | 31 | 32 | def setup_output_dir(output_dir, excludes=["code"]): 33 | """ensure not overwritting an exisiting/non-empty output dir""" 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir, exist_ok=False) 36 | else: 37 | existing_dirs_files = os.listdir(output_dir) # list 38 | remaining = set(existing_dirs_files) - set(excludes) 39 | remaining = [e for e in remaining if "slurm" not in e] 40 | remaining = [e for e in remaining if ".out" not in e] 41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" 42 | logger.warn(f"remaining dirs or files: {remaining}") 43 | 44 | 45 | def setup_main(): 46 | """ 47 | Setup config, logger, output_dir, etc. 48 | Shared for pretrain and all downstream tasks. 49 | """ 50 | config = setup_config() 51 | if hasattr(config, "evaluate") and config.evaluate: 52 | config = setup_evaluate_config(config) 53 | init_distributed_mode(config) 54 | 55 | if is_main_process(): 56 | setup_output_dir(config.output_dir, excludes=["code"]) 57 | setup_logger(output=config.output_dir, color=True, name="vindlu") 58 | logger.info(f"config: {Config.pretty_text(config)}") 59 | Config.dump(config, os.path.join(config.output_dir, "config.json")) 60 | return config 61 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import logging 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def setup_for_distributed(is_master): 11 | import warnings 12 | 13 | builtin_warn = warnings.warn 14 | 15 | def warn(*args, **kwargs): 16 | force = kwargs.pop("force", False) 17 | if is_master or force: 18 | builtin_warn(*args, **kwargs) 19 | 20 | # Log warnings only once 21 | warnings.warn = warn 22 | warnings.simplefilter("once", UserWarning) 23 | 24 | if not is_master: 25 | logging.disable() 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | 36 | def get_world_size(): 37 | if not is_dist_avail_and_initialized(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def get_rank(): 43 | if not is_dist_avail_and_initialized(): 44 | return 0 45 | return dist.get_rank() 46 | 47 | 48 | def is_main_process(): 49 | return get_rank() == 0 50 | 51 | 52 | def save_on_master(*args, **kwargs): 53 | if is_main_process(): 54 | torch.save(*args, **kwargs) 55 | 56 | 57 | def is_port_in_use(port): 58 | import socket 59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 60 | return s.connect_ex(('localhost', port)) == 0 61 | 62 | 63 | def init_distributed_mode(args): 64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 65 | # job started by torch.distributed.launch 66 | args.rank = int(os.environ["RANK"]) 67 | args.world_size = int(os.environ['WORLD_SIZE']) 68 | args.gpu = int(os.environ['LOCAL_RANK']) 69 | elif 'SLURM_PROCID' in os.environ: 70 | # local rank on the current node / global rank 71 | local_rank = int(os.environ['SLURM_LOCALID']) 72 | global_rank = int(os.environ['SLURM_PROCID']) 73 | # number of processes / GPUs per node 74 | world_size = int(os.environ["SLURM_NNODES"]) * \ 75 | int(os.environ["SLURM_TASKS_PER_NODE"][0]) 76 | 77 | args.rank = global_rank 78 | args.gpu = local_rank 79 | args.world_size = world_size 80 | else: 81 | logger.info('Not using distributed mode') 82 | args.distributed = False 83 | return 84 | 85 | args.distributed = True 86 | 87 | torch.cuda.set_device(args.gpu) 88 | args.dist_backend = 'nccl' 89 | 90 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node 91 | dist_port = int(args.dist_url.split(":")[-1]) 92 | while is_port_in_use(dist_port): 93 | dist_port += 10 94 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) 95 | 96 | logger.info('| distributed init (rank {}): {}'.format( 97 | args.rank, args.dist_url)) 98 | if "SLURM_JOB_ID" in os.environ: 99 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") 100 | torch.distributed.init_process_group( 101 | backend=args.dist_backend, init_method=args.dist_url, 102 | world_size=args.world_size, rank=args.rank) 103 | torch.distributed.barrier() 104 | setup_for_distributed(args.rank == 0) 105 | 106 | 107 | # Copyright (c) Facebook, Inc. and its affiliates. 108 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py 109 | class GatherLayer(torch.autograd.Function): 110 | """ 111 | Gather tensors from all workers with support for backward propagation: 112 | This implementation does not cut the gradients as torch.distributed.all_gather does. 113 | """ 114 | 115 | @staticmethod 116 | def forward(ctx, x): 117 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 118 | dist.all_gather(output, x) 119 | return tuple(output) 120 | 121 | @staticmethod 122 | def backward(ctx, *grads): 123 | all_gradients = torch.stack(grads) 124 | dist.all_reduce(all_gradients) 125 | return all_gradients[dist.get_rank()] 126 | 127 | 128 | # copied from megavlt 129 | def gather_tensor_along_batch_with_backward(tensor, dim=0): 130 | world_size = get_world_size() 131 | 132 | if world_size < 2: 133 | return tensor 134 | 135 | tensor_list = GatherLayer.apply(tensor) 136 | tensor_list = torch.cat(tensor_list, dim=dim) 137 | return tensor_list 138 | 139 | 140 | @torch.no_grad() 141 | def gather_tensor_along_batch(tensor, dim=0): 142 | """ 143 | Performs all_gather operation on the provided tensors. 144 | *** Warning ***: torch.distributed.all_gather has no gradient. 145 | """ 146 | world_size = get_world_size() 147 | 148 | if world_size < 2: 149 | return tensor 150 | 151 | with torch.no_grad(): 152 | tensor_list = [] 153 | 154 | for _ in range(world_size): 155 | tensor_list.append(torch.zeros_like(tensor)) 156 | 157 | dist.all_gather(tensor_list, tensor) 158 | tensor_list = torch.cat(tensor_list, dim=dim) 159 | return tensor_list 160 | -------------------------------------------------------------------------------- /utils/easydict.py: -------------------------------------------------------------------------------- 1 | class EasyDict(dict): 2 | """ 3 | Get attributes 4 | 5 | >>> d = EasyDict({'foo':3}) 6 | >>> d['foo'] 7 | 3 8 | >>> d.foo 9 | 3 10 | >>> d.bar 11 | Traceback (most recent call last): 12 | ... 13 | AttributeError: 'EasyDict' object has no attribute 'bar' 14 | 15 | Works recursively 16 | 17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 18 | >>> isinstance(d.bar, dict) 19 | True 20 | >>> d.bar.x 21 | 1 22 | 23 | Bullet-proof 24 | 25 | >>> EasyDict({}) 26 | {} 27 | >>> EasyDict(d={}) 28 | {} 29 | >>> EasyDict(None) 30 | {} 31 | >>> d = {'a': 1} 32 | >>> EasyDict(**d) 33 | {'a': 1} 34 | 35 | Set attributes 36 | 37 | >>> d = EasyDict() 38 | >>> d.foo = 3 39 | >>> d.foo 40 | 3 41 | >>> d.bar = {'prop': 'value'} 42 | >>> d.bar.prop 43 | 'value' 44 | >>> d 45 | {'foo': 3, 'bar': {'prop': 'value'}} 46 | >>> d.bar.prop = 'newer' 47 | >>> d.bar.prop 48 | 'newer' 49 | 50 | 51 | Values extraction 52 | 53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 54 | >>> isinstance(d.bar, list) 55 | True 56 | >>> from operator import attrgetter 57 | >>> map(attrgetter('x'), d.bar) 58 | [1, 3] 59 | >>> map(attrgetter('y'), d.bar) 60 | [2, 4] 61 | >>> d = EasyDict() 62 | >>> d.keys() 63 | [] 64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 65 | >>> d.foo 66 | 3 67 | >>> d.bar.x 68 | 1 69 | 70 | Still like a dict though 71 | 72 | >>> o = EasyDict({'clean':True}) 73 | >>> o.items() 74 | [('clean', True)] 75 | 76 | And like a class 77 | 78 | >>> class Flower(EasyDict): 79 | ... power = 1 80 | ... 81 | >>> f = Flower() 82 | >>> f.power 83 | 1 84 | >>> f = Flower({'height': 12}) 85 | >>> f.height 86 | 12 87 | >>> f['power'] 88 | 1 89 | >>> sorted(f.keys()) 90 | ['height', 'power'] 91 | 92 | update and pop items 93 | >>> d = EasyDict(a=1, b='2') 94 | >>> e = EasyDict(c=3.0, a=9.0) 95 | >>> d.update(e) 96 | >>> d.c 97 | 3.0 98 | >>> d['c'] 99 | 3.0 100 | >>> d.get('c') 101 | 3.0 102 | >>> d.update(a=4, b=4) 103 | >>> d.b 104 | 4 105 | >>> d.pop('a') 106 | 4 107 | >>> d.a 108 | Traceback (most recent call last): 109 | ... 110 | AttributeError: 'EasyDict' object has no attribute 'a' 111 | """ 112 | 113 | def __init__(self, d=None, **kwargs): 114 | if d is None: 115 | d = {} 116 | if kwargs: 117 | d.update(**kwargs) 118 | for k, v in d.items(): 119 | setattr(self, k, v) 120 | # Class attributes 121 | for k in self.__class__.__dict__.keys(): 122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): 123 | setattr(self, k, getattr(self, k)) 124 | 125 | def __setattr__(self, name, value): 126 | if isinstance(value, (list, tuple)): 127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 128 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 129 | value = self.__class__(value) 130 | super(EasyDict, self).__setattr__(name, value) 131 | super(EasyDict, self).__setitem__(name, value) 132 | 133 | __setitem__ = __setattr__ 134 | 135 | def update(self, e=None, **f): 136 | d = e or dict() 137 | d.update(f) 138 | for k in d: 139 | setattr(self, k, d[k]) 140 | 141 | def pop(self, k, d=None): 142 | if hasattr(self, k): 143 | delattr(self, k) 144 | return super(EasyDict, self).pop(k, d) 145 | 146 | 147 | if __name__ == "__main__": 148 | import doctest 149 | 150 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from torch.optim import Optimizer 5 | import math 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def create_scheduler(args, optimizer): 10 | lr_scheduler = None 11 | if args.sched == 'cosine': 12 | lr_scheduler = get_cosine_schedule_with_warmup( 13 | optimizer, 14 | num_warmup_steps=args.num_warmup_steps, 15 | num_training_steps=args.num_training_steps, 16 | num_cycles=0.5, 17 | min_lr_multi=args.min_lr_multi 18 | ) 19 | return lr_scheduler 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 25 | ): 26 | """ 27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py 28 | 29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 31 | initial lr set in the optimizer. 32 | Args: 33 | optimizer ([`~torch.optim.Optimizer`]): 34 | The optimizer for which to schedule the learning rate. 35 | num_warmup_steps (`int`): 36 | The number of steps for the warmup phase. 37 | num_training_steps (`int`): 38 | The total number of training steps. 39 | num_cycles (`float`, *optional*, defaults to 0.5): 40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 41 | following a half-cosine). 42 | min_lr_multi (`float`, *optional*, defaults to 0): 43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. 44 | last_epoch (`int`, *optional*, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step): 51 | if current_step < num_warmup_steps: 52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) 53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 55 | 56 | return LambdaLR(optimizer, lr_lambda, last_epoch) 57 | --------------------------------------------------------------------------------