├── .gitignore
├── LICENSE
├── README.md
├── assets
└── ga.png
├── dataset
├── __init__.py
├── base_dataset.py
├── dataloader.py
├── dataset_train.py
└── dataset_val.py
├── demo
├── demo.py
└── run_demo.sh
├── models
├── __init__.py
├── configuration_llama.py
├── graph3dllm.py
├── helpers.py
├── load_llama.py
├── modeling_llama.py
├── moe
│ ├── __init__.py
│ ├── layer.py
│ └── moe_lora.py
├── position_embedding.py
└── transformer_vanilla
│ ├── __init__.py
│ ├── mhsa.py
│ ├── self_attention.py
│ └── transformer_block.py
├── others
├── analyze_grounding_results.py
├── calc_scanrefer_grounding_acc.py
├── eval_offline.py
├── extract_target_noun.py
├── gpt_generate.py
├── ground_visualize.py
├── llama_tmp.py
├── modify.py
├── prepare_anno_stage1.py
├── prepare_captions_noun.py
├── prepare_describe.py
├── prepare_eval_json.py
├── prepare_identifier_rich.py
├── prepare_multi3dref.py
├── prepare_obj_align_data.py
├── prepare_obj_caption.py
├── prepare_ref_captions.py
├── prepare_referit_anno_stage1.py
├── prepare_scanqa.py
├── prepare_scanrefer_grounding_train.py
├── prepare_scene_align_data.py
├── prepare_scene_level_dataset.py
├── prepare_sqa3d.py
├── prepare_train_stage3.py
├── process_annos.py
├── process_preds.py
├── process_vil3dref_multichoice.py
├── process_vil3dref_results.py
├── run_chat.sh
├── run_eval.sh
├── run_generate.sh
├── tmp.py
├── tmp2.py
├── visualize.py
└── viz.py
├── preprocess
├── README.md
├── prepare_filtered_mask3d_gnn_data.py
├── prepare_gnn_data.py
├── prepare_mask3d_data.py
├── prepare_mask3d_img_feat.py
├── prepare_multi3dref_annos.py
├── prepare_multi3dref_location_annos.py
├── prepare_nr3d_annos.py
├── prepare_nr3dcaption_annos.py
├── prepare_objalign_annos.py
├── prepare_scan2cap_annos.py
├── prepare_scan2cap_location_annos.py
├── prepare_scannet_attributes.py
├── prepare_scannet_attributes_clasp.py
├── prepare_scannet_caption_annos.py
├── prepare_scannet_mask3d_attributes.py
├── prepare_scannet_region_caption_annos.py
├── prepare_scanqa_annos.py
├── prepare_scanrefer_annos.py
├── prepare_scanrefer_location_annos.py
├── prepare_sqa3d_annos.py
├── prepare_sr3d_annos.py
├── process_scannet_data.py
└── run_prepare.sh
├── prompts
├── concise_description.txt
├── concise_description_objxx.txt
├── conv_description.txt
├── dataset_generation
│ ├── conversation.txt
│ ├── detail.txt
│ └── textualize_obj.txt
├── detailed_description.txt
├── grounding_answer_templates.txt
├── grounding_prompts.txt
├── instruction.txt
├── nr3d_caption_templates.txt
├── object_caption_templates.txt
├── prompts.py
├── scanrefer_caption_templates.txt
├── scene_align_template.txt
├── score_template.txt
├── score_template_old.txt
├── system.txt
└── system_backup.txt
├── requirements.txt
├── scripts
├── config-gt-pretrain.py
├── config.py
├── config_3dgraphllm.py
├── run.sh
└── run_gt_pretrain.sh
├── tasks
├── shared_utils.py
└── train.py
└── utils
├── __init__.py
├── basic_utils.py
├── box_utils.py
├── config.py
├── config_utils.py
├── distributed.py
├── easydict.py
├── eval.py
├── eval_tmp.py
├── helper.py
├── logger.py
├── optimizer.py
├── pc_util.py
└── scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | outputs/
164 | annotations/
165 | wandb/
166 | vicuna-7b-v1.5/
167 | utils/capeval/
168 | annotations
169 | Meta-Llama-3-8B-Instruct
170 | *.pth
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Chat-Scene
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 3DGraphLLM
2 |
3 | [](https://arxiv.org/abs/2412.18450)
4 | [](https://huggingface.co/wingrune/3DGraphLLM)
5 |
6 | In this work, we propose 3DGraphLLM, a method for constructing a learnable representation of a 3D scene graph, which serves as input for LLMs to perform 3D vision-language tasks.
7 |
8 |
9 |
10 |
11 |
12 | ## News
13 |
14 | [2024.12] We release 3DGraphLLM pre-training on GT instance segmentation scene graphs
15 |
16 | [2024.12] We release 3DGraphLLM [paper](https://arxiv.org/abs/2412.18450) [code](https://github.com/CognitiveAISystems/3DGraphLLM)
17 |
18 | ### 🔥 Semantic relations boost LLM performance on 3D Referred Object Grounding and Dense Scene Captioning tasks
19 |
20 |
21 |
22 | | | [ScanRefer](https://github.com/daveredrum/ScanRefer) | | [Multi3dRefer](https://github.com/3dlg-hcvc/M3DRef-CLIP) | | [Scan2Cap](https://github.com/daveredrum/Scan2Cap) | | [ScanQA](https://github.com/ATR-DBI/ScanQA) | | [SQA3D](https://github.com/SilongYong/SQA3D) |
23 | |:----: |:---------: |:-------: |:------: |:------: |:---------: |:----------: |:------------: |:------: |:-----: |
24 | | | Acc@0.25 | Acc@0.5 | F1@0.25 | F1@0.5 | CIDEr@0.5 | B-4@0.5 | CIDEr | B-4 | EM |
25 | | [Chat-Scene](https://github.com/ZzZZCHS/Chat-Scene/tree/dev) | 55.5 | 50.2 | 57.1 | 52.3 | 77.1 | 36.3 | **87.7** | **14.3** | 54.6 |
26 | | 3DGraphLLM Vicuna-1.5 | 57.0 | 51.3 | 60.1 | 55.4 | 81.2 | 36.3 | 87.6 | 12.1 | 53.1 |
27 | **3DGraphLLM LLAMA3-8B** | **60.2** | **54.6** | **63.0** | **58.2** | **82.9** | **37.8** | 83.1 | 12.5 | **55.2** |
28 |
29 |
30 | ## 🔨 Preparation
31 |
32 | - Prepare the environment:
33 |
34 | ```shell
35 | conda create -n 3dgraphllm python=3.9.17
36 | conda activate 3dgraphllm
37 | conda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia
38 | pip install -r requirements.txt
39 | ```
40 | - If you don't have root permissions to install java (needed for pycocoeval scripts for metrics such as BLEU and CIDER), install it with conda:
41 |
42 | ```
43 | conda install -c conda-forge openjdk
44 | ```
45 |
46 |
47 | - Download LLM backbone:
48 | - We use LLAMA3-8B-Instruct in our experiments, which can be downloaded from [Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct).
49 |
50 | - Change the `llama_model_path` in [config.py](./scripts/config.py) to the path of `LLAMA3-8B-Instruct`.
51 |
52 |
53 | - Annotations and extracted features:
54 |
55 | Please follow the instructions in [preprocess](preprocess/).
56 |
57 |
58 | ## 🤖 Training and Inference
59 |
60 | - Pre-training on GT instance segmentation scene graphs.
61 | - Modify [run_gt_pretrain.sh](scripts/run_gt_pretrain.sh):
62 | ```python
63 | train_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref#nr3d_caption#obj_align"
64 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref"
65 | evaluate=False
66 | ```
67 |
68 |
69 | Explanation of "train_tag" and "val_tag"
70 |
71 | - Use `#` to seperate different datasets
72 |
73 | - Datasets:
74 | - `scanrefer`: [ScanRefer](https://github.com/daveredrum/ScanRefer) Dataset
75 | - `scan2cap`: [Scan2Cap](https://github.com/daveredrum/Scan2Cap) Dataset
76 | - `scanqa`: [ScanQA](https://github.com/ATR-DBI/ScanQA) Dataset
77 | - `sqa3d`: [SQA3D](https://github.com/SilongYong/SQA3D) Dataset
78 | - `multi3dref`: [Multi3dRefer](https://github.com/3dlg-hcvc/M3DRef-CLIP) Dataset
79 | - `nr3d_caption`: A captioning dataset originated from [Nr3D](https://github.com/referit3d/referit3d).
80 | - `obj_align`: A dataset originated from ScanRefer to align the object identifiers with object tokens.
81 |
82 |
83 | - Run: `bash scripts/run_gt_pretrain.sh`
84 |
85 | - Training
86 | - Modify [run.sh](scripts/run.sh):
87 | ```python
88 | train_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref#nr3d_caption#obj_align"
89 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref"
90 | evaluate=False
91 | pretrained_path="outputs/llama3-8b-gt-pretrain-2/ckpt_00_28927.pth"
92 | ```
93 | - Run: `bash scripts/run.sh`
94 |
95 |
96 | - Inference
97 |
98 | - Modify [run.sh](scripts/run.sh):
99 |
100 | ```python
101 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref"
102 | evaluate=True
103 | pretrained_path="/path/to/pretrained_model.pth"
104 | ```
105 |
106 | - Run: `bash scripts/run.sh`
107 |
108 | ## 🚀 Demo
109 |
110 | - Run: `bash demo/run_demo.sh`. You will be prompted to ask different queries about Scene 435 of ScanNet.
111 |
112 |
113 | ## 📪 Contact
114 |
115 | If you have any questions about the project, please open an issue in this repository or send an email to [Tatiana Zemskova](zemskova@airi.net).
116 |
117 | ## 📑 Citation
118 |
119 | If you find this work helpful, please consider citing our work as:
120 |
121 | ```
122 | @misc{zemskova20243dgraphllm,
123 | title={3DGraphLLM: Combining Semantic Graphs and Large Language Models for 3D Scene Understanding},
124 | author={Tatiana Zemskova and Dmitry Yudin},
125 | year={2024},
126 | eprint={2412.18450},
127 | archivePrefix={arXiv},
128 | primaryClass={cs.CV},
129 | url={https://arxiv.org/abs/2412.18450},
130 | }
131 | ```
132 |
133 |
134 | ## 😊 Acknowledgement
135 |
136 | Thanks to the open source of the following projects:
137 |
138 | [Chat-Scene](https://github.com/ZzZZCHS/Chat-Scene/tree/dev)
139 |
--------------------------------------------------------------------------------
/assets/ga.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/assets/ga.png
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import ConcatDataset, DataLoader
3 | from torchvision import transforms
4 | from torchvision.transforms import InterpolationMode
5 |
6 | from dataset.dataloader import MetaLoader
7 | from dataset.dataset_train import TrainDataset
8 | from dataset.dataset_val import ValDataset
9 |
10 | import logging
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def create_dataset(config):
15 | if config.evaluate:
16 | train_datasets = []
17 | else:
18 | train_files = []
19 | for train_name in config.train_tag.split('#'):
20 | if train_name not in config.train_file_dict:
21 | raise NotImplementedError
22 | train_files.append(config.train_file_dict[train_name])
23 |
24 | train_datasets = []
25 | datasets = []
26 | for train_file in train_files:
27 | datasets.append(TrainDataset(ann_list=train_file, config=config))
28 | dataset = ConcatDataset(datasets)
29 | train_datasets.append(dataset)
30 |
31 | val_files = {}
32 | for val_name in config.val_tag.split('#'):
33 | if val_name not in config.val_file_dict:
34 | raise NotImplementedError
35 | val_files[val_name] = config.val_file_dict[val_name]
36 |
37 | val_datasets = []
38 | for k, v in val_files.items():
39 | datasets = []
40 | if type(v[0]) != list:
41 | v = [v]
42 | for val_file in v:
43 | datasets.append(ValDataset(ann_list=val_file, dataset_name=k, config=config))
44 | dataset = ConcatDataset(datasets)
45 | val_datasets.append(dataset)
46 |
47 | return train_datasets, val_datasets
48 |
49 |
50 | def create_sampler(datasets, shuffles, num_tasks, global_rank):
51 | samplers = []
52 | for dataset, shuffle in zip(datasets, shuffles):
53 | sampler = torch.utils.data.DistributedSampler(
54 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
55 | )
56 | samplers.append(sampler)
57 | return samplers
58 |
59 |
60 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
61 | loaders = []
62 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
63 | datasets, samplers, batch_size, num_workers, is_trains, collate_fns
64 | ):
65 | if is_train:
66 | shuffle = sampler is None
67 | drop_last = True
68 | else:
69 | shuffle = False
70 | drop_last = False
71 | loader = DataLoader(
72 | dataset,
73 | batch_size=bs,
74 | num_workers=n_worker,
75 | pin_memory=False,
76 | sampler=sampler,
77 | shuffle=shuffle,
78 | collate_fn=collate_fn,
79 | drop_last=drop_last,
80 | persistent_workers=True if n_worker > 0 else False,
81 | )
82 | loaders.append(loader)
83 | return loaders
84 |
85 |
86 | def iterate_dataloaders(dataloaders):
87 | """Alternatively generate data from multiple dataloaders,
88 | since we use `zip` to concat multiple dataloaders,
89 | the loop will end when the smaller dataloader runs out.
90 |
91 | Args:
92 | dataloaders List(DataLoader): can be a single or multiple dataloaders
93 | """
94 | for data_tuples in zip(*dataloaders):
95 | for idx, data in enumerate(data_tuples):
96 | yield dataloaders[idx].dataset.media_type, data
97 |
--------------------------------------------------------------------------------
/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process
4 | import random
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class MetaLoader(object):
11 | """ wraps multiple data loader """
12 | def __init__(self, name2loader):
13 | """Iterates over multiple dataloaders, it ensures all processes
14 | work on data from the same dataloader. This loader will end when
15 | the shorter dataloader raises StopIteration exception.
16 |
17 | loaders: Dict, {name: dataloader}
18 | """
19 | self.name2loader = name2loader
20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()}
21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
22 | index2name = {v: k for k, v in name2index.items()}
23 |
24 | iter_order = []
25 | for n, l in name2loader.items():
26 | iter_order.extend([name2index[n]]*len(l))
27 |
28 | random.shuffle(iter_order)
29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
30 |
31 | # sync
32 | if is_dist_avail_and_initialized():
33 | # make sure all processes have the same order so that
34 | # each step they will have data from the same loader
35 | dist.broadcast(iter_order, src=0)
36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
37 |
38 | logger.info(str(self))
39 |
40 | def __str__(self):
41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
42 | for idx, (name, loader) in enumerate(self.name2loader.items()):
43 | output.append(
44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
45 | )
46 | return "\n".join(output)
47 |
48 | def __len__(self):
49 | return len(self.iter_order)
50 |
51 | def __iter__(self):
52 | """ this iterator will run indefinitely """
53 | for name in self.iter_order:
54 | _iter = self.name2iter[name]
55 | batch = next(_iter)
56 | yield name, batch
57 |
--------------------------------------------------------------------------------
/dataset/dataset_train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from dataset.base_dataset import BaseDataset, update_caption
9 | import glob
10 | import random
11 | from prompts.prompts import obj_caption_wid_prompt
12 | from torch.nn.utils.rnn import pad_sequence
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 |
18 | class TrainDataset(BaseDataset):
19 |
20 | cached_feats = {}
21 |
22 | def __init__(self, ann_list, config, **kwargs):
23 | super().__init__()
24 | self.feat_dim = config.model.input_dim
25 | self.img_feat_dim = config.model.img_input_dim
26 | self.max_obj_num = config.model.max_obj_num
27 | self.knn = config.model.knn
28 | self.point_cloud_type = ann_list[5]
29 |
30 | feat_file, img_feat_file, attribute_file, anno_file, feats_gnn_file = ann_list[:5]
31 | self.attributes = torch.load(attribute_file, map_location='cpu') if attribute_file is not None else None
32 | self.anno = json.load(open(anno_file, 'r'))
33 |
34 |
35 | if len(ann_list) > 6:
36 | sample_ratio = ann_list[-1]
37 | if sample_ratio < 1:
38 | self.anno = random.sample(self.anno, int(sample_ratio * len(self.anno)))
39 |
40 | if feat_file in TrainDataset.cached_feats and img_feat_file in TrainDataset.cached_feats and feats_gnn_file in TrainDataset.cached_feats:
41 | self.scene_feats, self.scene_masks = TrainDataset.cached_feats[feat_file]
42 | self.scene_img_feats = TrainDataset.cached_feats[img_feat_file]
43 | (self.scene_gnn_feats, self.scene_foreground_ids) = TrainDataset.cached_feats[feats_gnn_file]
44 | else:
45 | if feat_file is not None and os.path.exists(feat_file):
46 | self.feats = torch.load(feat_file, map_location='cpu')
47 | else:
48 | self.feats = None
49 | if img_feat_file is not None and os.path.exists(img_feat_file):
50 | self.img_feats = torch.load(img_feat_file, map_location='cpu')
51 | else:
52 | self.img_feats = None
53 | if feats_gnn_file is not None and os.path.exists(feats_gnn_file):
54 | self.feats_edge = torch.load(feats_gnn_file, map_location='cpu')
55 | else:
56 | self.feats_edge = None
57 |
58 | if self.attributes is None:
59 | self.scene_feats = self.feats
60 | self.scene_img_feats = self.scene_masks = None
61 | else:
62 | self.scene_feats, self.scene_img_feats, self.scene_masks, self.scene_gnn_feats, self.scene_foreground_ids = self.prepare_scene_features()
63 | TrainDataset.cached_feats[feat_file] = (self.scene_feats, self.scene_masks)
64 | TrainDataset.cached_feats[img_feat_file] = self.scene_img_feats
65 | TrainDataset.cached_feats[feats_gnn_file] = (self.scene_gnn_feats, self.scene_foreground_ids)
66 |
67 |
68 | def __len__(self):
69 | return len(self.anno)
70 |
71 | def __getitem__(self, index):
72 | if self.attributes is not None and self.anno[index]['scene_id'] not in self.attributes:
73 | # print(f"{self.anno[index]['scene_id']} not in attribute file!")
74 | return self.__getitem__(random.randint(0, len(self.anno)-1))
75 | if "obj_id" in self.anno[index]:
76 | obj_id = int(self.anno[index]["obj_id"])
77 | else:
78 | obj_id = random.randint(0, self.max_obj_num - 1)
79 | if 'prompt' not in self.anno[index]:
80 | question = random.choice(obj_caption_wid_prompt).replace('', f"")
81 | else:
82 | question = self.anno[index]["prompt"]
83 | caption = self.anno[index]["caption"]
84 | scene_id, scene_feat, scene_img_feat, scene_mask, scene_locs, assigned_ids, scene_gnn_feats, scene_foreground_ids = self.get_anno(index)
85 | caption = update_caption(caption, assigned_ids)
86 | question = update_caption(question, assigned_ids)
87 | return scene_feat, scene_img_feat, scene_mask, scene_locs, obj_id, assigned_ids, scene_gnn_feats, scene_foreground_ids, caption, question
88 |
89 |
90 | def train_collate_fn(batch):
91 | scene_feats, scene_img_feats, scene_masks, scene_locs, obj_ids, assigned_ids, scene_gnn_feats, scene_foreground_ids, captions, questions = zip(*batch)
92 | batch_scene_feat = pad_sequence(scene_feats, batch_first=True)
93 | batch_scene_img_feat = pad_sequence(scene_img_feats, batch_first=True)
94 | batch_scene_mask = pad_sequence(scene_masks, batch_first=True).to(torch.bool)
95 | batch_scene_locs = pad_sequence(scene_locs, batch_first=True)
96 | batch_assigned_ids = pad_sequence(assigned_ids, batch_first=True)
97 | batch_scene_gnn_feats = pad_sequence(scene_gnn_feats, batch_first=True)
98 | # batch_detach_mask = torch.ones_like(batch_scene_mask, dtype=torch.bool)
99 | # for i in range(batch_detach_mask.shape[0]):
100 | # batch_detach_mask[i][:detach_masks[i].shape[0]] = detach_masks[i]
101 | obj_ids = torch.tensor(obj_ids)
102 | foreground_ids = scene_foreground_ids
103 | return {
104 | "scene_feat": batch_scene_feat,
105 | "scene_img_feat": batch_scene_img_feat,
106 | "scene_locs": batch_scene_locs,
107 | "scene_mask": batch_scene_mask,
108 | "assigned_ids": batch_assigned_ids,
109 | "scene_gnn_feats": batch_scene_gnn_feats,
110 | "foreground_ids": foreground_ids,
111 | # "detach_mask": batch_detach_mask,
112 | "obj_ids": obj_ids,
113 | "answers": captions,
114 | "questions": questions
115 | # "ref_captions": ref_captions,
116 | # "ids": index
117 | }
118 |
--------------------------------------------------------------------------------
/dataset/dataset_val.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import json
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from dataset.base_dataset import BaseDataset, update_caption
9 | import glob
10 | import random
11 | from prompts.prompts import obj_caption_wid_prompt
12 | from torch.nn.utils.rnn import pad_sequence
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class ValDataset(BaseDataset):
18 |
19 | cached_feats = {}
20 |
21 | def __init__(self, ann_list, dataset_name, config, **kwargs):
22 | super().__init__()
23 | self.dataset_name = dataset_name
24 | self.feat_dim = config.model.input_dim
25 | self.img_feat_dim = config.model.img_input_dim
26 | self.max_obj_num = config.model.max_obj_num
27 | self.knn = config.model.knn
28 | self.point_cloud_type = ann_list[5]
29 |
30 | feat_file, img_feat_file, attribute_file, anno_file, feats_gnn_file = ann_list[:5]
31 | self.attributes = torch.load(attribute_file, map_location='cpu') if attribute_file is not None else None
32 | self.anno = json.load(open(anno_file, 'r'))
33 |
34 | if feat_file in ValDataset.cached_feats and img_feat_file in ValDataset.cached_feats and feats_gnn_file in ValDataset.cached_feats:
35 | self.scene_feats, self.scene_masks = ValDataset.cached_feats[feat_file]
36 | self.scene_img_feats = ValDataset.cached_feats[img_feat_file]
37 | (self.scene_gnn_feats, self.scene_foreground_ids) = ValDataset.cached_feats[feats_gnn_file]
38 | else:
39 | if feat_file is not None and os.path.exists(feat_file):
40 | self.feats = torch.load(feat_file, map_location='cpu')
41 | else:
42 | self.feats = None
43 | if img_feat_file is not None and os.path.exists(img_feat_file):
44 | self.img_feats = torch.load(img_feat_file, map_location='cpu')
45 | else:
46 | self.img_feats = None
47 | if feats_gnn_file is not None and os.path.exists(feats_gnn_file):
48 | self.feats_edge = torch.load(feats_gnn_file, map_location='cpu')
49 | else:
50 | self.feats_edge = None
51 |
52 | if self.attributes is None:
53 | self.scene_feats = self.feats
54 | self.scene_img_feats = self.scene_masks = None
55 | else:
56 | self.scene_feats, self.scene_img_feats, self.scene_masks, self.scene_gnn_feats, self.scene_foreground_ids = self.prepare_scene_features()
57 | ValDataset.cached_feats[feat_file] = (self.scene_feats, self.scene_masks)
58 | ValDataset.cached_feats[img_feat_file] = self.scene_img_feats
59 | ValDataset.cached_feats[feats_gnn_file] = (self.scene_gnn_feats, self.scene_foreground_ids)
60 |
61 | def __len__(self):
62 | return len(self.anno)
63 |
64 | def __getitem__(self, index):
65 | scene_id, scene_feat, scene_img_feat, scene_mask, scene_locs, assigned_ids, scene_gnn_feats, scene_foreground_ids = self.get_anno(index)
66 | obj_id = int(self.anno[index].get('obj_id', 0))
67 | pred_id = int(self.anno[index].get('pred_id', 0))
68 | type_info = int(self.anno[index].get('sqa_type', 0))
69 | if 'sqa_type' in self.anno[index]:
70 | type_info = self.anno[index]['sqa_type']
71 | elif 'eval_type' in self.anno[index]:
72 | type_info = self.anno[index]['eval_type']
73 | elif 'type_info' in self.anno[index]:
74 | type_info = self.anno[index]['type_info']
75 | if 'prompt' not in self.anno[index]:
76 | prompt = random.choice(obj_caption_wid_prompt).replace('', f"")
77 | else:
78 | prompt = self.anno[index]["prompt"]
79 | ref_captions = self.anno[index]["ref_captions"].copy() if "ref_captions" in self.anno[index] else []
80 | qid = self.anno[index]["qid"] if "qid" in self.anno[index] else 0
81 | return scene_feat, scene_img_feat, scene_mask, scene_locs, obj_id, assigned_ids, scene_gnn_feats, scene_foreground_ids, prompt, ref_captions, scene_id, qid, pred_id, type_info
82 |
83 |
84 | def val_collate_fn(batch):
85 | scene_feats, scene_img_feats, scene_masks, scene_locs, obj_ids, assigned_ids, scene_gnn_feats, scene_foreground_ids, prompts, ref_captions, scene_ids, qids, pred_ids, type_infos = zip(*batch)
86 | batch_scene_feat = pad_sequence(scene_feats, batch_first=True)
87 | batch_scene_img_feat = pad_sequence(scene_img_feats, batch_first=True)
88 | batch_scene_mask = pad_sequence(scene_masks, batch_first=True).to(torch.bool)
89 | batch_scene_locs = pad_sequence(scene_locs, batch_first=True)
90 | batch_assigned_ids = pad_sequence(assigned_ids, batch_first=True)
91 | batch_scene_gnn_feats = pad_sequence(scene_gnn_feats, batch_first=True)
92 | obj_ids = torch.tensor(obj_ids)
93 | pred_ids = torch.tensor(pred_ids)
94 | foreground_ids = scene_foreground_ids
95 |
96 |
97 | return {
98 | "scene_feat": batch_scene_feat,
99 | "scene_img_feat": batch_scene_img_feat,
100 | "scene_locs": batch_scene_locs,
101 | "scene_mask": batch_scene_mask,
102 | "assigned_ids": batch_assigned_ids,
103 | "scene_gnn_feats": batch_scene_gnn_feats,
104 | "foreground_ids": foreground_ids,
105 | "obj_ids": obj_ids,
106 | "custom_prompt": prompts,
107 | "ref_captions": ref_captions,
108 | "scene_id": scene_ids,
109 | "qid": qids,
110 | "pred_ids": pred_ids,
111 | "type_infos": type_infos
112 | # "ids": index
113 | }
114 |
115 |
--------------------------------------------------------------------------------
/demo/run_demo.sh:
--------------------------------------------------------------------------------
1 | which_python=$(which python)
2 | export PYTHONPATH=${PYTHONPATH}:${which_python}:.
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 |
5 | python demo/demo.py scripts/config_3dgraphllm.py
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/models/__init__.py
--------------------------------------------------------------------------------
/models/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch.nn as nn
3 | from functools import partial
4 | import copy
5 |
6 |
7 | class BatchNormDim1Swap(nn.BatchNorm1d):
8 | """
9 | Used for nn.Transformer that uses a HW x N x C rep
10 | """
11 |
12 | def forward(self, x):
13 | """
14 | x: HW x N x C
15 | permute to N x C x HW
16 | Apply BN on C
17 | permute back
18 | """
19 | hw, n, c = x.shape
20 | x = x.permute(1, 2, 0)
21 | x = super(BatchNormDim1Swap, self).forward(x)
22 | # x: n x c x hw -> hw x n x c
23 | x = x.permute(2, 0, 1)
24 | return x
25 |
26 |
27 | NORM_DICT = {
28 | "bn": BatchNormDim1Swap,
29 | "bn1d": nn.BatchNorm1d,
30 | "id": nn.Identity,
31 | "ln": nn.LayerNorm,
32 | }
33 |
34 | ACTIVATION_DICT = {
35 | "relu": nn.ReLU,
36 | "gelu": nn.GELU,
37 | "silu": nn.SiLU,
38 | "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1),
39 | }
40 |
41 | WEIGHT_INIT_DICT = {
42 | "xavier_uniform": nn.init.xavier_uniform_,
43 | }
44 |
45 |
46 | class GenericMLP(nn.Module):
47 | def __init__(
48 | self,
49 | input_dim,
50 | hidden_dims,
51 | output_dim,
52 | norm_fn_name=None,
53 | activation="silu",
54 | use_conv=False,
55 | dropout=None,
56 | hidden_use_bias=False,
57 | output_use_bias=True,
58 | output_use_activation=False,
59 | output_use_norm=False,
60 | weight_init_name=None,
61 | weight_init_std=0.02
62 | ):
63 | super().__init__()
64 | activation = ACTIVATION_DICT[activation]
65 | norm = None
66 | if norm_fn_name is not None:
67 | norm = NORM_DICT[norm_fn_name]
68 | if norm_fn_name == "ln" and use_conv:
69 | norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm
70 |
71 | if dropout is not None:
72 | if not isinstance(dropout, list):
73 | dropout = [dropout for _ in range(len(hidden_dims))]
74 |
75 | layers = []
76 | prev_dim = input_dim
77 | for idx, x in enumerate(hidden_dims):
78 | if use_conv:
79 | layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias)
80 | else:
81 | layer = nn.Linear(prev_dim, x, bias=hidden_use_bias)
82 | layers.append(layer)
83 | if norm:
84 | layers.append(norm(x))
85 | layers.append(activation())
86 | if dropout is not None:
87 | layers.append(nn.Dropout(p=dropout[idx]))
88 | prev_dim = x
89 | if use_conv:
90 | layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias)
91 | else:
92 | layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias)
93 | layers.append(layer)
94 |
95 | if output_use_norm:
96 | layers.append(norm(output_dim))
97 |
98 | if output_use_activation:
99 | layers.append(activation())
100 |
101 | self.layers = nn.Sequential(*layers)
102 | # self.weight_init_std = weight_init_std
103 | # self.apply(self._init_weights)
104 |
105 | def _init_weights(self, module):
106 | std = self.weight_init_std
107 | if isinstance(module, nn.Linear):
108 | module.weight.data.normal_(mean=0.0, std=std)
109 | if module.bias is not None:
110 | module.bias.data.zero_()
111 |
112 | # if weight_init_name is not None:
113 | # self.do_weight_init(weight_init_name)
114 | #
115 | # def do_weight_init(self, weight_init_name):
116 | # func = WEIGHT_INIT_DICT[weight_init_name]
117 | # for (_, param) in self.named_parameters():
118 | # if param.dim() > 1: # skips batchnorm/layernorm
119 | # func(param)
120 |
121 | def forward(self, x):
122 | output = self.layers(x)
123 | return output
124 |
125 |
126 | def get_clones(module, N):
127 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
128 |
--------------------------------------------------------------------------------
/models/moe/__init__.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import peft
3 | from peft import PEFT_TYPE_TO_CONFIG_MAPPING
4 | from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
5 |
6 |
7 | # register MoE LoRA
8 | class PeftType(str, enum.Enum):
9 | PROMPT_TUNING = "PROMPT_TUNING"
10 | P_TUNING = "P_TUNING"
11 | PREFIX_TUNING = "PREFIX_TUNING"
12 | LORA = "LORA"
13 | ADALORA = "ADALORA"
14 | ADAPTION_PROMPT = "ADAPTION_PROMPT"
15 | IA3 = "IA3"
16 | MOE_LORA = 'MOE_LORA'
17 |
18 | peft.PeftType = PeftType
19 |
20 | from .moe_lora import MoeLoraConfig, MoeLoraModel
21 | PEFT_TYPE_TO_CONFIG_MAPPING[peft.PeftType.MOE_LORA] = MoeLoraConfig
22 | PEFT_TYPE_TO_MODEL_MAPPING[peft.PeftType.MOE_LORA] = MoeLoraModel
23 |
24 |
25 | __all__ = [
26 | 'MoeLoraConfig',
27 | ]
28 |
--------------------------------------------------------------------------------
/models/moe/moe_lora.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from peft import LoraConfig, LoraModel, PeftType
3 | from peft.utils import _get_submodules
4 | from peft.tuners.lora import Embedding as LoraEmbedding
5 | import re
6 | import torch.nn as nn
7 | import warnings
8 | from typing import Optional
9 |
10 | from .layer import MoeLoraLayer, MoeLinear
11 |
12 |
13 | @dataclass
14 | class MoeLoraConfig(LoraConfig):
15 |
16 | num_experts: int = field(
17 | default=16,
18 | metadata={'help': 'number of experts in MoE Lora Layer'})
19 |
20 | gate_mode: str = field(
21 | default='top2_gate',
22 | metadata={'help': 'choice: [top2_gate, dual_gate]'}
23 | )
24 |
25 | def __post_init__(self):
26 | self.peft_type = PeftType.MOE_LORA
27 |
28 |
29 | class MoeLoraModel(LoraModel):
30 |
31 | def __init__(self, model, config, adapter_name):
32 | super().__init__(model, config, adapter_name)
33 |
34 | def _find_and_replace(self, adapter_name):
35 | lora_config = self.peft_config[adapter_name]
36 | loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
37 | if loaded_in_8bit:
38 | raise NotImplementedError
39 |
40 | is_target_modules_in_base_model = False
41 | kwargs = {
42 | "r": lora_config.r,
43 | "num_experts": lora_config.num_experts,
44 | "gate_mode": lora_config.gate_mode,
45 | "lora_alpha": lora_config.lora_alpha,
46 | "lora_dropout": lora_config.lora_dropout,
47 | "fan_in_fan_out": lora_config.fan_in_fan_out,
48 | "init_lora_weights": lora_config.init_lora_weights,
49 | }
50 |
51 | key_list = [key for key, _ in self.model.named_modules()]
52 | for key in key_list:
53 | if isinstance(lora_config.target_modules, str):
54 | target_module_found = re.fullmatch(lora_config.target_modules, key)
55 | else:
56 | target_module_found = any(
57 | key.endswith(target_key) for target_key in lora_config.target_modules)
58 |
59 | if target_module_found:
60 | if not is_target_modules_in_base_model:
61 | is_target_modules_in_base_model = True
62 | parent, target, target_name = _get_submodules(self.model, key)
63 |
64 | if hasattr(target, "bias"):
65 | bias = target.bias is not None
66 | else:
67 | bias = False
68 |
69 | if isinstance(target, MoeLoraLayer) and isinstance(target, nn.Linear):
70 | target.update_moe_layer(
71 | adapter_name,
72 | lora_config.r,
73 | lora_config.num_experts,
74 | lora_config.lora_alpha,
75 | lora_config.lora_dropout,
76 | lora_config.init_lora_weights)
77 |
78 | elif isinstance(target, nn.Linear):
79 | in_features, out_features = target.in_features, target.out_features
80 | if kwargs["fan_in_fan_out"]:
81 | warnings.warn(
82 | "fan_in_fan_out is set to True but the target module is "
83 | "`torch.nn.Linear`. Setting fan_in_fan_out to False."
84 | )
85 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
86 | new_module = MoeLinear(
87 | adapter_name, in_features, out_features, bias=bias, **kwargs)
88 |
89 | elif isinstance(target, nn.Conv1D):
90 | in_features, out_features = target.weight.ds_shape \
91 | if hasattr(target.weight, "ds_shape") else target.weight.shape
92 | if not kwargs["fan_in_fan_out"]:
93 | warnings.warn(
94 | "fan_in_fan_out is set to False but the target module is "
95 | "`torch.nn.Conv1D`. Setting fan_in_fan_out to True."
96 | )
97 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
98 | new_module = MoeLinear(
99 | adapter_name, in_features, out_features, bias=bias, **kwargs)
100 |
101 | else:
102 | raise RuntimeError(
103 | f"Target module {target} is not supported. Currently, only "
104 | f"``torch.nn.Linear`, torch.nn.Conv1D` and `torch.nn.Embedding` "
105 | f"are supported.")
106 |
107 | self._replace_module(parent, target_name, new_module, target)
108 |
109 | if not is_target_modules_in_base_model:
110 | raise ValueError(
111 | f"Target modules {lora_config.target_modules} not found in the base model. "
112 | f"Please check the target modules and try again."
113 | )
114 |
--------------------------------------------------------------------------------
/models/position_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 | import numpy as np
9 | from utils.pc_util import shift_scale_points
10 |
11 |
12 | class PositionalEmbedding(nn.Module):
13 | def __init__(self, sigma=1, dim=4096):
14 | super().__init__()
15 | self.sigma = sigma
16 | self.dim = dim // 2
17 | self.w = torch.randn((self.dim, 3)) * sigma
18 | self.w = nn.Parameter(self.w, requires_grad=True)
19 |
20 | def forward(self, x):
21 | bs, obj_num, _ = x.shape
22 | x = x.reshape(-1, 3)
23 | v = torch.cat([torch.sin(self.w.detach() @ x.T), torch.cos(self.w.detach() @ x.T)])
24 | v = v.T.reshape(bs, obj_num, -1)
25 | v_norm = v / v.norm(dim=-1).unsqueeze(-1)
26 | return v_norm
27 |
28 |
29 | class PositionEmbeddingCoordsSine(nn.Module):
30 | def __init__(
31 | self,
32 | temperature=10000,
33 | normalize=True,
34 | scale=None,
35 | pos_type="fourier",
36 | d_pos=None,
37 | d_in=3,
38 | gauss_scale=1.0,
39 | ):
40 | super().__init__()
41 | self.temperature = temperature
42 | self.normalize = normalize
43 | if scale is not None and normalize is False:
44 | raise ValueError("normalize should be True if scale is passed")
45 | if scale is None:
46 | scale = 2 * math.pi
47 | assert pos_type in ["sine", "fourier"]
48 | self.pos_type = pos_type
49 | self.scale = scale
50 | if pos_type == "fourier":
51 | assert d_pos is not None
52 | assert d_pos % 2 == 0
53 | # define a gaussian matrix input_ch -> output_ch
54 | B = torch.empty((d_in, d_pos // 2)).normal_()
55 | B *= gauss_scale
56 | self.register_buffer("gauss_B", B)
57 | self.d_pos = d_pos
58 |
59 | def get_sine_embeddings(self, xyz, num_channels, input_range):
60 | # clone coords so that shift/scale operations do not affect original tensor
61 | orig_xyz = xyz
62 | xyz = orig_xyz.clone()
63 |
64 | ncoords = xyz.shape[1]
65 | if self.normalize:
66 | xyz = shift_scale_points(xyz, src_range=input_range)
67 |
68 | ndim = num_channels // xyz.shape[2]
69 | if ndim % 2 != 0:
70 | ndim -= 1
71 | # automatically handle remainder by assiging it to the first dim
72 | rems = num_channels - (ndim * xyz.shape[2])
73 |
74 | assert (
75 | ndim % 2 == 0
76 | ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}"
77 |
78 | final_embeds = []
79 | prev_dim = 0
80 |
81 | for d in range(xyz.shape[2]):
82 | cdim = ndim
83 | if rems > 0:
84 | # add remainder in increments of two to maintain even size
85 | cdim += 2
86 | rems -= 2
87 |
88 | if cdim != prev_dim:
89 | dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device)
90 | dim_t = self.temperature ** (2 * (dim_t // 2) / cdim)
91 |
92 | # create batch x cdim x nccords embedding
93 | raw_pos = xyz[:, :, d]
94 | if self.scale:
95 | raw_pos *= self.scale
96 | pos = raw_pos[:, :, None] / dim_t
97 | pos = torch.stack(
98 | (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3
99 | ).flatten(2)
100 | final_embeds.append(pos)
101 | prev_dim = cdim
102 |
103 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
104 | return final_embeds
105 |
106 | def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None):
107 | # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
108 |
109 | if num_channels is None:
110 | num_channels = self.gauss_B.shape[1] * 2
111 |
112 | bsize, npoints = xyz.shape[0], xyz.shape[1]
113 | assert num_channels > 0 and num_channels % 2 == 0
114 | d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1]
115 | d_out = num_channels // 2
116 | assert d_out <= max_d_out
117 | assert d_in == xyz.shape[-1]
118 |
119 | # clone coords so that shift/scale operations do not affect original tensor
120 | orig_xyz = xyz
121 | xyz = orig_xyz.clone()
122 |
123 | ncoords = xyz.shape[1]
124 | if self.normalize:
125 | xyz = shift_scale_points(xyz, src_range=input_range)
126 |
127 | xyz *= 2 * np.pi
128 | xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view(
129 | bsize, npoints, d_out
130 | )
131 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()]
132 |
133 | # return batch x d_pos x npoints embedding
134 | final_embeds = torch.cat(final_embeds, dim=2)
135 | return final_embeds
136 |
137 | def forward(self, xyz, num_channels=None, input_range=None):
138 | assert isinstance(xyz, torch.Tensor)
139 | assert xyz.ndim == 3
140 | # xyz is batch x npoints x 3
141 | if self.pos_type == "sine":
142 | with torch.no_grad():
143 | return self.get_sine_embeddings(xyz, num_channels, input_range)
144 | elif self.pos_type == "fourier":
145 | with torch.no_grad():
146 | return self.get_fourier_embeddings(xyz, num_channels, input_range)
147 | else:
148 | raise ValueError(f"Unknown {self.pos_type}")
149 |
150 | def extra_repr(self):
151 | st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}"
152 | if hasattr(self, "gauss_B"):
153 | st += (
154 | f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}"
155 | )
156 | return st
157 |
--------------------------------------------------------------------------------
/models/transformer_vanilla/__init__.py:
--------------------------------------------------------------------------------
1 | from .mhsa import MultiHeadSelfAttention
2 | from .self_attention import SelfAttention
3 | from .transformer_block import TransformerEncoder, CMT
4 |
--------------------------------------------------------------------------------
/models/transformer_vanilla/mhsa.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | def compute_mhsa(q, k, v, scale_factor=1, mask=None):
8 | # resulted shape will be: [batch, heads, tokens, tokens]
9 | scaled_dot_prod = torch.einsum('... i d , ... j d -> ... i j', q, k) * scale_factor
10 |
11 | if mask is not None:
12 | breakpoint()
13 | # assert mask.shape == scaled_dot_prod.shape[2:]
14 | scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
15 |
16 | attention = torch.softmax(scaled_dot_prod, dim=-1)
17 | # calc result per head
18 | return torch.einsum('... i j , ... j d -> ... i d', attention, v)
19 |
20 |
21 | class MultiHeadSelfAttention(nn.Module):
22 | def __init__(self, dim, heads=8, dim_head=None):
23 | """
24 | Implementation of multi-head attention layer of the original transformer model.
25 | einsum and einops.rearrange is used whenever possible
26 | Args:
27 | dim: token's dimension, i.e. word embedding vector size
28 | heads: the number of distinct representations to learn
29 | dim_head: the dim of the head. In general dim_head k b h t d ', k=3, h=self.heads))
47 |
48 | out = compute_mhsa(q, k, v, mask=mask, scale_factor=self.scale_factor)
49 |
50 | # re-compose: merge heads with dim_head
51 | out = rearrange(out, "b h t d -> b t (h d)")
52 | # Apply final linear transformation layer
53 | return self.W_0(out)
54 |
--------------------------------------------------------------------------------
/models/transformer_vanilla/self_attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | class SelfAttention(nn.Module):
8 | """
9 | Implementation of plain self attention mechanism with einsum operations
10 | Paper: https://arxiv.org/abs/1706.03762
11 | Blog: https://theaisummer.com/transformer/
12 | """
13 |
14 | def __init__(self, dim):
15 | """
16 | Args:
17 | dim: for NLP it is the dimension of the embedding vector
18 | the last dimension size that will be provided in forward(x)
19 | where x is a 3D tensor
20 | """
21 | super().__init__()
22 | self.to_qvk = nn.Linear(dim, dim * 3, bias=False)
23 | self.scale_factor = dim ** -0.5 # 1/np.sqrt(dim)
24 |
25 | def forward(self, x, mask=None):
26 | assert x.dim() == 3, '3D tensor must be provided'
27 | qkv = self.to_qvk(x) # [batch, tokens, dim*3 ]
28 |
29 | # decomposition to q,v,k
30 | # rearrange tensor to [3, batch, tokens, dim] and cast to tuple
31 | q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d ', k=3))
32 |
33 | # Resulting shape: [batch, tokens, tokens]
34 | scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) * self.scale_factor
35 |
36 | if mask is not None:
37 | assert mask.shape == scaled_dot_prod.shape[1:]
38 | scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
39 |
40 | attention = torch.softmax(scaled_dot_prod, dim=-1)
41 | return torch.einsum('b i j , b j d -> b i d', attention, v)
42 |
--------------------------------------------------------------------------------
/others/analyze_grounding_results.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import jsonlines
4 |
5 | def is_explicitly_view_dependent(s):
6 | target_words = {'front', 'behind', 'back', 'right', 'left', 'facing', 'leftmost', 'rightmost',
7 | 'looking', 'across'}
8 | for word in target_words:
9 | if word in s:
10 | return True
11 | return False
12 |
13 |
14 | attrs = torch.load("annotations/scannet_val_attributes.pt")
15 |
16 |
17 | # val_file = "/root/scene-LLaMA/datasets/exprs_neurips22/gtlabelpcd_mix/nr3d/preds/val_outs.json"
18 | val_file = "outputs/2023-11-17-230123_dp0.1_lr2e-4_sta2_ep3_objscale200_scenescale50_bs1_cosine_objalign_scenealign/preds_epoch-1_step0.json"
19 | nr3d_anno_file = "/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/sr3d.jsonl"
20 |
21 | anno_root = "annotations" # annotation dir
22 | val_results = json.load(open(val_file))
23 | # val_results_dict = {}
24 | # for val_item in val_results:
25 | # val_results_dict[val_item["qid"]] = val_item
26 | nr3d_anno = {}
27 | with jsonlines.open(nr3d_anno_file, "r") as reader:
28 | for l in reader:
29 | nr3d_anno[l["item_id"]] = l
30 |
31 | val_num = len(val_results)
32 |
33 | easy_acc, hard_acc = 0, 0
34 | view_indep_acc, view_dep_acc = 0, 0
35 | acc = 0
36 | easy_num, hard_num = 0, 0
37 | view_indep_num, view_dep_num = 0, 0
38 |
39 |
40 | # from nltk.tokenize import sent_tokenize
41 |
42 | # for v in val_results:
43 | #
44 | # pred = v["pred"]
45 | # target = v["ref_captions"][0]
46 | # scene_id = v["scene_id"]
47 | # obj_id = v["obj_id"]
48 | # object_labels = attrs[scene_id]["objects"]
49 | # hardness = object_labels.count(object_labels[obj_id])
50 | # # print(object_labels)
51 | # # breakpoint()
52 | # caption = v["prompt"].split("the given description, \"")[1].split(",\" please provide the")[0]
53 | # tokens = sent_tokenize(caption)[0].split()
54 | # print(tokens)
55 | # # print(caption)
56 | # flag = pred == target
57 | # acc += flag
58 | # if is_explicitly_view_dependent(tokens):
59 | # view_dep_acc += flag
60 | # view_dep_num += 1
61 | # else:
62 | # view_indep_acc += flag
63 | # view_indep_num += 1
64 | # if hardness > 2:
65 | # hard_acc += flag
66 | # hard_num += 1
67 | # else:
68 | # easy_acc += flag
69 | # easy_num += 1
70 |
71 | dataset = "sr3d"
72 | val_file = f"/root/scene-LLaMA/datasets/exprs_neurips22/gtlabelpcd_mix/{dataset}/preds/val_outs.json"
73 | nr3d_anno_file = f"/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/{dataset}.jsonl"
74 |
75 | val_results = json.load(open(val_file))
76 |
77 | val_num = len(val_results)
78 |
79 | nr3d_anno = {}
80 | with jsonlines.open(nr3d_anno_file, "r") as reader:
81 | for l in reader:
82 | nr3d_anno[l["item_id"]] = l
83 |
84 | for k, v in val_results.items():
85 | obj_ids = v["obj_ids"]
86 | obj_logits = v["obj_logits"]
87 | obj_logits = (torch.tensor(obj_logits)).softmax(dim=-1).tolist()
88 | scene_id = nr3d_anno[k]["scan_id"]
89 | caption = nr3d_anno[k]["utterance"]
90 | target_id = nr3d_anno[k]["target_id"]
91 | instance_type = nr3d_anno[k]["instance_type"]
92 | tokens = nr3d_anno[k]["tokens"]
93 | logit_ids = zip(obj_logits, obj_ids)
94 | logit_ids = sorted(logit_ids, reverse=True)
95 | logits, ids = zip(*logit_ids)
96 | object_labels = attrs[scene_id]["objects"]
97 | hardness = object_labels.count(instance_type)
98 | # if k in val_results_dict:
99 | # pred = val_results_dict[k]["pred"]
100 | # ref_captions = val_results_dict[k]["ref_captions"]
101 | # flag = pred == ref_captions[0]
102 | # flag = 0.5
103 | if logits[1] > 0.01 and logits[2] <= 0.01:
104 | flag = 0.6
105 | else:
106 | flag = ids[0] == target_id
107 | acc += flag
108 | if is_explicitly_view_dependent(tokens):
109 | view_dep_acc += flag
110 | view_dep_num += 1
111 | else:
112 | view_indep_acc += flag
113 | view_indep_num += 1
114 | if hardness > 2:
115 | hard_acc += flag
116 | hard_num += 1
117 | else:
118 | easy_acc += flag
119 | easy_num += 1
120 |
121 | print(f"Acc: {float(acc) / val_num} {acc} {val_num}")
122 | print(f"Easy-Acc: {float(easy_acc) / easy_num} {easy_acc} {easy_num}")
123 | print(f"Hard-Acc: {float(hard_acc) / hard_num} {hard_acc} {hard_num}")
124 |
125 | print(f"View-Dep-Acc: {float(view_dep_acc) / view_dep_num}")
126 | print(f"View-Indep-Acc: {float(view_indep_acc) / view_indep_num}")
127 |
--------------------------------------------------------------------------------
/others/eval_offline.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | sys.path.append('.')
5 |
6 | from utils.eval import calc_scanrefer_score, clean_answer, calc_scan2cap_score, calc_scanqa_score, calc_sqa3d_score, calc_multi3dref_score
7 |
8 | from pycocoevalcap.bleu.bleu import Bleu
9 | #from pycocoevalcap.meteor.meteor import Meteor
10 | from pycocoevalcap.rouge.rouge import Rouge
11 | from pycocoevalcap.cider.cider import Cider
12 | #from pycocoevalcap.spice.spice import Spice
13 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
14 |
15 | output_dir = 'outputs/3dgraphllm_2e-5_ep6'
16 |
17 | tokenizer = PTBTokenizer()
18 | scorers = [
19 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
20 | #(Meteor(), "METEOR"),
21 | (Rouge(), "ROUGE_L"),
22 | (Cider(), "CIDEr"),
23 | #(Spice(), "SPICE")
24 | ]
25 |
26 |
27 | prefix = 'preds_epoch4_step0'
28 |
29 | all_val_scores = {}
30 |
31 | #for task in ['scanqa', 'scanrefer', 'scan2cap', 'sqa3d', 'multi3dref']:
32 | for task in ['scanrefer']:
33 | save_preds = []
34 | for filename in os.listdir(output_dir):
35 | if filename.startswith(prefix) and task in filename:
36 | preds = json.load(open(os.path.join(output_dir, filename)))
37 | save_preds += preds
38 | print(len(save_preds))
39 | val_scores = {}
40 | if task == 'scanqa':
41 | val_scores = calc_scanqa_score(save_preds, tokenizer, scorers)
42 | if task == 'scanrefer':
43 | val_scores = calc_scanrefer_score(save_preds)
44 | if task == 'multi3dref':
45 | val_scores = calc_multi3dref_score(save_preds)
46 | if task == 'scan2cap':
47 | val_scores = calc_scan2cap_score(save_preds, tokenizer, scorers)
48 | if task == 'sqa3d':
49 | val_scores = calc_sqa3d_score(save_preds, tokenizer, scorers)
50 |
51 | all_val_scores = {**all_val_scores, **val_scores}
52 |
53 | print(json.dumps(all_val_scores, indent=4))
--------------------------------------------------------------------------------
/others/extract_target_noun.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import sys
4 |
5 | from models.modeling_llama import LlamaForCausalLM
6 | from transformers import LlamaTokenizer, LlamaConfig
7 | from collections import defaultdict
8 |
9 | from tqdm import tqdm
10 |
11 |
12 | llama_model_path = "model/vicuna-7b-v0"
13 |
14 | print("Loading LLaMA")
15 | llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
16 | model = LlamaForCausalLM.from_pretrained(
17 | llama_model_path,
18 | torch_dtype=torch.float16
19 | ).cuda()
20 | for p in model.parameters():
21 | p.requires_grad = False
22 | print("Loading LLaMA Done")
23 |
24 | model.eval()
25 |
26 |
27 | dataset_name = "sr3d_merge"
28 | anno_file = f"anno/{dataset_name}_captions.json"
29 | output_anno_file = f"anno/{dataset_name}_captions_noun.json"
30 | annos = json.load(open(anno_file, "r"))
31 | output_annos = defaultdict(list)
32 |
33 | # nr3d
34 | # prompt_head = "###System: Given a sentence that asks for an object in a scene. Extract the primary subject from each sentence and include any accompanying adjectives, if present. " \
35 | # "###Human: When facing the bookcases, choose the plant directly on the right, next to the right most bookcase. " \
36 | # "###Assistant: plant. " \
37 | # "###Human: The big black box between the door and the couch. " \
38 | # "###Assistant: big black box. " \
39 | # "###Human: pick the white pillow that has a pillow above and under it. " \
40 | # "###Assistant: white pillow. "
41 |
42 | # sr3d
43 | prompt_head = "###System: Given a sentence that asks for an object in a scene. Extract the primary subject from each sentence and include any accompanying adjectives, if present. " \
44 | "###Human: find the office chair that is near the copier. " \
45 | "###Assistant: office chair. " \
46 | "###Human: select the trash can that is near the printer. " \
47 | "###Assistant: trash can. " \
48 | "###Human: the monitor that is near the door. " \
49 | "###Assistant: monitor. "
50 |
51 | for i, k in tqdm(enumerate(annos.keys())):
52 | print(f"{i} / {len(annos)}")
53 | captions = annos[k]
54 | for caption in captions:
55 | end_1 = (caption.find(".") + len(caption)) % len(caption)
56 | # end_2 = (caption.find(",") + len(caption)) % len(caption)
57 | bk = end_1
58 | sen_caption = caption[:bk] + "."
59 | prompt = prompt_head + "###Human: " + sen_caption + " ###"
60 | input_token = llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to("cuda:0")
61 | input_embed = model.model.embed_tokens(input_token.input_ids)
62 | outputs = model.generate(
63 | inputs_embeds=input_embed,
64 | max_new_tokens=16,
65 | num_beams=1,
66 | do_sample=True,
67 | min_length=1,
68 | top_p=0.9,
69 | repetition_penalty=1.0,
70 | length_penalty=1,
71 | temperature=1.0
72 | )
73 | output = outputs[0]
74 | if output[0] == 0:
75 | output = output[1:]
76 | if output[0] == 1:
77 | output = output[1:]
78 | output_text = llama_tokenizer.decode(output, add_special_tokens=False)
79 | print("INPUT:", sen_caption)
80 | print("OUTPUT:", output_text)
81 | try:
82 | output_annos[k].append(output_text.split("Assistant:")[1].split(".")[0].strip())
83 | print("EX OUTPUT:", output_text.split("Assistant:")[1].split(".")[0].strip())
84 | except Exception:
85 | print("Fail:")
86 | output_annos[k].append(caption[:bk])
87 |
88 | with open(output_anno_file, "w") as f:
89 | json.dump(output_annos, f, indent=4)
90 |
--------------------------------------------------------------------------------
/others/gpt_generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | sleep_time = NUM_SECONDS_TO_SLEEP
13 | while True:
14 | try:
15 | response = openai.ChatCompletion.create(
16 | model='gpt-4',
17 | # model="gpt-3.5-turbo-0613",
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except Exception as e:
30 | print(e)
31 | print(f"!!sleep for {sleep_time}")
32 | time.sleep(sleep_time)
33 | sleep_time *= 2
34 |
35 | return response['choices'][0]['message']['content']
36 |
37 |
38 | def parse_score(review):
39 | try:
40 | score_pair = review.split('\n')[0]
41 | score_pair = score_pair.replace(',', ' ')
42 | sp = score_pair.split(' ')
43 | if len(sp) == 2:
44 | return [float(sp[0]), float(sp[1])]
45 | else:
46 | print('error', review)
47 | return [-1, -1]
48 | except Exception as e:
49 | print(e)
50 | print('error', review)
51 | return [-1, -1]
52 |
53 |
54 | if __name__ == '__main__':
55 | parser = argparse.ArgumentParser(description='ChatGPT-based dataset generation.')
56 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
57 | args = parser.parse_args()
58 |
59 |
60 | idx = 0
61 | prompt_head = "You are a 3D scene understanding expert specializing in 3D visual assistance. I will provide you with a catalog of objects within a 3D scene, with each object\'s information presented in the following format: object\'s ID, object\'s class name, object\'s 3D coordinates, and a concise description of the object. The object list is as follows:\n"
62 | prompt_end = "Your task is to generate a comprehensive description of the entire scene. This description should encompass an analysis of the scene's functionality, an examination of the key objects within the scene, their spatial relationships with the surrounding objects, an assessment of the arrangement of the objects within the scene, and other relevant insights. An important guideline to follow is that when referring to an object, you must explicitly include its object ID. The description should be more than 200 words and less than 300 words."
63 | import glob
64 | for split in ["train", "val"]:
65 | for file_path in glob.glob(f"annotations/scene_dataset/obj_info_list/{split}/*.json"):
66 | scene_id = file_path.split("/")[-1][:-5]
67 | print("-" * 20)
68 | print(scene_id)
69 | output_path = f"annotations/scene_dataset/gpt_generation/{split}/{scene_id}.json"
70 | if os.path.exists(output_path):
71 | print("skip")
72 | continue
73 | obj_infos = json.load(open(file_path, "r"))
74 | prompt = prompt_head + obj_infos + prompt_end
75 | print(prompt)
76 | answer = get_eval(prompt, args.max_tokens)
77 | print(answer)
78 | with open(output_path, "w") as f:
79 | json.dump(answer, f)
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/others/llama_tmp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import sys
4 | sys.path.append(".")
5 | from models.modeling_llama import LlamaForCausalLM
6 | from transformers import LlamaTokenizer, LlamaConfig
7 | from collections import defaultdict
8 |
9 | from tqdm import tqdm
10 |
11 |
12 | llama_model_path = "model/vicuna-7b-v0"
13 |
14 | print("Loading LLaMA")
15 | llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
16 | model = LlamaForCausalLM.from_pretrained(
17 | llama_model_path,
18 | torch_dtype=torch.float16
19 | )
20 | # model = model.to("cuda")
21 | # print(torch.cuda.memory_allocated(device="cuda:0")/1e9)
22 | # exit()
23 | print("is training:", model.training)
24 | # for p in model.parameters():
25 | # p.requires_grad = False
26 | print("Loading LLaMA Done")
27 |
28 | llama_tokenizer.add_tokens(["", ""], special_tokens=False)
29 |
30 | model.resize_token_embeddings(len(llama_tokenizer))
31 |
32 |
33 | def get_text_len(text):
34 | return llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.shape[1]
35 |
36 |
37 | def get_ids(text):
38 | return llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids
39 |
40 |
41 | def get_emb(text, is_eval=False):
42 | input_ids = llama_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids
43 | if is_eval:
44 | model.eval()
45 | else:
46 | model.train()
47 | return model.model.embed_tokens(input_ids)
48 |
49 |
50 | breakpoint()
51 |
--------------------------------------------------------------------------------
/others/modify.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | # split = 'val'
5 | # ori_annos = json.load(open(f'annotations/scanrefer_{split}_stage2_objxx.json'))
6 |
7 | # templates = [line.rstrip() for line in open('prompts/scanrefer_caption_templates.txt')]
8 |
9 | # new_annos = []
10 | # for anno in ori_annos:
11 | # obj_id = anno['obj_id']
12 | # anno['prompt'] = random.choice(templates).replace('', f"")
13 | # new_annos.append(anno)
14 |
15 | # with open(f'annotations/scanrefer_{split}_stage2_caption_OBJ.json', 'w') as f:
16 | # json.dump(new_annos, f, indent=4)
17 |
18 | # ori_annos = json.load(open(f'annotations/nr3d_{split}_stage2_objxx.json'))
19 |
20 | # templates = [line.rstrip() for line in open('prompts/nr3d_caption_templates.txt')]
21 |
22 | # new_annos = []
23 | # for anno in ori_annos:
24 | # obj_id = anno['obj_id']
25 | # anno['prompt'] = random.choice(templates).replace('', f"")
26 | # new_annos.append(anno)
27 |
28 | # with open(f'annotations/nr3d_{split}_stage2_caption_OBJ.json', 'w') as f:
29 | # json.dump(new_annos, f, indent=4)
30 |
31 |
32 | # for dataset in ["scanrefer", "nr3d"]:
33 | # x = json.load(open(f'annotations/{dataset}_val_stage2_caption_OBJ.json'))
34 | # x = random.sample(x, 100)
35 | # with open(f'annotations/{dataset}_val_stage2_caption100_OBJ.json', 'w') as f:
36 | # json.dump(x, f, indent=4)
37 |
38 | # x = json.load(open('annotations/scanqa_val_stage2_objxx.json'))
39 |
40 | # x = random.sample(x, 100)
41 |
42 | # with open('annotations/scanqa_val_stage2_objxx100.json', 'w') as f:
43 | # json.dump(x, f, indent=4)
44 |
45 |
46 | # split = 'val'
47 | # iou = '25'
48 |
49 | # ori_annos = json.load(open(f'annotations/scanrefer_pointgroup_{split}_stage2_caption_iou{iou}.json'))
50 |
51 | # templates = [line.rstrip() for line in open('prompts/scanrefer_caption_templates.txt')]
52 |
53 | # new_annos = []
54 | # for anno in ori_annos:
55 | # obj_id = anno['obj_id']
56 | # anno['prompt'] = random.choice(templates).replace('', f"{obj_id:02}")
57 | # new_annos.append(anno)
58 |
59 | # with open(f'annotations/scanrefer_pointgroup_{split}_stage2_caption_iou{iou}.json', 'w') as f:
60 | # json.dump(new_annos, f, indent=4)
61 |
62 |
63 | old_annos = json.load(open('/mnt/petrelfs/huanghaifeng/share/Chat-3D-v2/annotations/scanrefer_mask3d_val_stage2_caption_iou0.json'))
64 |
65 | print(len(old_annos))
--------------------------------------------------------------------------------
/others/prepare_anno_stage1.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from sklearn.model_selection import train_test_split
4 | from collections import defaultdict
5 |
6 | dataset_name = "nr3d"
7 |
8 | dataset_dir = f"/root/autodl-tmp/scene-LLaMA/datasets/{dataset_name}"
9 | if dataset_name == "scanrefer":
10 | dataset_train = json.load(open(os.path.join(dataset_dir, "ScanRefer_filtered_train.json"), "r"))
11 | dataset_val = json.load(open(os.path.join(dataset_dir, "ScanRefer_filtered_val.json"), "r"))
12 | else:
13 | dataset_train = json.load(open(os.path.join(dataset_dir, "train.json"), "r"))
14 | dataset_val = json.load(open(os.path.join(dataset_dir, "val.json"), "r"))
15 | captions = defaultdict(list)
16 |
17 |
18 | def process(dataset_data):
19 | new_list = []
20 | for data in dataset_data:
21 | scene_id = data["scene_id"] if dataset_name == "scanrefer" else data["scan_id"]
22 | obj_id = int(data["object_id"]) if dataset_name == "scanrefer" else int(data["tgt_idx"])
23 | feat_path = f"{scene_id}/{obj_id:03}.pt"
24 | caption = data["description"] if dataset_name == "scanrefer" else data["query"]
25 | new_data = {
26 | "pc_feat_path": feat_path,
27 | "caption": caption,
28 | "scene_id": scene_id,
29 | "obj_id": obj_id
30 | }
31 | captions[f"{scene_id}_{obj_id}"].append(caption)
32 | new_list.append(new_data)
33 | return new_list
34 |
35 |
36 | output_train = process(dataset_train)
37 | output_val = process(dataset_val)
38 |
39 |
40 | with open(f"anno/{dataset_name}_train_stage2.json", "w") as f:
41 | json.dump(output_train, f)
42 |
43 | with open(f"anno/{dataset_name}_val_stage2.json", "w") as f:
44 | json.dump(output_val, f)
45 |
46 | with open(f"anno/{dataset_name}_captions.json", "w") as f:
47 | json.dump(captions, f)
48 |
--------------------------------------------------------------------------------
/others/prepare_captions_noun.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | dataset = "sr3d_merge"
5 | val_split_path = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_val.txt"
6 | train_split_path = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_train.txt"
7 |
8 | caption_path = f"anno/{dataset}_captions_noun.json"
9 |
10 | train_scene_ids = []
11 | val_scene_ids = []
12 |
13 | with open(train_split_path, "r") as f:
14 | for line in f.readlines():
15 | train_scene_ids.append(line.strip())
16 |
17 | with open(val_split_path, "r") as f:
18 | for line in f.readlines():
19 | val_scene_ids.append(line.strip())
20 |
21 | captions = json.load(open(caption_path, "r"))
22 |
23 | train_captions = {}
24 | val_captions = {}
25 |
26 | for k, v in captions.items():
27 | scene_id = "_".join(k.split("_")[:-1])
28 | if scene_id in train_scene_ids:
29 | train_captions[k] = v
30 | if scene_id in val_scene_ids:
31 | val_captions[k] = v
32 |
33 | output_train_path = f"anno/{dataset}_train_captions_noun.json"
34 | output_val_path = f"anno/{dataset}_val_captions_noun.json"
35 |
36 | with open(output_train_path, "w") as f:
37 | json.dump(train_captions, f)
38 |
39 | with open(output_val_path, "w") as f:
40 | json.dump(val_captions, f)
41 |
--------------------------------------------------------------------------------
/others/prepare_describe.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | anno_root = "anno"
6 | anno_file = os.path.join(anno_root, "scanrefer_val_conversation.json")
7 |
8 | annos = json.load(open(anno_file, "r"))
9 |
10 | # questions = []
11 | # with open("prompts/detailed_description.txt", "r") as f:
12 | # for q in f.readlines():
13 | # questions.append(q.strip())
14 |
15 | # for k, a in annos.items():
16 | # q = random.choice(questions)
17 | # annos[k] = [{
18 | # "Question": q,
19 | # "Answer": a
20 | # }]
21 | output_anno = []
22 | for k, v in annos.items():
23 | if len(v) == 0:
24 | continue
25 | scene_id = "_".join(k.split("_")[:-1])
26 | obj_id = int(k.split("_")[-1])
27 | for i in range(len(v)):
28 | q = v[i]["Question"]
29 | a = v[i]["Answer"]
30 | output_anno.append({
31 | "scene_id": scene_id,
32 | "obj_id": obj_id,
33 | "qid": i,
34 | "prompt": q,
35 | "ref": a
36 | })
37 | # if len(output_anno) >= 500:
38 | # break
39 |
40 | output_anno = sorted(output_anno, key=lambda x: f"{x['scene_id']}_{x['obj_id']:03}_{x['qid']:2}")
41 | print(len(output_anno))
42 |
43 | with open("anno/scanrefer_val_conv.json", "w") as f:
44 | json.dump(output_anno, f, indent=4)
45 |
--------------------------------------------------------------------------------
/others/prepare_eval_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | anno_root = "anno"
6 | detail_refs = json.load(open(os.path.join(anno_root, "scanrefer_val_describe.json"), "r"))
7 | conv_refs = json.load(open(os.path.join(anno_root, "scanrefer_val_conversation.json"), "r"))
8 | target_infos = json.load(open(os.path.join(anno_root, "scanrefer_val_content.json"), "r"))
9 |
10 | # detail_items = set(detail_refs.keys())
11 | # conv_items = set(conv_refs.keys())
12 | # item_list = list(detail_items & conv_items)
13 | # print(len(item_list))
14 | #
15 | # sampled_items = sorted(random.sample(item_list, 100))
16 | #
17 | # with open("scripts/eval/item_list.json", "w") as f:
18 | # json.dump(sampled_items, f)
19 |
20 | item_list = json.load(open("eval/item_list.json", "r"))
21 |
22 | # detail_samples = []
23 | # conv_samples = []
24 | # qid = 0
25 | # answers = []
26 | # questions = []
27 | #
28 | # for item in item_list:
29 | # scene_id = "_".join(item.split("_")[:-1])
30 | # obj_id = int(item.split("_")[-1])
31 | # detail_samples.append({
32 | # "scene_id": scene_id,
33 | # "obj_id": obj_id,
34 | # "prompt": detail_refs[item][0]["Question"]
35 | # })
36 | # answers.append({
37 | # "question_id": qid,
38 | # "text": detail_refs[item][0]["Answer"],
39 | # "category": "detail"
40 | # })
41 | # questions.append({
42 | # "question_id": qid,
43 | # "text": detail_refs[item][0]["Question"],
44 | # "category": "detail",
45 | # "item_id": item
46 | # })
47 | # qid += 1
48 | # conv = random.choice(conv_refs[item])
49 | # while len(conv["Answer"]) == 0:
50 | # print(f"empty answer in {item}...")
51 | # conv = random.choice(conv_refs[item])
52 | # conv_samples.append({
53 | # "scene_id": scene_id,
54 | # "obj_id": obj_id,
55 | # "prompt": conv["Question"]
56 | # })
57 | # answers.append({
58 | # "question_id": qid,
59 | # "text": conv["Answer"],
60 | # "category": "conv"
61 | # })
62 | # questions.append({
63 | # "question_id": qid,
64 | # "text": conv["Question"],
65 | # "category": "conv",
66 | # "item_id": item
67 | # })
68 | # qid += 1
69 | #
70 | # with open(os.path.join(anno_root, "scanrefer_val_describe100.json"), "w") as f:
71 | # json.dump(detail_samples, f)
72 | # with open(os.path.join(anno_root, "scanrefer_val_conversation100.json"), "w") as f:
73 | # json.dump(conv_samples, f)
74 | #
75 | # with open("eval/qa200_gpt4_answer.json", "w") as f:
76 | # json.dump(answers, f, indent=4)
77 | # with open("eval/qa200_questions.json", "w") as f:
78 | # json.dump(questions, f, indent=4)
79 |
80 | old_answers = json.load(open("eval/qa200_gpt4_answer.json", "r"))
81 | old_questions = json.load(open("eval/qa200_questions.json", "r"))
82 |
83 | new_list = [item_list[0]]
84 | for i in range(1, len(item_list)):
85 | if item_list[i][:11] != item_list[i-1][:11]:
86 | new_list.append(item_list[i])
87 | new_list = random.sample(new_list, 30)
88 |
89 | new_answers = []
90 | new_questions = []
91 | qid = 0
92 |
93 | for i in range(len(old_questions)):
94 | ques = old_questions[i]
95 | ans = old_answers[i]
96 | if ques["item_id"] in new_list:
97 | ques["question_id"] = ans["answer_id"] = qid
98 | qid += 1
99 | new_questions.append(ques)
100 | new_answers.append(ans)
101 |
102 | with open("eval/qa60_gpt4_answer.json", "w") as f:
103 | json.dump(new_answers, f, indent=4)
104 | with open("eval/qa60_questions.json", "w") as f:
105 | json.dump(new_questions, f, indent=4)
106 |
--------------------------------------------------------------------------------
/others/prepare_identifier_rich.py:
--------------------------------------------------------------------------------
1 | import json
2 | import glob
3 |
4 | split = "train"
5 |
6 | annos = []
7 | for file_name in glob.glob(f"annotations/scene_dataset/gpt_generation/{split}/*.json"):
8 | scene_id = file_name.split("/")[-1][:-5]
9 | obj_id = 0
10 | prompt = "Provide a comprehensive description of the entire scene."
11 | caption = json.load(open(file_name, "r"))
12 | annos.append({
13 | "scene_id": scene_id,
14 | "obj_id": obj_id,
15 | "prompt": prompt,
16 | "caption": caption
17 | })
18 |
19 | with open(f"annotations/scene_dataset_{split}_stage2.json", "w") as f:
20 | json.dump(annos, f, indent=4)
21 |
--------------------------------------------------------------------------------
/others/prepare_multi3dref.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | split = 'train'
5 | ori_anno = json.load(open(os.path.join('annotations/multi3drefer_train_val', f"multi3drefer_{split}.json")))
6 |
7 | new_anno = []
--------------------------------------------------------------------------------
/others/prepare_obj_align_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | split = "val"
5 | with open(f"annotations/scanrefer_{split}_stage1.json", "r") as f:
6 | annos = json.load(f)
7 |
8 | if split == "train":
9 | with open(f"annotations/scannet_{split}_stage1.json", "r") as f:
10 | annos.extend(json.load(f))
11 |
12 | print(len(annos))
13 |
14 | new_annos = []
15 | unwanted_words = ["wall", "ceiling", "floor", "object", "item"]
16 |
17 | with open("prompts/obj_align_template.txt", 'r') as f:
18 | answer_templates = f.read().splitlines()
19 |
20 | for anno in annos:
21 | scene_id = anno["scene_id"]
22 | obj_id = anno["obj_id"]
23 | caption = anno["captions"][0]
24 | prompt = f"What is the ?"
25 | if any(x in caption for x in unwanted_words):
26 | continue
27 | if split == "train":
28 | answer_template = random.choice(answer_templates)
29 | if answer_template.count("{}") == 2:
30 | answer = answer_template.format(f"", caption)
31 | else:
32 | answer = answer_template.format(caption)
33 | new_annos.append({
34 | "scene_id": scene_id,
35 | "obj_id": obj_id,
36 | "prompt": prompt,
37 | "caption": answer
38 | })
39 | else:
40 | answers = []
41 | for answer_template in answer_templates:
42 | if answer_template.count("{}") == 2:
43 | answer = answer_template.format(f"", caption)
44 | else:
45 | answer = answer_template.format(caption)
46 | answers.append(answer)
47 | new_annos.append({
48 | "scene_id": scene_id,
49 | "obj_id": obj_id,
50 | "prompt": prompt,
51 | "ref_captions": answers
52 | })
53 |
54 | print(len(new_annos))
55 |
56 | with open(f"annotations/obj_align_{split}_OBJ.json", "w") as f:
57 | json.dump(new_annos, f, indent=4)
58 |
--------------------------------------------------------------------------------
/others/prepare_obj_caption.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | split = 'val'
6 | all_objects = json.load(open('annotations/all_objects_with_full_desc_0318.json'))
7 | templates = [line.rstrip() for line in open('prompts/object_caption_templates.txt')]
8 | scan_list = [line.rstrip() for line in open(f"annotations/scannet_{split}.txt")]
9 |
10 | new_annos = []
11 | for k, v in all_objects.items():
12 | if k not in scan_list:
13 | continue
14 | for o in v['object_list']:
15 | if o['description'] is not None:
16 | new_annos.append({
17 | 'scene_id': k,
18 | 'obj_id': o['object_id'],
19 | 'prompt': random.choice(templates).replace('', f""),
20 | 'ref_captions': [o['description'].capitalize()+'.'],
21 | 'related_ids': [o['object_id']]
22 | })
23 | print(len(new_annos))
24 |
25 | with open(f'annotations/scannet_{split}_stage2_caption_OBJ.json', 'w') as f:
26 | json.dump(new_annos, f, indent=4)
27 |
--------------------------------------------------------------------------------
/others/prepare_ref_captions.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | val_anno = json.load(open("anno/scanrefer_val_stage2.json", "r"))
6 | ref_captions = json.load(open("anno/scanrefer_captions.json", "r"))
7 | train_convs = json.load(open("anno/scanrefer_train_conversation.json", "r"))
8 | output_anno = []
9 |
10 | for k, v in train_convs.items():
11 | scene_id = "_".join(k.split("_")[:-1])
12 | obj_id = int(k.split("_")[-1])
13 | if scene_id[:7] != "scene00" or len(v) == 0:
14 | continue
15 | qa = random.choice(v)
16 | output_anno.append({
17 | "scene_id": scene_id,
18 | "obj_id": obj_id,
19 | "prompt": qa["Question"],
20 | "ref_captions": []
21 | })
22 |
23 |
24 | # cap_list = []
25 | # for k, v in ref_captions.items():
26 | # if k[:9] == "scene0000":
27 | # cap_list.append({
28 | # "scene_obj": k,
29 | # "captions": v
30 | # })
31 | # cap_list = sorted(cap_list, key=lambda x: x["scene_obj"])
32 | # with open("anno/tmp.json", "w") as f:
33 | # json.dump(cap_list, f)
34 | #
35 | # exit()
36 |
37 | # prompt_list = []
38 | # with open("prompts/conv_description.txt", "r") as f:
39 | # for line in f.readlines():
40 | # prompt = line.strip().split(" ")[-1]
41 | # prompt_list.append(prompt)
42 | #
43 | # id_set = set()
44 | #
45 | # output_anno = []
46 | #
47 | # for anno in val_anno:
48 | # scene_id = anno["scene_id"]
49 | # obj_id = anno["obj_id"]
50 | # item_id = f"{scene_id}_{obj_id}"
51 | # if scene_id[:7] == "scene00" and item_id not in id_set:
52 | # id_set.add(item_id)
53 | # prompt = random.choice(prompt_list)
54 | # output_anno.append({
55 | # "scene_id": scene_id,
56 | # "obj_id": obj_id,
57 | # "prompt": prompt,
58 | # "ref_captions": []
59 | # })
60 |
61 | print(len(output_anno))
62 | output_anno = sorted(output_anno, key=lambda x: f"{x['scene_id']}_{x['obj_id']:03}")
63 | output_anno = output_anno[:200]
64 |
65 | with open("anno/scanrefer_val_convs.json", "w") as f:
66 | json.dump(output_anno, f, indent=4)
67 |
--------------------------------------------------------------------------------
/others/prepare_referit_anno_stage1.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from sklearn.model_selection import train_test_split
4 | from collections import defaultdict
5 |
6 | dataset_name = "sr3d_merge"
7 | referit_anno_root = "/root/scene-LLaMA/datasets/referit3d/annotations"
8 |
9 | dataset_dir = os.path.join(referit_anno_root, "bert_tokenized")
10 | train_split_file = os.path.join(referit_anno_root, "splits/scannetv2_train.txt")
11 | val_split_file = os.path.join(referit_anno_root, "splits/scannetv2_val.txt")
12 | # dataset_train = json.load(open(os.path.join(dataset_dir, "train.json"), "r"))
13 | # dataset_val = json.load(open(os.path.join(dataset_dir, "val.json"), "r"))
14 |
15 | sr3d_data = []
16 | with open(os.path.join(dataset_dir, f"sr3d.jsonl"), "r") as f:
17 | for line in f:
18 | data = json.loads(line)
19 | for k, v in data.items():
20 | if type(v) == list:
21 | data[k] = frozenset(v)
22 | sr3d_data.append(data)
23 |
24 | sr3d_plus_data = []
25 | with open(os.path.join(dataset_dir, f"sr3d+.jsonl"), "r") as f:
26 | for line in f:
27 | data = json.loads(line)
28 | for k, v in data.items():
29 | if type(v) == list:
30 | data[k] = frozenset(v)
31 | sr3d_plus_data.append(data)
32 |
33 | sr3d_data = {frozenset(d.items()) for d in sr3d_data}
34 | sr3d_plus_data = {frozenset(d.items()) for d in sr3d_plus_data}
35 | sr3d_merged_data = sr3d_data | sr3d_plus_data
36 | print(len(sr3d_data), len(sr3d_plus_data), len(sr3d_merged_data))
37 | #
38 | # exit()
39 |
40 | train_scenes = []
41 | with open(train_split_file, "r") as f:
42 | for line in f.readlines():
43 | train_scenes.append(line.strip())
44 |
45 | val_scenes = []
46 | with open(val_split_file, "r") as f:
47 | for line in f.readlines():
48 | val_scenes.append(line.strip())
49 |
50 | print(len(train_scenes), len(val_scenes))
51 |
52 |
53 | train_data = []
54 | val_data = []
55 |
56 | correct_false = 0
57 | other_data = 0
58 |
59 | # with open(os.path.join(dataset_dir, f"{dataset_name}.jsonl"), "r") as f:
60 | # for line in f:
61 | # tmp_data = json.loads(line)
62 | # if not tmp_data["correct_guess"]:
63 | # correct_false += 1
64 | # continue
65 | # if tmp_data["scan_id"] in train_scenes:
66 | # train_data.append(tmp_data)
67 | # elif tmp_data["scan_id"] in val_scenes:
68 | # val_data.append(tmp_data)
69 | # else:
70 | # # print(tmp_data["scan_id"])
71 | # other_data += 1
72 |
73 | for tmp_data in sr3d_merged_data:
74 | scan_id, target_id, utterance = None, None, None
75 | for k, v in tmp_data:
76 | if k == "scan_id":
77 | scan_id = v
78 | if k == "target_id":
79 | target_id = v
80 | if k == "utterance":
81 | utterance = v
82 | tmp_data = {
83 | "scan_id": scan_id,
84 | "target_id": target_id,
85 | "utterance": utterance
86 | }
87 | if tmp_data["scan_id"] in train_scenes:
88 | train_data.append(tmp_data)
89 | elif tmp_data["scan_id"] in val_scenes:
90 | val_data.append(tmp_data)
91 | else:
92 | # print(tmp_data["scan_id"])
93 | other_data += 1
94 |
95 | print(len(train_data), len(val_data), correct_false, other_data)
96 |
97 | captions = defaultdict(list)
98 |
99 |
100 | def process(dataset_data):
101 | new_list = []
102 | for data in dataset_data:
103 | scene_id = data["scan_id"]
104 | obj_id = int(data["target_id"])
105 | feat_path = f"{scene_id}/{obj_id:03}.pt"
106 | caption = data["utterance"]
107 | new_data = {
108 | "pc_feat_path": feat_path,
109 | "caption": caption,
110 | "scene_id": scene_id,
111 | "obj_id": obj_id
112 | }
113 | captions[f"{scene_id}_{obj_id}"].append(caption)
114 | new_list.append(new_data)
115 | return new_list
116 |
117 |
118 | output_train = process(train_data)
119 | output_val = process(val_data)
120 |
121 |
122 | with open(f"anno/{dataset_name}_train_stage1.json", "w") as f:
123 | json.dump(output_train, f)
124 |
125 | with open(f"anno/{dataset_name}_val_stage1.json", "w") as f:
126 | json.dump(output_val, f)
127 |
128 | with open(f"anno/{dataset_name}_captions.json", "w") as f:
129 | json.dump(captions, f)
130 |
--------------------------------------------------------------------------------
/others/prepare_scanqa.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | replace_list = [["10", "ten"], ["12", "twelve"], ["1", "one"], ["2", "two"], ["3", "three"], ["4", "four"], ["5", "five"],
4 | ["6", "six"], ["7", "seven"], ["8", "eight"], ["9", "nine"]]
5 |
6 | split = "train"
7 | with open(f"annotations/scanqa/ScanQA_v1.0_{split}.json", "r") as f:
8 | annos = json.load(f)
9 | print(len(annos))
10 | new_annos = []
11 | for anno in annos:
12 | scene_id = anno["scene_id"]
13 | obj_ids = anno["object_ids"] if "object_ids" in anno else []
14 | question = anno["question"]
15 | # for (a, b) in replace_list:
16 | # question = question.replace(a, b)
17 |
18 | # prompt = f"Pay attention to obj{obj_ids[0]:02}"
19 | # if len(obj_ids) == 2:
20 | # prompt += f" and obj{obj_ids[1]:02}"
21 | # elif len(obj_ids) > 2:
22 | # for i in range(1, len(obj_ids)-1):
23 | # prompt += f", obj{obj_ids[i]:02}"
24 | # prompt += f", and obj{obj_ids[-1]:02}"
25 | # if len(obj_ids) > 0:
26 | # related_prompt = f"The relevant object{'s are' if len(obj_ids) > 1 else ' is'} obj{obj_ids[0]:02}"
27 | # if len(obj_ids) == 2:
28 | # related_prompt += f" and obj{obj_ids[1]:02}"
29 | # elif len(obj_ids) > 2:
30 | # for i in range(1, len(obj_ids)-1):
31 | # related_prompt += f", obj{obj_ids[i]:02}"
32 | # related_prompt += f", and obj{obj_ids[-1]:02}"
33 | # related_prompt += "."
34 | # prompt += ". " + question + " Answer the question using a single word or phrase."
35 | prompt = question + " Answer the question using a single word or phrase."
36 | # prompt = question + " The answer should be a phrase or a single word."
37 |
38 | answers = anno["answers"]
39 | if split == "val":
40 | # for i in range(len(answers)):
41 | # for (a, b) in replace_list:
42 | # answers[i] = answers[i].replace(a, b)
43 | new_annos.append({
44 | "scene_id": scene_id,
45 | "obj_id": obj_ids[0],
46 | "prompt": prompt,
47 | "ref_captions": answers
48 | })
49 | elif split == "train":
50 | for i in range(len(answers)):
51 | if i > 0 and answers[i] == answers[i-1]:
52 | continue
53 | answer = answers[i]
54 | # for (a, b) in replace_list:
55 | # answer = answer.replace(a, b)
56 | answer = answer.capitalize()
57 | if answer[-1] != ".":
58 | answer += "."
59 | # answer = "The answer is " + answer + "."
60 | new_annos.append({
61 | "scene_id": scene_id,
62 | "obj_id": obj_ids[0],
63 | "prompt": prompt,
64 | "caption": answer,
65 | # "related_ids": obj_ids
66 | })
67 | print(len(new_annos))
68 |
69 | with open(f"annotations/scanqa_{split}.json", "w") as f:
70 | json.dump(new_annos, f, indent=4)
71 |
--------------------------------------------------------------------------------
/others/prepare_scene_align_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import torch
4 | import random
5 |
6 | split = "train"
7 |
8 | annos = []
9 |
10 | if split == "val":
11 | with open(f"annotations/scanrefer_{split}_stage1.json", "r") as f:
12 | annos.extend(json.load(f))
13 |
14 | if split == "train":
15 | with open(f"annotations/scannet_{split}_stage1.json", "r") as f:
16 | annos.extend(json.load(f))
17 |
18 | scene_attrs = json.load(open("annotations/scannet_attributes.json", "r"))
19 |
20 | print(len(annos))
21 |
22 | new_annos = []
23 |
24 | for anno in annos:
25 | scene_id = anno["scene_id"]
26 | obj_id = int(anno["obj_id"])
27 | # caption = anno["captions"][0]
28 |
29 | locs = torch.tensor(scene_attrs[scene_id]["locs"])
30 | obj_num = locs.shape[0]
31 | if obj_num <= 6:
32 | continue
33 | centers = locs[:, :3]
34 | dis = ((centers - centers[obj_id])**2).sum(dim=-1)
35 | center_diff = (centers - centers[obj_id]).abs()
36 |
37 | mode = random.randint(0, 3)
38 | if mode == 0:
39 | if random.randint(0, 1) == 0:
40 | prompt = f"Which object is closest to obj{obj_id:02}?"
41 | answer_id = int(dis.topk(k=2, largest=False)[1][1])
42 | answer = f"The closest object to obj{obj_id:02} is obj{answer_id:02}."
43 | else:
44 | prompt = f"Which object is farthest from obj{obj_id:02}?"
45 | answer_id = int(dis.topk(k=2, largest=True)[1][0])
46 | answer = f"The farthest object from obj{obj_id:02} is obj{answer_id:02}."
47 | if mode == 1:
48 | a = random.randint(0, obj_num-1)
49 | b = random.randint(0, obj_num-1)
50 | while a == obj_id or b == obj_id or a == b:
51 | a = random.randint(0, obj_num - 1)
52 | b = random.randint(0, obj_num - 1)
53 | if random.randint(0, 1):
54 | prompt = f"Which object is closer to obj{obj_id:02}, obj{a:02} or obj{b:02}?"
55 | answer_id = a if dis[a] < dis[b] else b
56 | answer = f"The closer object to obj{obj_id:02} is obj{answer_id:02}."
57 | else:
58 | prompt = f"Which object is farther from obj{obj_id:02}, obj{a:02} or obj{b:02}?"
59 | answer_id = a if dis[a] > dis[b] else b
60 | answer = f"The farther object from obj{obj_id:02} is obj{answer_id:02}."
61 | if mode == 2:
62 | a = random.randint(0, obj_num - 1)
63 | b = random.randint(0, obj_num - 1)
64 | while a == obj_id or b == obj_id or a == b:
65 | a = random.randint(0, obj_num - 1)
66 | b = random.randint(0, obj_num - 1)
67 | z_a = locs[a][2] - locs[a][5]
68 | z_b = locs[b][2] - locs[b][5]
69 | if random.randint(0, 1):
70 | prompt = f"Which object is located at the higher position, obj{a:02} or obj{b:02}?"
71 | if z_a < z_b:
72 | a, b = b, a
73 | answer = f"Obj{a:02} is located at a higher position compared to obj{b:02}."
74 | else:
75 | prompt = f"Which object is located at the lower position, obj{a:02} or obj{b:02}?"
76 | if z_a > z_b:
77 | a, b = b, a
78 | answer = f"Obj{a:02} is located at a lower position compared to obj{b:02}."
79 | if mode == 3:
80 | prompt = f"List the five closest objects to obj{obj_id:02} in ascending order of their object IDs."
81 | answer_ids = dis.topk(k=6, largest=False)[1][1:6].tolist()
82 | answer_ids.sort()
83 | answer = f"The five closest objects to obj{obj_id:02} in ascending order are: obj{answer_ids[0]:02}, obj{answer_ids[1]:02}, obj{answer_ids[2]:02}, obj{answer_ids[3]:02}, and obj{answer_ids[4]:02}."
84 | # nearest
85 | if split == "train":
86 | new_annos.append({
87 | "scene_id": scene_id,
88 | "obj_id": obj_id,
89 | "prompt": prompt,
90 | "caption": answer
91 | })
92 | else:
93 | new_annos.append({
94 | "scene_id": scene_id,
95 | "obj_id": obj_id,
96 | "prompt": prompt,
97 | "ref_captions": [answer]
98 | })
99 |
100 | with open(f"annotations/scene_align_{split}.json", "w") as f:
101 | json.dump(new_annos, f, indent=4)
102 |
--------------------------------------------------------------------------------
/others/prepare_scene_level_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from collections import defaultdict
4 |
5 | split = "val"
6 |
7 | annos = json.load(open(f"annotations/scanrefer_{split}_stage2_objxx.json", "r"))
8 | attrs = torch.load(f"annotations/scannet_{split}_attributes.pt")
9 |
10 | new_annos = defaultdict(dict)
11 |
12 | for anno in annos:
13 | scene_id = anno["scene_id"]
14 | obj_id = anno["obj_id"]
15 | caption = anno["caption"] if "caption" in anno else anno["ref_captions"][0]
16 | new_annos[scene_id][f"{obj_id:03}"] = {
17 | "loc": attrs[scene_id]["locs"][obj_id][:3].tolist(),
18 | "caption": caption,
19 | "class_name": attrs[scene_id]["objects"][obj_id]
20 | }
21 |
22 |
23 | print(len(new_annos))
24 |
25 | for scene_id in new_annos.keys():
26 | message = ""
27 | for i in range(200):
28 | obj_id = f"{i:03}"
29 | if obj_id not in new_annos[scene_id]:
30 | continue
31 | obj_anno = new_annos[scene_id][obj_id]
32 | class_name = obj_anno["class_name"]
33 | loc = obj_anno["loc"]
34 | for j in range(len(loc)):
35 | loc[j] = round(loc[j], 2)
36 | caption = obj_anno["caption"]
37 | message += f"obj{i:02}: {class_name}; {loc}; {caption}\n"
38 | with open(f"annotations/scene_dataset/obj_info_list/{split}/{scene_id}.json", "w") as f:
39 | json.dump(message, f)
40 |
41 |
--------------------------------------------------------------------------------
/others/prepare_sqa3d.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 | import nltk
5 | import random
6 | from tqdm import tqdm
7 |
8 | anno_dir = 'annotations/sqa3d'
9 |
10 | # for filename in os.listdir(anno_dir):
11 | # x = json.load(open(os.path.join(anno_dir, filename)))
12 | # with open(os.path.join(anno_dir, filename), 'w') as f:
13 | # json.dump(x, f, indent=4)
14 |
15 |
16 | def convert_person_view(sentence):
17 | # first-person view to second-person view
18 | forms = {'i': 'you', 'me': 'you', 'my': 'your', 'mine': 'yours', 'am': 'are'}
19 | def translate(word):
20 | if word.lower() in forms:
21 | return forms[word.lower()]
22 | return word
23 | result = ' '.join([translate(word) for word in nltk.wordpunct_tokenize(sentence)])
24 | return result.capitalize()
25 |
26 |
27 | def get_sqa_question_type(question):
28 | question = question.lstrip()
29 | if question[:4].lower() == 'what':
30 | return 0
31 | elif question[:2].lower() == 'is':
32 | return 1
33 | elif question[:3].lower() == 'how':
34 | return 2
35 | elif question[:3].lower() == 'can':
36 | return 3
37 | elif question[:5].lower() == 'which':
38 | return 4
39 | else:
40 | return 5 # others
41 |
42 |
43 | for split in ['train', 'val']:
44 | scan_ids = []
45 | sqa_annos = []
46 | question_file = os.path.join(anno_dir, f'v1_balanced_questions_{split}_scannetv2.json')
47 | with open(question_file, 'r', encoding='utf-8') as f:
48 | question_data = json.load(f)['questions']
49 | question_map = {}
50 | for item in question_data:
51 | question_map[item['question_id']] = {
52 | 's': [item['situation']] + item['alternative_situation'], # list of str
53 | 'q': item['question'], # str
54 | }
55 |
56 | anno_file = os.path.join(anno_dir, f'v1_balanced_sqa_annotations_{split}_scannetv2.json')
57 | with open(anno_file, 'r', encoding='utf-8') as f:
58 | anno_data = json.load(f)['annotations']
59 | for item in tqdm(anno_data):
60 | scan_ids.append(item['scene_id'])
61 | # sqa_annos.append({
62 | # 's': question_map[item['question_id']]['s'], # list of str
63 | # 'q': question_map[item['question_id']]['q'], # str
64 | # 'a': [meta['answer'] for meta in item['answers']], # list of str
65 | # 'pos': np.array(list(item['position'].values())), # array (3,)
66 | # 'rot': np.array(list(item['rotation'].values())), # array (4,)
67 | # })
68 | scene_id = item['scene_id']
69 | obj_id = 0
70 | situation = random.choice(question_map[item['question_id']]['s'])
71 | question = question_map[item['question_id']]['q']
72 | question_type = get_sqa_question_type(question)
73 | prompt = situation + ' ' + question + " Answer the question using a single word or phrase."
74 | answers = [meta['answer'] for meta in item['answers']]
75 | if split == 'train':
76 | answer = random.choice(answers)
77 | answer = answer.capitalize()
78 | if answer[-1] != ".":
79 | answer += "."
80 | sqa_annos.append({
81 | 'scene_id': scene_id,
82 | 'obj_id': obj_id,
83 | 'prompt': prompt,
84 | 'caption': answer,
85 | 'sqa_type': question_type
86 | })
87 | else:
88 | sqa_annos.append({
89 | 'scene_id': scene_id,
90 | 'obj_id': obj_id,
91 | 'prompt': prompt,
92 | 'ref_captions': answers,
93 | 'sqa_type': question_type
94 | })
95 | # print(sqa_annos[-1])
96 | # breakpoint()
97 | with open(f"annotations/sqa3d_{split}.json", "w") as f:
98 | json.dump(sqa_annos, f, indent=4)
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/others/prepare_train_stage3.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | import json
3 |
4 | train_split_file = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_train.txt"
5 | train_list = []
6 | with open(train_split_file, "r") as f:
7 | for line in f.readlines():
8 | train_list.append(line.strip())
9 |
10 | nr3d_anno_file = "/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/nr3d.jsonl"
11 | nr3d_anno = []
12 | with jsonlines.open(nr3d_anno_file, "r") as reader:
13 | for l in reader:
14 | nr3d_anno.append(l)
15 |
16 | anno_root = "annotations" # annotation dir
17 | attribute_file = f"{anno_root}/scannet_attributes_old.json"
18 | attributes = json.load(open(attribute_file, 'r'))
19 |
20 | q_template = "Evaluate the request below and determine whether it accurately identifies the target object enclosed within the tags '' and '': {} Please respond with the answer in the following format: 'The answer is: True' if the request correctly localizes the target object, or 'The answer is: False' if it does not."
21 | a_template = "The answer is: {}."
22 |
23 | from tqdm import tqdm
24 | import random
25 | from collections import defaultdict
26 | utters = defaultdict(list)
27 | for item in nr3d_anno:
28 | scene_id = item["scan_id"]
29 | if scene_id not in train_list:
30 | continue
31 | utter = item["utterance"]
32 | utters[scene_id].append(utter)
33 | # target_id = item["target_id"]
34 |
35 | new_anno = []
36 |
37 | for item in tqdm(nr3d_anno):
38 | scene_id = item["scan_id"]
39 | if scene_id not in train_list:
40 | continue
41 | utter = item["utterance"]
42 | target_id = item["target_id"]
43 | new_anno.append({
44 | "scene_id": scene_id,
45 | "obj_id": target_id,
46 | "QA": [{
47 | "Question": q_template.format(utter),
48 | "Answer": a_template.format("True")
49 | }]
50 | })
51 |
52 | utter = random.choice(utters[scene_id])
53 | while utter == item["utterance"]:
54 | utter = random.choice(utters[scene_id])
55 | new_anno.append({
56 | "scene_id": scene_id,
57 | "obj_id": target_id,
58 | "QA": [{
59 | "Question": q_template.format(utter),
60 | "Answer": a_template.format("False")
61 | }]
62 | })
63 |
64 | with open("annotations/nr3d_train_tf.json", "w") as f:
65 | json.dump(new_anno, f)
66 |
--------------------------------------------------------------------------------
/others/process_preds.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | model = "2023-09-17-144020_dp0.1_lr5e-5_sta3_ep3"
5 | conv_pred_file = f"eval/conv100_{model}.json"
6 | detail_pred_file = f"eval/detail100_{model}.json"
7 |
8 | qa_name = "qa60"
9 | conv_preds_ = json.load(open(conv_pred_file, "r"))
10 | detail_preds_ = json.load(open(detail_pred_file, "r"))
11 | ques = json.load(open(f"eval/{qa_name}_questions.json", "r"))
12 |
13 | conv_preds = {}
14 | detail_preds = {}
15 |
16 | for pred in conv_preds_:
17 | item_id = f"{pred['scene_id']}_{pred['obj_id']}"
18 | conv_preds[item_id] = pred["pred"]
19 | for pred in detail_preds_:
20 | item_id = f"{pred['scene_id']}_{pred['obj_id']}"
21 | detail_preds[item_id] = pred["pred"]
22 |
23 | answers = []
24 |
25 | for que in ques:
26 | item_id = que["item_id"]
27 | answer = conv_preds[item_id] if que["category"] == "conv" else detail_preds[item_id]
28 | if ": " in answer:
29 | answer = ": ".join(answer.split(": ")[1:])
30 | answers.append({
31 | "answer_id": que["question_id"],
32 | "text": answer
33 | })
34 |
35 | with open(f"eval/{qa_name}_{model}_answer.json", "w") as f:
36 | json.dump(answers, f, indent=4)
37 |
--------------------------------------------------------------------------------
/others/process_vil3dref_multichoice.py:
--------------------------------------------------------------------------------
1 | """
2 | loss/og3d: 2.9594, loss/obj3d_clf: 3.3753, loss/obj3d_clf_pre: 2.0714, loss/txt_clf: 0.6708, loss/total: 10.2789, loss/cross_attn_0: 0.0032, loss/cross_attn_1: 0.0011, loss/cross_attn_2: 0.0011, loss/cross_attn_3: 0.0012, loss/self_attn_0: 0.1595, loss/self_attn_1: 0.0425, loss/self_attn_2: 0.0541, loss/self_attn_3: 0.1030, loss/hidden_state_0: 0.3919, loss/hidden_state_1: 0.0765, loss/hidden_state_2: 0.1033, loss/hidden_state_3: 0.1308, loss/hidden_state_4: 0.1337, acc/og3d: 0.6373, acc/og3d_class: 0.8903, acc/obj3d_clf: 0.6828, acc/obj3d_clf_pre: 0.6131, acc/txt_clf: 0.9281
3 | """
4 |
5 | import json
6 | import jsonlines
7 | import math
8 | import torch
9 | from random import shuffle
10 | import random
11 | split = "val"
12 | val_file = f"/root/scene-LLaMA/datasets/exprs_neurips22/gtlabelpcd_mix/nr3d/preds/{split}_outs.json"
13 | nr3d_anno_file = "/root/scene-LLaMA/datasets/referit3d/annotations/bert_tokenized/nr3d.jsonl"
14 |
15 | val_results = json.load(open(val_file))
16 |
17 | nr3d_anno = {}
18 | with jsonlines.open(nr3d_anno_file, "r") as reader:
19 | for l in reader:
20 | nr3d_anno[l["item_id"]] = l
21 |
22 | val_split_path = "/root/scene-LLaMA/datasets/referit3d/annotations/splits/scannetv2_val.txt"
23 | scene_ids = []
24 | with open(val_split_path, "r") as f:
25 | for line in f.readlines():
26 | scene_ids.append(line.strip())
27 |
28 | shuffle(scene_ids)
29 | scene_num = len(scene_ids)
30 | train_scene_num = int(scene_num * 0.8)
31 | train_scene_ids, val_scene_ids = scene_ids[:train_scene_num], scene_ids[train_scene_num:]
32 |
33 |
34 |
35 | multi_choice_template = "Given the description of \",\" I have received a list of possible objects from a robust 3D localization model: []. These objects are considered potential matches for the given description. Kindly provide the object ID that you believe is the closest match to the description. If you believe none of the listed objects are a correct match, please specify an alternative object ID."
36 |
37 | item_list = []
38 | train_output_annos = []
39 | val_output_annos = []
40 |
41 | acc = 0
42 | random_acc = 0
43 | origin_acc = 0
44 | tot_len = 0
45 | max_len = 0
46 | thres = 1e-2
47 | tot_num = 0
48 | for k, v in val_results.items():
49 | obj_ids = v["obj_ids"]
50 | obj_logits = v["obj_logits"]
51 | obj_logits = (torch.tensor(obj_logits)).softmax(dim=-1).tolist()
52 | scan_id = nr3d_anno[k]["scan_id"]
53 | utter = nr3d_anno[k]["utterance"]
54 | target_id = nr3d_anno[k]["target_id"]
55 | logit_ids = zip(obj_logits, obj_ids)
56 | logit_ids = sorted(logit_ids, reverse=True)
57 | logits, ids = zip(*logit_ids)
58 | # print(logits)
59 | # print(ids)
60 | # print(target_id)
61 | # breakpoint()
62 | can_ids = []
63 | if split == "val":
64 | for i in range(min(len(logits), 5)):
65 | if logits[i] > thres:
66 | can_ids.append(ids[i])
67 | else:
68 | can_num = random.randint(1, 5)
69 | can_ids = ids[:can_num]
70 | if len(can_ids) == 1:
71 | continue
72 | # can_ids = sorted(can_ids)
73 | id_list = ""
74 | for i in range(len(can_ids)):
75 | if i > 0:
76 | id_list += ", "
77 | id_list += f"obj{can_ids[i]:02}"
78 | if utter[-1] == ".":
79 | utter = utter[:-1]
80 | prompt = multi_choice_template.replace("", utter).replace("", id_list)
81 | answer = f"obj{target_id:02}.".capitalize()
82 | # logits = (torch.tensor(logits[:5]) / 5.).softmax(dim=-1).tolist()
83 | # print(logits)
84 | # if ids[0] == target_id:
85 | # acc += 1
86 |
87 | # item_list.append({
88 | # "can_ids": ids[:5],
89 | # "can_preds": logits[:5],
90 | # "utter": utter,
91 | # "target_id": target_id,
92 | # "scan_id": scan_id
93 | # })
94 | if scan_id in train_scene_ids:
95 | train_output_annos.append({
96 | "scene_id": scan_id,
97 | "obj_id": target_id,
98 | "prompt": prompt,
99 | "caption": answer,
100 | })
101 | else:
102 | val_output_annos.append({
103 | "scene_id": scan_id,
104 | "obj_id": target_id,
105 | "prompt": prompt,
106 | "ref_captions": [answer],
107 | "qid": k
108 | })
109 | if target_id in can_ids:
110 | acc += 1
111 | if random.choice(can_ids) == target_id:
112 | random_acc += 1
113 | if ids[0] == target_id:
114 | origin_acc += 1
115 | tot_len += len(can_ids)
116 | max_len = len(can_ids) if len(can_ids) > max_len else max_len
117 | tot_num += 1
118 |
119 | # if split == "val":
120 | # with open(f"annotations/nr3d_{split}_stage2_multichoice{str(thres)}.json", "w") as f:
121 | # json.dump(train_output_annos, f, indent=4)
122 | # else:
123 | # with open(f"annotations/nr3d_{split}_stage2_multichoice.json", "w") as f:
124 | # json.dump(val_output_annos, f, indent=4)
125 |
126 | with open(f"annotations/nr3d_train_stage2_multichoice{str(thres)}.json", "w") as f:
127 | json.dump(train_output_annos, f, indent=4)
128 | with open(f"annotations/nr3d_val_stage2_multichoice{str(thres)}.json", "w") as f:
129 | json.dump(val_output_annos, f, indent=4)
130 |
131 | print(tot_num)
132 | print("Origin Acc:", float(origin_acc) / tot_num)
133 | print("Upper Acc:", float(acc) / tot_num)
134 | print("Random Acc:", float(random_acc) / tot_num)
135 | print("mean len:", float(tot_len) / tot_num)
136 | print("max len:", max_len)
137 | # print(len(item_list))
138 | # print(item_list[:5])
139 | # exit()
140 |
--------------------------------------------------------------------------------
/others/run_chat.sh:
--------------------------------------------------------------------------------
1 |
2 | pretrained_path=outputs/2023-09-07-205742_dp0.1_lr5e-5_sta3_ep3/ckpt_02.pth
3 |
4 | CUDA_VISIBLE_DEVICES=2 python others/process_vil3dref_results.py \
5 | scripts/config_old.py \
6 | pretrained_path "$pretrained_path" \
7 | model.max_txt_len 20
8 |
--------------------------------------------------------------------------------
/others/run_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model=2023-09-17-144020_dp0.1_lr5e-5_sta3_ep3
4 | qa_name=qa60
5 | OPENAI_API_KEY="" python eval/eval.py \
6 | --question eval/"$qa_name"_questions.json \
7 | --context anno/scanrefer_val_content.json \
8 | --answer-list \
9 | eval/"$qa_name"_gpt4_answer.json \
10 | eval/"$qa_name"_"$model"_answer.json \
11 | --rule eval/rule.json \
12 | --output eval/review_"$qa_name"_"$model".jsonl
--------------------------------------------------------------------------------
/others/run_generate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | OPENAI_API_KEY="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" python others/gpt_generate.py
--------------------------------------------------------------------------------
/others/tmp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | from collections import defaultdict
5 |
6 | # annos = json.load(open("annotations/obj_align_train_OBJ.json"))
7 |
8 | # new_annos = []
9 | # obj_ids = set()
10 |
11 | # for anno in annos:
12 | # if anno['scene_id'] == "scene0000_00":
13 | # tmp_id = int(anno["caption"].split("OBJ")[1][:3])
14 | # if tmp_id in obj_ids:
15 | # continue
16 | # obj_ids.add(tmp_id)
17 | # if anno["caption"].startswith("")
44 | # caption = ''
45 | # for idx in range(len(clean_text)):
46 | # if idx in p:
47 | # if p[idx][0] == '[':
48 | # caption += '['
49 | # else:
50 | # caption += ' ' + ', '.join(p[idx]) + ']'
51 | # caption += clean_text[idx]
52 | # return caption
53 |
54 |
55 | # scannet_root = '/mnt/petrelfs/share_data/huanghaifeng/data/processed/scannet'
56 | # x = json.load(open('data/step2_captions_by_scene_v2.json'))
57 |
58 | # train_annos = []
59 | # val_annos = []
60 |
61 | # train_scans = [line.rstrip() for line in open(os.path.join(scannet_root, f'train.txt'))]
62 | # val_scans = [line.rstrip() for line in open(os.path.join(scannet_root, f'val.txt'))]
63 |
64 |
65 | # import pandas as pd
66 | # from tqdm import tqdm
67 |
68 | # obj_csv = pd.read_csv('annotations/Cap3D_automated_Objaverse_no3Dword.csv', header=None)
69 | # obj_ids = []
70 | # obj_cap_dict = {}
71 | # feats = torch.load('annotations/objaverse_uni3d_feature.pt')
72 |
73 | # for obj_id, cap in tqdm(zip(obj_csv[0].values, obj_csv[1].values)):
74 | # # remove redundant quotation marks, here we do not directly strip because the mark may appear only at one side
75 | # if obj_id not in feats:
76 | # continue
77 | # if cap.startswith('"') and cap.endswith('"'):
78 | # cap = cap.strip('"')
79 | # elif cap.startswith("'") and cap.endswith("'"):
80 | # cap = cap.strip("'")
81 | # cap = cap.capitalize()
82 | # obj_ids.append(obj_id)
83 | # obj_cap_dict[obj_id] = cap
84 |
85 | # train_annos = []
86 | # val_annos = []
87 | # train_obj_ids = obj_ids[:-1000]
88 | # val_obj_ids = obj_ids[-1000:]
89 |
90 |
91 | # for obj_id in train_obj_ids:
92 | # train_annos.append({
93 | # 'scene_id': obj_id,
94 | # 'caption': obj_cap_dict[obj_id]
95 | # })
96 |
97 | # for obj_id in val_obj_ids:
98 | # val_annos.append({
99 | # 'scene_id': obj_id,
100 | # 'ref_captions': [obj_cap_dict[obj_id]]
101 | # })
102 |
103 | # print(len(train_annos))
104 | # print(len(val_annos))
105 |
106 |
107 | # with open('annotations/objaverse_caption_train.json', 'w') as f:
108 | # json.dump(train_annos, f, indent=4)
109 |
110 | # with open('annotations/objaverse_caption_val.json', 'w') as f:
111 | # json.dump(val_annos, f, indent=4)
112 |
113 |
114 | # train_feats = {}
115 | # val_feats = {}
116 |
117 | # for obj_id in train_obj_ids:
118 | # train_feats[obj_id] = feats[obj_id]
119 | # for obj_id in val_obj_ids:
120 | # val_feats[obj_id] = feats[obj_id]
121 |
122 | # torch.save(train_feats, 'annotations/objaverse_uni3d_feature_train.pt')
123 | # torch.save(val_feats, 'annotations/objaverse_uni3d_feature_val.pt')
124 |
125 |
126 | # import os
127 | # import gzip
128 | # import numpy as np
129 | # from tqdm import tqdm
130 |
131 | # folder_path = '/mnt/petrelfs/huanghaifeng/share/data/cap3d/8192_npy'
132 |
133 | # for filename in tqdm(os.listdir(folder_path)):
134 | # if filename.endswith('.npy'):
135 | # obj_id = filename.split('_8192')[0]
136 | # data = np.load(os.path.join(folder_path, filename))
137 | # with gzip.open(os.path.join(folder_path, obj_id + '.gz'), 'wb') as f:
138 | # np.save(f, data)
139 | # os.remove(os.path.join(folder_path, filename))
140 |
141 | import csv
142 | id2class = {}
143 | labels = set()
144 | class_label_file = "annotations/scannet/scannetv2-labels.combined.tsv"
145 | with open(class_label_file, "r") as f:
146 | csvreader = csv.reader(f, delimiter='\t')
147 | csvreader.__next__()
148 | for line in csvreader:
149 | id2class[line[0]] = line[1]
150 | labels.add(line[2])
151 | print(len(labels))
152 |
153 |
--------------------------------------------------------------------------------
/others/tmp2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 |
4 | split = 'val'
5 |
6 | scanqa_anno = json.load(open(f"annotations/scanqa_{split}_stage2_objxx.json"))
7 | scanrefer_anno = json.load(open(f"annotations/scanrefer_{split}_stage2_objxx.json"))
8 |
9 | scanqa_scan_list = np.unique([x['scene_id'] for x in scanqa_anno])
10 | scanrefer_scan_list = np.unique([x['scene_id'] for x in scanrefer_anno])
11 |
12 | print(len(set(scanrefer_scan_list) - set(scanqa_scan_list)))
13 | print(len(scanqa_scan_list))
14 | print(len(scanrefer_scan_list))
15 |
--------------------------------------------------------------------------------
/preprocess/README.md:
--------------------------------------------------------------------------------
1 | ## Skip the data preparation
2 |
3 | - Chat-Scene has provided all the prepared data in [Google Drive](https://drive.google.com/drive/folders/1iwVFUkvveehvwGcAnJK3EwLxBt5ggR2c?usp=sharing). Simply download the files and place them in the annotations/ directory. You’ll then be ready to run and test the code.
4 |
5 | - We've provided preprocessed VL-SAT features for semantic relations between objects as well as additional text annotations in [Yandex Disk](https://disk.yandex.ru/d/LpPJgHg8Qg6BpA)
6 |
7 | - We've provided VL-SAT features for fully-connected graphs with semantic relations between objects in [Yandex Disk](https://disk.yandex.ru/d/LpPJgHg8Qg6BpA) (output_vlsat.zip)
8 |
9 | ## Prepare data
10 |
11 | - Download the ScanNet dataset by following the [ScanNet instructions](https://github.com/ScanNet/ScanNet).
12 |
13 | - Extract object masks using a pretrained 3D detector:
14 | - Use [Mask3D](https://github.com/JonasSchult/Mask3D) for instance segmentation. We used the [checkpoint](https://omnomnom.vision.rwth-aachen.de/data/mask3d/checkpoints/scannet200/scannet200_val.ckpt) pretrained on ScanNet200.
15 | - The complete predicted results (especially the masks) for the train/validation sets are too large to share (~40GB). We’ve shared the post-processed [results](https://drive.google.com/file/d/1jwQYJvkWwRmawZvNOSy6U0lnqnEiasNX/view?usp=sharing):
16 | - Unzip the `mask3d_inst_seg.tar.gz` file.
17 | - Each file under `mask3d_inst_seg` contains the predicted results for a single scene, including a list of segmented instances with their labels and segmented indices.
18 |
19 | - Process object masks and prepare annotations:
20 | - If you use Mask3D for instance segmentation, set the `segment_result_dir` in [run_prepare.sh](run_prepare.sh) to the output directory of Mask3D.
21 | - If you use the downloaded `mask3d_inst_seg` directly, set `segment_result_dir` to None and set `inst_seg_dir` to the path of `mask3d_inst_seg`.
22 | - Run: `bash preprocess/run_prepare.sh`
23 |
24 | - Extract 3D features using a pretrained 3D encoder:
25 | - Follow [Uni3D](https://github.com/baaivision/Uni3D?tab=readme-ov-file) to extract 3D features for each instance. We used the pretrained model [uni3d-g](https://huggingface.co/BAAI/Uni3D/blob/main/modelzoo/uni3d-g/model.pt).
26 | - We've also provided modified code for feature extraction in this forked [repository](https://github.com/ZzZZCHS/Uni3D). Set the `data_dir` [here](https://github.com/ZzZZCHS/Uni3D/blob/main/main.py#L620) to the path to `${processed_data_dir}/pcd_all` (`processed_data_dir` is an intermediate directory set in `run_prepare.sh`). After preparing the environment, run `bash scripts/inference.sh`.
27 |
28 | - Extract 2D features using a pretrained 2D encoder:
29 |
30 | - We followed [OpenScene](https://github.com/pengsongyou/openscene)'s code to calculate the mapping between 3D points and 2D image pixels. This allows each object to be projected onto multi-view images. Based on the projected masks on the images, we extract and merge DINOv2 features from multi-view images for each object.
31 |
32 | - [TODO] Detailed implementation will be released.
33 |
34 | - Obtain connections based on the N nearest neighbors for each object, filter the fully connected graphs with VLSAT features for Mask3D segmentation. To achieve this, run the ```prepare_filtered_mask3d_gnn_data.py``` script after updating the paths to the directories containing the fully connected graphs for each scene, the object attributes, and the ScanNet splits. The number of nearest neighbors can be adjusted by modifying the ```KNN``` parameter at the beginning of the ```prepare_filtered_mask3d_gnn_data.py``` script.
35 |
36 | - Obtain connections based on the N nearest neighbors for each object, filter the fully connected graphs with VLSAT features for GT segmentation. To achieve this, run the ```prepare_gnn_data.py``` script after updating the paths to the directories containing the fully connected graphs for each scene, the object attributes, and the ScanNet splits. The number of nearest neighbors can be adjusted by modifying the ```KNN``` parameter at the beginning of the ```prepare_gnn_data.py``` script.
--------------------------------------------------------------------------------
/preprocess/prepare_mask3d_img_feat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser()
12 |
13 | parser.add_argument('--segmentor', required=True, type=str)
14 | parser.add_argument('--version', type=str, default='')
15 | args = parser.parse_args()
16 |
17 |
18 | segmentor = args.segmentor
19 | version = args.version
20 | # annos = json.load(open(f"annotations/scanrefer_{split}_stage2_grounding_OBJ.json", "r"))
21 | feats = torch.load(f'annotations/scannet_img_dinov2_features.pt')
22 | new_feats = {}
23 | item2iou = {}
24 | iou_thres = 0.5
25 |
26 | for split in ['train', 'val']:
27 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
28 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
29 | instance_attrs = torch.load(instance_attribute_file)
30 | scannet_attrs = torch.load(scannet_attribute_file)
31 | for k, v in tqdm(feats.items()):
32 | scene_id = '_'.join(k.split('_')[:2])
33 | if scene_id not in instance_attrs:
34 | continue
35 | obj_id = int(k.split('_')[-1])
36 | instance_locs = instance_attrs[scene_id]["locs"]
37 | scannet_locs = scannet_attrs[scene_id]["locs"]
38 | instance_num = instance_locs.shape[0]
39 | max_iou, max_id = -1, -1
40 | for pred_id in range(instance_num):
41 | pred_locs = instance_locs[pred_id].tolist()
42 | try:
43 | gt_locs = scannet_locs[obj_id].tolist()
44 | except:
45 | gt_locs = scannet_locs[obj_id-31].tolist()
46 | # print(k)
47 | # breakpoint()
48 | # break
49 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
50 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
51 | iou = box3d_iou(pred_corners, gt_corners)
52 | if iou > max_iou:
53 | max_iou = iou
54 | max_id = pred_id
55 | item_id = f"{scene_id}_{max_id:02}"
56 | if max_iou > iou_thres and (item_id not in new_feats or item2iou[item_id] < max_iou):
57 | new_feats[item_id] = v
58 | item2iou[item_id] = max_iou
59 |
60 | torch.save(new_feats, f'annotations/scannet_img_mask3d_dinov2_features{version}.pt')
--------------------------------------------------------------------------------
/preprocess/prepare_multi3dref_annos.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import os
4 | import sys
5 | sys.path.append('.')
6 | from tqdm import tqdm
7 | import argparse
8 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
9 | from prompts.prompts import multi3dref_prompt, ID_format
10 | import random
11 | from collections import defaultdict
12 | import string
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--segmentor', required=True, type=str)
16 | parser.add_argument('--version', type=str, default='')
17 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
18 | args = parser.parse_args()
19 |
20 | segmentor = args.segmentor
21 | version = args.version
22 |
23 | for split in ['train', 'val']:
24 | annos = json.load(open(f"annotations/multi3drefer/multi3drefer_{split}.json"))
25 | new_annos = []
26 |
27 | count_all = defaultdict(int)
28 | count_used = defaultdict(int)
29 | if segmentor == 'deva':
30 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu')
31 | else:
32 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
33 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
34 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
35 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
36 |
37 | for i, anno in tqdm(enumerate(annos)):
38 | scene_id = anno['scene_id']
39 | count_all[anno['eval_type']] += 1
40 | if segmentor == 'deva':
41 | if scene_id not in seg_gt_ious:
42 | continue
43 | seg_gt_iou = seg_gt_ious[scene_id]
44 | else:
45 | if scene_id not in instance_attrs:
46 | continue
47 | instance_locs = instance_attrs[scene_id]["locs"]
48 | scannet_locs = scannet_attrs[scene_id]["locs"]
49 | instance_num = instance_locs.shape[0]
50 | gt_ids = anno['object_ids']
51 | caption = anno['description']
52 | if caption[-1] in string.punctuation:
53 | caption = caption[:-1]
54 | prompt = random.choice(multi3dref_prompt).replace("", caption)
55 | pred_ids = []
56 | flag = 1
57 | for gt_id in gt_ids:
58 | if segmentor == 'deva':
59 | if gt_id >= seg_gt_iou.shape[1]:
60 | flag = 0
61 | break
62 | max_iou, max_id = seg_gt_iou[:, gt_id].max(0)
63 | max_iou = float(max_iou)
64 | max_id = int(max_id)
65 | else:
66 | max_iou, max_id = -1, -1
67 | for pred_id in range(instance_num):
68 | pred_locs = instance_locs[pred_id].tolist()
69 | gt_locs = scannet_locs[gt_id].tolist()
70 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
71 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
72 | iou = box3d_iou(pred_corners, gt_corners)
73 | if iou > max_iou:
74 | max_iou = iou
75 | max_id = pred_id
76 | if split == 'train' and (max_iou < args.train_iou_thres or max_id in pred_ids):
77 | flag = 0
78 | break
79 | pred_ids.append(max_id)
80 | if flag == 0:
81 | continue
82 | count_used[anno['eval_type']] += 1
83 | pred_ids = sorted(pred_ids)
84 | pred_id_strs = [ID_format.format(pred_id) for pred_id in pred_ids]
85 | if len(pred_ids) == 0:
86 | answer = "No."
87 | elif len(pred_ids) == 1:
88 | answer = f"Yes. {pred_id_strs[0]}."
89 | elif len(pred_ids) == 2:
90 | answer = f"Yes. {pred_id_strs[0]} and {pred_id_strs[1]}."
91 | else:
92 | answer = f"Yes. {', '.join(pred_id_strs[:-1])}, and {pred_id_strs[-1]}."
93 | if split == 'train':
94 | new_annos.append({
95 | 'scene_id': scene_id,
96 | 'obj_id': 0,
97 | 'prompt': prompt,
98 | 'caption': answer,
99 | 'eval_type': anno['eval_type']
100 | })
101 | else:
102 | new_annos.append({
103 | 'scene_id': scene_id,
104 | 'obj_id': 0,
105 | 'prompt': prompt,
106 | 'ref_captions': gt_ids,
107 | 'eval_type': anno['eval_type']
108 | })
109 |
110 | print(f"Split: {split}")
111 | print(f"Count all: {len(annos)}", count_all)
112 | print(f"Count used: {len(new_annos)}", count_used)
113 |
114 | with open(f"annotations/multi3dref_{segmentor}_{split}{version}.json", "w") as f:
115 | json.dump(new_annos, f, indent=4)
116 |
--------------------------------------------------------------------------------
/preprocess/prepare_multi3dref_location_annos.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import os
4 | import sys
5 | sys.path.append('.')
6 | from tqdm import tqdm
7 | import argparse
8 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
9 | from prompts.prompts import multi3dref_location_prompt, ID_format
10 | import random
11 | from collections import defaultdict
12 | import string
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--segmentor', required=True, type=str)
16 | parser.add_argument('--version', type=str, default='')
17 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
18 | args = parser.parse_args()
19 |
20 | segmentor = args.segmentor
21 | version = args.version
22 |
23 | def num_to_location_token(ori_num):
24 | ori_num = int(ori_num * 100) + 500
25 | if ori_num < 0:
26 | ori_num = 0
27 | if ori_num > 999:
28 | ori_num = 999
29 | return f""
30 |
31 | for split in ['train', 'val']:
32 | annos = json.load(open(f"annotations/multi3drefer/multi3drefer_{split}.json"))
33 | new_annos = []
34 |
35 | count_all = defaultdict(int)
36 | count_used = defaultdict(int)
37 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
38 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
39 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
40 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
41 |
42 | for i, anno in tqdm(enumerate(annos)):
43 | scene_id = anno['scene_id']
44 | count_all[anno['eval_type']] += 1
45 | if scene_id not in instance_attrs:
46 | continue
47 | instance_locs = instance_attrs[scene_id]["locs"]
48 | scannet_locs = scannet_attrs[scene_id]["locs"]
49 | instance_num = instance_locs.shape[0]
50 | gt_ids = anno['object_ids']
51 | caption = anno['description']
52 | if caption[-1] in string.punctuation:
53 | caption = caption[:-1]
54 | prompt = random.choice(multi3dref_location_prompt).replace("", caption)
55 | locs_caption_list = []
56 | flag = 1
57 | for gt_id in gt_ids:
58 | max_iou, max_id = -1, -1
59 | # for pred_id in range(instance_num):
60 | # pred_locs = instance_locs[pred_id].tolist()
61 | # gt_locs = scannet_locs[gt_id].tolist()
62 | # pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
63 | # gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
64 | # iou = box3d_iou(pred_corners, gt_corners)
65 | # if iou > max_iou:
66 | # max_iou = iou
67 | # max_id = pred_id
68 | gt_locs = scannet_locs[gt_id].tolist()
69 | gt_loc_tokens = [num_to_location_token(x) for x in gt_locs]
70 | tmp_loc_caption = " " + " ".join(gt_loc_tokens) + " "
71 | locs_caption_list.append(tmp_loc_caption)
72 | if flag == 0:
73 | continue
74 | count_used[anno['eval_type']] += 1
75 | locs_caption_list = sorted(locs_caption_list)
76 | # pred_ids = sorted(pred_ids)
77 | # pred_id_strs = [ID_format.format(pred_id) for pred_id in pred_ids]
78 | if len(locs_caption_list) == 0:
79 | answer = "No."
80 | elif len(locs_caption_list) == 1:
81 | answer = f"Yes. {locs_caption_list[0]}."
82 | elif len(locs_caption_list) == 2:
83 | answer = f"Yes. {locs_caption_list[0]} and {locs_caption_list[1]}."
84 | else:
85 | answer = f"Yes. {', '.join(locs_caption_list[:-1])}, and {locs_caption_list[-1]}."
86 | if split == 'train':
87 | new_annos.append({
88 | 'scene_id': scene_id,
89 | 'obj_id': 0,
90 | 'prompt': prompt,
91 | 'caption': answer,
92 | 'eval_type': anno['eval_type']
93 | })
94 | else:
95 | new_annos.append({
96 | 'scene_id': scene_id,
97 | 'obj_id': 0,
98 | 'prompt': prompt,
99 | 'ref_captions': gt_ids,
100 | 'eval_type': anno['eval_type']
101 | })
102 |
103 | print(f"Split: {split}")
104 | print(f"Count all: {len(annos)}", count_all)
105 | print(f"Count used: {len(new_annos)}", count_used)
106 |
107 | with open(f"annotations/multi3dref_{segmentor}_{split}_location{version}.json", "w") as f:
108 | json.dump(new_annos, f, indent=4)
109 |
--------------------------------------------------------------------------------
/preprocess/prepare_nr3dcaption_annos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import argparse
10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
11 | from prompts.prompts import nr3d_caption_prompt
12 | import csv
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument('--segmentor', required=True, type=str)
18 | parser.add_argument('--version', type=str, default='')
19 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
20 | args = parser.parse_args()
21 |
22 | segmentor = args.segmentor
23 | version = args.version
24 |
25 | train_scenes = [x.strip() for x in open('annotations/scannet/scannetv2_train.txt').readlines()]
26 | val_scenes = [x.strip() for x in open('annotations/scannet/scannetv2_val.txt').readlines()]
27 | scene_lists = {
28 | 'train': train_scenes,
29 | 'val': val_scenes
30 | }
31 |
32 | raw_annos = []
33 | with open('annotations/referit3d/nr3d.csv') as csvfile:
34 | reader = csv.DictReader(csvfile)
35 | for row in reader:
36 | raw_annos.append({
37 | 'scene_id': row['scan_id'],
38 | 'obj_id': int(row['target_id']),
39 | 'caption': row['utterance']
40 | })
41 |
42 | for split in ["train"]:
43 | annos = [anno for anno in raw_annos if anno['scene_id'] in scene_lists[split]]
44 | new_annos = []
45 |
46 | if segmentor == 'deva':
47 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu')
48 | else:
49 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
50 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
51 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
52 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
53 |
54 | for anno in tqdm(annos):
55 | scene_id = anno['scene_id']
56 | obj_id = anno['obj_id']
57 | if segmentor == 'deva':
58 | if scene_id not in seg_gt_ious:
59 | continue
60 | seg_gt_iou = seg_gt_ious[scene_id]
61 | if obj_id >= seg_gt_iou.shape[1]:
62 | continue
63 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0)
64 | max_iou = float(max_iou)
65 | max_id = int(max_id)
66 | else:
67 | if scene_id not in instance_attrs:
68 | continue
69 | instance_locs = instance_attrs[scene_id]['locs']
70 | scannet_locs = scannet_attrs[scene_id]['locs']
71 | max_iou, max_id = -1, -1
72 | for pred_id in range(instance_locs.shape[0]):
73 | pred_locs = instance_locs[pred_id].tolist()
74 | gt_locs = scannet_locs[obj_id].tolist()
75 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
76 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
77 | iou = box3d_iou(pred_corners, gt_corners)
78 | if iou > max_iou:
79 | max_iou = iou
80 | max_id = pred_id
81 | prompt = random.choice(nr3d_caption_prompt).replace('', f"")
82 | if split == 'train':
83 | if max_iou >= args.train_iou_thres:
84 | new_annos.append({
85 | 'scene_id': scene_id,
86 | 'obj_id': obj_id,
87 | 'prompt': prompt,
88 | 'caption': anno['caption']
89 | })
90 | else:
91 | new_annos.append({
92 | 'scene_id': scene_id,
93 | 'obj_id': obj_id,
94 | 'prompt': prompt,
95 | 'ref_captions': [anno['caption']]
96 | })
97 | print(len(new_annos))
98 |
99 | with open(f"annotations/nr3d_caption_{segmentor}_{split}{version}.json", 'w') as f:
100 | json.dump(new_annos, f, indent=4)
101 |
--------------------------------------------------------------------------------
/preprocess/prepare_objalign_annos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import argparse
10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--segmentor', required=True, type=str)
16 | parser.add_argument('--version', type=str, default='')
17 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
18 | args = parser.parse_args()
19 |
20 | unwanted_words = ["wall", "ceiling", "floor", "object", "item"]
21 |
22 | segmentor = args.segmentor
23 | version = args.version
24 |
25 | for split in ["train", "val"]:
26 | new_annos = []
27 |
28 | if segmentor == 'deva':
29 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu')
30 | else:
31 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
32 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
33 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
34 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
35 |
36 | for scene_id in tqdm(scannet_attrs.keys()):
37 | if segmentor == 'deva':
38 | if scene_id not in seg_gt_ious:
39 | continue
40 | seg_gt_iou = seg_gt_ious[scene_id]
41 | segmented_num = seg_gt_iou.shape[0]
42 | gt_num = seg_gt_iou.shape[1]
43 | else:
44 | if scene_id not in instance_attrs:
45 | continue
46 | instance_locs = instance_attrs[scene_id]['locs']
47 | scannet_locs = scannet_attrs[scene_id]['locs']
48 | segmented_num = len(instance_locs)
49 | gt_num = len(scannet_locs)
50 | scannet_class_labels = scannet_attrs[scene_id]['objects']
51 | for obj_id in range(gt_num):
52 | class_label = scannet_class_labels[obj_id]
53 | if any(x in class_label for x in unwanted_words):
54 | continue
55 | if segmentor == 'deva':
56 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0)
57 | max_iou = float(max_iou)
58 | max_id = int(max_id)
59 | else:
60 | max_iou, max_id = -1, -1
61 | for pred_id in range(instance_locs.shape[0]):
62 | pred_locs = instance_locs[pred_id].tolist()
63 | gt_locs = scannet_locs[obj_id].tolist()
64 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
65 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
66 | iou = box3d_iou(pred_corners, gt_corners)
67 | if iou > max_iou:
68 | max_iou = iou
69 | max_id = pred_id
70 | prompt = f"What is the ?"
71 | caption = f" is a {class_label}."
72 | if split == 'train':
73 | if max_iou >= args.train_iou_thres:
74 | new_annos.append({
75 | 'scene_id': scene_id,
76 | 'obj_id': obj_id,
77 | 'prompt': prompt,
78 | 'caption': caption
79 | })
80 | else:
81 | new_annos.append({
82 | 'scene_id': scene_id,
83 | 'obj_id': obj_id,
84 | 'prompt': prompt,
85 | 'ref_captions': [caption]
86 | })
87 |
88 | print(len(new_annos))
89 | with open(f"annotations/obj_align_{segmentor}_{split}{version}.json", 'w') as f:
90 | json.dump(new_annos, f, indent=4)
--------------------------------------------------------------------------------
/preprocess/prepare_scan2cap_location_annos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import argparse
10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
11 | from prompts.prompts import scan2cap_location_prompt
12 | import nltk
13 |
14 |
15 | def capitalize_sentences(text):
16 | sentences = nltk.sent_tokenize(text)
17 | capitalized_sentences = [sentence.capitalize() for sentence in sentences]
18 | result = ' '.join(capitalized_sentences)
19 | return result
20 |
21 |
22 | parser = argparse.ArgumentParser()
23 |
24 | parser.add_argument('--segmentor', required=True, type=str)
25 | parser.add_argument('--version', type=str, default='')
26 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
27 | parser.add_argument('--max_obj_num', type=int, default=150)
28 | args = parser.parse_args()
29 |
30 |
31 | def num_to_location_token(ori_num):
32 | ori_num = int(ori_num * 100) + 500
33 | if ori_num < 0:
34 | ori_num = 0
35 | if ori_num > 999:
36 | ori_num = 999
37 | return f""
38 |
39 |
40 | for split in ["train", "val"]:
41 | segmentor = args.segmentor
42 | version = args.version
43 | annos = json.load(open(f"annotations/scanrefer/ScanRefer_filtered_{split}.json", "r"))
44 | new_annos = []
45 |
46 | print(len(annos))
47 | scene_ids = set()
48 | corpus = defaultdict(list)
49 | for anno in annos:
50 | gt_key = f"{anno['scene_id']}|{anno['object_id']}"
51 | description = capitalize_sentences(anno['description'])
52 | corpus[gt_key].append(description)
53 | scene_ids.add(anno['scene_id'])
54 | scene_ids = list(scene_ids)
55 |
56 | count = [0] * args.max_obj_num
57 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
58 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
59 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
60 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
61 |
62 |
63 | covered25_num, covered50_num = 0, 0
64 | count_all = 0
65 | for scene_id in tqdm(scene_ids):
66 | if scene_id not in instance_attrs:
67 | continue
68 | instance_locs = instance_attrs[scene_id]["locs"]
69 | scannet_locs = scannet_attrs[scene_id]["locs"]
70 | segmented_num = len(instance_locs)
71 | gt_num = len(scannet_locs)
72 | gt_match_id = [-1] * gt_num
73 | gt_match_iou = [-1] * gt_num
74 | for pred_id in range(segmented_num):
75 | pred_locs = instance_locs[pred_id].tolist()
76 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
77 | max_id = max_iou = -1
78 | for gt_id in range(len(scannet_locs)):
79 | if f"{scene_id}|{gt_id}" not in corpus:
80 | continue
81 | gt_locs = scannet_locs[gt_id].tolist()
82 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
83 | iou = box3d_iou(pred_corners, gt_corners)
84 | if iou > max_iou:
85 | max_iou = iou
86 | max_id = gt_id
87 | if f"{scene_id}|{max_id}" not in corpus:
88 | continue
89 | if max_iou > gt_match_iou[max_id]:
90 | gt_match_iou[max_id] = max_iou
91 | gt_match_id[max_id] = pred_id
92 | for gt_id, pred_id in enumerate(gt_match_id):
93 | if f"{scene_id}|{gt_id}" in corpus:
94 | count_all += len(corpus[f"{scene_id}|{gt_id}"])
95 | if pred_id == -1:
96 | continue
97 | if split == 'train' and gt_match_iou[gt_id] < args.train_iou_thres:
98 | continue
99 | if gt_match_iou[gt_id] >= 0.25:
100 | covered25_num += len(corpus[f"{scene_id}|{gt_id}"])
101 | if gt_match_iou[gt_id] >= 0.5:
102 | covered50_num += len(corpus[f"{scene_id}|{gt_id}"])
103 | count[pred_id] += 1
104 | pred_locs = instance_locs[pred_id].tolist()
105 | loc_tokens = [num_to_location_token(x) for x in pred_locs]
106 | loc_caption = " " + " ".join(loc_tokens) + " "
107 | if split == 'train':
108 | for caption in corpus[f"{scene_id}|{gt_id}"]:
109 | new_annos.append({
110 | 'scene_id': scene_id,
111 | 'obj_id': gt_id,
112 | 'pred_id': pred_id,
113 | 'prompt': random.choice(scan2cap_location_prompt).replace(f"", loc_caption),
114 | "caption": caption,
115 | "iou": gt_match_iou[gt_id]
116 | })
117 | else:
118 | new_annos.append({
119 | 'scene_id': scene_id,
120 | 'obj_id': gt_id,
121 | 'pred_id': pred_id,
122 | 'prompt': random.choice(scan2cap_location_prompt).replace(f"", loc_caption),
123 | "ref_captions": corpus[f"{scene_id}|{gt_id}"],
124 | "iou": gt_match_iou[gt_id]
125 | })
126 |
127 | print(len(new_annos))
128 | print(covered25_num, covered50_num)
129 | print(count_all)
130 | # print(count)
131 |
132 | with open(f"annotations/scan2cap_{segmentor}_{split}_location{version}.json", "w") as f:
133 | json.dump(new_annos, f, indent=4)
--------------------------------------------------------------------------------
/preprocess/prepare_scannet_attributes.py:
--------------------------------------------------------------------------------
1 | from plyfile import PlyData
2 | import numpy as np
3 | import os
4 | import json
5 | import torch
6 | from collections import defaultdict
7 | from tqdm import tqdm
8 | import argparse
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument('--scannet_dir', required=True, type=str,
13 | help='the path of the directory to scannet scans')
14 | args = parser.parse_args()
15 |
16 | raw_data_dir = os.path.join(args.scannet_dir, 'scans')
17 |
18 |
19 | for split in ["train", "val"]:
20 | split_file = f"annotations/scannet/scannetv2_{split}.txt"
21 | scan_names = [line.rstrip() for line in open(split_file)]
22 | print(f'{split} split scans: {len(scan_names)}')
23 | outputs = {}
24 | for scan_id in tqdm(scan_names):
25 | aggregation_path = os.path.join(raw_data_dir, scan_id, scan_id + '.aggregation.json')
26 | segs_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.0.010000.segs.json')
27 | scan_ply_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.labels.ply')
28 |
29 | data = PlyData.read(scan_ply_path)
30 | x = np.asarray(data.elements[0].data['x']).astype(np.float32)
31 | y = np.asarray(data.elements[0].data['y']).astype(np.float32)
32 | z = np.asarray(data.elements[0].data['z']).astype(np.float32)
33 | pc = np.stack([x, y, z], axis=1)
34 |
35 | align_matrix = np.eye(4)
36 | with open(os.path.join(raw_data_dir, scan_id, '%s.txt'%(scan_id)), 'r') as f:
37 | for line in f:
38 | if line.startswith('axisAlignment'):
39 | align_matrix = np.array([float(x) for x in line.strip().split()[-16:]]).astype(np.float32).reshape(4, 4)
40 | break
41 |
42 | pts = np.ones((pc.shape[0], 4), dtype=pc.dtype)
43 | pts[:, 0:3] = pc
44 | pc = np.dot(pts, align_matrix.transpose())[:, :3]
45 |
46 | scan_aggregation = json.load(open(aggregation_path))
47 | segments_info = json.load(open(segs_path))
48 | segment_indices = segments_info["segIndices"]
49 | segment_indices_dict = defaultdict(list)
50 | for i, s in enumerate(segment_indices):
51 | segment_indices_dict[s].append(i)
52 |
53 | pc_segment_label = [''] * pc.shape[0]
54 |
55 | instance_labels = []
56 | inst_locs = []
57 | for idx, object_info in enumerate(scan_aggregation['segGroups']):
58 | object_instance_label = object_info['label']
59 | object_id = object_info['objectId']
60 | segments = object_info["segments"]
61 | pc_ids = []
62 | for s in segments:
63 | pc_ids.extend(segment_indices_dict[s])
64 | object_pc = pc[pc_ids]
65 | object_center = (np.max(object_pc, axis=0) + np.min(object_pc, axis=0)) / 2.0
66 | object_size = np.max(object_pc, axis=0) - np.min(object_pc, axis=0)
67 | object_bbox = torch.from_numpy(np.concatenate([object_center, object_size], axis=0))
68 | inst_locs.append(object_bbox)
69 | instance_labels.append(object_instance_label)
70 | inst_locs = torch.stack(inst_locs, dim=0)
71 | outputs[scan_id] = {
72 | 'objects': instance_labels,
73 | 'locs': inst_locs
74 | }
75 |
76 | torch.save(outputs, f"annotations/scannet_{split}_attributes.pt")
77 |
78 |
--------------------------------------------------------------------------------
/preprocess/prepare_scannet_attributes_clasp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | import glob
5 | import numpy as np
6 | import argparse
7 | from tqdm import tqdm
8 |
9 | parser = argparse.ArgumentParser()
10 |
11 | parser.add_argument('--scan_dir', required=True, type=str,
12 | help='the path of the directory to be saved preprocessed scans')
13 | parser.add_argument('--segmentor', required=True, type=str)
14 | parser.add_argument('--max_inst_num', required=True, type=int)
15 | parser.add_argument('--version', type=str, default='')
16 | args = parser.parse_args()
17 |
18 |
19 | save_feats = {}
20 | for split in ["val"]:
21 | scan_dir = args.scan_dir
22 | output_dir = "annotations"
23 | split_path = f"annotations/scannet/scannetv2_{split}.txt"
24 |
25 | scan_ids = [line.strip() for line in open(split_path).readlines()]
26 |
27 | scan_ids = sorted(scan_ids)
28 | # print(scan_ids)
29 |
30 | scans = {}
31 | for scan_id in scan_ids:
32 | pcd_path = os.path.join(scan_dir, f"{scan_id}.pth")
33 | if not os.path.exists(pcd_path):
34 | # print('skip', scan_id)
35 | continue
36 | pred_results = torch.load(pcd_path, map_location='cpu')
37 | inst_locs = []
38 | pred_boxes = pred_results['pred_boxes']
39 | num_insts = pred_boxes.shape[0]
40 | for i in range(min(num_insts, args.max_inst_num)):
41 | center = pred_boxes[i].mean(dim=0)
42 | size = pred_boxes[i][1] - pred_boxes[i][0]
43 | inst_locs.append(torch.cat([center, size], 0))
44 | save_feats[f"{scan_id}_{i:02}"] = pred_results['queries'][0][i]
45 | inst_locs = torch.stack(inst_locs, dim=0).to(torch.float32)
46 | scans[scan_id] = {
47 | # 'objects': instance_class_labels, # (n_obj, )
48 | 'locs': inst_locs, # (n_obj, 6) center xyz, whl
49 | }
50 | print(f"{split}: {len(scans)}")
51 |
52 | torch.save(scans, os.path.join(output_dir, f"scannet_{args.segmentor}_{split}_attributes{args.version}.pt"))
53 |
54 | # torch.save(save_feats, os.path.join(output_dir, f"scannet_{args.segmentor}_{args.segmentor}_feats.pt"))
--------------------------------------------------------------------------------
/preprocess/prepare_scannet_caption_annos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import argparse
10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
11 | from prompts.prompts import obj_caption_wid_prompt
12 |
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--segmentor', required=True, type=str)
16 | parser.add_argument('--version', type=str, default='')
17 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
18 | args = parser.parse_args()
19 |
20 | segmentor = args.segmentor
21 | version = args.version
22 |
23 | for split in ['train']:
24 | annos = json.load(open(f"annotations/scannet_{split}_caption.json"))
25 | new_annos = []
26 |
27 | if segmentor == 'deva':
28 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu')
29 | else:
30 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
31 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
32 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
33 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
34 |
35 | for i, anno in tqdm(enumerate(annos)):
36 | scene_id = anno['scene_id']
37 | obj_id = anno['obj_id']
38 | if segmentor == 'deva':
39 | if scene_id not in seg_gt_ious:
40 | continue
41 | seg_gt_iou = seg_gt_ious[scene_id]
42 | if obj_id >= seg_gt_iou.shape[1]:
43 | continue
44 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0)
45 | max_iou = float(max_iou)
46 | max_id = int(max_id)
47 | else:
48 | if scene_id not in instance_attrs:
49 | continue
50 | instance_locs = instance_attrs[scene_id]["locs"]
51 | scannet_locs = scannet_attrs[scene_id]["locs"]
52 | instance_num = instance_locs.shape[0]
53 | max_iou, max_id = -1, -1
54 | for pred_id in range(instance_num):
55 | pred_locs = instance_locs[pred_id].tolist()
56 | gt_locs = scannet_locs[obj_id].tolist()
57 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
58 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
59 | iou = box3d_iou(pred_corners, gt_corners)
60 | if iou > max_iou:
61 | max_iou = iou
62 | max_id = pred_id
63 | if split == 'train':
64 | if max_iou > args.train_iou_thres:
65 | new_annos.append({
66 | 'scene_id': scene_id,
67 | 'obj_id': obj_id,
68 | 'prompt': random.choice(obj_caption_wid_prompt).replace('', f""),
69 | 'caption': anno['caption']
70 | })
71 | else:
72 | new_annos.append({
73 | 'scene_id': scene_id,
74 | 'obj_id': obj_id,
75 | 'prompt': random.choice(obj_caption_wid_prompt).replace('', f""),
76 | 'ref_captions': anno['ref_captions']
77 | })
78 |
79 | print(f"Split: {split}")
80 | print(f"{len(annos)} -> {len(new_annos)}")
81 |
82 | with open(f"annotations/scannet_caption_{segmentor}_{split}{version}.json", "w") as f:
83 | json.dump(new_annos, f, indent=4)
84 |
--------------------------------------------------------------------------------
/preprocess/prepare_scannet_mask3d_attributes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import os
4 | import glob
5 | import numpy as np
6 | import argparse
7 | from tqdm import tqdm
8 |
9 | parser = argparse.ArgumentParser()
10 |
11 | parser.add_argument('--scan_dir', required=True, type=str,
12 | help='the path of the directory to be saved preprocessed scans')
13 | parser.add_argument('--segmentor', required=True, type=str)
14 | parser.add_argument('--max_inst_num', required=True, type=int)
15 | parser.add_argument('--version', type=str, default='')
16 | args = parser.parse_args()
17 |
18 |
19 |
20 | for split in ["train", "val"]:
21 | scan_dir = os.path.join(args.scan_dir, 'pcd_all')
22 | output_dir = "annotations"
23 | split_path = f"annotations/scannet/scannetv2_{split}.txt"
24 |
25 | scan_ids = [line.strip() for line in open(split_path).readlines()]
26 |
27 | scan_ids = sorted(scan_ids)
28 | # print(scan_ids)
29 |
30 | scans = {}
31 | for scan_id in tqdm(scan_ids):
32 | pcd_path = os.path.join(scan_dir, f"{scan_id}.pth")
33 | if not os.path.exists(pcd_path):
34 | print('skip', scan_id)
35 | continue
36 | points, colors, instance_class_labels, instance_segids = torch.load(pcd_path)
37 | inst_locs = []
38 | num_insts = len(instance_class_labels)
39 | for i in range(min(num_insts, args.max_inst_num)):
40 | inst_mask = instance_segids[i]
41 | pc = points[inst_mask]
42 | if len(pc) == 0:
43 | print(scan_id, i, 'empty bbox')
44 | inst_locs.append(np.zeros(6, ).astype(np.float32))
45 | continue
46 | size = pc.max(0) - pc.min(0)
47 | center = (pc.max(0) + pc.min(0)) / 2
48 | inst_locs.append(np.concatenate([center, size], 0))
49 | inst_locs = torch.tensor(np.stack(inst_locs, 0), dtype=torch.float32)
50 | scans[scan_id] = {
51 | 'objects': instance_class_labels, # (n_obj, )
52 | 'locs': inst_locs, # (n_obj, 6) center xyz, whl
53 | }
54 |
55 | torch.save(scans, os.path.join(output_dir, f"scannet_{args.segmentor}_{split}_attributes{args.version}.pt"))
--------------------------------------------------------------------------------
/preprocess/prepare_scanqa_annos.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | for split in ['train', 'val']:
4 | if split == 'test':
5 | with open(f"annotations/scanqa/ScanQA_v1.0_test_w_obj.json", "r") as f:
6 | annos = json.load(f)
7 | with open(f"annotations/scanqa/ScanQA_v1.0_test_wo_obj.json", "r") as f:
8 | annos.extend(json.load(f))
9 | else:
10 | with open(f"annotations/scanqa/ScanQA_v1.0_{split}.json", "r") as f:
11 | annos = json.load(f)
12 | print(len(annos))
13 | new_annos = []
14 | for anno in annos:
15 | scene_id = anno["scene_id"]
16 | obj_ids = anno["object_ids"] if "object_ids" in anno else [0]
17 | question = anno["question"]
18 |
19 | prompt = question + " Answer the question using a single word or phrase."
20 |
21 | answers = anno["answers"] if "answers" in anno else []
22 | if split == "train":
23 | for i in range(len(answers)):
24 | if i > 0 and answers[i] == answers[i-1]:
25 | continue
26 | answer = answers[i]
27 | answer = answer.capitalize()
28 | if answer[-1] != ".":
29 | answer += "."
30 | new_annos.append({
31 | "scene_id": scene_id,
32 | "obj_id": obj_ids[0],
33 | "prompt": prompt,
34 | "caption": answer,
35 | })
36 | else:
37 | new_annos.append({
38 | "scene_id": scene_id,
39 | "obj_id": obj_ids[0],
40 | "prompt": prompt,
41 | "ref_captions": answers
42 | })
43 | print(len(new_annos))
44 |
45 | with open(f"annotations/scanqa_{split}.json", "w") as f:
46 | json.dump(new_annos, f, indent=4)
47 |
--------------------------------------------------------------------------------
/preprocess/prepare_scanrefer_annos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import argparse
10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
11 | from prompts.prompts import grounding_prompt
12 | import string
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument('--segmentor', required=True, type=str)
18 | parser.add_argument('--version', type=str, default='')
19 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
20 | parser.add_argument('--max_obj_num', type=int, default=150)
21 | args = parser.parse_args()
22 |
23 | segmentor = args.segmentor
24 | version = args.version
25 |
26 | for split in ["train", "val"]:
27 | count = [0] * args.max_obj_num
28 | annos = json.load(open(f"annotations/scanrefer/ScanRefer_filtered_{split}.json", "r"))
29 | annos = sorted(annos, key=lambda p: f"{p['scene_id']}_{int(p['object_id']):03}")
30 | new_annos = []
31 |
32 | if segmentor == 'deva':
33 | seg_gt_ious = torch.load(f"annotations/scannet_{segmentor}_seg_gt_ious.pt", map_location='cpu')
34 | else:
35 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
36 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
37 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
38 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
39 |
40 | iou25_count = 0
41 | iou50_count = 0
42 | # maxiou_count = 0
43 | # valid_count = 0
44 | for i, anno in tqdm(enumerate(annos)):
45 | scene_id = anno['scene_id']
46 | obj_id = int(anno['object_id'])
47 | desc = anno['description']
48 | if desc[-1] in string.punctuation:
49 | desc = desc[:-1]
50 | prompt = random.choice(grounding_prompt).replace('', desc)
51 |
52 | if segmentor == 'deva':
53 | if scene_id not in seg_gt_ious:
54 | continue
55 | seg_gt_iou = seg_gt_ious[scene_id]
56 | if obj_id >= seg_gt_iou.shape[1]:
57 | continue
58 | max_iou, max_id = seg_gt_iou[:, obj_id].max(0)
59 | max_iou = float(max_iou)
60 | max_id = int(max_id)
61 | else:
62 | if scene_id not in instance_attrs:
63 | continue
64 | instance_locs = instance_attrs[scene_id]["locs"]
65 | scannet_locs = scannet_attrs[scene_id]["locs"]
66 | instance_num = instance_locs.shape[0]
67 | max_iou, max_id = -1, -1
68 | for pred_id in range(instance_num):
69 | pred_locs = instance_locs[pred_id].tolist()
70 | gt_locs = scannet_locs[obj_id].tolist()
71 | pred_corners = construct_bbox_corners(pred_locs[:3], pred_locs[3:])
72 | gt_corners = construct_bbox_corners(gt_locs[:3], gt_locs[3:])
73 | iou = box3d_iou(pred_corners, gt_corners)
74 | if iou > max_iou:
75 | max_iou = iou
76 | max_id = pred_id
77 | # maxiou_count += max_iou
78 | # valid_count += 1
79 | if max_iou >= 0.25:
80 | iou25_count += 1
81 | if max_iou >= 0.5:
82 | iou50_count += 1
83 | count[max_id] += 1
84 | if split == "train":
85 | if max_iou >= args.train_iou_thres:
86 | new_annos.append({
87 | "scene_id": scene_id,
88 | "obj_id": max_id,
89 | "caption": f".",
90 | "prompt": prompt
91 | })
92 | else:
93 | new_annos.append({
94 | "scene_id": scene_id,
95 | "obj_id": obj_id,
96 | "ref_captions": [f"."],
97 | "prompt": prompt
98 | })
99 |
100 | print(len(new_annos))
101 | print(count)
102 | # print(maxiou_count / valid_count)
103 | # print(f"max iou@0.25: {iou25_count / len(new_annos)}")
104 | # print(f"max iou@0.5: {iou50_count / len(new_annos)}")
105 |
106 | with open(f"annotations/scanrefer_{segmentor}_{split}{version}.json", "w") as f:
107 | json.dump(new_annos, f, indent=4)
--------------------------------------------------------------------------------
/preprocess/prepare_scanrefer_location_annos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import sys
4 | sys.path.append('.')
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import argparse
10 | from utils.box_utils import get_box3d_min_max, box3d_iou, construct_bbox_corners
11 | from prompts.prompts import grounding_location_prompt
12 | import string
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument('--segmentor', required=True, type=str)
18 | parser.add_argument('--version', type=str, default='')
19 | parser.add_argument('--train_iou_thres', type=float, default=0.75)
20 | parser.add_argument('--max_obj_num', type=int, default=150)
21 | args = parser.parse_args()
22 |
23 | segmentor = args.segmentor
24 | version = args.version
25 |
26 | def num_to_location_token(ori_num):
27 | ori_num = int(ori_num * 100) + 500
28 | if ori_num < 0:
29 | ori_num = 0
30 | if ori_num > 999:
31 | ori_num = 999
32 | return f""
33 |
34 | for split in ["train", "val"]:
35 | count = [0] * args.max_obj_num
36 | annos = json.load(open(f"annotations/scanrefer/ScanRefer_filtered_{split}.json", "r"))
37 | annos = sorted(annos, key=lambda p: f"{p['scene_id']}_{int(p['object_id']):03}")
38 | new_annos = []
39 |
40 | instance_attribute_file = f"annotations/scannet_{segmentor}_{split}_attributes{version}.pt"
41 | scannet_attribute_file = f"annotations/scannet_{split}_attributes.pt"
42 | instance_attrs = torch.load(instance_attribute_file, map_location='cpu')
43 | scannet_attrs = torch.load(scannet_attribute_file, map_location='cpu')
44 |
45 | iou25_count = 0
46 | iou50_count = 0
47 | # maxiou_count = 0
48 | # valid_count = 0
49 | for i, anno in tqdm(enumerate(annos)):
50 | scene_id = anno['scene_id']
51 | obj_id = int(anno['object_id'])
52 | desc = anno['description']
53 | if desc[-1] in string.punctuation:
54 | desc = desc[:-1]
55 | prompt = random.choice(grounding_location_prompt).replace('', desc)
56 |
57 | if scene_id not in instance_attrs:
58 | continue
59 | scannet_locs = scannet_attrs[scene_id]["locs"]
60 | gt_locs = scannet_locs[obj_id].tolist()
61 |
62 | gt_loc_tokens = [num_to_location_token(x) for x in gt_locs]
63 | caption = " " + " ".join(gt_loc_tokens) + " "
64 |
65 | if split == "train":
66 | new_annos.append({
67 | "scene_id": scene_id,
68 | "obj_id": obj_id,
69 | "caption": caption,
70 | "prompt": prompt
71 | })
72 | else:
73 | new_annos.append({
74 | "scene_id": scene_id,
75 | "obj_id": obj_id,
76 | "ref_captions": [caption],
77 | "prompt": prompt
78 | })
79 |
80 | print(len(new_annos))
81 | print(count)
82 | # print(maxiou_count / valid_count)
83 | # print(f"max iou@0.25: {iou25_count / len(new_annos)}")
84 | # print(f"max iou@0.5: {iou50_count / len(new_annos)}")
85 |
86 | with open(f"annotations/scanrefer_{segmentor}_{split}_location{version}.json", "w") as f:
87 | json.dump(new_annos, f, indent=4)
--------------------------------------------------------------------------------
/preprocess/prepare_sqa3d_annos.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 | import nltk
5 | import random
6 | from tqdm import tqdm
7 |
8 | anno_dir = 'annotations/sqa3d'
9 |
10 |
11 | def convert_person_view(sentence):
12 | # first-person view to second-person view
13 | forms = {'i': 'you', 'me': 'you', 'my': 'your', 'mine': 'yours', 'am': 'are'}
14 | def translate(word):
15 | if word.lower() in forms:
16 | return forms[word.lower()]
17 | return word
18 | result = ' '.join([translate(word) for word in nltk.wordpunct_tokenize(sentence)])
19 | return result.capitalize()
20 |
21 |
22 | def get_sqa_question_type(question):
23 | question = question.lstrip()
24 | if question[:4].lower() == 'what':
25 | return 0
26 | elif question[:2].lower() == 'is':
27 | return 1
28 | elif question[:3].lower() == 'how':
29 | return 2
30 | elif question[:3].lower() == 'can':
31 | return 3
32 | elif question[:5].lower() == 'which':
33 | return 4
34 | else:
35 | return 5 # others
36 |
37 |
38 | for split in ['train', 'val']:
39 | scan_ids = []
40 | sqa_annos = []
41 | question_file = os.path.join(anno_dir, f'v1_balanced_questions_{split}_scannetv2.json')
42 | with open(question_file, 'r', encoding='utf-8') as f:
43 | question_data = json.load(f)['questions']
44 | question_map = {}
45 | for item in question_data:
46 | question_map[item['question_id']] = {
47 | 's': [item['situation']] + item['alternative_situation'], # list of str
48 | 'q': item['question'], # str
49 | }
50 |
51 | anno_file = os.path.join(anno_dir, f'v1_balanced_sqa_annotations_{split}_scannetv2.json')
52 | with open(anno_file, 'r', encoding='utf-8') as f:
53 | anno_data = json.load(f)['annotations']
54 | for item in tqdm(anno_data):
55 | scan_ids.append(item['scene_id'])
56 | scene_id = item['scene_id']
57 | obj_id = 0
58 | situation = random.choice(question_map[item['question_id']]['s'])
59 | question = question_map[item['question_id']]['q']
60 | question_type = get_sqa_question_type(question)
61 | prompt = situation + ' ' + question + " Answer the question using a single word or phrase."
62 | answers = [meta['answer'] for meta in item['answers']]
63 | if split == 'train':
64 | answer = random.choice(answers)
65 | answer = answer.capitalize()
66 | if answer[-1] != ".":
67 | answer += "."
68 | sqa_annos.append({
69 | 'scene_id': scene_id,
70 | 'obj_id': obj_id,
71 | 'prompt': prompt,
72 | 'caption': answer,
73 | 'sqa_type': question_type
74 | })
75 | else:
76 | sqa_annos.append({
77 | 'scene_id': scene_id,
78 | 'obj_id': obj_id,
79 | 'prompt': prompt,
80 | 'ref_captions': answers,
81 | 'sqa_type': question_type
82 | })
83 | with open(f"annotations/sqa3d_{split}.json", "w") as f:
84 | json.dump(sqa_annos, f, indent=4)
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/preprocess/process_scannet_data.py:
--------------------------------------------------------------------------------
1 | from plyfile import PlyData
2 | import numpy as np
3 | import os
4 | import json
5 | from pytorch3d.io import load_obj
6 | import torch
7 | from collections import defaultdict
8 | from tqdm import tqdm
9 | import mmengine
10 |
11 | data_root = '/mnt/petrelfs/share_data/huanghaifeng/maoxiaohan/ScanNet_v2'
12 | raw_data_dir = os.path.join(data_root, 'scans')
13 | meta_data_dir = os.path.join(data_root, 'meta_data')
14 | output_dir = '/mnt/petrelfs/share_data/huanghaifeng/data/processed/scannet'
15 | idx2class = json.load(open('/mnt/petrelfs/share_data/huanghaifeng/referit3d/referit3d/data/mappings/scannet_idx_to_semantic_class.json'))
16 | idx2class = {int(k): v for k, v in idx2class.items()}
17 | class2idx = {v: k for k, v in idx2class.items()}
18 | scan2axis_align = json.load(open('/mnt/petrelfs/share_data/huanghaifeng/data/processed/scannet/scans_axis_alignment_matrices.json'))
19 |
20 |
21 | def process_one_scan(scan_id):
22 | save_dir = os.path.join(output_dir, 'scans', scan_id)
23 | if os.path.exists(os.path.join(save_dir, 'object_infos.json')):
24 | return
25 | # label_path = os.path.join(raw_data_dir, scan_id, "labels.instances.annotated.v2.ply")
26 | aggregation_path = os.path.join(raw_data_dir, scan_id, scan_id + '.aggregation.json')
27 | segs_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.0.010000.segs.json')
28 | scan_ply_path = os.path.join(raw_data_dir, scan_id, scan_id + '_vh_clean_2.labels.ply')
29 |
30 | data = PlyData.read(scan_ply_path)
31 | x = np.asarray(data.elements[0].data['x']).astype(np.float32)
32 | y = np.asarray(data.elements[0].data['y']).astype(np.float32)
33 | z = np.asarray(data.elements[0].data['z']).astype(np.float32)
34 | pc = np.stack([x, y, z], axis=1)
35 |
36 | axis_align_matrix = np.array(scan2axis_align[scan_id], dtype=np.float32).reshape(4, 4)
37 | pts = np.ones((pc.shape[0], 4), dtype=pc.dtype)
38 | pts[:, :3] = pc
39 | pc = np.dot(pts, axis_align_matrix.transpose())[:, :3]
40 |
41 | scan_aggregation = json.load(open(aggregation_path))
42 | segments_info = json.load(open(segs_path))
43 | segment_indices = segments_info["segIndices"]
44 | segment_indices_dict = defaultdict(list)
45 | for i, s in enumerate(segment_indices):
46 | segment_indices_dict[s].append(i)
47 |
48 | pc_instance_id = np.zeros(pc.shape[0]).astype(np.int32) * -1
49 | # pc_semantic_label_id = np.zeros(pc.shape[0]).astype(np.int32) * -1
50 | pc_segment_label = [''] * pc.shape[0]
51 |
52 | valid_ids = []
53 | all_objects = []
54 | for idx, object_info in enumerate(scan_aggregation['segGroups']):
55 | object_instance_label = object_info['label']
56 | object_id = object_info['objectId']
57 | segments = object_info["segments"]
58 | valid_ids.append(idx)
59 | pc_ids = []
60 | for s in segments:
61 | pc_ids.extend(segment_indices_dict[s])
62 | pc_instance_id[pc_ids] = object_id
63 | object_pc = pc[pc_ids]
64 | object_center = (np.max(object_pc, axis=0) + np.min(object_pc, axis=0)) / 2.0
65 | object_size = np.max(object_pc, axis=0) - np.min(object_pc, axis=0)
66 | object_bbox = np.concatenate([object_center, object_size], axis=0)
67 | all_objects.append({
68 | 'bbox': object_bbox.tolist(),
69 | 'label': object_instance_label
70 | })
71 | object_infos = {
72 | 'valid_ids': valid_ids,
73 | 'object_list': all_objects
74 | }
75 |
76 | save_dir = os.path.join(output_dir, 'scans', scan_id)
77 | if not os.path.exists(save_dir):
78 | os.makedirs(save_dir)
79 | with open(os.path.join(save_dir, 'object_infos.json'), 'w') as f:
80 | json.dump(object_infos, f, indent=4)
81 | # np.save(os.path.join(save_dir, 'object_infos.npy'), object_infos)
82 | # np.save(os.path.join(save_dir, 'axis_align_matrix.npy'), axis_align_matrix)
83 |
84 |
85 | def process_split(split):
86 | assert split in ['train', 'val', 'test']
87 | split_file = os.path.join(meta_data_dir, f'scannetv2_{split}.txt')
88 | scan_names = [line.rstrip() for line in open(split_file)]
89 | print(f'{split} split scans: {len(scan_names)}')
90 | # new_split_file = os.path.join(output_dir, f'{split}.txt')
91 | # valid_names = []
92 | # for scan_name in tqdm(scan_names):
93 | # if scan_name in mapping:
94 | # new_scan_name = mapping[scan_name]
95 | # process_one_scan(scan_name, new_scan_name)
96 | # valid_names.append(scan_name)
97 | # params = []
98 | # for scan_name in scan_names:
99 | # if scan_name in mapping:
100 | # new_scan_name = mapping[scan_name]
101 | # params.append((scan_name, new_scan_name))
102 |
103 | parallel = True
104 |
105 | if parallel:
106 | mmengine.utils.track_parallel_progress(process_one_scan, scan_names, 8)
107 | else:
108 | for scan_id in tqdm(scan_names):
109 | process_one_scan(scan_id)
110 |
111 | # if not os.path.exists(new_split_file):
112 | # with open(new_split_file, 'w') as f:
113 | # for t in valid_names:
114 | # f.write(f'{t}\n')
115 |
116 |
117 | for s in ['train', 'val']:
118 | process_split(s)
119 |
120 |
--------------------------------------------------------------------------------
/preprocess/run_prepare.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | # scannet_dir="/mnt/petrelfs/share_data/maoxiaohan/ScanNet_v2"
4 | scannet_dir="/mnt/hwfile/OpenRobotLab/huanghaifeng/data/scannet"
5 | version=""
6 | segment_result_dir="/mnt/hwfile/OpenRobotLab/huanghaifeng/data/processed/scannet/Mask3DInst"
7 | # segment_result_dir="/mnt/petrelfs/share_data/chenyilun/haifeng/Mask3DInst"
8 | inst_seg_dir=""
9 | class_label_file="annotations/scannet/scannetv2-labels.combined.tsv"
10 | max_obj_num=100
11 |
12 | train_iou_thres=0.5
13 |
14 | processed_data_dir="/mnt/hwfile/OpenRobotLab/huanghaifeng/data/processed/scannet/mask3d_ins_data"
15 | # processed_data_dir="/mnt/petrelfs/share_data/chenyilun/haifeng/mask3d_ins_data"
16 | segmentor="mask3d"
17 | # processed_data_dir="/mnt/petrelfs/share_data/chenyilun/share/mask3d/proposals"
18 | # segmentor="clasp"
19 | # segmentor="deva"
20 |
21 | python preprocess/prepare_mask3d_data.py \
22 | --scannet_dir "$scannet_dir" \
23 | --output_dir "$processed_data_dir" \
24 | --segment_dir "$segment_result_dir" \
25 | --inst_seg_dir "$inst_seg_dir" \
26 | --class_label_file "$class_label_file" \
27 | --apply_global_alignment \
28 | --num_workers 16 \
29 | --parallel
30 |
31 | python preprocess/prepare_scannet_mask3d_attributes.py \
32 | --scan_dir "$processed_data_dir" \
33 | --segmentor "$segmentor" \
34 | --max_inst_num "$max_obj_num" \
35 | --version "$version"
36 |
37 | # python preprocess/prepare_scannet_attributes.py \
38 | # --scannet_dir "$scannet_dir"
39 |
40 | # python preprocess/prepare_scannet_attributes_clasp.py \
41 | # --scan_dir "$processed_data_dir" \
42 | # --segmentor "$segmentor" \
43 | # --version "$version" \
44 | # --max_inst_num "$max_obj_num"
45 |
46 | python preprocess/prepare_scanrefer_annos.py \
47 | --segmentor "$segmentor" \
48 | --version "$version" \
49 | --train_iou_thres "$train_iou_thres" \
50 | --max_obj_num "$max_obj_num"
51 |
52 | python preprocess/prepare_scan2cap_annos.py \
53 | --segmentor "$segmentor" \
54 | --version "$version" \
55 | --train_iou_thres "$train_iou_thres" \
56 | --max_obj_num "$max_obj_num"
57 |
58 | python preprocess/prepare_objalign_annos.py \
59 | --segmentor "$segmentor" \
60 | --version "$version" \
61 | --train_iou_thres "$train_iou_thres"
62 |
63 | python preprocess/prepare_nr3dcaption_annos.py \
64 | --segmentor "$segmentor" \
65 | --version "$version" \
66 | --train_iou_thres "$train_iou_thres"
67 |
68 | python preprocess/prepare_multi3dref_annos.py \
69 | --segmentor "$segmentor" \
70 | --version "$version" \
71 | --train_iou_thres "$train_iou_thres"
72 |
73 | python preprocess/prepare_scanqa_annos.py
74 |
75 | python preprocess/prepare_sqa3d_annos.py
76 |
77 | # python preprocess/prepare_nr3d_annos.py \
78 | # --segmentor "$segmentor" \
79 | # --version "$version" \
80 | # --train_iou_thres "$train_iou_thres" \
81 | # --max_obj_num "$max_obj_num"
82 |
83 | # python preprocess/prepare_sr3d_annos.py \
84 | # --segmentor "$segmentor" \
85 | # --version "$version" \
86 | # --train_iou_thres "$train_iou_thres" \
87 | # --max_obj_num "$max_obj_num"
88 |
89 |
90 |
91 | # python preprocess/prepare_scannet_caption_annos.py \
92 | # --segmentor "$segmentor" \
93 | # --version "$version" \
94 | # --train_iou_thres "$train_iou_thres"
95 |
96 | # python preprocess/prepare_scannet_region_caption_annos.py \
97 | # --segmentor "$segmentor" \
98 | # --version "$version" \
99 | # --train_iou_thres "$train_iou_thres"
100 |
101 | # python preprocess/prepare_scanrefer_location_annos.py \
102 | # --segmentor "$segmentor" \
103 | # --version "$version" \
104 | # --train_iou_thres "$train_iou_thres" \
105 | # --max_obj_num "$max_obj_num"
106 |
107 | # python preprocess/prepare_multi3dref_location_annos.py \
108 | # --segmentor "$segmentor" \
109 | # --version "$version" \
110 | # --train_iou_thres "$train_iou_thres"
111 |
112 | # python preprocess/prepare_scan2cap_location_annos.py \
113 | # --segmentor "$segmentor" \
114 | # --version "$version" \
115 | # --train_iou_thres "$train_iou_thres" \
116 | # --max_obj_num "$max_obj_num"
--------------------------------------------------------------------------------
/prompts/concise_description.txt:
--------------------------------------------------------------------------------
1 | Describe the target object in the 3D scene concisely.
2 | Provide a brief description of the given target object in the 3D scene.
3 | Offer a succinct explanation of the target object in the 3D scene presented.
4 | Summarize the visual content of the target object in the 3D scene.
5 | Give a short and clear explanation of the previous target object in the 3D scene.
6 | Share a concise interpretation of the target object in the 3D scene provided.
7 | Present a compact description of the the target object's key features in the 3D scene.
8 | Relay a brief, clear account of the target object shown in the 3D scene.
9 | Render a clear and concise summary of the target object in the 3D scene.
10 | Write a terse but informative summary of the target object in the 3D scene.
11 | Create a compact narrative representing the target object in the 3D scene presented.
--------------------------------------------------------------------------------
/prompts/concise_description_objxx.txt:
--------------------------------------------------------------------------------
1 | Provide a brief description of obj within the 3D scene.
2 | Offer a succinct overview of obj situated in the 3D environment.
3 | Summarize the key details of obj in the 3D setting.
4 | Give a concise account of obj in the 3D scene.
5 | Deliver a compact portrayal of obj within the 3D context.
6 | Outline obj in the 3D scene with brevity.
7 | Present a pithy depiction of obj in the 3D environment.
8 | Provide a concise report on obj in the 3D scene.
9 | Describe obj within the 3D scene in a succinct manner.
10 | Offer a brief 3D scene description of obj.
--------------------------------------------------------------------------------
/prompts/conv_description.txt:
--------------------------------------------------------------------------------
1 | Where is the target object.
--------------------------------------------------------------------------------
/prompts/dataset_generation/conversation.txt:
--------------------------------------------------------------------------------
1 | You are an AI visual assistant, and you are seeing an object in a 3D scene. What you see is provided with several sentences, describing the same object you are looking at, and the position of surrounding objects in the 3D scene to represent the content of the 3D scene. Based on these descriptions of this object and the location of surrounding objects in the 3D scene, answer all the questions as if you are in the 3D scene.
2 | Design a conversation between you and a person asking about this object in the 3D scene. The answers should be in a tone that a visual AI assistant is in the 3D scene and answering the question. Ask diverse questions and give corresponding answers.
3 | Include questions asking about the visual content of this object, including the object types, object shape, object attribute, object functions, object locations, relative positions between objects, etc. Only include questions that have definite answers:
4 | (1) Questions whose contents can be confidently observed and answered based on the 3D scene.
5 | (2) Questions whose absence from the 3D scene can be confidently determined.
6 | Also include complex questions that are relevant to the object, because you are seeing a 3D scene, complex problems should focus more on discussing spatial reasoning, human interaction in the scene, and the rationality, potential risk and purpose of the position of objects in the room. (For example, assuming a person is interacting with this object, ask about the relative position of other objects to the person and the person's path to access those objects. It is valuable to ask about the specific function of an object, its current location, and suggestions for placement based on background knowledge. etc.) Again, provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary.
7 | Notice! Do not ask about uncertain details, do not ask unsure questions that can be answered by the given information, do not mention the information of descriptions in the generated answer, and don't need to specify the class or properties of the described object in the question.
--------------------------------------------------------------------------------
/prompts/dataset_generation/detail.txt:
--------------------------------------------------------------------------------
1 | You are an AI 3D visual assistant, and you are seeing an object in a 3D scene. What you see is provided with several sentences, describing the same object you are looking at, and the position of surrounding objects in the 3D scene to represent the content of the 3D scene. Based on these descriptions of this object and the location of surrounding objects in the 3D scene, summary and describe the placement, function of this object, and how a person can access this object in detail as if you are in the 3D scene.
2 | Importantly, do not mention any specific spatial coordinate values. The description should be more than 150 words and less than 200 words.
--------------------------------------------------------------------------------
/prompts/dataset_generation/textualize_obj.txt:
--------------------------------------------------------------------------------
1 | Descriptions: ["There is a single white armchair. placed next to the window of the room.", "The sofa chair is the corner chair. lying parallel to the wall. a small table with the lamp is present beside the chair.", "This is a white sofa chair. it is under a window.", "This is a white armchair. is next to a lamp.", "This is the corner sofa chair. a small table with a lamp can be seen near this chair."]
2 | Described object: {sofa chair:[-1.31, 3.15, 0.59]}; Neighbor objects: {window:[-1.12, 4.12, 1.59], table:[0.86, 1.61, 0.38], doorframe:[-2.25, 0.67, 1.27], windowsill:[0.88, 3.97, 0.98], windowsill:[-1.32, 3.93, 0.91], sofa chair:[0.98, 3.35, 0.71], window:[1.16, 4.18, 1.73], pillow:[1.35, 0.29, 0.46], table:[-0.15, -2.66, 0.26], tv:[-2.2, -0.55, 1.52]}
--------------------------------------------------------------------------------
/prompts/detailed_description.txt:
--------------------------------------------------------------------------------
1 | Describe the target object in detail.
2 | Provide a detailed description of the given target object in the 3D scene.
3 | Give an elaborate explanation of the target object you see in the 3D scene.
4 | Share a comprehensive rundown of the presented target object in the 3D scene.
5 | Offer a thorough analysis of the target object in the 3D scene.
6 | Explain the various aspects of the target object before you.
7 | Clarify the contents of the displayed target object with great detail.
8 | Characterize the target object using a well-detailed description.
9 | Break down the elements of the target object in a detailed manner.
10 | Walk through the important details of the target object in the 3D scene.
11 | Portray the target object with a rich, descriptive narrative.
12 | Narrate the contents of the target object with precision.
13 | Analyze the target object in a comprehensive and detailed manner.
14 | Illustrate the target object through a descriptive explanation.
15 | Examine the target object closely and share its details.
16 | Write an exhaustive depiction of the given target object in the 3D scene.
--------------------------------------------------------------------------------
/prompts/grounding_answer_templates.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/prompts/grounding_answer_templates.txt
--------------------------------------------------------------------------------
/prompts/grounding_prompts.txt:
--------------------------------------------------------------------------------
1 | Here is a description: "" Which object aligns most closely with the provided description? Please respond with the object ID only.
2 | Here is a description: "" Can you identify the object that best fits the provided description? Please respond with the object ID only.
3 | Here is a description: "" Identify the object that is the most suitable match for the given description. Please respond with the object ID only.
4 | Here is a description: "" Which object corresponds most accurately to the provided description? Please respond with the object ID only.
5 | Here is a description: "" Choose the object that best meets the features outlined in the description. Please respond with the object ID only.
6 | Here is a description: "" Which object is the most appropriate match for the description provided? Please respond with the object ID only.
7 | Here is a description: "" Select the object that most closely matches the given description. Please respond with the object ID only.
8 | Here is a description: "" Determine which object is the optimal match for the description provided. Please respond with the object ID only.
9 | Here is a description: "" Which object aligns best with the description given? Please respond with the object ID only.
10 | Here is a description: "" Pick the object that best suits the description given. Please respond with the object ID only.
--------------------------------------------------------------------------------
/prompts/instruction.txt:
--------------------------------------------------------------------------------
1 | The conversation centers around an indoor scene. Object information: [].
--------------------------------------------------------------------------------
/prompts/nr3d_caption_templates.txt:
--------------------------------------------------------------------------------
1 | Detail the spatial positioning of the amidst surrounding elements.
2 | Illustrate the 's placement relative to its environment.
3 | Explain the 's location in correlation with nearby items.
4 | Elaborate on the 's spatial context within the scene.
5 | Describe how the is situated in relation to other elements present.
6 | Provide insight into the 's positioning among its surroundings.
7 | Discuss the relative placement of the compared to its surrounding context.
8 | Offer a depiction of the 's spatial orientation within the scene.
9 | Interpret the 's location within the broader context of the scene.
10 | Present the 's spatial relationship with other entities within the scene.
--------------------------------------------------------------------------------
/prompts/object_caption_templates.txt:
--------------------------------------------------------------------------------
1 | Portray the visual characteristics of the .
2 | Detail the outward presentation of the .
3 | Provide a depiction of the 's appearance.
4 | Illustrate how the looks.
5 | Describe the visual aspects of the .
6 | Convey the physical attributes of the .
7 | Outline the external features of the .
8 | Render the appearance of the in words.
9 | Depict the outward form of the .
10 | Elaborate on the visual representation of the .
--------------------------------------------------------------------------------
/prompts/scanrefer_caption_templates.txt:
--------------------------------------------------------------------------------
1 | Begin by detailing the visual aspects of the before delving into its spatial context among other elements within the scene.
2 | First, depict the physical characteristics of the , followed by its placement and interactions within the surrounding environment.
3 | Describe the appearance of the , then elaborate on its positioning relative to other objects in the scene.
4 | Paint a picture of the visual attributes of , then explore how it relates spatially to other elements in the scene.
5 | Start by articulating the outward features of the , then transition into its spatial alignment within the broader scene.
6 | Provide a detailed description of the appearance of before analyzing its spatial connections with other elements in the scene.
7 | Capture the essence of the appearance of , then analyze its spatial relationships within the scene's context.
8 | Detail the physical characteristics of the and subsequently examine its spatial dynamics amidst other objects in the scene.
9 | Describe the visual traits of first, then elucidate its spatial arrangements in relation to neighboring elements.
10 | Begin by outlining the appearance of , then proceed to illustrate its spatial orientation within the scene alongside other objects.
--------------------------------------------------------------------------------
/prompts/scene_align_template.txt:
--------------------------------------------------------------------------------
1 | Obj12 is the nearest object to obj00.
--------------------------------------------------------------------------------
/prompts/score_template.txt:
--------------------------------------------------------------------------------
1 | ###Human:
2 |
3 | Evaluate the request below and determine whether it accurately identifies the target object enclosed within the tags '' and '':
4 |
5 |
6 | {}
7 |
8 |
9 | For reference, I have asked for a powerful 3D discriminator machine, it indicates that the request has a {}% likelihood of being accurate in identifying the target object. However, the machine is not always correct, so you need to provide the final determination.
10 |
11 | Please respond with the answer in the following format: 'The answer is: True' if the request correctly localizes the target object, or 'The answer is: False' if it does not.
12 |
13 | ###
--------------------------------------------------------------------------------
/prompts/score_template_old.txt:
--------------------------------------------------------------------------------
1 | ###Human:
2 |
3 | {}.
4 |
5 | Please evaluate the relevance of the sentence above, which describes an object in the scene, to the target object (given between the tags "" and ""). Assign a score on a scale of 1 to 5, where higher scores indicate greater relevance. Use the following rating system: 1 for totally irrelevant, 2 for irrelevant, 3 for somewhat relevant or somewhat irrelevant, 4 for relevant, and 5 for highly relevant.
6 |
7 | Begin your response by stating the rating as "Rating: x/5.", where 'x' represents the assigned score. After assigning the rating, provide explanations for your rating.
8 |
9 | ###
--------------------------------------------------------------------------------
/prompts/system.txt:
--------------------------------------------------------------------------------
1 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
--------------------------------------------------------------------------------
/prompts/system_backup.txt:
--------------------------------------------------------------------------------
1 | ###System: In this task, you will be provided with extensive information regarding objects within a 3D scene. Our conversation will focus primarily on a specific target object, and I will also supply information about all other objects present in the scene.
2 | Target object: .
3 | All the objects in the scene: [].
4 |
5 |
6 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The conversation focuses on a 3D indoor scene, which contains dozens of 3D objects. Here is a list of object information: []. "objxxx" refers to the object ID, which will be used to identify a specific object in the following conversation.
7 |
8 | A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
9 | The conversation focuses on a 3D indoor scene, which contains dozens of 3D objects. Here is a list of object information: []. "objxx" refers to the object ID, which will be used to identify a specific object in the following conversation.
10 |
11 | A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. The conversation centers around a 3D indoor scene that encompasses numerous 3D objects. Here is a list of object information: []. Objects are separated by "," and each object is identified by an ID in the format "objxx".
12 |
13 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.5.3
2 | transformers==4.39.3
3 | einops==0.6.1
4 | plyfile==1.0.1
5 | trimesh==3.23.1
6 | peft==0.9.0
7 | termcolor==2.3.0
8 | scipy==1.12.0
9 | pycocoevalcap==1.2
10 | sentencepiece==0.2.0
11 | protobuf==4.25.3
12 | packaging==24.0
13 | flash_attn==2.5.6
14 | mmengine==0.10.3
--------------------------------------------------------------------------------
/scripts/config_3dgraphllm.py:
--------------------------------------------------------------------------------
1 | # ========================= data ==========================
2 | anno_root = "annotations" # annotation dir
3 | pc_encoder = "uni3d"
4 | segmentor = "mask3d"
5 | version = ""
6 |
7 | seg_feat_file = f"{anno_root}/scannet_{segmentor}_{pc_encoder}_feats.pt"
8 | seg_img_feat_file = f"{anno_root}/scannet_{segmentor}_videofeats.pt"
9 | seg_val_attr_file = f"{anno_root}/scannet_{segmentor}_val_attributes.pt"
10 | seg_val_gnn_file = f"{anno_root}/scannet_{segmentor}_val_gnn_feats_2{version}.pt"
11 |
12 | val_file_dict = {
13 | 'demo': [
14 | seg_feat_file,
15 | seg_img_feat_file,
16 | seg_val_attr_file,
17 | seg_val_gnn_file,
18 | f"{segmentor}",
19 | ],
20 | }
21 |
22 | # ========================= model ==========================
23 | model = dict(
24 | llama_model_path="./Meta-Llama-3-8B-Instruct",
25 | input_dim=1024,
26 | img_input_dim=1024,
27 | attr_dim=512,
28 | scene_dim=256,
29 | pos_dim=128,
30 | encoder_num_layers=3,
31 | low_resource=False,
32 | system_path="./prompts/system.txt",
33 | instruction_path="./prompts/instruction.txt",
34 | max_txt_len=64,
35 | end_sym="",
36 | role=("USER", "ASSISTANT"),
37 | add_scene_token=False,
38 | add_img_token=True,
39 | use_lora=True,
40 | train_emb=True,
41 | train_img_proj=True,
42 | no_obj=False,
43 | max_obj_num=150,
44 | bidirection=False,
45 | add_pos_emb=False,
46 | feat_fusion=False,
47 | fuse_with_id=False,
48 | use_objid=True,
49 | use_location_token=False,
50 | knn=2,
51 | bbox_embed=False,
52 | gt_pretrain=False,
53 | nms=True,
54 | nn_distance=True,
55 | max_knn=2
56 | )
57 |
58 | lora = dict(
59 | lora_target_modules=[
60 | "q_proj",
61 | "v_proj",
62 | "k_proj",
63 | "o_proj",
64 | "gate_proj",
65 | "up_proj",
66 | "down_proj"
67 | ],
68 | lora_r=16,
69 | lora_alpha=16,
70 | lora_dropout=0.05
71 | )
72 |
73 | optimizer = dict(
74 | opt="adamW",
75 | lr=5e-6,
76 | opt_betas=[0.9, 0.999], # default
77 | weight_decay=0.02,
78 | scaler_enable=False,
79 | max_grad_norm=0.01, # requires a positive float, use -1 to disable
80 | # use a different lr for some modules, e.g., larger lr for new modules
81 | different_lr=dict(
82 | enable=False,
83 | module_names=["model.embed_tokens"],
84 | lr=[5e-4],
85 | wd=[0.02]
86 | ),
87 | )
88 |
89 | scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.1)
90 |
91 | evaluate = False
92 |
93 | # ========================= wandb ==========================
94 | wandb = dict(
95 | enable=False,
96 | entity="anonym", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
97 | project="3DGraphLLM",
98 | )
99 | dist_url = "env://"
100 | device = "cuda"
101 |
102 | # ========================= others ==========================
103 | output_dir = "./llama3-8b-gt-pretrain-2-3rscan" # output dir
104 | resume = False # if True, load optimizer and scheduler states as well
105 | debug = False
106 | log_freq = 20
107 | # eval_freq = 500
108 | seed = 42
109 |
110 | save_latest = False
111 | do_save = True
112 | auto_resume = True
113 | pretrained_path = ""
114 | img_projector_path = ""
115 |
116 | debug=False
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | which_python=$(which python)
2 | export PYTHONPATH=${PYTHONPATH}:${which_python}:.
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 |
5 | export MASTER_PORT=$((54000 + $RANDOM % 10000))
6 | export MASTER_ADDR=localhost
7 |
8 | epoch=3
9 | batch_size=8
10 | lr=2e-5
11 | train_emb=True
12 | train_img_proj=True
13 | add_img_token=True
14 | add_scene_token=False
15 | no_obj=False
16 | input_dim=1024 # 1024
17 | bidirection=False
18 | different_lr=False
19 | max_obj_num=150
20 | lora_r=16
21 | lora_alpha=16
22 | add_pos_emb=False
23 | feat_fusion=False
24 | fuse_with_id=False
25 | config=""
26 | max_grad_norm=0.01
27 | seed=42
28 | use_location_token=False
29 |
30 | #llama_model_path="./vicuna-7b-v1.5"
31 | llama_model_path="./Meta-Llama-3-8B-Instruct"
32 |
33 | train_tag="scanrefer#obj_align#nr3d_caption#scan2cap#scanqa#sqa3d#multi3dref"
34 | val_tag="scanrefer#scan2cap#scanqa#sqa3d#multi3dref"
35 |
36 | evaluate=True
37 | debug=False
38 | if [ $evaluate = "True" ]; then
39 | enable_wandb=False
40 | gpu_num=1
41 | do_save=True
42 | other_info="evaluation"
43 | else
44 | enable_wandb=True
45 | gpu_num=1
46 | do_save=True
47 | other_info="chatscene"
48 | fi
49 |
50 | tag="${train_tag}__${val_tag}__${other_info}"
51 |
52 | pretrained_path="./demo/3dgraphllm.pth"
53 |
54 | OUTPUT_DIR=outputs/3dgraphllm_2e-5_ep6
55 | mkdir -p ${OUTPUT_DIR}
56 |
57 | python tasks/train.py \
58 | "$(dirname $0)/${config}config.py" \
59 | output_dir "$OUTPUT_DIR" \
60 | scheduler.epochs "$epoch" \
61 | optimizer.lr "$lr" \
62 | model.add_scene_token "$add_scene_token" \
63 | model.add_img_token "$add_img_token" \
64 | pretrained_path "$pretrained_path" \
65 | evaluate "$evaluate" \
66 | wandb.enable "$enable_wandb" \
67 | gpu_num "$gpu_num" \
68 | do_save "$do_save" \
69 | batch_size "$batch_size" \
70 | model.train_emb "$train_emb" \
71 | model.train_img_proj "$train_img_proj" \
72 | train_tag "$train_tag" \
73 | val_tag "$val_tag" \
74 | model.no_obj "$no_obj" \
75 | segmentor "$segmentor" \
76 | pc_encoder "$pc_encoder" \
77 | model.input_dim "$input_dim" \
78 | model.bidirection "$bidirection" \
79 | optimizer.different_lr.enable "$different_lr" \
80 | model.max_obj_num "$max_obj_num" \
81 | lora.lora_r "$lora_r" \
82 | lora.lora_alpha "$lora_alpha" \
83 | model.add_pos_emb "$add_pos_emb" \
84 | model.feat_fusion "$feat_fusion" \
85 | optimizer.max_grad_norm "$max_grad_norm" \
86 | seed "$seed" \
87 | model.fuse_with_id "$fuse_with_id" \
88 | model.llama_model_path "$llama_model_path" \
89 | model.use_location_token "$use_location_token" \
90 | model.gt_pretrain False
91 |
92 |
--------------------------------------------------------------------------------
/scripts/run_gt_pretrain.sh:
--------------------------------------------------------------------------------
1 | which_python=$(which python)
2 | export PYTHONPATH=${PYTHONPATH}:${which_python}:.
3 | echo "PYTHONPATH: ${PYTHONPATH}"
4 |
5 | export MASTER_PORT=$((54000 + $RANDOM % 10000))
6 | export MASTER_ADDR=localhost
7 |
8 | epoch=3
9 | batch_size=8
10 | lr=2e-5
11 | train_emb=True
12 | train_img_proj=True
13 | add_img_token=True
14 | add_scene_token=False
15 | no_obj=False
16 | input_dim=1024 # 1024
17 | bidirection=False
18 | different_lr=False
19 | max_obj_num=150
20 | lora_r=16
21 | lora_alpha=16
22 | add_pos_emb=False
23 | feat_fusion=False
24 | fuse_with_id=False
25 | config=""
26 | max_grad_norm=0.01
27 | seed=42
28 | use_location_token=False
29 |
30 | #llama_model_path="./vicuna-7b-v1.5"
31 | llama_model_path="./Meta-Llama-3-8B-Instruct"
32 |
33 | train_tag="scanrefer#obj_align#nr3d_caption#scan2cap#scanqa#sqa3d#multi3dref"
34 | val_tag="scanrefer#scan2cap#multi3dref#sqa3d#scanqa"
35 |
36 | evaluate=False
37 | debug=False
38 | resume=False
39 | if [ $evaluate = "True" ]; then
40 | enable_wandb=False
41 | gpu_num=1
42 | do_save=True
43 | other_info="evaluation"
44 | else
45 | enable_wandb=False
46 | gpu_num=1
47 | do_save=True
48 | other_info="chatscene"
49 | fi
50 |
51 | tag="${train_tag}__${val_tag}__${other_info}"
52 |
53 | pretrained_path=""
54 |
55 | OUTPUT_DIR=outputs/llama3-8b-gt-pretrain-2
56 |
57 | python tasks/train.py \
58 | "$(dirname $0)/${config}config-gt-pretrain.py" \
59 | output_dir "$OUTPUT_DIR" \
60 | scheduler.epochs "$epoch" \
61 | optimizer.lr "$lr" \
62 | model.add_scene_token "$add_scene_token" \
63 | model.add_img_token "$add_img_token" \
64 | pretrained_path "$pretrained_path" \
65 | evaluate "$evaluate" \
66 | wandb.enable "$enable_wandb" \
67 | gpu_num "$gpu_num" \
68 | do_save "$do_save" \
69 | batch_size "$batch_size" \
70 | model.train_emb "$train_emb" \
71 | model.train_img_proj "$train_img_proj" \
72 | train_tag "$train_tag" \
73 | val_tag "$val_tag" \
74 | model.no_obj "$no_obj" \
75 | segmentor "$segmentor" \
76 | pc_encoder "$pc_encoder" \
77 | model.input_dim "$input_dim" \
78 | model.bidirection "$bidirection" \
79 | optimizer.different_lr.enable "$different_lr" \
80 | model.max_obj_num "$max_obj_num" \
81 | lora.lora_r "$lora_r" \
82 | lora.lora_alpha "$lora_alpha" \
83 | model.add_pos_emb "$add_pos_emb" \
84 | model.feat_fusion "$feat_fusion" \
85 | optimizer.max_grad_norm "$max_grad_norm" \
86 | seed "$seed" \
87 | model.fuse_with_id "$fuse_with_id" \
88 | model.llama_model_path "$llama_model_path" \
89 | model.use_location_token "$use_location_token" \
90 | model.knn 2 \
91 | model.gt_pretrain True
92 |
--------------------------------------------------------------------------------
/tasks/shared_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | import os.path as osp
5 | from os.path import join
6 |
7 | import torch
8 | from torch.utils.data import ConcatDataset, DataLoader
9 |
10 | from utils.optimizer import create_optimizer
11 | from utils.scheduler import create_scheduler
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_media_types(datasources):
17 | """get the media types for for all the dataloaders.
18 |
19 | Args:
20 | datasources (List): List of dataloaders or datasets.
21 |
22 | Returns: List. The media_types.
23 |
24 | """
25 | if isinstance(datasources[0], DataLoader):
26 | datasets = [dataloader.dataset for dataloader in datasources]
27 | else:
28 | datasets = datasources
29 | media_types = [
30 | dataset.datasets[0].media_type
31 | if isinstance(dataset, ConcatDataset)
32 | else dataset.media_type
33 | for dataset in datasets
34 | ]
35 |
36 | return media_types
37 |
38 |
39 | def setup_model(
40 | config, model_cls, find_unused_parameters=False
41 | ):
42 | logger.info("Creating model")
43 | config = copy.deepcopy(config)
44 |
45 | model = model_cls(config=config)
46 |
47 | model = model.to(torch.device(config.device))
48 | model_without_ddp = model
49 | if config.distributed:
50 | model = torch.nn.parallel.DistributedDataParallel(
51 | model,
52 | device_ids=[config.gpu],
53 | find_unused_parameters=find_unused_parameters, # `False` for image-only task
54 | gradient_as_bucket_view=True
55 | )
56 | optimizer = create_optimizer(config.optimizer, model, config)
57 | scheduler = create_scheduler(config.scheduler, optimizer)
58 | scaler = torch.cuda.amp.GradScaler(enabled=config.optimizer.scaler_enable, growth_interval=100)
59 |
60 | start_epoch = 0
61 | global_step = 0
62 |
63 | # auto resume the latest checkpoint
64 | if config.get("auto_resume", False):
65 | logger.info("Auto resuming")
66 | model_latest = join(config.output_dir, "ckpt_latest.pth")
67 | model_best = join(config.output_dir, "ckpt_best.pth")
68 | large_num = -1
69 | for p in os.listdir(config.output_dir):
70 | if 'ckpt' in p:
71 | num = p.split('_')[1].split('.')[0]
72 | if str.isnumeric(num):
73 | if int(num) > large_num:
74 | large_num = int(num)
75 | if large_num != -1:
76 | model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
77 | if osp.isfile(model_latest) and not config.pretrained_path:
78 | config.pretrained_path = model_latest
79 | config.resume = True
80 | elif osp.isfile(model_best) and not config.pretrained_path:
81 | config.pretrained_path = model_best
82 | config.resume = True
83 | else:
84 | logger.info(f"Not found checkpoint in {config.output_dir}")
85 |
86 | if osp.isfile(config.img_projector_path):
87 | img_projector_sd = torch.load(config.img_projector_path, map_location="cpu")
88 | msg = model_without_ddp.object_img_proj.load_state_dict(img_projector_sd)
89 | logger.info(f"Loaded pretrained image projector from {config.img_projector_path}.")
90 |
91 | if osp.isfile(config.pretrained_path):
92 | checkpoint = torch.load(config.pretrained_path, map_location="cpu")
93 | state_dict = checkpoint["model"]
94 |
95 | if config.resume:
96 | optimizer.load_state_dict(checkpoint["optimizer"])
97 | scheduler.load_state_dict(checkpoint["scheduler"])
98 | scaler.load_state_dict(checkpoint["scaler"])
99 | start_epoch = checkpoint["epoch"] + 1
100 | global_step = checkpoint["global_step"]
101 | keys_to_delete = []
102 | for name, param in state_dict.items():
103 | if name not in model_without_ddp.state_dict():
104 | continue
105 | if param.size() != model_without_ddp.state_dict()[name].size():
106 | keys_to_delete.append(name)
107 | for key in keys_to_delete:
108 | del state_dict[key]
109 | msg = model_without_ddp.load_state_dict(state_dict, strict=False)
110 | logger.info(msg)
111 | logger.info(f"Loaded checkpoint from {config.pretrained_path}.")
112 | else:
113 | logger.warning("No pretrained checkpoint provided, training from scratch.")
114 |
115 | return (
116 | model,
117 | model_without_ddp,
118 | optimizer,
119 | scheduler,
120 | scaler,
121 | start_epoch,
122 | global_step,
123 | )
124 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CognitiveAISystems/3DGraphLLM/418f6529e029761a52017d0eb2a5380689805c62/utils/__init__.py
--------------------------------------------------------------------------------
/utils/box_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def get_box3d_min_max(corner):
5 | ''' Compute min and max coordinates for 3D bounding box
6 | Note: only for axis-aligned bounding boxes
7 |
8 | Input:
9 | corners: numpy array (8,3), assume up direction is Z (batch of N samples)
10 | Output:
11 | box_min_max: an array for min and max coordinates of 3D bounding box IoU
12 |
13 | '''
14 |
15 | min_coord = corner.min(axis=0)
16 | max_coord = corner.max(axis=0)
17 | x_min, x_max = min_coord[0], max_coord[0]
18 | y_min, y_max = min_coord[1], max_coord[1]
19 | z_min, z_max = min_coord[2], max_coord[2]
20 |
21 | return x_min, x_max, y_min, y_max, z_min, z_max
22 |
23 |
24 | def box3d_iou(corners1, corners2):
25 | ''' Compute 3D bounding box IoU.
26 |
27 | Input:
28 | corners1: numpy array (8,3), assume up direction is Z
29 | corners2: numpy array (8,3), assume up direction is Z
30 | Output:
31 | iou: 3D bounding box IoU
32 |
33 | '''
34 |
35 | x_min_1, x_max_1, y_min_1, y_max_1, z_min_1, z_max_1 = get_box3d_min_max(corners1)
36 | x_min_2, x_max_2, y_min_2, y_max_2, z_min_2, z_max_2 = get_box3d_min_max(corners2)
37 | xA = np.maximum(x_min_1, x_min_2)
38 | yA = np.maximum(y_min_1, y_min_2)
39 | zA = np.maximum(z_min_1, z_min_2)
40 | xB = np.minimum(x_max_1, x_max_2)
41 | yB = np.minimum(y_max_1, y_max_2)
42 | zB = np.minimum(z_max_1, z_max_2)
43 | inter_vol = np.maximum((xB - xA), 0) * np.maximum((yB - yA), 0) * np.maximum((zB - zA), 0)
44 | box_vol_1 = (x_max_1 - x_min_1) * (y_max_1 - y_min_1) * (z_max_1 - z_min_1)
45 | box_vol_2 = (x_max_2 - x_min_2) * (y_max_2 - y_min_2) * (z_max_2 - z_min_2)
46 | iou = inter_vol / (box_vol_1 + box_vol_2 - inter_vol + 1e-8)
47 |
48 | return iou
49 |
50 |
51 | def construct_bbox_corners(center, box_size):
52 | sx, sy, sz = box_size
53 | x_corners = [sx / 2, sx / 2, -sx / 2, -sx / 2, sx / 2, sx / 2, -sx / 2, -sx / 2]
54 | y_corners = [sy / 2, -sy / 2, -sy / 2, sy / 2, sy / 2, -sy / 2, -sy / 2, sy / 2]
55 | z_corners = [sz / 2, sz / 2, sz / 2, sz / 2, -sz / 2, -sz / 2, -sz / 2, -sz / 2]
56 | corners_3d = np.vstack([x_corners, y_corners, z_corners])
57 | corners_3d[0, :] = corners_3d[0, :] + center[0]
58 | corners_3d[1, :] = corners_3d[1, :] + center[1]
59 | corners_3d[2, :] = corners_3d[2, :] + center[2]
60 | corners_3d = np.transpose(corners_3d)
61 |
62 | return corners_3d
63 |
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from os.path import dirname, join
5 |
6 | from utils.config import Config
7 | from utils.distributed import init_distributed_mode, is_main_process
8 | from utils.logger import setup_logger
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def setup_config():
14 | """Conbine yaml config and command line config with OmegaConf.
15 | Also converts types, e.g., `'None'` (str) --> `None` (None)
16 | """
17 | config = Config.get_config()
18 | if config.debug:
19 | config.wandb.enable = False
20 | return config
21 |
22 |
23 | def setup_evaluate_config(config):
24 | """setup evaluation default settings, e.g., disable wandb"""
25 | assert config.evaluate
26 | config.wandb.enable = False
27 | if config.output_dir is None:
28 | config.output_dir = join(dirname(config.pretrained_path), "eval")
29 | return config
30 |
31 |
32 | def setup_output_dir(output_dir, excludes=["code"]):
33 | """ensure not overwritting an exisiting/non-empty output dir"""
34 | if not os.path.exists(output_dir):
35 | os.makedirs(output_dir, exist_ok=False)
36 | else:
37 | existing_dirs_files = os.listdir(output_dir) # list
38 | remaining = set(existing_dirs_files) - set(excludes)
39 | remaining = [e for e in remaining if "slurm" not in e]
40 | remaining = [e for e in remaining if ".out" not in e]
41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
42 | logger.warn(f"remaining dirs or files: {remaining}")
43 |
44 |
45 | def setup_main():
46 | """
47 | Setup config, logger, output_dir, etc.
48 | Shared for pretrain and all downstream tasks.
49 | """
50 | config = setup_config()
51 | if hasattr(config, "evaluate") and config.evaluate:
52 | config = setup_evaluate_config(config)
53 | init_distributed_mode(config)
54 |
55 | if is_main_process():
56 | setup_output_dir(config.output_dir, excludes=["code"])
57 | setup_logger(output=config.output_dir, color=True, name="vindlu")
58 | logger.info(f"config: {Config.pretty_text(config)}")
59 | Config.dump(config, os.path.join(config.output_dir, "config.json"))
60 | return config
61 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | import logging
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def setup_for_distributed(is_master):
11 | import warnings
12 |
13 | builtin_warn = warnings.warn
14 |
15 | def warn(*args, **kwargs):
16 | force = kwargs.pop("force", False)
17 | if is_master or force:
18 | builtin_warn(*args, **kwargs)
19 |
20 | # Log warnings only once
21 | warnings.warn = warn
22 | warnings.simplefilter("once", UserWarning)
23 |
24 | if not is_master:
25 | logging.disable()
26 |
27 |
28 | def is_dist_avail_and_initialized():
29 | if not dist.is_available():
30 | return False
31 | if not dist.is_initialized():
32 | return False
33 | return True
34 |
35 |
36 | def get_world_size():
37 | if not is_dist_avail_and_initialized():
38 | return 1
39 | return dist.get_world_size()
40 |
41 |
42 | def get_rank():
43 | if not is_dist_avail_and_initialized():
44 | return 0
45 | return dist.get_rank()
46 |
47 |
48 | def is_main_process():
49 | return get_rank() == 0
50 |
51 |
52 | def save_on_master(*args, **kwargs):
53 | if is_main_process():
54 | torch.save(*args, **kwargs)
55 |
56 |
57 | def is_port_in_use(port):
58 | import socket
59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60 | return s.connect_ex(('localhost', port)) == 0
61 |
62 |
63 | def init_distributed_mode(args):
64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
65 | # job started by torch.distributed.launch
66 | args.rank = int(os.environ["RANK"])
67 | args.world_size = int(os.environ['WORLD_SIZE'])
68 | args.gpu = int(os.environ['LOCAL_RANK'])
69 | elif 'SLURM_PROCID' in os.environ:
70 | # local rank on the current node / global rank
71 | local_rank = int(os.environ['SLURM_LOCALID'])
72 | global_rank = int(os.environ['SLURM_PROCID'])
73 | # number of processes / GPUs per node
74 | world_size = int(os.environ["SLURM_NNODES"]) * \
75 | int(os.environ["SLURM_TASKS_PER_NODE"][0])
76 |
77 | args.rank = global_rank
78 | args.gpu = local_rank
79 | args.world_size = world_size
80 | else:
81 | logger.info('Not using distributed mode')
82 | args.distributed = False
83 | return
84 |
85 | args.distributed = True
86 |
87 | torch.cuda.set_device(args.gpu)
88 | args.dist_backend = 'nccl'
89 |
90 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
91 | dist_port = int(args.dist_url.split(":")[-1])
92 | while is_port_in_use(dist_port):
93 | dist_port += 10
94 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
95 |
96 | logger.info('| distributed init (rank {}): {}'.format(
97 | args.rank, args.dist_url))
98 | if "SLURM_JOB_ID" in os.environ:
99 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
100 | torch.distributed.init_process_group(
101 | backend=args.dist_backend, init_method=args.dist_url,
102 | world_size=args.world_size, rank=args.rank)
103 | torch.distributed.barrier()
104 | setup_for_distributed(args.rank == 0)
105 |
106 |
107 | # Copyright (c) Facebook, Inc. and its affiliates.
108 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
109 | class GatherLayer(torch.autograd.Function):
110 | """
111 | Gather tensors from all workers with support for backward propagation:
112 | This implementation does not cut the gradients as torch.distributed.all_gather does.
113 | """
114 |
115 | @staticmethod
116 | def forward(ctx, x):
117 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
118 | dist.all_gather(output, x)
119 | return tuple(output)
120 |
121 | @staticmethod
122 | def backward(ctx, *grads):
123 | all_gradients = torch.stack(grads)
124 | dist.all_reduce(all_gradients)
125 | return all_gradients[dist.get_rank()]
126 |
127 |
128 | # copied from megavlt
129 | def gather_tensor_along_batch_with_backward(tensor, dim=0):
130 | world_size = get_world_size()
131 |
132 | if world_size < 2:
133 | return tensor
134 |
135 | tensor_list = GatherLayer.apply(tensor)
136 | tensor_list = torch.cat(tensor_list, dim=dim)
137 | return tensor_list
138 |
139 |
140 | @torch.no_grad()
141 | def gather_tensor_along_batch(tensor, dim=0):
142 | """
143 | Performs all_gather operation on the provided tensors.
144 | *** Warning ***: torch.distributed.all_gather has no gradient.
145 | """
146 | world_size = get_world_size()
147 |
148 | if world_size < 2:
149 | return tensor
150 |
151 | with torch.no_grad():
152 | tensor_list = []
153 |
154 | for _ in range(world_size):
155 | tensor_list.append(torch.zeros_like(tensor))
156 |
157 | dist.all_gather(tensor_list, tensor)
158 | tensor_list = torch.cat(tensor_list, dim=dim)
159 | return tensor_list
160 |
--------------------------------------------------------------------------------
/utils/easydict.py:
--------------------------------------------------------------------------------
1 | class EasyDict(dict):
2 | """
3 | Get attributes
4 |
5 | >>> d = EasyDict({'foo':3})
6 | >>> d['foo']
7 | 3
8 | >>> d.foo
9 | 3
10 | >>> d.bar
11 | Traceback (most recent call last):
12 | ...
13 | AttributeError: 'EasyDict' object has no attribute 'bar'
14 |
15 | Works recursively
16 |
17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
18 | >>> isinstance(d.bar, dict)
19 | True
20 | >>> d.bar.x
21 | 1
22 |
23 | Bullet-proof
24 |
25 | >>> EasyDict({})
26 | {}
27 | >>> EasyDict(d={})
28 | {}
29 | >>> EasyDict(None)
30 | {}
31 | >>> d = {'a': 1}
32 | >>> EasyDict(**d)
33 | {'a': 1}
34 |
35 | Set attributes
36 |
37 | >>> d = EasyDict()
38 | >>> d.foo = 3
39 | >>> d.foo
40 | 3
41 | >>> d.bar = {'prop': 'value'}
42 | >>> d.bar.prop
43 | 'value'
44 | >>> d
45 | {'foo': 3, 'bar': {'prop': 'value'}}
46 | >>> d.bar.prop = 'newer'
47 | >>> d.bar.prop
48 | 'newer'
49 |
50 |
51 | Values extraction
52 |
53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
54 | >>> isinstance(d.bar, list)
55 | True
56 | >>> from operator import attrgetter
57 | >>> map(attrgetter('x'), d.bar)
58 | [1, 3]
59 | >>> map(attrgetter('y'), d.bar)
60 | [2, 4]
61 | >>> d = EasyDict()
62 | >>> d.keys()
63 | []
64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
65 | >>> d.foo
66 | 3
67 | >>> d.bar.x
68 | 1
69 |
70 | Still like a dict though
71 |
72 | >>> o = EasyDict({'clean':True})
73 | >>> o.items()
74 | [('clean', True)]
75 |
76 | And like a class
77 |
78 | >>> class Flower(EasyDict):
79 | ... power = 1
80 | ...
81 | >>> f = Flower()
82 | >>> f.power
83 | 1
84 | >>> f = Flower({'height': 12})
85 | >>> f.height
86 | 12
87 | >>> f['power']
88 | 1
89 | >>> sorted(f.keys())
90 | ['height', 'power']
91 |
92 | update and pop items
93 | >>> d = EasyDict(a=1, b='2')
94 | >>> e = EasyDict(c=3.0, a=9.0)
95 | >>> d.update(e)
96 | >>> d.c
97 | 3.0
98 | >>> d['c']
99 | 3.0
100 | >>> d.get('c')
101 | 3.0
102 | >>> d.update(a=4, b=4)
103 | >>> d.b
104 | 4
105 | >>> d.pop('a')
106 | 4
107 | >>> d.a
108 | Traceback (most recent call last):
109 | ...
110 | AttributeError: 'EasyDict' object has no attribute 'a'
111 | """
112 |
113 | def __init__(self, d=None, **kwargs):
114 | if d is None:
115 | d = {}
116 | if kwargs:
117 | d.update(**kwargs)
118 | for k, v in d.items():
119 | setattr(self, k, v)
120 | # Class attributes
121 | for k in self.__class__.__dict__.keys():
122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
123 | setattr(self, k, getattr(self, k))
124 |
125 | def __setattr__(self, name, value):
126 | if isinstance(value, (list, tuple)):
127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
128 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
129 | value = self.__class__(value)
130 | super(EasyDict, self).__setattr__(name, value)
131 | super(EasyDict, self).__setitem__(name, value)
132 |
133 | __setitem__ = __setattr__
134 |
135 | def update(self, e=None, **f):
136 | d = e or dict()
137 | d.update(f)
138 | for k in d:
139 | setattr(self, k, d[k])
140 |
141 | def pop(self, k, d=None):
142 | if hasattr(self, k):
143 | delattr(self, k)
144 | return super(EasyDict, self).pop(k, d)
145 |
146 |
147 | if __name__ == "__main__":
148 | import doctest
149 |
150 |
--------------------------------------------------------------------------------
/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from torch.optim import Optimizer
5 | import math
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 |
9 | def create_scheduler(args, optimizer):
10 | lr_scheduler = None
11 | if args.sched == 'cosine':
12 | lr_scheduler = get_cosine_schedule_with_warmup(
13 | optimizer,
14 | num_warmup_steps=args.num_warmup_steps,
15 | num_training_steps=args.num_training_steps,
16 | num_cycles=0.5,
17 | min_lr_multi=args.min_lr_multi
18 | )
19 | return lr_scheduler
20 |
21 |
22 | def get_cosine_schedule_with_warmup(
23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
25 | ):
26 | """
27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
28 |
29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
31 | initial lr set in the optimizer.
32 | Args:
33 | optimizer ([`~torch.optim.Optimizer`]):
34 | The optimizer for which to schedule the learning rate.
35 | num_warmup_steps (`int`):
36 | The number of steps for the warmup phase.
37 | num_training_steps (`int`):
38 | The total number of training steps.
39 | num_cycles (`float`, *optional*, defaults to 0.5):
40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
41 | following a half-cosine).
42 | min_lr_multi (`float`, *optional*, defaults to 0):
43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
44 | last_epoch (`int`, *optional*, defaults to -1):
45 | The index of the last epoch when resuming training.
46 | Return:
47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
48 | """
49 |
50 | def lr_lambda(current_step):
51 | if current_step < num_warmup_steps:
52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
55 |
56 | return LambdaLR(optimizer, lr_lambda, last_epoch)
57 |
--------------------------------------------------------------------------------