├── .gitignore
├── LICENSE
├── README.md
├── config
└── data
│ ├── pretrain_mixed.json
│ ├── pretrain_r2r.json
│ ├── pretrain_reverie.json
│ └── pretrain_rxr.json
├── data
├── __init__.py
├── common.py
├── dataset.py
├── loader.py
├── r2r_data.py
└── r2r_tasks.py
├── demo_r2r.py
├── docker
├── Dockerfile
└── installation_guide_wo_docker.txt
├── engine_finetune.py
├── eval_speaker.py
├── exps
└── finetune.sh
├── gradio_app.py
├── images
└── c-instructor.png
├── landmark
├── extract_landmark_r2r.py
├── extract_landmark_reverie.py
├── extract_landmark_rxr.py
└── select_eng_rxr.py
├── llama
├── __init__.py
├── llama.py
├── llama_adapter.py
├── tokenizer.py
└── utils.py
├── main_finetune.py
├── preprocess
├── build_image_lmdb.py
├── precompute_img_features_clip.py
└── utils.py
├── pycocoevalcap
├── __init__.py
├── bleu
│ ├── LICENSE
│ ├── __init__.py
│ ├── bleu.py
│ └── bleu_scorer.py
├── cider
│ ├── __init__.py
│ ├── cider.py
│ └── cider_scorer.py
├── clip_tokenizer
│ ├── bpe_simple_vocab_16e6.txt.gz
│ └── tokenization_clip.py
├── eval.py
├── meteor
│ ├── __init__.py
│ ├── data
│ │ └── paraphrase-en.gz
│ ├── meteor-1.5.jar
│ └── meteor.py
├── rouge
│ ├── __init__.py
│ └── rouge.py
├── spice
│ ├── __init__.py
│ ├── spice-1.0.jar
│ └── spice.py
├── tokenizer
│ ├── __init__.py
│ ├── ptbtokenizer.py
│ └── stanford-corenlp-3.4.1.jar
└── utils.py
├── reduce_checkpoint.py
├── requirements.txt
└── util
├── bleu.py
├── extract_adapter_from_checkpoint.py
├── lr_sched.py
└── misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | .DS_Store
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 | logs/
133 | *.c
134 | *.so
135 | .idea
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # C-Instructor: Controllable Navigation Instruction Generation with Chain of Thought Prompting
2 |
3 | Official implementation of the **ECCV 2024** paper **Controllable Navigation Instruction Generation with Chain of Thought Prompting** [[Link]](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/04155.pdf).
4 |
5 |

6 |
7 | ## News
8 |
9 | - 12/16/2024: Initial release 🎉🎉🎉.
10 |
11 | ## Setup
12 |
13 | We recommend using our [Dockerfile](docker/Dockerfile) to setup the environment. If you encounter any issues, please refer to [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator).
14 |
15 | ### Prerequisites
16 |
17 | - Nvidia GPU with driver >= 396.37
18 | - Install [docker](https://docs.docker.com/engine/installation/)
19 | - Install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
20 | - Note: CUDA / cuDNN toolkits do not need to be installed (these are provided by the docker image)
21 |
22 | ### Clone Repo
23 |
24 | Clone the Matterport3D Simulator repository:
25 |
26 | ```bash
27 | # Make sure to clone with --recursive
28 | git clone --recursive https://github.com/peteanderson80/Matterport3DSimulator.git
29 | cd Matterport3DSimulator
30 | ```
31 |
32 | If you didn't clone with the `--recursive` flag, then you'll need to manually clone the pybind submodule from the top-level directory:
33 |
34 | ```bash
35 | git submodule update --init --recursive
36 | ```
37 |
38 | ### Dataset Download
39 |
40 | To use the simulator you must first download the [Matterport3D Dataset](https://niessner.github.io/Matterport/) which is available after requesting access [here](https://niessner.github.io/Matterport/). The download script that will be provided allows for downloading of selected data types. At minimum you must download the `matterport_skybox_images` and `undistorted_camera_parameters`. If you wish to use depth outputs then also download `undistorted_depth_images` (not required for C-Instructor).
41 |
42 | Set an environment variable to the location of the **unzipped** dataset, where `` is the full absolute path (not a relative path or symlink) to the directory containing the individual matterport scan directories (17DRP5sb8fy, 2t7WUuJeko7, etc):
43 |
44 | ```bash
45 | export MATTERPORT_DATA_DIR=
46 | ```
47 |
48 | Note that if `` is a remote sshfs mount, you will need to mount it with the `-o allow_root` option or the docker container won't be able to access this directory.
49 |
50 | ### Building using Docker
51 |
52 | Build the docker image:
53 |
54 | ```bash
55 | docker build -t mattersim:9.2-devel-ubuntu20.04 .
56 | ```
57 |
58 | Run the docker container, mounting both the git repo and the dataset:
59 |
60 | ```bash
61 | docker run -it --mount type=bind,source=$MATTERPORT_DATA_DIR,target=/root/mount/Matterport3DSimulator/data/v1/scans --volume {ACTUAL_PATH}:/root/mount/{XXX} mattersim:9.2-devel-ubuntu20.04
62 | ```
63 |
64 | Now (from inside the docker container), build the simulator code:
65 |
66 | ```bash
67 | cd /root/mount/Matterport3DSimulator
68 | mkdir build && cd build
69 | cmake -DEGL_RENDERING=ON ..
70 | make
71 | cd ../
72 | ```
73 |
74 | #### Rendering Options (GPU, CPU, off-screen)
75 |
76 | Note that there are three rendering options, which are selected using [cmake](https://cmake.org/) options during the build process (by varying line 3 in the build commands immediately above):
77 |
78 | - GPU rendering using OpenGL (requires an X server): `cmake ..` (default)
79 | - Off-screen GPU rendering using [EGL](https://www.khronos.org/egl/): `cmake -DEGL_RENDERING=ON ..`
80 | - Off-screen CPU rendering using [OSMesa](https://www.mesa3d.org/osmesa.html): `cmake -DOSMESA_RENDERING=ON ..`
81 |
82 | The recommended (fast) approach for training agents is using off-screen GPU rendering (EGL).
83 |
84 | ### Dataset Preprocessing
85 |
86 | To make data loading faster and to reduce memory usage we preprocess the `matterport_skybox_images` by downscaling and combining all cube faces into a single image. While still inside the docker container, run the following script:
87 |
88 | ```bash
89 | ./scripts/downsize_skybox.py
90 | ```
91 |
92 | This will take a while depending on the number of processes used (which is a setting in the script).
93 |
94 | After completion, the `matterport_skybox_images` subdirectories in the dataset will contain image files with filename format `_skybox_small.jpg`. By default images are downscaled by 50% and 20 processes are used.
95 |
96 | #### Depth Outputs (Not Required for C-Instructor)
97 |
98 | If you need depth outputs as well as RGB (via `sim.setDepthEnabled(True)`), precompute matching depth skybox images by running this script:
99 |
100 | ```bash
101 | ./scripts/depth_to_skybox.py
102 | ```
103 |
104 | Depth skyboxes are generated from the `undistorted_depth_images` using a simple blending approach. As the depth images contain many missing values (corresponding to shiny, bright, transparent, and distant surfaces, which are common in the dataset) we apply a simple crossbilateral filter based on the [NYUv2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) code to fill all but the largest holes. A couple of things to keep in mind:
105 |
106 | - We assume that the `undistorted depth images` are aligned to the `matterport_skybox_images`, but in fact this alignment is not perfect. For certain applications where better alignment is required (e.g., generating RGB pointclouds) it might be necessary to replace the `matterport_skybox_images` by stitching together `undistorted_color_images` (which are perfectly aligned to the `undistorted_depth_images`).
107 | - In the generated depth skyboxes, the depth value is the euclidean distance from the camera center (not the distance in the z direction). This is corrected by the simulator (see Simulator API, below).
108 |
109 | ### Running Tests
110 |
111 | Now (still from inside the docker container), run the unit tests:
112 |
113 | ```bash
114 | ./build/tests ~Timing
115 | ```
116 |
117 | Assuming all tests pass, `sim_imgs` will now contain some test images rendered by the simulator. You may also wish to test the rendering frame rate. The following command will try to load all the Matterport environments into memory (requiring around 50 GB memory), and then some information about the rendering frame rate (at 640x480 resolution, RGB outputs only) will be printed to stdout:
118 |
119 | ```bash
120 | ./build/tests Timing
121 | ```
122 |
123 | The timing test must be run individually from the other tests to get accurate results. Not that the Timing test will fail if there is insufficient memory. As long as all the other tests pass (i.e., `./build/tests ~Timing`) then the install is good. Refer to the [Catch](https://github.com/philsquared/Catch) documentation for unit test configuration options.
124 |
125 | ### Precompute Image Features
126 |
127 | Copy `preprocess` folder to `Matterport3DSimulator/tasks` and use `precompute_img_features_clip.py` for extracting CLIP features.
128 |
129 | ### Pre-trained LLaMA Weights
130 |
131 | Obtain the LLaMA backbone weights using [this form](https://forms.gle/jk851eBVbX1m5TAv5). Please note that checkpoints from unofficial sources (e.g., BitTorrent) may contain malicious code and should be used with care. Organize the downloaded file in the following structure:
132 |
133 | ```
134 | /path/to/llama_model_weights
135 | ├── 7B
136 | │ ├── checklist.chk
137 | │ ├── consolidated.00.pth
138 | │ └── params.json
139 | └── tokenizer.model
140 | ```
141 |
142 | ### LLaMA Adapter Weights
143 |
144 | The weights of LLaMA Adapter can be obtained through [Github Release](https://github.com/OpenGVLab/LLaMA-Adapter/releases/tag/v.2.0.0).
145 |
146 | ### Data Preparation
147 |
148 | Download the annotations from HAMT [Dropbox](https://www.dropbox.com/sh/3a5j03u286px604/AABNp887W7_Fhgv13gUt4wzda?dl=0).
149 |
150 | ## Landmark Extraction
151 |
152 | Extract landmarks using scripts under `landmark`.
153 |
154 | ## Training
155 |
156 | ### Pre-training
157 |
158 | We pre-train the model on the PREVALENT dataset using the following command until convergence:
159 |
160 | ```bash
161 | bash exps/finetune.sh {path_to_llama}/LLaMA-7B/ {path_to_llama_adapter}/7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth config/data/pretrain_r2r.json {results_dir}
162 | ```
163 |
164 | Note that you will need to specify the arguments in `exps/finetune.sh` and `config/data/pretrain_r2r.json`.
165 |
166 | ### Fine-tuning
167 |
168 | We fine-tune the model on other VLN datasets using the following command until convergence:
169 |
170 | ```bash
171 | bash exps/finetune.sh {path_to_llama}/LLaMA-7B/ {path_to_ckpts}/{filename}-7B.pth config/data/pretrain_{dataset_name}.json {results_dir}
172 | ```
173 |
174 | Note that you will need to specify the arguments in `exps/finetune.sh` and `config/data/pretrain_{dataset_name}.json`.
175 |
176 | ## Inference
177 |
178 | Please refer to `demo_r2r.py` for inference and navigation path visualization.
179 |
180 | ## Evaluation
181 |
182 | Please refer to `pycocoevalcap/eval.py` for evaluation. To run the evaluation script, please install java and prepare the necessities according to [this link](https://github.com/tylin/coco-caption/blob/master/get_stanford_models.sh).
183 |
184 | ## Citation
185 |
186 | If you are using C-Instructor for your research, please cite the following paper:
187 |
188 | ```bibtex
189 | @inproceedings{kong2025controllable,
190 | title={Controllable navigation instruction generation with chain of thought prompting},
191 | author={Kong, Xianghao and Chen, Jinyu and Wang, Wenguan and Su, Hang and Hu, Xiaolin and Yang, Yi and Liu, Si},
192 | booktitle={European Conference on Computer Vision},
193 | pages={37--54},
194 | year={2025},
195 | organization={Springer}
196 | }
197 | ```
198 |
199 | ## Acknowledgements
200 |
201 | This project is built upon [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter/tree/main/llama_adapter_v2_multimodal7b), [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator), [HAMT](https://github.com/cshizhe/VLN-HAMT), and [Microsoft COCO Caption Evaluation](https://github.com/tylin/coco-caption).
202 |
--------------------------------------------------------------------------------
/config/data/pretrain_mixed.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_datasets": {
3 | "R2R": {
4 | "name": "R2R",
5 | "train_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/train_landmark_vis.jsonl",
6 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/RxR/data/pretrain/rxr_train_guide_landmark_vis_score.jsonl",
7 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/train_landmark_vis_score.jsonl"
8 | ],
9 | "val_seen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/val_seen_landmark_vis.jsonl",
10 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/RxR/data/pretrain/rxr_val_seen_guide_landmark_vis_score.jsonl",
11 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/val_seen_landmark_vis_score.jsonl"],
12 | "val_unseen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/val_unseen_landmark_vis.jsonl",
13 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/RxR/data/pretrain/rxr_val_unseen_guide_landmark_vis_score.jsonl",
14 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/val_unseen_landmark_vis_score.jsonl"],
15 | "img_ft_file": "/data/user/kxh/instructllm/Matterport3DSimulator/img_features/vit_l_14_clip.hdf5",
16 | "scanvp_cands_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/scanvp_candview_relangles.json",
17 | "connectivity_dir": "/data/user/kxh/instructllm/Matterport3DSimulator/connectivity",
18 | "bboxes_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json",
19 | "tasks": [
20 | "sap",
21 | "itm",
22 | "lmp"
23 | ],
24 | "mix_ratio": [
25 | 4,
26 | 1,
27 | 1
28 | ]
29 | }
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/config/data/pretrain_r2r.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_datasets": {
3 | "R2R": {
4 | "name": "R2R",
5 | "train_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/train_landmark_vis_score.jsonl",
6 | "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/train_prevalent_generated_landmark.jsonl"],
7 | "val_seen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/val_seen_landmark_vis_score.jsonl"],
8 | "val_unseen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/val_unseen_landmark_vis_score.jsonl"],
9 | "img_ft_file": "/data/user/kxh/instructllm/Matterport3DSimulator/img_features/vit_l_14_clip.hdf5",
10 | "scanvp_cands_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/scanvp_candview_relangles.json",
11 | "connectivity_dir": "/data/user/kxh/instructllm/Matterport3DSimulator/connectivity",
12 | "bboxes_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json",
13 | "tasks": [
14 | "sap",
15 | "itm",
16 | "lmp"
17 | ],
18 | "mix_ratio": [
19 | 4,
20 | 1,
21 | 1
22 | ]
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/config/data/pretrain_reverie.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_datasets": {
3 | "R2R": {
4 | "name": "R2R",
5 | "train_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/train_landmark_vis.jsonl"
6 | ],
7 | "val_seen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/val_seen_landmark_vis.jsonl"],
8 | "val_unseen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/val_unseen_landmark_vis.jsonl"],
9 | "img_ft_file": "/data/user/kxh/instructllm/Matterport3DSimulator/img_features/vit_l_14_clip.hdf5",
10 | "scanvp_cands_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/scanvp_candview_relangles.json",
11 | "connectivity_dir": "/data/user/kxh/instructllm/Matterport3DSimulator/connectivity",
12 | "bboxes_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json",
13 | "tasks": [
14 | "sap",
15 | "itm",
16 | "lmp"
17 | ],
18 | "mix_ratio": [
19 | 4,
20 | 1,
21 | 1
22 | ]
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/config/data/pretrain_rxr.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_datasets": {
3 | "R2R": {
4 | "name": "R2R",
5 | "train_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/RxR/data/pretrain/rxr_train_guide_landmark.jsonl"],
6 | "val_seen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/RxR/data/pretrain/rxr_val_seen_guide_landmark.jsonl"],
7 | "val_unseen_traj_files": ["/data/user/kxh/instructllm/Matterport3DSimulator/tasks/RxR/data/pretrain/rxr_val_unseen_guide_landmark.jsonl"],
8 | "img_ft_file": "/data/user/kxh/instructllm/Matterport3DSimulator/img_features/vit_l_14_clip.hdf5",
9 | "scanvp_cands_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/scanvp_candview_relangles.json",
10 | "connectivity_dir": "/data/user/kxh/instructllm/Matterport3DSimulator/connectivity",
11 | "bboxes_file": "/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json",
12 | "tasks": [
13 | "sap",
14 | "itm",
15 | "lmp"
16 | ],
17 | "mix_ratio": [
18 | 4,
19 | 1,
20 | 1
21 | ]
22 | }
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .r2r_data import MultiStepNavData
2 |
3 | from .r2r_tasks import (
4 | MlmDataset, mlm_collate,
5 | SapDataset, sap_collate,
6 | SarDataset, sar_collate,
7 | SprelDataset, sprel_collate,
8 | MrcDataset, mrc_collate,
9 | ItmDataset, itm_collate,
10 | LmpDataset, lmp_collate
11 | )
12 |
13 | from .loader import PrefetchLoader, MetaLoader, build_dataloader
14 |
--------------------------------------------------------------------------------
/data/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def pad_tensors(tensors, lens=None, pad=0):
6 | """B x [T, ...]"""
7 | if lens is None:
8 | lens = [t.size(0) for t in tensors]
9 | max_len = max(lens)
10 | bs = len(tensors)
11 | hid = list(tensors[0].size()[1:])
12 | size = [bs, max_len] + hid
13 |
14 | dtype = tensors[0].dtype
15 | output = torch.zeros(*size, dtype=dtype)
16 | if pad:
17 | output.data.fill_(pad)
18 | for i, (t, l) in enumerate(zip(tensors, lens)):
19 | output.data[i, :l, ...] = t.data
20 | return output
21 |
22 |
23 | def gen_seq_masks(seq_lens, max_len=None):
24 | seq_lens = np.array(seq_lens)
25 | if max_len is None:
26 | max_len = max(seq_lens)
27 | batch_size = len(seq_lens)
28 | masks = np.arange(max_len).reshape(-1, max_len).repeat(batch_size, 0)
29 | masks = masks < seq_lens.reshape(-1, 1)
30 | return masks
31 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | import random
5 |
6 | import cv2
7 | import pandas as pd
8 | import torch
9 | import torchvision.transforms as transforms
10 | import yaml
11 | from torch.utils.data import Dataset
12 | from PIL import Image
13 |
14 | import llama.utils
15 | from llama import Tokenizer
16 |
17 | try:
18 | from torchvision.transforms import InterpolationMode
19 | BICUBIC = InterpolationMode.BICUBIC
20 | except ImportError:
21 | BICUBIC = Image.BICUBIC
22 |
23 |
24 | PROMPT_DICT = {
25 | "prompt_input": (
26 | "Below is an instruction that describes a task, paired with an input that provides further context. "
27 | "Write a response that appropriately completes the request.\n\n"
28 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
29 | ),
30 | "prompt_no_input": (
31 | "Below is an instruction that describes a task. "
32 | "Write a response that appropriately completes the request.\n\n"
33 | "### Instruction:\n{instruction}\n\n### Response:"
34 | ),
35 | }
36 |
37 | # create data
38 | transform_train = transforms.Compose([
39 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
40 | antialias=None), # 3 is bicubic
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
43 |
44 |
45 | class AssisterDataset(Dataset):
46 | pass
47 |
48 |
49 | class CaptionDataset(Dataset):
50 | def __init__(self, data_path, dataset_name='r2r', max_words=30, tokenizer_path=None, training=True):
51 | self.data_path = data_path
52 | self.dataset_name = dataset_name
53 | print(f"Load {self.dataset_name} dataset from {self.data_path}")
54 | self.ann = self.load_data(training=training)
55 | self.max_words = max_words
56 | self.tokenizer = Tokenizer(model_path=tokenizer_path)
57 |
58 | def load_data(self, training=True):
59 | split = 'train' if training else 'val_unseen'
60 | with open(os.path.join(self.data_path, f'path_caption_{self.dataset_name}_{split}.json')) as f:
61 | caption_data = json.load(f)
62 |
63 | anno = []
64 | for id, value in caption_data.items():
65 | item_to_append = {'captions': value['captions']}
66 | for i, gt in enumerate(value['gt']):
67 | item_to_append['id'] = id + '_' + str(i)
68 | item_to_append['gt'] = gt
69 | anno.append(item_to_append)
70 |
71 | return anno
72 |
73 | def __len__(self):
74 | return len(self.ann)
75 |
76 | def __getitem__(self, index):
77 | data_item = self.ann[index]
78 | id = data_item['id']
79 | gt = data_item['gt']
80 | captions = data_item['captions']
81 |
82 | image = torch.zeros(3, 224, 224)
83 |
84 | if self.dataset_name == 'reverie':
85 | format_instruction = (
86 | "You are given captions of a sequence of views of a path in an indoor environment separated by semicolons. "
87 | "Please generate a high-level target-oriented instruction briefly for an intelligent agent to follow. "
88 | "You should only output the instruction."
89 | )
90 | elif self.dataset_name == 'r2r':
91 | format_instruction = (
92 | "You are given captions of a sequence of views of a path in an indoor environment separated by semicolons. "
93 | "Please describe the path according to the given captions in details for an intelligent agent to follow."
94 | "You should only output the instruction."
95 | )
96 | else:
97 | raise NotImplementedError(f"dataset_name {self.dataset_name} not implemented")
98 | format_input = "; ".join(captions)
99 | input1 = llama.utils.format_prompt(format_instruction, format_input)
100 | ori_prompt = input1
101 | input2 = input1 + gt
102 | input1 = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False), dtype=torch.int64)
103 | input2 = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True), dtype=torch.int64)
104 | padding = self.max_words - input2.shape[0]
105 | if padding > 0:
106 | input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.int64) - 1))
107 | elif padding < 0:
108 | input2 = input2[:self.max_words]
109 | labels = copy.deepcopy(input2)
110 | labels[:len(input1)] = -1
111 | input2_mask = input2.ge(0)
112 | label_mask = labels.ge(0)
113 | input2[~input2_mask] = 0
114 | labels[~label_mask] = 0
115 | input2_mask = input2_mask.float()
116 | label_mask = label_mask.float()
117 | return input2, labels, input2_mask, image, id, ori_prompt, gt
118 |
119 |
120 | class FinetuneDataset(Dataset):
121 | def __init__(self, config_path, transform, max_words=30, tokenizer_path=None):
122 | print(f"read dataset config from {config_path}")
123 | with open(config_path, 'r') as f:
124 | self.config = yaml.load(f, Loader=yaml.FullLoader)
125 | print("DATASET CONFIG:")
126 | print(self.config)
127 | ann = []
128 | for meta_path in self.config['META']:
129 | meta_l = json.load(open(meta_path))
130 | print(f"{meta_path}: len {len(meta_l)}")
131 | ann += meta_l
132 | self.ann = ann
133 | print(f"total length: {len(self)}")
134 | self.transform = transform
135 | self.max_words = max_words
136 | self.tokenizer = Tokenizer(model_path=tokenizer_path)
137 |
138 | def __len__(self):
139 | return len(self.ann)
140 |
141 | def __getitem__(self, index):
142 | data_item = self.ann[index]
143 | if 'image' in data_item.keys():
144 | filename = data_item['image']
145 | question = data_item['conversations'][0]['value']
146 | answer = data_item['conversations'][1]['value']
147 |
148 | image = cv2.imread(filename)
149 | image = Image.fromarray(image)
150 | image = self.transform(image)
151 | format_instruction = question
152 | format_input = None
153 | else:
154 | image = torch.zeros(3, 224, 224)
155 | format_instruction = data_item['instruction'],
156 | format_input = data_item['input']
157 | answer = data_item['output']
158 | input1 = llama.utils.format_prompt(format_instruction, format_input)
159 | input2 = input1 + answer
160 | input1 = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False), dtype=torch.int64)
161 | input2 = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True), dtype=torch.int64)
162 | padding = self.max_words - input2.shape[0]
163 | if padding > 0:
164 | input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.int64) - 1))
165 | elif padding < 0:
166 | input2 = input2[:self.max_words]
167 | labels = copy.deepcopy(input2)
168 | labels[:len(input1)] = -1
169 | input2_mask = input2.ge(0)
170 | label_mask = labels.ge(0)
171 | input2[~input2_mask] = 0
172 | labels[~label_mask] = 0
173 | input2_mask = input2_mask.float()
174 | label_mask = label_mask.float()
175 | return input2, labels, input2_mask, image
176 |
177 |
178 | class PretrainDataset(Dataset):
179 | def __init__(self, config_path, transform, max_words=30, tokenizer_path=None):
180 | print(f"read dataset config from {config_path}")
181 | with open(config_path, 'r') as f:
182 | self.config = yaml.load(f, Loader=yaml.FullLoader)
183 | print("DATASET CONFIG:")
184 | print(self.config)
185 | images, captions = [], []
186 | for meta_path in self.config['META']:
187 | images_this_meta, captions_this_meta = [], []
188 | for chunk in pd.read_csv(meta_path, sep='\t', lineterminator='\n', chunksize=10 ** 6):
189 | images_this_meta.extend(chunk['url'].tolist())
190 | captions_this_meta.extend(chunk['caption'].tolist())
191 | print(f"{meta_path}: len {len(images_this_meta)}")
192 | images.extend(images_this_meta)
193 | captions.extend(captions_this_meta)
194 |
195 | self.data_list = []
196 | for x, y in zip(images, captions):
197 | self.data_list.append({'url': x, 'caption': y})
198 | print(f"total length: {len(self)}")
199 | self.transform = transform
200 | self.max_words = max_words
201 | self.tokenizer = Tokenizer(model_path=tokenizer_path)
202 |
203 | def __len__(self):
204 | return len(self.data_list)
205 |
206 | def __getitem__(self, index):
207 | sample = self.data_list[index]
208 | image_path, caption = sample['url'], sample['caption']
209 | if isinstance(caption, list):
210 | caption = random.choice(caption)
211 | caption = str(caption)
212 |
213 | image = cv2.imread(image_path)
214 | image = Image.fromarray(image)
215 | image = self.transform(image)
216 |
217 | format_instruction = "Generate caption of this image"
218 | input1 = llama.utils.format_prompt(format_instruction, None)
219 | input2 = input1 + caption
220 |
221 | input1 = torch.tensor(self.tokenizer.encode(input1, bos=True, eos=False), dtype=torch.int64)
222 | input2 = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True), dtype=torch.int64)
223 | padding = self.max_words - input2.shape[0]
224 | if padding > 0:
225 | input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.int64) - 1))
226 | elif padding < 0:
227 | input2 = input2[:self.max_words]
228 | labels = copy.deepcopy(input2)
229 | labels[:len(input1)] = -1
230 | input2_mask = input2.ge(0)
231 | label_mask = labels.ge(0)
232 | input2[~input2_mask] = 0
233 | labels[~label_mask] = 0
234 | input2_mask = input2_mask.float()
235 | label_mask = label_mask.float()
236 | return input2, labels, input2_mask, image
237 |
--------------------------------------------------------------------------------
/data/loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | A prefetch loader to speedup data loading
6 | Modified from Nvidia Deep Learning Examples
7 | (https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch).
8 | """
9 | from typing import List, Dict, Tuple, Union, Iterator
10 |
11 | import torch
12 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
13 | from torch.utils.data.distributed import DistributedSampler
14 | import torch.distributed as dist
15 |
16 |
17 | class MetaLoader:
18 | """wraps multiple data loaders"""
19 |
20 | def __init__(
21 | self, loaders, accum_steps: int = 1, distributed: bool = False, device=None, num_iters=None
22 | ):
23 | assert isinstance(loaders, dict)
24 | self.name2loader = {}
25 | self.name2iter = {}
26 | self.name2pre_epoch = {}
27 | self.names: List[str] = []
28 | ratios: List[int] = []
29 | for n, l in loaders.items():
30 | if isinstance(l, tuple):
31 | l, r, p = l
32 | elif isinstance(l, DataLoader):
33 | r = 1
34 | def p(e): return None
35 | else:
36 | raise ValueError()
37 | self.names.append(n)
38 | self.name2loader[n] = l
39 | self.name2iter[n] = iter(l)
40 | self.name2pre_epoch[n] = p
41 | ratios.append(r)
42 |
43 | self.accum_steps = accum_steps
44 | self.device = device
45 | self.sampling_ratios = torch.tensor(ratios).float().to(self.device)
46 | self.distributed = distributed
47 | self.step = 0
48 |
49 | self.num_iters = num_iters
50 | self.epoch_id = 0
51 |
52 | def __len__(self):
53 | if self.num_iters is None:
54 | return sum(len(l) for l in self.name2loader.values())
55 | else:
56 | return self.num_iters
57 |
58 | def __iter__(self) -> Iterator[Tuple]:
59 | """this iterator will run indefinitely if num_iters is None"""
60 | task_id = None
61 | if self.num_iters is not None:
62 | for _ in range(self.num_iters):
63 | if self.step % self.accum_steps == 0:
64 | task_id = torch.multinomial(self.sampling_ratios, 1)
65 | if self.distributed:
66 | # make sure all process is training same task
67 | dist.broadcast(task_id, 0)
68 | self.step += 1
69 | task = self.names[task_id.cpu().item()]
70 | iter_ = self.name2iter[task]
71 | try:
72 | batch = next(iter_)
73 | except StopIteration:
74 | self.epoch_id += 1
75 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch
76 | # before creating the DataLoader iterator is necessary to make shuffling work properly
77 | # across multiple epochs. Otherwise, the same ordering will be always used.
78 | self.name2pre_epoch[task](self.epoch_id)
79 | iter_ = iter(self.name2loader[task])
80 | batch = next(iter_)
81 | self.name2iter[task] = iter_
82 |
83 | # yield task, batch
84 | yield batch
85 | else:
86 | while True:
87 | if self.step % self.accum_steps == 0:
88 | task_id = torch.multinomial(self.sampling_ratios, 1)
89 | if self.distributed:
90 | # make sure all process is training same task
91 | dist.broadcast(task_id, 0)
92 | self.step += 1
93 | task = self.names[task_id.cpu().item()]
94 | iter_ = self.name2iter[task]
95 | try:
96 | batch = next(iter_)
97 | except StopIteration:
98 | self.epoch_id += 1
99 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch
100 | # before creating the DataLoader iterator is necessary to make shuffling work properly
101 | # across multiple epochs. Otherwise, the same ordering will be always used.
102 | self.name2pre_epoch[task](self.epoch_id)
103 | iter_ = iter(self.name2loader[task])
104 | batch = next(iter_)
105 | self.name2iter[task] = iter_
106 |
107 | # yield task, batch
108 | yield batch
109 |
110 |
111 | def move_to_cuda(batch: Union[List, Tuple, Dict, torch.Tensor], device: torch.device):
112 | if isinstance(batch, torch.Tensor):
113 | return batch.to(device, non_blocking=True)
114 | elif isinstance(batch, list):
115 | return [move_to_cuda(t, device) for t in batch]
116 | elif isinstance(batch, tuple):
117 | return tuple(move_to_cuda(t, device) for t in batch)
118 | elif isinstance(batch, dict):
119 | return {n: move_to_cuda(t, device) for n, t in batch.items()}
120 | return batch
121 |
122 |
123 | class PrefetchLoader(object):
124 | """
125 | overlap compute and cuda data transfer
126 | """
127 |
128 | def __init__(self, loader, device: torch.device):
129 | self.loader = loader
130 | self.device = device
131 |
132 | def __iter__(self):
133 | loader_it = iter(self.loader)
134 | self.preload(loader_it)
135 | batch = self.next(loader_it)
136 | while batch is not None:
137 | yield batch
138 | batch = self.next(loader_it)
139 |
140 | def __len__(self):
141 | return len(self.loader)
142 |
143 | def preload(self, it):
144 | try:
145 | self.batch = next(it)
146 | except StopIteration:
147 | self.batch = None
148 | return
149 | self.batch = move_to_cuda(self.batch, self.device)
150 |
151 | def next(self, it):
152 | batch = self.batch
153 | self.preload(it)
154 | return batch
155 |
156 | def __getattr__(self, name):
157 | method = self.loader.__getattribute__(name)
158 | return method
159 |
160 |
161 | def build_dataloader(task, dataset, collate_fn, is_train: bool, opts):
162 |
163 | batch_size = opts.batch_size
164 | # if task == 'itm':
165 | # batch_size = batch_size // 2
166 |
167 | if opts.local_rank == -1:
168 | if is_train:
169 | sampler: Union[
170 | RandomSampler, SequentialSampler, DistributedSampler
171 | ] = RandomSampler(dataset)
172 | else:
173 | sampler = SequentialSampler(dataset)
174 |
175 | size = torch.cuda.device_count() if torch.cuda.is_available() else 1
176 | def pre_epoch(e): return None
177 |
178 | # DataParallel: scale the batch size by the number of GPUs
179 | if size > 1:
180 | batch_size *= size
181 |
182 | else:
183 | size = dist.get_world_size()
184 | sampler = DistributedSampler(
185 | dataset, num_replicas=size, rank=dist.get_rank(), shuffle=is_train
186 | )
187 | pre_epoch = sampler.set_epoch
188 |
189 | loader = DataLoader(
190 | dataset,
191 | sampler=sampler,
192 | batch_size=batch_size,
193 | num_workers=opts.num_workers,
194 | pin_memory=opts.pin_mem,
195 | collate_fn=collate_fn,
196 | drop_last=False,
197 | )
198 |
199 | return loader, pre_epoch
200 |
--------------------------------------------------------------------------------
/demo_r2r.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 |
5 | import json
6 | import networkx as nx
7 | import numpy as np
8 | import torch
9 | from easydict import EasyDict
10 | from tqdm import tqdm
11 | from PIL import Image
12 |
13 | import MatterSim
14 |
15 | import llama
16 | from data import MultiStepNavData
17 | from r2r.data_utils import load_nav_graphs
18 | from main_finetune import create_dataloaders
19 |
20 |
21 | dataset_name = "r2r"
22 | llama_dir = "/data/user/kxh/instructllm/LLaMA-7B"
23 | data_config = f"config/data/pretrain_{dataset_name}.json"
24 | llama_tokenzier_path = os.path.join(llama_dir, "tokenizer.model")
25 | matterport_connectivity_dir = "/data/user/kxh/instructllm/Matterport3DSimulator/connectivity"
26 | matterport_img_dir = "/data/user/kxh/instructllm/Matterport3D/v1/scans"
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser("llama_adapterV2 R2R demo", add_help=False)
31 | parser.add_argument(
32 | "--batch_size",
33 | default=1,
34 | type=int,
35 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
36 | )
37 | parser.add_argument("--num_workers", default=2, type=int)
38 | parser.add_argument(
39 | "--pin_mem",
40 | action="store_true",
41 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
42 | )
43 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
44 | parser.set_defaults(pin_mem=True)
45 | parser.add_argument("--ckpt_dir", default="results_r2r", type=str)
46 | parser.add_argument("--local_rank", default=-1, type=int)
47 | parser.add_argument("--max_words", default=384, type=int, help="max number of input words")
48 |
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | def build_dataloader(args, device):
54 | dataset_cfg = json.load(open(data_config))
55 | r2r_cfg = EasyDict(dataset_cfg["train_datasets"]["R2R"])
56 | traj_files = r2r_cfg.val_seen_traj_files
57 | # traj_files = r2r_cfg.val_unseen_traj_files
58 | val_nav_db = MultiStepNavData(
59 | traj_files,
60 | r2r_cfg.img_ft_file,
61 | r2r_cfg.scanvp_cands_file,
62 | r2r_cfg.connectivity_dir,
63 | image_prob_size=0,
64 | image_feat_size=768,
65 | angle_feat_size=4,
66 | max_txt_len=args.max_words,
67 | max_act_len=100,
68 | hist_enc_pano=True,
69 | ob_cand_pano_view=False,
70 | val_sample_num=None,
71 | in_memory=True,
72 | tokenizer_path=llama_tokenzier_path,
73 | bboxes_file=r2r_cfg.bboxes_file,
74 | )
75 | val_dataloaders = create_dataloaders(r2r_cfg, val_nav_db, None, False, device, args)
76 | val_dataloader = val_dataloaders["itm"]
77 |
78 | return val_dataloader
79 |
80 |
81 | def build_simulator(connectivity_dir, scan_dir):
82 | sim = MatterSim.Simulator()
83 | sim.setNavGraphPath(connectivity_dir)
84 | sim.setDatasetPath(scan_dir)
85 | sim.setRenderingEnabled(True)
86 | sim.setDiscretizedViewingAngles(True)
87 | sim.setCameraResolution(640, 480)
88 | sim.setCameraVFOV(math.radians(60))
89 | sim.setBatchSize(1)
90 | sim.setPreloadingEnabled(True)
91 | sim.initialize()
92 | return sim
93 |
94 |
95 | def load_graphs(connectivity_dir):
96 | """
97 | load graph from scan,
98 | Store the graph {scan_id: graph} in graphs
99 | Store the shortest path {scan_id: {view_id_x: {view_id_y: [path]} } } in paths
100 | Store the distances in distances. (Structure see above)
101 | Load connectivity graph for each scan, useful for reasoning about shortest paths
102 | :return: graphs, paths, distances
103 | """
104 | with open(os.path.join(connectivity_dir, "scans.txt"), "r") as f:
105 | scans = [scan.strip() for scan in f.readlines()]
106 | print(f"Loading navigation graphs for {len(scans)} scans")
107 | graphs = load_nav_graphs(connectivity_dir, scans)
108 | shortest_paths = {}
109 | for scan, G in graphs.items(): # compute all shortest paths
110 | shortest_paths[scan] = dict(nx.all_pairs_dijkstra_path(G))
111 | shortest_distances = {}
112 | for scan, G in graphs.items(): # compute all shortest paths
113 | shortest_distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
114 |
115 | return graphs, shortest_paths, shortest_distances
116 |
117 |
118 | def main(args):
119 | device = "cuda" if torch.cuda.is_available() else "cpu"
120 |
121 | val_dataloader = build_dataloader(args, device)
122 |
123 | # choose from BIAS-7B, LORA-BIAS-7B, CAPTION-7B.pth
124 | model, preprocess = llama.load(
125 | os.path.join(args.ckpt_dir, "checkpoint-7B.pth"),
126 | llama_dir,
127 | device,
128 | max_batch_size=args.batch_size,
129 | max_seq_len=args.max_words,
130 | )
131 | model.eval()
132 |
133 | # prompt = llama.format_prompt('You are given a sequence of views of a path. '
134 | # 'Please describe the path in details for an intelligent agent to follow. \n\n'
135 | # 'Sample description: Walk through the kitchen passed the stove and sink, turn right after the island and walk towards the couch. Turn left and the couch and walk towards the dining room table, stop before the table. \n'
136 | # 'Description: ')
137 | # prompt = llama.format_prompt('You are given a sequence of views of a path. '
138 | # 'Please describe the path in details for an intelligent agent to follow.')
139 | # prompt = llama.format_prompt('You are a navigator to navigate in an unseen environment. You need to follow the instruction "".'
140 | # 'The past trajectory is given. You don\'t know where to go now. Generate the question you need to ask.'
141 | # 'Question: ')
142 | dataset_to_landmark_prompt = {
143 | "r2r": "You are given a sequence of views of a path. Please extract critical landmarks in the path.",
144 | "reverie": "You are given a sequence of views of a path in an indoor environment. "
145 | "Please extract several critical landmarks in the path for generating a brief high-level target-oriented instruction.",
146 | "rxr": "You are given a sequence of views of a path in an indoor environment. "
147 | "Please extract critical landmarks describing the starting position and the path.",
148 | }
149 | prompt_landmark = llama.utils.format_prompt(dataset_to_landmark_prompt[dataset_name])
150 |
151 | dataset_to_prompt = {
152 | "r2r": "You are given a sequence of views of a path in an indoor environment. "
153 | "Please describe the path according to the given landmarks in details for an intelligent agent to follow.\n"
154 | "Landmarks: {}",
155 | "reverie": "You are given a sequence of views of a path in an indoor environment and critical landmarks for a brief high-level target-oriented instruction. "
156 | "Please generate the indicated high-level target-oriented instruction briefly for an intelligent agent to follow.\n"
157 | "Landmarks: {}",
158 | "rxr": "You are given a sequence of views of a path in an indoor environment. "
159 | "Please describe the starting position and the path according to the given landmarks in details for an intelligent agent to follow.\n"
160 | "Landmarks: {}",
161 | }
162 | prompt = llama.utils.format_prompt(dataset_to_prompt[dataset_name])
163 |
164 | id2path = {}
165 | # num_correct_gt = 0
166 | # num_distance_reduce = 0
167 |
168 | traj_img_dir = os.path.join(args.ckpt_dir, "../traj_img")
169 | os.makedirs(traj_img_dir, exist_ok=True)
170 |
171 | # img_size = Image.open(os.path.join(matterport_img_dir, '1LXtFkjw3qL/matterport_skybox_images/0b22fa63d0f54a529c525afbf2e8bb25_skybox_small.jpg')).size
172 | img_size = (640, 480)
173 | sim = build_simulator(matterport_connectivity_dir, matterport_img_dir)
174 |
175 | # nav_graphs, shortest_paths, shortest_distances = load_graphs(matterport_connectivity_dir)
176 |
177 | for batch in tqdm(val_dataloader):
178 | select_indexes = []
179 | for i in range(len(batch["path_id"])):
180 | path_id = batch["path_id"][i]
181 | if path_id in id2path:
182 | id2path[path_id]["gt"].append(batch["txt"][i])
183 | else:
184 | id2path[path_id] = {"gt": [batch["txt"][i]]}
185 | select_indexes.append(i)
186 |
187 | # select_indexes = list(range(len(batch['path_id'])))
188 |
189 | batch_size = len(select_indexes)
190 | if batch_size == 0:
191 | continue
192 |
193 | prompts = [prompt_landmark] * batch_size
194 | imgs = batch["hist_img_fts"][select_indexes]
195 | ang_feats = batch["hist_ang_fts"][select_indexes]
196 | pano_img_feats = None
197 | pano_ang_feats = None
198 | if "hist_pano_img_fts" in batch:
199 | pano_img_feats = batch["hist_pano_img_fts"][select_indexes]
200 | pano_ang_feats = batch["hist_pano_ang_fts"][select_indexes]
201 | ob_img_feats = None
202 | ob_ang_feats = None
203 | # ob_attn_mask = None
204 | ob_id_seps = None
205 |
206 | # prompts = batch['ori_prompt']
207 | # imgs = batch['hist_img_fts']
208 | # ang_feats = batch['hist_ang_fts']
209 | # pano_img_feats = None
210 | # pano_ang_feats = None
211 | # if 'hist_pano_img_fts' in batch:
212 | # pano_img_feats = batch['hist_pano_img_fts']
213 | # pano_ang_feats = batch['hist_pano_ang_fts']
214 | # ob_img_feats = None
215 | # ob_ang_feats = None
216 | # # ob_attn_mask = None
217 | # ob_id_seps = None
218 | # if 'ob_img_fts' in batch:
219 | # ob_img_feats = batch['ob_img_fts']
220 | # ob_ang_feats = batch['ob_ang_fts']
221 | # # ob_attn_mask = batch['ob_attn_mask']
222 | # ob_id_seps = batch['ob_id_seps']
223 |
224 | # prompt = llama.format_prompt(f'You are a navigator to navigate in an unseen environment. You need to follow the instruction "{batch["txt"][0]}".'
225 | # 'The past trajectory is given. You don\'t know where to go now. Generate the question you need to ask.')
226 |
227 | pred_landmarks = model.generate(
228 | imgs,
229 | prompts,
230 | ang_feats=ang_feats,
231 | pano_img_feats=pano_img_feats,
232 | pano_ang_feats=pano_ang_feats,
233 | ob_img_feats=ob_img_feats,
234 | ob_ang_feats=ob_ang_feats,
235 | ob_id_seps=ob_id_seps,
236 | )
237 |
238 | prompts = [prompt.format(pred_landmark) for pred_landmark in pred_landmarks]
239 | # prompts = batch['ori_prompt'][:batch_size]
240 | results = model.generate(
241 | imgs,
242 | prompts,
243 | ang_feats=ang_feats,
244 | pano_img_feats=pano_img_feats,
245 | pano_ang_feats=pano_ang_feats,
246 | ob_img_feats=ob_img_feats,
247 | ob_ang_feats=ob_ang_feats,
248 | ob_id_seps=ob_id_seps,
249 | temperature=1.0 if dataset_name == "rxr" else 0.1,
250 | )
251 |
252 | for i in range(batch_size):
253 | sel_i = select_indexes[i]
254 | path_id = batch["path_id"][sel_i]
255 |
256 | landmark = pred_landmarks[i]
257 | id2path[path_id]["pred_landmark"] = landmark
258 | result = results[i]
259 | id2path[path_id]["inference"] = result
260 | # if "inference" not in id2path[path_id]:
261 | # id2path[path_id]["inference"] = {}
262 | # instr_id = batch["instr_id"][sel_i]
263 | # id2path[path_id]["inference"][instr_id] = result
264 |
265 |
266 | # if result == batch['gt_id'][sel_i]:
267 | # num_correct_gt += 1
268 |
269 | # t_cur = batch['hist_lens'][sel_i] - 1
270 | # scan_shortest_distances = shortest_distances[batch['scan'][sel_i]]
271 | # cur_distance = scan_shortest_distances[batch['path'][sel_i][t_cur]][batch['path'][sel_i][-1]]
272 | # for vp in scan_shortest_distances.keys():
273 | # if result == vp[:8]:
274 | # if scan_shortest_distances[vp][batch['path'][sel_i][-1]] < cur_distance:
275 | # num_distance_reduce += 1
276 | # break
277 |
278 | if not os.path.exists(os.path.join(traj_img_dir, f"{path_id}.jpg")):
279 | img_concat = Image.new("RGB", (img_size[0], img_size[1] * len(batch["path"][sel_i])))
280 | for j in range(len(batch["path"][sel_i])):
281 | sim.newEpisode(
282 | [batch["scan"][sel_i]],
283 | [batch["path"][sel_i][j]],
284 | [batch["abs_pos_angles"][sel_i][j][0]],
285 | [batch["abs_pos_angles"][sel_i][j][1]],
286 | )
287 | state = sim.getState()[0]
288 | rgb = np.array(state.rgb, copy=False) # BGR
289 |
290 | # img = Image.fromarray(rgb)
291 | # img = preprocess(img).unsqueeze(0).to(device)
292 | # caption_prompt = llama.format_prompt('Please describe this image in details.')
293 | # result = model.generate(img, [caption_prompt])[0]
294 | # print(f'Path ID: {batch["path_id"][0]}')
295 | # print(result)
296 |
297 | img = Image.fromarray(rgb[:, :, ::-1])
298 | img_concat.paste(img, (0, j * img_size[1]))
299 |
300 | # if j < len(batch['path'][0]) - 1:
301 | # for k, vp in enumerate(state.navigableLocations):
302 | # if vp.viewpointId == batch['path'][0][j + 1]:
303 | # sim.makeAction([k], [vp.rel_heading], [vp.rel_elevation])
304 | # break
305 |
306 | # img_path = os.path.join(matterport_img_dir, batch['scan'][0], 'matterport_skybox_images', f'{view}_skybox_small.jpg')
307 | # img = Image.open(img_path)
308 | # img_concat.paste(img, (0, j * img_size[1]))
309 |
310 | img_concat.save(os.path.join(traj_img_dir, f"{path_id}.jpg"))
311 |
312 | # print(f'Num Correct GT: {num_correct_gt}')
313 | # print(f'Num Distance Reduce: {num_distance_reduce}')
314 |
315 | # print(f'Total Samples: {len(val_dataloader) * args.batch_size}')
316 | # print(f'GT Acc: {num_correct_gt / (len(val_dataloader) * args.batch_size)}')
317 | # print(f'Distance Reduce Acc: {num_distance_reduce / (len(val_dataloader) * args.batch_size)}')
318 |
319 | json_file = open(os.path.join(args.ckpt_dir, f"id2path_{dataset_name}_val_seen.json"), "w")
320 | json.dump(id2path, json_file)
321 | json_file.close()
322 |
323 |
324 | if __name__ == "__main__":
325 | args = parse_args()
326 | main(args)
327 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Matterport3DSimulator
2 | # Requires nvidia gpu with driver 396.37 or higher
3 |
4 | FROM nvidia/cudagl:11.4.2-devel
5 |
6 | # Install cudnn
7 | # ENV CUDNN_VERSION 8.2.4.15
8 | # LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}"
9 |
10 | # RUN apt-get update && apt-get install -y --no-install-recommends \
11 | # libcudnn8=$CUDNN_VERSION-1+cuda11.4 \
12 | # libcudnn8-dev=$CUDNN_VERSION-1+cuda11.4 \
13 | # && \
14 | # apt-mark hold libcudnn8 && \
15 | # rm -rf /var/lib/apt/lists/*
16 |
17 | # openssh-server for sshd
18 | # sudo for switch user
19 | RUN apt-get update && apt-get install -y --no-install-recommends openssh-server sudo
20 |
21 | # Allow sshd PasswordAuthentication
22 | RUN sed -i 's/#PasswordAuthentication yes/PasswordAuthentication yes/g' /etc/ssh/sshd_config
23 |
24 | # Install a few libraries to support both EGL and OSMESA options
25 | ENV DEBIAN_FRONTEND=noninteractive
26 | RUN apt-get update && apt-get install -y wget doxygen curl libjsoncpp-dev libepoxy-dev libglm-dev libosmesa6 libosmesa6-dev libglew-dev libopencv-dev python3-setuptools python3-dev python3-pip git htop tmux libaio-dev zip
27 | RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple torch torchvision torchaudio opencv-python numpy pandas networkx fairscale sentencepiece gradio nvitop h5py progressbar2 lmdb jsonlines easydict tensorboard ipykernel
28 | RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple deepspeed
29 | RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple git+https://github.com/csuhan/timm_0_3_2.git
30 | RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple git+https://github.com/openai/CLIP.git
31 |
32 | #install latest cmake
33 | # ADD https://cmake.org/files/v3.27/cmake-3.27.1-linux-x86_64.sh /cmake-3.27.1-linux-x86_64.sh
34 | # RUN mkdir /opt/cmake
35 | # RUN sh /cmake-3.27.1-linux-x86_64.sh --prefix=/opt/cmake --skip-license
36 | # RUN ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake
37 | RUN cmake --version
38 |
39 | ENV PYTHONPATH=/root/mount/Matterport3DSimulator/build
40 |
--------------------------------------------------------------------------------
/docker/installation_guide_wo_docker.txt:
--------------------------------------------------------------------------------
1 | # Install glvnd
2 | sudo apt update && sudo apt install -y \
3 | pkg-config \
4 | libglvnd-dev \
5 | libgl1-mesa-dev \
6 | libegl1-mesa-dev \
7 | libgles2-mesa-dev
8 |
9 |
10 | # Install a few libraries to support both EGL and OSMESA options
11 | sudo apt update && sudo apt install -y wget doxygen curl libjsoncpp-dev libepoxy-dev libglm-dev libosmesa6 libosmesa6-dev libglew-dev libopencv-dev python3-setuptools python3-dev python3-pip git htop tmux libaio-dev zip nload
12 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple
13 | torch torchvision torchaudio opencv-python numpy pandas networkx fairscale sentencepiece gradio gpustat h5py progressbar2 lmdb jsonlines easydict tensorboard ipykernel
14 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple deepspeed
15 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple git+https://github.com/csuhan/timm_0_3_2.git
16 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple git+https://github.com/openai/CLIP.git
17 |
18 | # Install latest cmake
19 | sudo wget -O cmake-3.27.4-linux-x86_64.sh https://cmake.org/files/v3.27/cmake-3.27.4-linux-x86_64.sh
20 | sudo mkdir /opt/cmake
21 | sudo sh cmake-3.27.4-linux-x86_64.sh --prefix=/opt/cmake --skip-license
22 | sudo ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake
23 |
24 | export PYTHONPATH=/{PATH_TO_SIMULATOR}/Matterport3DSimulator/build
25 |
--------------------------------------------------------------------------------
/engine_finetune.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 |
5 | import torch
6 |
7 | import util.misc as misc
8 | import util.lr_sched as lr_sched
9 |
10 | from llama import LLaMA_adapter
11 |
12 |
13 | def train_one_epoch(
14 | model: LLaMA_adapter,
15 | data_loader: Iterable,
16 | optimizer: torch.optim.Optimizer,
17 | device: torch.device,
18 | epoch: int,
19 | loss_scaler,
20 | log_writer=None,
21 | args=None,
22 | ):
23 | model.train(True)
24 | # model.module.set_default_trainability()
25 |
26 | metric_logger = misc.MetricLogger(delimiter=" ")
27 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
28 | header = "Epoch: [{}]".format(epoch)
29 | print_freq = 10
30 |
31 | accum_iter = args.accum_iter
32 |
33 | optimizer.zero_grad()
34 |
35 | if log_writer is not None:
36 | print("log_dir: {}".format(log_writer.log_dir))
37 |
38 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
39 | # examples, labels, example_mask, imgs
40 | examples = batch["txt_ids"]
41 | labels = batch["txt_labels"]
42 | imgs = batch["hist_img_fts"]
43 | ang_feats = batch["hist_ang_fts"]
44 | pano_img_feats = None
45 | pano_ang_feats = None
46 | if "hist_pano_img_fts" in batch:
47 | pano_img_feats = batch["hist_pano_img_fts"]
48 | pano_ang_feats = batch["hist_pano_ang_fts"]
49 |
50 | ob_img_feats = None
51 | ob_ang_feats = None
52 | # ob_attn_mask = None
53 | ob_nav_types = None
54 | ob_id_seps = None
55 | ob_action_viewindex = None
56 | if "ob_img_fts" in batch:
57 | ob_img_feats = batch["ob_img_fts"]
58 | ob_ang_feats = batch["ob_ang_fts"]
59 | # ob_attn_mask = batch['ob_attn_mask']
60 | ob_nav_types = batch["ob_nav_types"]
61 | ob_id_seps = batch["ob_id_seps"]
62 | ob_action_viewindex = batch["ob_action_viewindex"]
63 |
64 | # we use a per iteration (instead of per epoch) lr scheduler
65 | if data_iter_step % accum_iter == 0:
66 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
67 |
68 | if imgs is not None:
69 | imgs = imgs.to(device, non_blocking=True)
70 | with torch.cuda.amp.autocast():
71 | c_loss, m_loss = model(
72 | examples,
73 | labels,
74 | imgs,
75 | ang_feats,
76 | pano_img_feats,
77 | pano_ang_feats,
78 | ob_img_feats,
79 | ob_ang_feats,
80 | ob_nav_types,
81 | ob_id_seps,
82 | ob_action_viewindex,
83 | )
84 | loss = c_loss + m_loss * 0
85 | loss_value = loss.item()
86 | c_loss_value = c_loss.item()
87 | m_loss_value = m_loss
88 | if not math.isfinite(loss_value):
89 | print("Loss is {}, stopping training".format(loss_value))
90 | sys.exit(1)
91 |
92 | loss /= accum_iter
93 | loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0)
94 | if (data_iter_step + 1) % accum_iter == 0:
95 | optimizer.zero_grad()
96 |
97 | torch.cuda.synchronize()
98 |
99 | metric_logger.update(closs=c_loss_value)
100 | metric_logger.update(mloss=m_loss_value)
101 |
102 | lr = optimizer.param_groups[0]["lr"]
103 | metric_logger.update(lr=lr)
104 |
105 | loss_value_reduce = misc.all_reduce_mean(loss_value)
106 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
107 | m_loss_value_reduce = misc.all_reduce_mean(m_loss_value)
108 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
109 | """We use epoch_1000x as the x-axis in tensorboard.
110 | This calibrates different curves when batch size changes.
111 | """
112 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
113 | log_writer.add_scalar("c_train_loss", c_loss_value_reduce, epoch_1000x)
114 | log_writer.add_scalar("m_train_loss", m_loss_value_reduce, epoch_1000x)
115 | log_writer.add_scalar("lr", lr, epoch_1000x)
116 |
117 | # gather the stats from all processes
118 | metric_logger.synchronize_between_processes()
119 | print("Averaged stats:", metric_logger)
120 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
121 |
122 |
123 | def eval_one_epoch(model: LLaMA_adapter, data_loader: Iterable, device: torch.device, epoch: int, log_writer=None):
124 | model.eval()
125 |
126 | metric_logger = misc.MetricLogger(delimiter=" ")
127 | header = "Epoch: [{}]".format(epoch)
128 | print_freq = 10
129 |
130 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
131 | # examples, labels, example_mask, imgs
132 | examples = batch["txt_ids"]
133 | labels = batch["txt_labels"]
134 | imgs = batch["hist_img_fts"]
135 | ang_feats = batch["hist_ang_fts"]
136 | pano_img_feats = None
137 | pano_ang_feats = None
138 | if "hist_pano_img_fts" in batch:
139 | pano_img_feats = batch["hist_pano_img_fts"]
140 | pano_ang_feats = batch["hist_pano_ang_fts"]
141 |
142 | ob_img_feats = None
143 | ob_ang_feats = None
144 | # ob_attn_mask = None
145 | ob_nav_types = None
146 | ob_id_seps = None
147 | ob_action_viewindex = None
148 | if "ob_img_fts" in batch:
149 | ob_img_feats = batch["ob_img_fts"]
150 | ob_ang_feats = batch["ob_ang_fts"]
151 | # ob_attn_mask = batch['ob_attn_mask']
152 | ob_nav_types = batch["ob_nav_types"]
153 | ob_id_seps = batch["ob_id_seps"]
154 | ob_action_viewindex = batch["ob_action_viewindex"]
155 |
156 | if imgs is not None:
157 | imgs = imgs.to(device, non_blocking=True)
158 | with torch.no_grad():
159 | with torch.cuda.amp.autocast():
160 | c_loss, m_loss = model(
161 | examples,
162 | labels,
163 | imgs,
164 | ang_feats,
165 | pano_img_feats,
166 | pano_ang_feats,
167 | ob_img_feats,
168 | ob_ang_feats,
169 | ob_nav_types,
170 | ob_id_seps,
171 | ob_action_viewindex,
172 | )
173 | loss = c_loss + m_loss * 0
174 | loss_value = loss.item()
175 | c_loss_value = c_loss.item()
176 | m_loss_value = m_loss
177 |
178 | if not math.isfinite(loss_value):
179 | print("Loss is {}, stopping training".format(loss_value))
180 | sys.exit(1)
181 |
182 | torch.cuda.synchronize()
183 |
184 | metric_logger.update(closs=c_loss_value)
185 | metric_logger.update(mloss=m_loss_value)
186 |
187 | loss_value_reduce = misc.all_reduce_mean(loss_value)
188 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
189 | m_loss_value_reduce = misc.all_reduce_mean(m_loss_value)
190 | if log_writer is not None:
191 | """We use epoch_1000x as the x-axis in tensorboard.
192 | This calibrates different curves when batch size changes.
193 | """
194 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
195 | if "ob_img_fts" in batch:
196 | log_writer.add_scalar("c_val_loss_sap", c_loss_value_reduce, epoch_1000x)
197 | log_writer.add_scalar("m_val_loss_sap", m_loss_value_reduce, epoch_1000x)
198 | else:
199 | log_writer.add_scalar("c_val_loss_itm", c_loss_value_reduce, epoch_1000x)
200 | log_writer.add_scalar("m_val_loss_itm", m_loss_value_reduce, epoch_1000x)
201 |
202 | # gather the stats from all processes
203 | metric_logger.synchronize_between_processes()
204 | print("Averaged stats:", metric_logger)
205 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
206 |
207 |
208 | def train_one_epoch_img(
209 | model: LLaMA_adapter,
210 | data_loader: Iterable,
211 | optimizer: torch.optim.Optimizer,
212 | device: torch.device,
213 | epoch: int,
214 | loss_scaler,
215 | log_writer=None,
216 | args=None,
217 | ):
218 | model.train(True)
219 | # model.module.set_default_trainability()
220 |
221 | metric_logger = misc.MetricLogger(delimiter=" ")
222 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
223 | header = "Epoch: [{}]".format(epoch)
224 | print_freq = 10
225 |
226 | accum_iter = args.accum_iter
227 |
228 | optimizer.zero_grad()
229 |
230 | if log_writer is not None:
231 | print("log_dir: {}".format(log_writer.log_dir))
232 |
233 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
234 | examples, labels, example_mask, imgs, gt_id, ori_prompt, gt_caption = batch
235 |
236 | # we use a per iteration (instead of per epoch) lr scheduler
237 | if data_iter_step % accum_iter == 0:
238 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
239 |
240 | if imgs is not None:
241 | imgs = imgs.to(device, non_blocking=True)
242 | with torch.cuda.amp.autocast():
243 | c_loss, m_loss = model(
244 | examples,
245 | labels,
246 | imgs,
247 | )
248 | loss = c_loss + m_loss * 0
249 | loss_value = loss.item()
250 | c_loss_value = c_loss.item()
251 | m_loss_value = m_loss
252 | if not math.isfinite(loss_value):
253 | print("Loss is {}, stopping training".format(loss_value))
254 | sys.exit(1)
255 |
256 | loss /= accum_iter
257 | loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0)
258 | if (data_iter_step + 1) % accum_iter == 0:
259 | optimizer.zero_grad()
260 |
261 | torch.cuda.synchronize()
262 |
263 | metric_logger.update(closs=c_loss_value)
264 | metric_logger.update(mloss=m_loss_value)
265 |
266 | lr = optimizer.param_groups[0]["lr"]
267 | metric_logger.update(lr=lr)
268 |
269 | loss_value_reduce = misc.all_reduce_mean(loss_value)
270 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
271 | m_loss_value_reduce = misc.all_reduce_mean(m_loss_value)
272 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
273 | """We use epoch_1000x as the x-axis in tensorboard.
274 | This calibrates different curves when batch size changes.
275 | """
276 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
277 | log_writer.add_scalar("c_train_loss", c_loss_value_reduce, epoch_1000x)
278 | log_writer.add_scalar("m_train_loss", m_loss_value_reduce, epoch_1000x)
279 | log_writer.add_scalar("lr", lr, epoch_1000x)
280 |
281 | # gather the stats from all processes
282 | metric_logger.synchronize_between_processes()
283 | print("Averaged stats:", metric_logger)
284 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
285 |
286 |
287 | def eval_one_epoch_img(model: LLaMA_adapter, data_loader: Iterable, device: torch.device, epoch: int, log_writer=None):
288 | model.eval()
289 |
290 | metric_logger = misc.MetricLogger(delimiter=" ")
291 | header = "Epoch: [{}]".format(epoch)
292 | print_freq = 10
293 |
294 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
295 | examples, labels, example_mask, imgs, gt_id, ori_prompt, gt_caption = batch
296 |
297 | if imgs is not None:
298 | imgs = imgs.to(device, non_blocking=True)
299 | with torch.no_grad():
300 | with torch.cuda.amp.autocast():
301 | c_loss, m_loss = model(
302 | examples,
303 | labels,
304 | imgs
305 | )
306 | loss = c_loss + m_loss * 0
307 | loss_value = loss.item()
308 | c_loss_value = c_loss.item()
309 | m_loss_value = m_loss
310 |
311 | if not math.isfinite(loss_value):
312 | print("Loss is {}, stopping training".format(loss_value))
313 | sys.exit(1)
314 |
315 | torch.cuda.synchronize()
316 |
317 | metric_logger.update(closs=c_loss_value)
318 | metric_logger.update(mloss=m_loss_value)
319 |
320 | loss_value_reduce = misc.all_reduce_mean(loss_value)
321 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
322 | m_loss_value_reduce = misc.all_reduce_mean(m_loss_value)
323 | if log_writer is not None:
324 | """We use epoch_1000x as the x-axis in tensorboard.
325 | This calibrates different curves when batch size changes.
326 | """
327 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
328 | log_writer.add_scalar("c_val_loss", c_loss_value_reduce, epoch_1000x)
329 | log_writer.add_scalar("m_val_loss", m_loss_value_reduce, epoch_1000x)
330 |
331 | # gather the stats from all processes
332 | metric_logger.synchronize_between_processes()
333 | print("Averaged stats:", metric_logger)
334 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
335 |
--------------------------------------------------------------------------------
/eval_speaker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 |
6 | from llama import Tokenizer
7 | from util.bleu import compute_bleu
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser('Speaker Evaluator', add_help=False)
12 | parser.add_argument('--ckpt_dir', default='results', type=str)
13 |
14 | args = parser.parse_args()
15 | return args
16 |
17 |
18 | def eval_speaker(input_path):
19 | json_path = os.path.join(input_path, 'id2path.json')
20 | with open(json_path, 'r') as f:
21 | id2path = json.load(f)
22 |
23 | # SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') # Split on any non-alphanumeric character
24 | tokenizer = Tokenizer('/root/mount/LLaMA-7B/tokenizer.model')
25 |
26 | refs = []
27 | candidates = []
28 | for pair in id2path.values():
29 | gt_sentence_list = pair['gt']
30 | gt_list = []
31 | for sentence in gt_sentence_list:
32 | # gt_list.append([s.strip().lower() for s in SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0])
33 | gt_list.append(tokenizer.encode(sentence, bos=False, eos=False))
34 | refs.append(gt_list)
35 |
36 | inference_sentence = pair['inference']
37 | # inference_list = [s.strip().lower() for s in SENTENCE_SPLIT_REGEX.split(inference_sentence.strip()) if len(s.strip()) > 0]
38 | inference_list = tokenizer.encode(inference_sentence, bos=False, eos=False)
39 | candidates.append(inference_list)
40 |
41 | tup = compute_bleu(refs, candidates, smooth=False)
42 | bleu_score = tup[0]
43 | precisions = tup[1]
44 | print(f'Bleu: {bleu_score:.4f}')
45 | print("Bleu 1: %0.4f Bleu 2: %0.4f, Bleu 3 :%0.4f, Bleu 4: %0.4f" % tuple(precisions))
46 |
47 |
48 | if __name__ == '__main__':
49 | args = parse_args()
50 | eval_speaker(args.ckpt_dir)
51 |
--------------------------------------------------------------------------------
/exps/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 | # ori bs 16 max words 384
3 | # rxr bs 4 max words 1000
4 |
5 | LLAMA_PATH="$1"
6 | PRETRAINED_PATH="$2" # path to pre-trained checkpoint
7 | CONFIG="$3"
8 | OUTPUT_DIR="$4"
9 |
10 | mkdir -p "$OUTPUT_DIR"
11 |
12 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
13 | python3 -u -m torch.distributed.launch --master_port=1112 --nproc_per_node=4 --use_env \
14 | main_finetune.py --data_config "$CONFIG" --batch_size 16 --max_words 384 \
15 | --epochs 20 --warmup_epochs 2 --blr 1e-4 --weight_decay 0.02 \
16 | --llama_path "$LLAMA_PATH" \
17 | --output_dir "$OUTPUT_DIR" \
18 | --pretrained_path "$PRETRAINED_PATH" \
19 | &>> "$OUTPUT_DIR"/output.log &
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import gradio as gr
3 | import torch
4 | from PIL import Image
5 |
6 | import llama
7 |
8 |
9 | device = "cuda" if torch.cuda.is_available() else "cpu"
10 |
11 | llama_dir = "/path/to/LLaMA/"
12 |
13 | model, preprocess = llama.load("BIAS-7B", llama_dir, device)
14 | model.half()
15 | model.eval()
16 |
17 | def multi_modal_generate(
18 | img_path: str,
19 | prompt: str,
20 | max_gen_len=256,
21 | temperature: float = 0.1,
22 | top_p: float = 0.75,
23 | ):
24 | try:
25 | img = Image.fromarray(cv2.imread(img_path))
26 | except:
27 | return ""
28 |
29 | img = preprocess(img).unsqueeze(0).half().to(device)
30 | prompt = llama.format_prompt(prompt)
31 |
32 | result = model.generate(img, [prompt],
33 | max_gen_len=max_gen_len,
34 | temperature=temperature,
35 | top_p=top_p)
36 | print(result[0])
37 | return result[0]
38 |
39 |
40 | def create_multi_modal_demo():
41 | with gr.Blocks() as instruct_demo:
42 | with gr.Row():
43 | with gr.Column():
44 | img = gr.Image(label='Input', type='filepath')
45 | question = gr.Textbox(lines=2, label="Prompt")
46 | max_len = gr.Slider(minimum=1, maximum=512,
47 | value=256, label="Max length")
48 | with gr.Accordion(label='Advanced options', open=False):
49 | temp = gr.Slider(minimum=0, maximum=1,
50 | value=0.1, label="Temperature")
51 | top_p = gr.Slider(minimum=0, maximum=1,
52 | value=0.75, label="Top p")
53 |
54 | run_botton = gr.Button("Run")
55 |
56 | with gr.Column():
57 | outputs = gr.Textbox(lines=10, label="Output")
58 |
59 | inputs = [img, question, max_len, temp, top_p]
60 |
61 | examples = [
62 | ["../docs/logo_v1.png", "Please introduce this painting.", 256, 0.1, 0.75],
63 | ]
64 |
65 | gr.Examples(
66 | examples=examples,
67 | inputs=inputs,
68 | outputs=outputs,
69 | fn=multi_modal_generate,
70 | cache_examples=False
71 | )
72 | run_botton.click(fn=multi_modal_generate,
73 | inputs=inputs, outputs=outputs)
74 | return instruct_demo
75 |
76 |
77 | description = """
78 | # LLaMA-Adapter V2🚀
79 | The official demo for **LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model**.
80 |
81 | Please refer to our [arXiv paper](https://arxiv.org/abs/2304.15010) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
82 |
83 | The demo for **LLaMA-Adapter V1** is available at: [Huggingface Spaces](https://huggingface.co/spaces/csuhan/LLaMA-Adapter).
84 | """
85 |
86 | with gr.Blocks(css="h1,p {text-align: center;}") as demo:
87 | gr.Markdown(description)
88 | with gr.TabItem("Multi-Modal Interaction"):
89 | create_multi_modal_demo()
90 |
91 | demo.queue(api_open=True, concurrency_count=1).launch(share=True)
92 |
--------------------------------------------------------------------------------
/images/c-instructor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/images/c-instructor.png
--------------------------------------------------------------------------------
/landmark/extract_landmark_r2r.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import h5py
4 | import jsonlines
5 | import numpy as np
6 | import stanza
7 | from tqdm import tqdm
8 |
9 |
10 | scanvp_cands_file = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/scanvp_candview_relangles.json'
11 | bboxes_file = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json'
12 | img_ft_file = '/data/user/kxh/instructllm/Matterport3DSimulator/img_features/vit_l_14_clip.hdf5'
13 | img_feature_store = {}
14 |
15 |
16 | def extract_landmark_lang(nlp_pipeline, input_file, output_file):
17 | ignore_txts = ['turn', 'left', 'right', 'top', 'bottom', 'front', 'back', 'end', 'level', 'stop', 'exit', 'room', 'way', 'one', 'area']
18 | with jsonlines.open(input_file, 'r') as reader:
19 | with jsonlines.open(output_file, 'w') as writer:
20 | for item in tqdm(reader):
21 | del item['instr_encodings']
22 |
23 | in_docs = [stanza.Document([], text=instr) for instr in item['instructions']]
24 | out_docs = nlp_pipeline(in_docs)
25 | item['landmarks'] = []
26 | for out_doc in out_docs:
27 | doc_landmarks = set()
28 | for sent in out_doc.sentences:
29 | for word in sent.words:
30 | if word.upos == 'NOUN' and len(word.lemma) > 1 and word.lemma not in ignore_txts:
31 | doc_landmarks.add(word.lemma)
32 | doc_landmarks = list(doc_landmarks)
33 | item['landmarks'].append(doc_landmarks)
34 | # item = {'landmarks': item['landmarks']}
35 | writer.write(item)
36 |
37 |
38 | def get_image_feature(scan, viewpoint):
39 | key = f"{scan}_{viewpoint}"
40 | if key in img_feature_store:
41 | fts = img_feature_store[key]
42 | else:
43 | with h5py.File(img_ft_file, "r") as f:
44 | fts = f[key][...].astype(np.float32)
45 | fts = fts / np.linalg.norm(fts, axis=1, keepdims=True)
46 | img_feature_store[key] = fts
47 | return fts
48 |
49 |
50 | def get_scan2vp2obj():
51 | scan2vp2obj = {}
52 | with open(bboxes_file, 'r') as f:
53 | bbox_data = json.load(f)
54 | for scanvp, value in bbox_data.items():
55 | scan, vp = scanvp.split("_")
56 | if scan not in scan2vp2obj:
57 | scan2vp2obj[scan] = {}
58 | if vp not in scan2vp2obj[scan]:
59 | scan2vp2obj[scan][vp] = []
60 | for objinfo in value.values():
61 | if objinfo["visible_pos"]:
62 | append_objinfo = {"name": objinfo["name"].replace("#", " "), "visible_pos": objinfo["visible_pos"]}
63 | scan2vp2obj[scan][vp].append(append_objinfo)
64 | return scan2vp2obj
65 |
66 |
67 | def extract_landmark_vis(input_file, output_file):
68 | with open(scanvp_cands_file, 'r') as f:
69 | scanvp_cands = json.load(f)
70 |
71 | scan2vp2obj = get_scan2vp2obj()
72 |
73 | with jsonlines.open(input_file, 'r') as reader:
74 | with jsonlines.open(output_file, 'w') as writer:
75 | for item in tqdm(reader):
76 | scan = item['scan']
77 | vp2obj = scan2vp2obj[scan]
78 | path_len = len(item['path'])
79 | visual_landmarks = {}
80 | for i in range(path_len - 1):
81 | cur_vp = item['path'][i]
82 | next_vp = item['path'][i + 1]
83 | cur_fts = get_image_feature(scan, cur_vp)
84 | next_fts = get_image_feature(scan, next_vp)
85 |
86 | scanvp_cur = scan + '_' + cur_vp
87 | cands = scanvp_cands[scanvp_cur]
88 | non_cand_vp_nums = []
89 | for cand_id, cand_value in cands.items():
90 | if cand_id == next_vp:
91 | cand_vp_num = cand_value[0]
92 | else:
93 | non_cand_vp_nums.append(cand_value[0])
94 |
95 | cand_objs = {}
96 | non_cand_objs = {}
97 | for obj_info in vp2obj[cur_vp]:
98 | obj_name = obj_info['name']
99 | if cand_vp_num in obj_info['visible_pos']:
100 | cand_objs[obj_name] = 1
101 | cand_vp_fts = cur_fts[cand_vp_num]
102 | for non_cand_vp_num in non_cand_vp_nums:
103 | if non_cand_vp_num in obj_info['visible_pos']:
104 | non_cand_vp_fts = cur_fts[non_cand_vp_num]
105 | feat_sim = (1 - np.dot(cand_vp_fts, non_cand_vp_fts)) * 2
106 | if obj_name not in non_cand_objs:
107 | non_cand_objs[obj_name] = feat_sim
108 | else:
109 | non_cand_objs[obj_name] += feat_sim
110 | for obj_name in cand_objs:
111 | if obj_name in non_cand_objs:
112 | cand_objs[obj_name] -= non_cand_objs[obj_name]
113 |
114 | cur_fts_mean = np.mean(cur_fts, axis=0)
115 | cur_fts_mean_norm = cur_fts_mean / np.linalg.norm(cur_fts_mean)
116 | next_fts_mean = np.mean(next_fts, axis=0)
117 | next_fts_mean_norm = next_fts_mean / np.linalg.norm(next_fts_mean)
118 | feat_sim = np.dot(cur_fts_mean_norm, next_fts_mean_norm)
119 | feat_coeff = (1 - feat_sim) * 50
120 | if obj_name in visual_landmarks:
121 | visual_landmarks[obj_name] += cand_objs[obj_name] * feat_coeff
122 | else:
123 | visual_landmarks[obj_name] = cand_objs[obj_name] * feat_coeff
124 |
125 | item['visual_landmarks'] = [obj_name for obj_name, score in visual_landmarks.items() if score > 0.25]
126 | # item = {'visual_landmarks': visual_landmarks}
127 | # item = {'visual_landmarks': item['visual_landmarks']}
128 | writer.write(item)
129 |
130 |
131 | if __name__ == '__main__':
132 | splits = ['train', 'val_seen', 'val_unseen', 'train_prevalent_generated']
133 | input_files = [split + '.jsonl' for split in splits]
134 | output_files = [split + '_landmark.jsonl' for split in splits]
135 | output_files_vis = [split + '_landmark_vis_score.jsonl' for split in splits]
136 |
137 | # nlp_pipeline = stanza.Pipeline('en', processors='tokenize,pos,lemma')
138 |
139 | # for input_file, output_file in zip(input_files, output_files):
140 | # extract_landmark_lang(nlp_pipeline, input_file, output_file)
141 |
142 | for input_file, output_file in zip(output_files, output_files_vis):
143 | extract_landmark_vis(input_file, output_file)
144 |
--------------------------------------------------------------------------------
/landmark/extract_landmark_reverie.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import jsonlines
4 | import stanza
5 | from tqdm import tqdm
6 |
7 |
8 | scanvp_cands_file = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/pretrain/scanvp_candview_relangles.json'
9 | bboxes_file = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json'
10 |
11 |
12 | def extract_landmark_lang(nlp_pipeline, input_file, input_ori_file, output_file):
13 | # with open(input_ori_file, 'r') as f:
14 | # ori_data = json.load(f)
15 | # path_id_to_instr_l = {item['path_id']: item['instructions_l'] for item in ori_data}
16 |
17 | ignore_txts = ['turn', 'left', 'right', 'top', 'bottom', 'front', 'back', 'end', 'level', 'stop', 'exit', 'room', 'way', 'one', 'area']
18 | with jsonlines.open(input_file, 'r') as reader:
19 | with jsonlines.open(output_file, 'w') as writer:
20 | for item in tqdm(reader):
21 | in_docs = [stanza.Document([], text=instr) for instr in item['instructions']]
22 | out_docs = nlp_pipeline(in_docs)
23 | item['landmarks'] = []
24 | for out_doc in out_docs:
25 | doc_landmarks = set()
26 | for sent in out_doc.sentences:
27 | for word in sent.words:
28 | if word.upos == 'NOUN' and len(word.lemma) > 1 and word.lemma not in ignore_txts:
29 | doc_landmarks.add(word.lemma)
30 | doc_landmarks = list(doc_landmarks)
31 | item['landmarks'].append(doc_landmarks)
32 | # item = {'landmarks': item['landmarks']}
33 | writer.write(item)
34 |
35 |
36 | def extract_landmark_vis(input_file, output_file):
37 | with open(scanvp_cands_file, 'r') as f:
38 | scanvp_cands = json.load(f)
39 |
40 | scan2vp2obj = {}
41 | with open(bboxes_file, 'r') as f:
42 | bbox_data = json.load(f)
43 | for scanvp, value in bbox_data.items():
44 | scan, vp = scanvp.split("_")
45 | if scan not in scan2vp2obj:
46 | scan2vp2obj[scan] = {}
47 | if vp not in scan2vp2obj[scan]:
48 | scan2vp2obj[scan][vp] = []
49 | for objinfo in value.values():
50 | if objinfo["visible_pos"]:
51 | append_objinfo = {"name": objinfo["name"].replace("#", " "), "visible_pos": objinfo["visible_pos"]}
52 | scan2vp2obj[scan][vp].append(append_objinfo)
53 |
54 | with jsonlines.open(input_file, 'r') as reader:
55 | with jsonlines.open(output_file, 'w') as writer:
56 | for item in tqdm(reader):
57 | scan = item['scan']
58 | vp2obj = scan2vp2obj[scan]
59 | path_len = len(item['path'])
60 | item['visual_landmarks'] = set()
61 | for i in range(path_len - 1):
62 | cur_vp = item['path'][i]
63 | next_vp = item['path'][i + 1]
64 | scanvp_cur = scan + '_' + cur_vp
65 |
66 | cands = scanvp_cands[scanvp_cur]
67 | non_cand_vp_nums = set()
68 | for cand_id, cand_value in cands.items():
69 | if cand_id == next_vp:
70 | cand_vp_num = cand_value[0]
71 | else:
72 | non_cand_vp_nums.add(cand_value[0])
73 |
74 | cand_obj_names = set()
75 | non_cand_obj_names = set()
76 | for obj_info in vp2obj[cur_vp]:
77 | obj_name = obj_info['name']
78 | if cand_vp_num in obj_info['visible_pos']:
79 | cand_obj_names.add(obj_name)
80 | elif non_cand_vp_nums.intersection(set(obj_info['visible_pos'])):
81 | non_cand_obj_names.add(obj_name)
82 | cand_obj_names -= non_cand_obj_names
83 | item['visual_landmarks'] |= cand_obj_names
84 | item['visual_landmarks'] = list(item['visual_landmarks'])
85 | # item = {'visual_landmarks': item['visual_landmarks']}
86 | writer.write(item)
87 |
88 |
89 | if __name__ == '__main__':
90 | splits = ['train', 'val_seen', 'val_unseen']
91 | input_files = [split + '.jsonl' for split in splits]
92 | input_ori_files = ['../REVERIE_' + split + '.json' for split in splits]
93 | output_files = [split + '_landmark.jsonl' for split in splits]
94 | output_files_vis = [split + '_landmark_vis.jsonl' for split in splits]
95 |
96 | # nlp_pipeline = stanza.Pipeline('en', processors='tokenize,pos,lemma')
97 |
98 | # for input_file, input_ori_file, output_file in zip(input_files, input_ori_files, output_files):
99 | # extract_landmark_lang(nlp_pipeline, input_file, input_ori_file, output_file)
100 |
101 | for input_file, output_file in zip(output_files, output_files_vis):
102 | extract_landmark_vis(input_file, output_file)
103 |
--------------------------------------------------------------------------------
/landmark/extract_landmark_rxr.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import h5py
4 | import jsonlines
5 | import numpy as np
6 | import stanza
7 | from tqdm import tqdm
8 |
9 |
10 | scanvp_cands_file = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/pretrain/scanvp_candview_relangles.json'
11 | bboxes_file = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/REVERIE/data/BBoxes.json'
12 | img_ft_file = '/data/user/kxh/instructllm/Matterport3DSimulator/img_features/vit_l_14_clip.hdf5'
13 | img_feature_store = {}
14 |
15 |
16 | def extract_landmark_lang(nlp_pipeline, input_file, input_ori_file, output_file):
17 | with jsonlines.open(input_ori_file, 'r') as reader:
18 | id2instr = {item['instruction_id']: item['instruction'] for item in reader}
19 |
20 | ignore_txts = ['turn', 'left', 'right', 'top', 'bottom', 'front', 'back', 'end', 'level', 'stop', 'exit', 'room', 'way', 'one', 'area']
21 | with jsonlines.open(input_file, 'r') as reader:
22 | with jsonlines.open(output_file, 'w') as writer:
23 | for item in tqdm(reader):
24 | instr_ids_en = [instr_id for instr_id in item['instr_ids'] if instr_id in id2instr]
25 | if len(instr_ids_en) == 0:
26 | continue
27 |
28 | item['instr_ids'] = instr_ids_en
29 | item['instructions'] = [id2instr[instr_id] for instr_id in instr_ids_en]
30 |
31 | in_docs = [stanza.Document([], text=instr) for instr in item['instructions']]
32 | out_docs = nlp_pipeline(in_docs)
33 | item['landmarks'] = []
34 | for out_doc in out_docs:
35 | doc_landmarks = set()
36 | for sent in out_doc.sentences:
37 | for word in sent.words:
38 | if word.upos == 'NOUN' and len(word.lemma) > 1 and word.lemma not in ignore_txts:
39 | doc_landmarks.add(word.lemma)
40 | doc_landmarks = list(doc_landmarks)
41 | item['landmarks'].append(doc_landmarks)
42 | del item['instr_encodings']
43 | # item = {'landmarks': item['landmarks']}
44 | writer.write(item)
45 |
46 |
47 | def get_image_feature(scan, viewpoint):
48 | key = f"{scan}_{viewpoint}"
49 | if key in img_feature_store:
50 | fts = img_feature_store[key]
51 | else:
52 | with h5py.File(img_ft_file, "r") as f:
53 | fts = f[key][...].astype(np.float32)
54 | fts = fts / np.linalg.norm(fts, axis=1, keepdims=True)
55 | img_feature_store[key] = fts
56 | return fts
57 |
58 |
59 | def get_scan2vp2obj():
60 | scan2vp2obj = {}
61 | with open(bboxes_file, 'r') as f:
62 | bbox_data = json.load(f)
63 | for scanvp, value in bbox_data.items():
64 | scan, vp = scanvp.split("_")
65 | if scan not in scan2vp2obj:
66 | scan2vp2obj[scan] = {}
67 | if vp not in scan2vp2obj[scan]:
68 | scan2vp2obj[scan][vp] = []
69 | for objinfo in value.values():
70 | if objinfo["visible_pos"]:
71 | append_objinfo = {"name": objinfo["name"].replace("#", " "), "visible_pos": objinfo["visible_pos"]}
72 | scan2vp2obj[scan][vp].append(append_objinfo)
73 | return scan2vp2obj
74 |
75 |
76 | def extract_landmark_vis(input_file, output_file):
77 | with open(scanvp_cands_file, 'r') as f:
78 | scanvp_cands = json.load(f)
79 |
80 | scan2vp2obj = get_scan2vp2obj()
81 |
82 | with jsonlines.open(input_file, 'r') as reader:
83 | with jsonlines.open(output_file, 'w') as writer:
84 | for item in tqdm(reader):
85 | scan = item['scan']
86 | vp2obj = scan2vp2obj[scan]
87 | path_len = len(item['path'])
88 | visual_landmarks = {}
89 | for i in range(path_len - 1):
90 | cur_vp = item['path'][i]
91 | next_vp = item['path'][i + 1]
92 | cur_fts = get_image_feature(scan, cur_vp)
93 | next_fts = get_image_feature(scan, next_vp)
94 |
95 | scanvp_cur = scan + '_' + cur_vp
96 | cands = scanvp_cands[scanvp_cur]
97 | non_cand_vp_nums = []
98 | for cand_id, cand_value in cands.items():
99 | if cand_id == next_vp:
100 | cand_vp_num = cand_value[0]
101 | else:
102 | non_cand_vp_nums.append(cand_value[0])
103 |
104 | cand_objs = {}
105 | non_cand_objs = {}
106 | for obj_info in vp2obj[cur_vp]:
107 | obj_name = obj_info['name']
108 | if cand_vp_num in obj_info['visible_pos']:
109 | cand_objs[obj_name] = 1
110 | cand_vp_fts = cur_fts[cand_vp_num]
111 | for non_cand_vp_num in non_cand_vp_nums:
112 | if non_cand_vp_num in obj_info['visible_pos']:
113 | non_cand_vp_fts = cur_fts[non_cand_vp_num]
114 | feat_sim = (1 - np.dot(cand_vp_fts, non_cand_vp_fts)) * 2
115 | if obj_name not in non_cand_objs:
116 | non_cand_objs[obj_name] = feat_sim
117 | else:
118 | non_cand_objs[obj_name] += feat_sim
119 | for obj_name in cand_objs:
120 | if obj_name in non_cand_objs:
121 | cand_objs[obj_name] -= non_cand_objs[obj_name]
122 |
123 | cur_fts_mean = np.mean(cur_fts, axis=0)
124 | cur_fts_mean_norm = cur_fts_mean / np.linalg.norm(cur_fts_mean)
125 | next_fts_mean = np.mean(next_fts, axis=0)
126 | next_fts_mean_norm = next_fts_mean / np.linalg.norm(next_fts_mean)
127 | feat_sim = np.dot(cur_fts_mean_norm, next_fts_mean_norm)
128 | feat_coeff = (1 - feat_sim) * 50
129 | if obj_name in visual_landmarks:
130 | visual_landmarks[obj_name] += cand_objs[obj_name] * feat_coeff
131 | else:
132 | visual_landmarks[obj_name] = cand_objs[obj_name] * feat_coeff
133 |
134 | item['visual_landmarks'] = [obj_name for obj_name, score in visual_landmarks.items() if score > 0.25]
135 | # item = {'visual_landmarks': visual_landmarks}
136 | # item = {'visual_landmarks': item['visual_landmarks']}
137 | writer.write(item)
138 |
139 |
140 | if __name__ == '__main__':
141 | splits = ['train', 'val_seen', 'val_unseen']
142 | input_files = [f'rxr_{split}_guide_enc_xlmr.jsonl' for split in splits]
143 | input_ori_files = [f'../rxr_{split}_guide_enc_xlmr_en.jsonl' for split in splits]
144 | output_files = [f'rxr_{split}_guide_landmark.jsonl' for split in splits]
145 | output_files_vis = [f'rxr_{split}_guide_landmark_vis_score.jsonl' for split in splits]
146 |
147 | # nlp_pipeline = stanza.Pipeline('en', processors='tokenize,pos,lemma')
148 |
149 | # for input_file, input_ori_file, output_file in zip(input_files, input_ori_files, output_files):
150 | # extract_landmark_lang(nlp_pipeline, input_file, input_ori_file, output_file)
151 |
152 | for input_file, output_file in zip(output_files, output_files_vis):
153 | extract_landmark_vis(input_file, output_file)
154 |
--------------------------------------------------------------------------------
/landmark/select_eng_rxr.py:
--------------------------------------------------------------------------------
1 | import jsonlines
2 | from tqdm import tqdm
3 |
4 |
5 | def compute_max_len(input_file):
6 | with jsonlines.open(input_file, 'r') as reader:
7 | max_len = 0
8 | for item in tqdm(reader):
9 | max_len = max(max_len, len(item['instruction'].split()))
10 | return max_len
11 |
12 |
13 | def process(input_file, output_file):
14 | with jsonlines.open(input_file, 'r') as reader:
15 | with jsonlines.open(output_file, 'w') as writer:
16 | for item in tqdm(reader):
17 | if item['language'].startswith('en'):
18 | writer.write(item)
19 |
20 |
21 | if __name__ == '__main__':
22 | splits = ['train', 'val_train_seen', 'val_seen', 'val_unseen']
23 | input_files = [f'rxr_{split}_guide_enc_xlmr.jsonl' for split in splits]
24 | output_files = [f'rxr_{split}_guide_enc_xlmr_en.jsonl' for split in splits]
25 |
26 | for input_file, output_file in zip(input_files, output_files):
27 | process(input_file, output_file)
28 |
--------------------------------------------------------------------------------
/llama/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama import ModelArgs, Transformer
2 | from .tokenizer import Tokenizer
3 | from .llama_adapter import *
4 | from .utils import format_prompt
--------------------------------------------------------------------------------
/llama/llama.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Optional, Tuple
5 | from dataclasses import dataclass
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import Embedding, Linear
11 | import torch.nn.functional as F
12 |
13 |
14 | @dataclass
15 | class ModelArgs:
16 | dim: int = 512
17 | n_layers: int = 8
18 | n_heads: int = 8
19 | vocab_size: int = -1 # defined later by tokenizer
20 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
21 | norm_eps: float = 1e-5
22 |
23 | max_batch_size: int = 32
24 | max_seq_len: int = 2048
25 |
26 | w_bias: bool = False # use bias tuning
27 | w_lora: bool = False # use lora tuning
28 | lora_rank: int = 16
29 | w_new_gate: bool = False # for compatibility
30 |
31 |
32 | class RMSNorm(torch.nn.Module):
33 | def __init__(self, dim: int, eps: float = 1e-6):
34 | super().__init__()
35 | self.eps = eps
36 | self.weight = nn.Parameter(torch.ones(dim))
37 |
38 | def _norm(self, x):
39 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
40 |
41 | def forward(self, x):
42 | output = self._norm(x.float()).type_as(x)
43 | return output * self.weight
44 |
45 |
46 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
47 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
48 | t = torch.arange(end, device=freqs.device) # type: ignore
49 | freqs = torch.outer(t, freqs).float() # type: ignore
50 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
51 | return freqs_cis
52 |
53 |
54 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
55 | ndim = x.ndim
56 | assert 0 <= 1 < ndim
57 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
58 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
59 | return freqs_cis.view(*shape)
60 |
61 |
62 | def apply_rotary_emb(
63 | xq: torch.Tensor,
64 | xk: torch.Tensor,
65 | freqs_cis: torch.Tensor,
66 | ) -> Tuple[torch.Tensor, torch.Tensor]:
67 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
68 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
69 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
70 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
71 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
72 | return xq_out.type_as(xq), xk_out.type_as(xk)
73 |
74 |
75 | class Attention(nn.Module):
76 | def __init__(self, args: ModelArgs):
77 | super().__init__()
78 | self.args = args
79 |
80 | self.n_local_heads = args.n_heads
81 | self.head_dim = args.dim // args.n_heads
82 |
83 | self.wq = Linear(
84 | args.dim,
85 | args.n_heads * self.head_dim,
86 | bias=args.w_bias
87 | )
88 | self.wk = Linear(
89 | args.dim,
90 | args.n_heads * self.head_dim,
91 | bias=False
92 | )
93 | self.wv = Linear(
94 | args.dim,
95 | args.n_heads * self.head_dim,
96 | bias=False
97 | )
98 | self.wo = Linear(
99 | args.n_heads * self.head_dim,
100 | args.dim,
101 | bias=args.w_bias
102 | )
103 | if args.w_bias:
104 | nn.init.constant_(self.wq.bias.data, 0)
105 | nn.init.constant_(self.wo.bias.data, 0)
106 |
107 | self.w_lora = args.w_lora
108 | if args.w_lora:
109 | self.lora_wq_l1 = Linear(args.dim, args.lora_rank, bias=False)
110 | self.lora_wq_l2 = Linear(args.lora_rank, args.dim, bias=False)
111 |
112 | self.lora_wk_l1 = Linear(args.dim, args.lora_rank, bias=False)
113 | self.lora_wk_l2 = Linear(args.lora_rank, args.dim, bias=False)
114 |
115 | self.lora_wv_l1 = Linear(args.dim, args.lora_rank, bias=False)
116 | self.lora_wv_l2 = Linear(args.lora_rank, args.dim, bias=False)
117 |
118 | self.lora_wo_l1 = Linear(args.dim, args.lora_rank, bias=False)
119 | self.lora_wo_l2 = Linear(args.lora_rank, args.dim, bias=False)
120 | nn.init.constant_(self.lora_wq_l2.weight.data, 0)
121 | nn.init.constant_(self.lora_wk_l2.weight.data, 0)
122 | nn.init.constant_(self.lora_wv_l2.weight.data, 0)
123 | nn.init.constant_(self.lora_wo_l2.weight.data, 0)
124 |
125 | self.cache_k = None
126 | self.cache_v = None
127 |
128 | self.gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
129 |
130 | # self.ob_gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
131 |
132 | self.w_new_gate = args.w_new_gate
133 | if args.w_new_gate:
134 | self.new_gate = torch.nn.Parameter(torch.ones(1, 1, 1, 1))
135 |
136 |
137 | def train(self, mode: bool = True):
138 | if mode:
139 | self.cache_k = None
140 | self.cache_v = None
141 | else:
142 | self.cache_k = torch.zeros(
143 | (self.args.max_batch_size, self.args.max_seq_len, self.n_local_heads, self.head_dim)
144 | ).cuda()
145 | self.cache_v = torch.zeros(
146 | (self.args.max_batch_size, self.args.max_seq_len, self.n_local_heads, self.head_dim)
147 | ).cuda()
148 | return super().train(mode)
149 |
150 |
151 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
152 | bsz, seqlen, _ = x.shape
153 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
154 | if self.w_lora:
155 | xq = xq + self.lora_wq_l2(self.lora_wq_l1(x))
156 | xk = xk + self.lora_wk_l2(self.lora_wk_l1(x))
157 | xv = xv + self.lora_wv_l2(self.lora_wv_l1(x))
158 |
159 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
160 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
161 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
162 |
163 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
164 |
165 | if not self.training:
166 | self.cache_k = self.cache_k.to(xq)
167 | self.cache_v = self.cache_v.to(xq)
168 |
169 | self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
170 | self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
171 |
172 | keys = self.cache_k[:bsz, : start_pos + seqlen]
173 | values = self.cache_v[:bsz, : start_pos + seqlen]
174 | else:
175 | assert start_pos==0
176 | keys = xk
177 | values = xv
178 |
179 | if adapter is not None:
180 | adapter_len = adapter.shape[1]
181 | adapter_v = self.wv(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
182 | adapter_v = adapter_v.transpose(1, 2)
183 |
184 | if adapter_len > 1:
185 | adapter_k = self.wk(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
186 | adapter_k = adapter_k.transpose(1, 2)
187 |
188 | # if ob_adapter is not None:
189 | # ob_adapter_len = ob_adapter.shape[1]
190 | # ob_adapter_v = self.wv(ob_adapter).view(bsz, ob_adapter_len, self.n_local_heads, self.head_dim)
191 | # ob_adapter_v = ob_adapter_v.transpose(1, 2)
192 |
193 | # ob_adapter_k = self.wk(ob_adapter).view(bsz, ob_adapter_len, self.n_local_heads, self.head_dim)
194 | # ob_adapter_k = ob_adapter_k.transpose(1, 2)
195 |
196 | xq = xq.transpose(1, 2)
197 | keys = keys.transpose(1, 2)
198 | values = values.transpose(1, 2)
199 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
200 |
201 | if mask is not None:
202 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
203 |
204 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
205 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
206 |
207 | if adapter is not None:
208 | if adapter_len > 1:
209 | adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
210 | adapter_scores = self.gate.tanh() * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
211 | if self.w_new_gate:
212 | adapter_scores = self.new_gate * adapter_scores
213 | output = output + torch.matmul(adapter_scores, adapter_v)
214 | else:
215 | output = output + self.gate.tanh() * adapter_v
216 |
217 | # if ob_adapter is not None:
218 | # ob_adapter_scores = torch.matmul(xq, ob_adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
219 | # ob_attn_mask = ob_attn_mask.unsqueeze(1)[:, :, :ob_adapter_scores.shape[2]]
220 | # ob_adapter_scores = ob_adapter_scores * ob_attn_mask
221 | # ob_adapter_scores = F.softmax(ob_adapter_scores.float(), dim=-1).type_as(xq) * ob_attn_mask
222 | # ob_adapter_scores = self.ob_gate.tanh() * ob_adapter_scores
223 | # output = output + torch.matmul(ob_adapter_scores, ob_adapter_v)
224 |
225 | output = output.transpose(
226 | 1, 2
227 | ).contiguous().view(bsz, seqlen, -1)
228 |
229 | if self.w_lora:
230 | return self.wo(output) + self.lora_wo_l2(self.lora_wo_l1(output))
231 | else:
232 | return self.wo(output)
233 |
234 |
235 | class FeedForward(nn.Module):
236 | def __init__(
237 | self,
238 | dim: int,
239 | hidden_dim: int,
240 | multiple_of: int,
241 | args: ModelArgs
242 | ):
243 | super().__init__()
244 | hidden_dim = int(2 * hidden_dim / 3)
245 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
246 |
247 | self.w1 = Linear(
248 | dim, hidden_dim, bias=args.w_bias
249 | )
250 | self.w2 = Linear(
251 | hidden_dim, dim, bias=args.w_bias
252 | )
253 | self.w3 = Linear(
254 | dim, hidden_dim, bias=args.w_bias
255 | )
256 | if args.w_bias:
257 | nn.init.constant_(self.w1.bias.data, 0)
258 | nn.init.constant_(self.w2.bias.data, 0)
259 | nn.init.constant_(self.w3.bias.data, 0)
260 |
261 | self.w_lora = args.w_lora
262 | if args.w_lora:
263 | self.lora_w1_l1 = Linear(dim, args.lora_rank, bias=False)
264 | self.lora_w1_l2 = Linear(args.lora_rank, hidden_dim, bias=False)
265 | self.lora_w2_l1 = Linear(hidden_dim, args.lora_rank, bias=False)
266 | self.lora_w2_l2 = Linear(args.lora_rank, dim, bias=False)
267 | self.lora_w3_l1 = Linear(dim, args.lora_rank, bias=False)
268 | self.lora_w3_l2 = Linear(args.lora_rank, hidden_dim, bias=False)
269 | nn.init.constant_(self.lora_w1_l2.weight.data, 0)
270 | nn.init.constant_(self.lora_w2_l2.weight.data, 0)
271 | nn.init.constant_(self.lora_w3_l2.weight.data, 0)
272 |
273 | def forward(self, x):
274 | if self.w_lora:
275 | out = F.silu(self.w1(x) + self.lora_w1_l2(self.lora_w1_l1(x))) * (self.w3(x) + self.lora_w3_l2(self.lora_w3_l1(x)))
276 | return self.w2(out) + self.lora_w2_l2(self.lora_w2_l1(out))
277 | else:
278 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
279 |
280 |
281 | class TransformerBlock(nn.Module):
282 | def __init__(self, layer_id: int, args: ModelArgs):
283 | super().__init__()
284 | self.n_heads = args.n_heads
285 | self.dim = args.dim
286 | self.head_dim = args.dim // args.n_heads
287 | self.attention = Attention(args)
288 | self.feed_forward = FeedForward(
289 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, args=args
290 | )
291 | self.layer_id = layer_id
292 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
293 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
294 |
295 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
296 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)
297 | out = h + self.feed_forward.forward(self.ffn_norm(h))
298 | return out
299 |
300 |
301 | class Transformer(nn.Module):
302 | def __init__(self, params: ModelArgs):
303 | super().__init__()
304 | self.params = params
305 | self.vocab_size = params.vocab_size
306 | self.n_layers = params.n_layers
307 | self.tok_embeddings = Embedding(
308 | params.vocab_size, params.dim
309 | )
310 |
311 | self.layers = torch.nn.ModuleList()
312 | for layer_id in range(params.n_layers):
313 | self.layers.append(TransformerBlock(layer_id, params))
314 |
315 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
316 | self.output = Linear(
317 | params.dim, params.vocab_size, bias=False
318 | )
319 |
320 | self.freqs_cis = precompute_freqs_cis(
321 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
322 | )
323 |
324 | @torch.inference_mode()
325 | def forward(self, tokens: torch.Tensor, start_pos: int):
326 | _bsz, seqlen = tokens.shape
327 | h = self.tok_embeddings(tokens)
328 | self.freqs_cis = self.freqs_cis.to(h.device)
329 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
330 |
331 | mask = None
332 | if seqlen > 1:
333 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
334 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
335 |
336 | for layer in self.layers:
337 | h = layer(h, start_pos, freqs_cis, mask)
338 | h = self.norm(h)
339 | output = self.output(h[:, -1, :]) # only compute last logits
340 | return output.float()
341 |
--------------------------------------------------------------------------------
/llama/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from sentencepiece import SentencePieceProcessor
5 | from logging import getLogger
6 | from typing import List
7 | import os
8 |
9 |
10 | logger = getLogger()
11 |
12 |
13 | class Tokenizer:
14 | def __init__(self, model_path: str):
15 | # reload tokenizer
16 | assert os.path.isfile(model_path), model_path
17 | self.sp_model = SentencePieceProcessor(model_file=model_path)
18 | logger.info(f"Reloaded SentencePiece model from {model_path}")
19 |
20 | # BOS / EOS token IDs
21 | self.n_words: int = self.sp_model.vocab_size()
22 | self.bos_id: int = self.sp_model.bos_id()
23 | self.eos_id: int = self.sp_model.eos_id()
24 | self.pad_id: int = self.sp_model.pad_id()
25 | logger.info(
26 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
27 | )
28 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
29 |
30 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
31 | assert type(s) is str
32 | t = self.sp_model.encode(s)
33 | if bos:
34 | t = [self.bos_id] + t
35 | if eos:
36 | t = t + [self.eos_id]
37 | return t
38 |
39 | def decode(self, t: List[int]) -> str:
40 | return self.sp_model.decode(t)
41 |
--------------------------------------------------------------------------------
/llama/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib
3 | import hashlib
4 | import warnings
5 |
6 | from tqdm import tqdm
7 | import torch
8 |
9 |
10 | def sample_top_p(probs, p):
11 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
12 | probs_sum = torch.cumsum(probs_sort, dim=-1)
13 | mask = probs_sum - probs_sort > p
14 | probs_sort[mask] = 0.0
15 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
16 | next_token = torch.multinomial(probs_sort, num_samples=1)
17 | next_token = torch.gather(probs_idx, -1, next_token)
18 | return next_token
19 |
20 |
21 | def format_prompt(instruction, input=None):
22 |
23 | PROMPT_DICT = {
24 | "prompt_input": (
25 | "Below is an instruction that describes a task, paired with an input that provides further context. "
26 | "Write a response that appropriately completes the request.\n\n"
27 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
28 | ),
29 | "prompt_no_input": (
30 | "Below is an instruction that describes a task. "
31 | "Write a response that appropriately completes the request.\n\n"
32 | "### Instruction:\n{instruction}\n\n### Response:"
33 | ),
34 | }
35 | if input is None:
36 | return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction})
37 | else:
38 | return PROMPT_DICT["prompt_input"].format_map({'instruction': instruction, 'input': input})
39 |
40 |
41 | def _download(url: str, root: str):
42 | os.makedirs(root, exist_ok=True)
43 | filename = os.path.basename(url)
44 | # assume the url is https://some/path/sha256_model.pth
45 | expected_sha256 = url.split("/")[-1].split('_')[0]
46 | # expected_sha256 = url.split("/")[-2]
47 | download_target = os.path.join(root, filename)
48 |
49 | if os.path.exists(download_target) and not os.path.isfile(download_target):
50 | raise RuntimeError(f"{download_target} exists and is not a regular file")
51 |
52 | if os.path.isfile(download_target):
53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54 | return download_target
55 | else:
56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57 |
58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
60 | while True:
61 | buffer = source.read(8192)
62 | if not buffer:
63 | break
64 |
65 | output.write(buffer)
66 | loop.update(len(buffer))
67 |
68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
69 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
70 |
71 | return download_target
72 |
--------------------------------------------------------------------------------
/preprocess/build_image_lmdb.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 |
5 | import lmdb
6 | import numpy as np
7 | from PIL import Image
8 |
9 | import MatterSim
10 |
11 |
12 | # Simulator image parameters
13 | WIDTH = 640
14 | HEIGHT = 480
15 | VFOV = 60
16 |
17 | scan_data_dir = '../../data/v1/scans'
18 | connectivity_dir = '../../connectivity'
19 |
20 | sim = MatterSim.Simulator()
21 | sim.setDatasetPath(scan_data_dir)
22 | sim.setNavGraphPath(connectivity_dir)
23 | sim.setPreloadingEnabled(True)
24 | sim.setCameraResolution(WIDTH, HEIGHT)
25 | sim.setCameraVFOV(math.radians(VFOV))
26 | sim.setDiscretizedViewingAngles(True)
27 | sim.setBatchSize(1)
28 | sim.initialize()
29 |
30 | viewpoint_ids = []
31 | with open(os.path.join(connectivity_dir, 'scans.txt')) as f:
32 | scans = [x.strip() for x in f]
33 | for scan in scans:
34 | with open(os.path.join(connectivity_dir, f'{scan}_connectivity.json')) as f:
35 | data = json.load(f)
36 | viewpoint_ids.extend([(scan, x['image_id']) for x in data if x['included']])
37 | print(f'Loaded {len(viewpoint_ids)} viewpoints')
38 |
39 |
40 | NEWHEIGHT = 248
41 | NEWWIDTH = int(WIDTH / HEIGHT * NEWHEIGHT)
42 | print(NEWHEIGHT, NEWWIDTH)
43 |
44 | data_size_per_img = np.random.randint(255, size=(NEWHEIGHT, NEWWIDTH, 3), dtype=np.uint8).nbytes
45 | print(data_size_per_img, 36*data_size_per_img*len(viewpoint_ids))
46 |
47 | lmdb_path = '../../img_features/panoimages.lmdb'
48 |
49 | env = lmdb.open(lmdb_path, map_size=int(1e12))
50 |
51 |
52 | for i, viewpoint_id in enumerate(viewpoint_ids):
53 | scan, vp = viewpoint_id
54 | if i % 100 == 0:
55 | print(i, scan, vp)
56 |
57 | key = f'{scan}_{vp}'
58 | key_byte = key.encode('ascii')
59 |
60 | txn = env.begin(write=True)
61 |
62 | images = []
63 | for ix in range(36):
64 | if ix == 0:
65 | sim.newEpisode([scan], [vp], [0], [math.radians(-30)])
66 | elif ix % 12 == 0:
67 | sim.makeAction([0], [1.0], [1.0])
68 | else:
69 | sim.makeAction([0], [1.0], [0])
70 | state = sim.getState()[0]
71 | assert state.viewIndex == ix
72 | image = np.array(state.rgb, copy=True) # in BGR channel
73 | # cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
74 | image = Image.fromarray(image[:, :, ::-1])
75 | # resize
76 | image = image.resize((NEWWIDTH, NEWHEIGHT), Image.LANCZOS)
77 | image = np.array(image)
78 | images.append(image)
79 | images = np.stack(images, 0)
80 |
81 | txn.put(key_byte, images)
82 | txn.commit()
83 |
84 | env.close()
85 |
--------------------------------------------------------------------------------
/preprocess/precompute_img_features_clip.py:
--------------------------------------------------------------------------------
1 | ''' Script to precompute image features using a Pytorch ResNet CNN, using 36 discretized views
2 | at each viewpoint in 30 degree increments, and the provided camera WIDTH, HEIGHT
3 | and VFOV parameters. '''
4 |
5 | import argparse
6 | import math
7 | import os
8 |
9 | import h5py
10 | import numpy as np
11 | from PIL import Image
12 | from progressbar import ProgressBar
13 | import torch
14 | import torch.multiprocessing as mp
15 |
16 | import clip
17 | import MatterSim
18 |
19 | from utils import load_viewpoint_ids
20 |
21 |
22 | TSV_FIELDNAMES = ['scanId', 'viewpointId', 'image_w',
23 | 'image_h', 'vfov', 'features', 'logits']
24 | VIEWPOINT_SIZE = 36 # Number of discretized views from one viewpoint
25 | FEATURE_SIZE = 768
26 | LOGIT_SIZE = 1000
27 |
28 | WIDTH = 640
29 | HEIGHT = 480
30 | VFOV = 60
31 |
32 |
33 | def build_feature_extractor(model_name, checkpoint_file=None):
34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35 |
36 | model, img_transforms = clip.load(model_name, device='cpu')
37 | model.to(device)
38 | model.eval()
39 |
40 | return model, img_transforms, device
41 |
42 |
43 | def build_simulator(connectivity_dir, scan_dir):
44 | sim = MatterSim.Simulator()
45 | sim.setNavGraphPath(connectivity_dir)
46 | sim.setDatasetPath(scan_dir)
47 | sim.setCameraResolution(WIDTH, HEIGHT)
48 | sim.setCameraVFOV(math.radians(VFOV))
49 | sim.setDiscretizedViewingAngles(True)
50 | sim.setDepthEnabled(False)
51 | sim.setPreloadingEnabled(True)
52 | sim.setBatchSize(1)
53 | sim.initialize()
54 | return sim
55 |
56 |
57 | def clip_encode_image(model, x):
58 | # modified from CLIP
59 | x = model.visual.conv1(x) # shape = [*, width, grid, grid]
60 | # shape = [*, width, grid ** 2]
61 | x = x.reshape(x.shape[0], x.shape[1], -1)
62 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
63 | x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1,
64 | x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
65 | x = x + model.visual.positional_embedding.to(x.dtype)
66 | x = model.visual.ln_pre(x)
67 |
68 | x = x.permute(1, 0, 2) # NLD -> LND
69 | x = model.visual.transformer(x)
70 | x = x.permute(1, 0, 2) # LND -> NLD
71 |
72 | # preserve all spatial tokens
73 | # x = model.visual.ln_post(x[:, :, :])
74 | x = model.visual.ln_post(x[:, 0, :])
75 |
76 | if model.visual.proj is not None:
77 | x = x @ model.visual.proj
78 |
79 | return x
80 |
81 |
82 | def process_features(proc_id, out_queue, scanvp_list, args):
83 | print(f'start proc_id: {proc_id}')
84 |
85 | # Set up the simulator
86 | sim = build_simulator(args.connectivity_dir, args.scan_dir)
87 |
88 | # Set up PyTorch CNN model
89 | torch.set_grad_enabled(False)
90 | model, img_transforms, device = build_feature_extractor(args.model_name, args.checkpoint_file)
91 |
92 | for scan_id, viewpoint_id in scanvp_list:
93 | # Loop all discretized views from this location
94 | images = []
95 | for ix in range(VIEWPOINT_SIZE):
96 | if ix == 0:
97 | sim.newEpisode([scan_id], [viewpoint_id],
98 | [0], [math.radians(-30)])
99 | elif ix % 12 == 0:
100 | sim.makeAction([0], [1.0], [1.0])
101 | else:
102 | sim.makeAction([0], [1.0], [0])
103 | state = sim.getState()[0]
104 | assert state.viewIndex == ix
105 |
106 | image = np.array(state.rgb, copy=True) # in BGR channel
107 | # cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108 | # image = Image.fromarray(image[:, :, ::-1])
109 | image = Image.fromarray(image)
110 | images.append(image)
111 |
112 | images = torch.stack([img_transforms(image).to(device) for image in images], 0)
113 | fts = []
114 | for k in range(0, len(images), args.batch_size):
115 | b_fts = clip_encode_image(model, images[k: k+args.batch_size])
116 | b_fts = b_fts.data.cpu().numpy() # B, 768
117 | fts.append(b_fts)
118 | fts = np.concatenate(fts, 0)
119 |
120 | out_queue.put((scan_id, viewpoint_id, fts))
121 |
122 | out_queue.put(None)
123 |
124 |
125 | def build_feature_file(args):
126 |
127 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
128 |
129 | scanvp_list = load_viewpoint_ids(args.connectivity_dir)
130 |
131 | num_workers = min(args.num_workers, len(scanvp_list))
132 | num_data_per_worker = len(scanvp_list) // num_workers
133 |
134 | out_queue = mp.Queue()
135 | processes = []
136 | for proc_id in range(num_workers):
137 | sidx = proc_id * num_data_per_worker
138 | eidx = None if proc_id == num_workers - 1 else sidx + num_data_per_worker
139 |
140 | process = mp.Process(
141 | target=process_features,
142 | args=(proc_id, out_queue, scanvp_list[sidx: eidx], args)
143 | )
144 | process.start()
145 | processes.append(process)
146 |
147 | num_finished_workers = 0
148 | num_finished_vps = 0
149 |
150 | progress_bar = ProgressBar(max_value=len(scanvp_list))
151 | progress_bar.start()
152 |
153 | with h5py.File(args.output_file, 'w') as outf:
154 | while num_finished_workers < num_workers:
155 | res = out_queue.get()
156 | if res is None:
157 | num_finished_workers += 1
158 | else:
159 | scan_id, viewpoint_id, fts = res
160 | key = f'{scan_id}_{viewpoint_id}'
161 | data = fts
162 | outf.create_dataset(key, data.shape, dtype='float', compression='gzip')
163 | outf[key][...] = data
164 | outf[key].attrs['scanId'] = scan_id
165 | outf[key].attrs['viewpointId'] = viewpoint_id
166 | outf[key].attrs['image_w'] = WIDTH
167 | outf[key].attrs['image_h'] = HEIGHT
168 | outf[key].attrs['vfov'] = VFOV
169 |
170 | num_finished_vps += 1
171 | progress_bar.update(num_finished_vps)
172 |
173 | progress_bar.finish()
174 | for process in processes:
175 | process.join()
176 |
177 |
178 | if __name__ == '__main__':
179 | parser = argparse.ArgumentParser()
180 | parser.add_argument('--model_name', default='ViT-L/14')
181 | parser.add_argument('--checkpoint_file', default=None)
182 | parser.add_argument('--connectivity_dir', default='../../connectivity')
183 | parser.add_argument('--scan_dir', default='../../data/v1/scans')
184 | parser.add_argument('--output_file')
185 | parser.add_argument('--batch_size', default=36, type=int)
186 | parser.add_argument('--num_workers', type=int, default=8)
187 | args = parser.parse_args()
188 |
189 | mp.set_start_method('spawn')
190 |
191 | build_feature_file(args)
192 |
--------------------------------------------------------------------------------
/preprocess/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 |
5 |
6 | def load_viewpoint_ids(connectivity_dir):
7 | viewpoint_ids = []
8 | with open(os.path.join(connectivity_dir, 'scans.txt')) as f:
9 | scans = [x.strip() for x in f]
10 | for scan in scans:
11 | with open(os.path.join(connectivity_dir, f'{scan}_connectivity.json')) as f:
12 | data = json.load(f)
13 | viewpoint_ids.extend([(scan, x['image_id']) for x in data if x['included']])
14 | print(f'Loaded {len(viewpoint_ids)} viewpoints')
15 | return viewpoint_ids
16 |
17 |
18 | class Timer(object):
19 | """A simple timer."""
20 |
21 | def __init__(self):
22 | self.total_time = 0.
23 | self.calls = 0
24 | self.start_time = 0.
25 | self.diff = 0.
26 | self.average_time = 0.
27 |
28 | def tic(self):
29 | # using time.time instead of time.clock because time time.clock
30 | # does not normalize for multithreading
31 | self.start_time = time.time()
32 |
33 | def toc(self, average=True):
34 | self.diff = time.time() - self.start_time
35 | self.total_time += self.diff
36 | self.calls += 1
37 | self.average_time = self.total_time / self.calls
38 | if average:
39 | return self.average_time
40 | else:
41 | return self.diff
42 |
--------------------------------------------------------------------------------
/pycocoevalcap/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/__init__.py
--------------------------------------------------------------------------------
/pycocoevalcap/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/pycocoevalcap/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/pycocoevalcap/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from .bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(list(gts.keys()) == list(res.keys()))
24 | imgIds = list(gts.keys())
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | #score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | # return (bleu, bleu_info)
44 | return score, scores
45 |
46 | def method(self):
47 | return "Bleu"
48 |
--------------------------------------------------------------------------------
/pycocoevalcap/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in range(1,n+1):
30 | for i in range(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.items():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, xxx_todo_changeme, eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 | (reflen, refmaxcounts) = xxx_todo_changeme
64 | testlen, counts = precook(test, n, True)
65 |
66 | result = {}
67 |
68 | # Calculate effective reference sentence length.
69 |
70 | if eff == "closest":
71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
72 | else: ## i.e., "average" or "shortest" or None
73 | result["reflen"] = reflen
74 |
75 | result["testlen"] = testlen
76 |
77 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
78 |
79 | result['correct'] = [0]*n
80 | for (ngram, count) in counts.items():
81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
82 |
83 | return result
84 |
85 | class BleuScorer(object):
86 | """Bleu scorer.
87 | """
88 |
89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
90 | # special_reflen is used in oracle (proportional effective ref len for a node).
91 |
92 | def copy(self):
93 | ''' copy the refs.'''
94 | new = BleuScorer(n=self.n)
95 | new.ctest = copy.copy(self.ctest)
96 | new.crefs = copy.copy(self.crefs)
97 | new._score = None
98 | return new
99 |
100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
101 | ''' singular instance '''
102 |
103 | self.n = n
104 | self.crefs = []
105 | self.ctest = []
106 | self.cook_append(test, refs)
107 | self.special_reflen = special_reflen
108 |
109 | def cook_append(self, test, refs):
110 | '''called by constructor and __iadd__ to avoid creating new instances.'''
111 |
112 | if refs is not None:
113 | self.crefs.append(cook_refs(refs))
114 | if test is not None:
115 | cooked_test = cook_test(test, self.crefs[-1])
116 | self.ctest.append(cooked_test) ## N.B.: -1
117 | else:
118 | self.ctest.append(None) # lens of crefs and ctest have to match
119 |
120 | self._score = None ## need to recompute
121 |
122 | def ratio(self, option=None):
123 | self.compute_score(option=option)
124 | return self._ratio
125 |
126 | def score_ratio(self, option=None):
127 | '''return (bleu, len_ratio) pair'''
128 | return (self.fscore(option=option), self.ratio(option=option))
129 |
130 | def score_ratio_str(self, option=None):
131 | return "%.4f (%.2f)" % self.score_ratio(option)
132 |
133 | def reflen(self, option=None):
134 | self.compute_score(option=option)
135 | return self._reflen
136 |
137 | def testlen(self, option=None):
138 | self.compute_score(option=option)
139 | return self._testlen
140 |
141 | def retest(self, new_test):
142 | if type(new_test) is str:
143 | new_test = [new_test]
144 | assert len(new_test) == len(self.crefs), new_test
145 | self.ctest = []
146 | for t, rs in zip(new_test, self.crefs):
147 | self.ctest.append(cook_test(t, rs))
148 | self._score = None
149 |
150 | return self
151 |
152 | def rescore(self, new_test):
153 | ''' replace test(s) with new test(s), and returns the new score.'''
154 |
155 | return self.retest(new_test).compute_score()
156 |
157 | def size(self):
158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
159 | return len(self.crefs)
160 |
161 | def __iadd__(self, other):
162 | '''add an instance (e.g., from another sentence).'''
163 |
164 | if type(other) is tuple:
165 | ## avoid creating new BleuScorer instances
166 | self.cook_append(other[0], other[1])
167 | else:
168 | assert self.compatible(other), "incompatible BLEUs."
169 | self.ctest.extend(other.ctest)
170 | self.crefs.extend(other.crefs)
171 | self._score = None ## need to recompute
172 |
173 | return self
174 |
175 | def compatible(self, other):
176 | return isinstance(other, BleuScorer) and self.n == other.n
177 |
178 | def single_reflen(self, option="average"):
179 | return self._single_reflen(self.crefs[0][0], option)
180 |
181 | def _single_reflen(self, reflens, option=None, testlen=None):
182 |
183 | if option == "shortest":
184 | reflen = min(reflens)
185 | elif option == "average":
186 | reflen = float(sum(reflens))/len(reflens)
187 | elif option == "closest":
188 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
189 | else:
190 | assert False, "unsupported reflen option %s" % option
191 |
192 | return reflen
193 |
194 | def recompute_score(self, option=None, verbose=0):
195 | self._score = None
196 | return self.compute_score(option, verbose)
197 |
198 | def compute_score(self, option=None, verbose=0):
199 | n = self.n
200 | small = 1e-9
201 | tiny = 1e-15 ## so that if guess is 0 still return 0
202 | bleu_list = [[] for _ in range(n)]
203 |
204 | if self._score is not None:
205 | return self._score
206 |
207 | if option is None:
208 | option = "average" if len(self.crefs) == 1 else "closest"
209 |
210 | self._testlen = 0
211 | self._reflen = 0
212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
213 |
214 | # for each sentence
215 | for comps in self.ctest:
216 | testlen = comps['testlen']
217 | self._testlen += testlen
218 |
219 | if self.special_reflen is None: ## need computation
220 | reflen = self._single_reflen(comps['reflen'], option, testlen)
221 | else:
222 | reflen = self.special_reflen
223 |
224 | self._reflen += reflen
225 |
226 | for key in ['guess','correct']:
227 | for k in range(n):
228 | totalcomps[key][k] += comps[key][k]
229 |
230 | # append per image bleu score
231 | bleu = 1.
232 | for k in range(n):
233 | bleu *= (float(comps['correct'][k]) + tiny) \
234 | /(float(comps['guess'][k]) + small)
235 | bleu_list[k].append(bleu ** (1./(k+1)))
236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
237 | if ratio < 1:
238 | for k in range(n):
239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
240 |
241 | if verbose > 1:
242 | print(comps, reflen)
243 |
244 | totalcomps['reflen'] = self._reflen
245 | totalcomps['testlen'] = self._testlen
246 |
247 | bleus = []
248 | bleu = 1.
249 | for k in range(n):
250 | bleu *= float(totalcomps['correct'][k] + tiny) \
251 | / (totalcomps['guess'][k] + small)
252 | bleus.append(bleu ** (1./(k+1)))
253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
254 | if ratio < 1:
255 | for k in range(n):
256 | bleus[k] *= math.exp(1 - 1/ratio)
257 |
258 | if verbose > 0:
259 | print(totalcomps)
260 | print("ratio:", ratio)
261 |
262 | self._score = bleus
263 | return self._score, bleu_list
264 |
--------------------------------------------------------------------------------
/pycocoevalcap/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/pycocoevalcap/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from .cider_scorer import CiderScorer
11 | import pdb
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19 | # set cider to sum over 1 to 4-grams
20 | self._n = n
21 | # set the standard deviation parameter for gaussian penalty
22 | self._sigma = sigma
23 |
24 | def compute_score(self, gts, res):
25 | """
26 | Main function to compute CIDEr score
27 | :param hypo_for_image (dict) : dictionary with key and value
28 | ref_for_image (dict) : dictionary with key and value
29 | :return: cider (float) : computed CIDEr score for the corpus
30 | """
31 |
32 | assert(list(gts.keys()) == list(res.keys()))
33 | imgIds = list(gts.keys())
34 |
35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
36 |
37 | for id in imgIds:
38 | hypo = res[id]
39 | ref = gts[id]
40 |
41 | # Sanity check.
42 | assert(type(hypo) is list)
43 | assert(len(hypo) == 1)
44 | assert(type(ref) is list)
45 | assert(len(ref) > 0)
46 |
47 | cider_scorer += (hypo[0], ref)
48 |
49 | (score, scores) = cider_scorer.compute_score()
50 |
51 | return score, scores
52 |
53 | def method(self):
54 | return "CIDEr"
--------------------------------------------------------------------------------
/pycocoevalcap/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in range(1,n+1):
23 | for i in range(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.items():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].items():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/pycocoevalcap/clip_tokenizer/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/clip_tokenizer/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/pycocoevalcap/clip_tokenizer/tokenization_clip.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import torch
3 | import html
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 | import numpy as np
9 | import copy
10 | import string
11 |
12 |
13 | @lru_cache()
14 | def default_bpe():
15 | return "clip_tokenizer/bpe_simple_vocab_16e6.txt.gz"
16 |
17 |
18 | @lru_cache()
19 | def bytes_to_unicode():
20 | """
21 | Returns list of utf-8 byte and a corresponding list of unicode strings.
22 | The reversible bpe codes work on unicode strings.
23 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25 | This is a signficant percentage of your normal, say, 32K bpe vocab.
26 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27 | And avoids mapping to whitespace/control characters the bpe code barfs on.
28 | """
29 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
30 | cs = bs[:]
31 | n = 0
32 | for b in range(2**8):
33 | if b not in bs:
34 | bs.append(b)
35 | cs.append(2**8+n)
36 | n += 1
37 | cs = [chr(n) for n in cs]
38 | return dict(zip(bs, cs))
39 |
40 |
41 | def get_pairs(word):
42 | """Return set of symbol pairs in a word.
43 | Word is represented as tuple of symbols (symbols being variable-length strings).
44 | """
45 | pairs = set()
46 | prev_char = word[0]
47 | for char in word[1:]:
48 | pairs.add((prev_char, char))
49 | prev_char = char
50 | return pairs
51 |
52 |
53 | def basic_clean(text):
54 | text = ftfy.fix_text(text)
55 | text = html.unescape(html.unescape(text))
56 | return text.strip()
57 |
58 |
59 | def whitespace_clean(text):
60 | text = re.sub(r'\s+', ' ', text)
61 | text = text.strip()
62 | return text
63 |
64 |
65 | class SimpleTokenizer(object):
66 | SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') # Split on any non-alphanumeric character
67 | def __init__(self, bpe_path: str = default_bpe()):
68 | self.byte_encoder = bytes_to_unicode()
69 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
70 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
71 | merges = merges[1:49152-256-2+1]
72 | merges = [tuple(merge.split()) for merge in merges]
73 | vocab = list(bytes_to_unicode().values())
74 | vocab = vocab + [v+'' for v in vocab]
75 | for merge in merges:
76 | vocab.append(''.join(merge))
77 | # vocab.extend(['<|startoftext|>', '<|endoftext|>'])
78 | vocab.extend(['', '', '', ''])
79 | self.encoder = dict(zip(vocab, range(len(vocab))))
80 | self.decoder = {v: k for k, v in self.encoder.items()}
81 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
82 | # self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
83 | self.cache = {'': '', '': '', '': '', '': ''}
84 | # self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
85 | self.pat = re.compile(r"""<\|BOS\|>|<\|EOS\|>|<\|UNK\|>|<\|MSK\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
86 |
87 | self.vocab = self.encoder
88 | self.word_to_index = copy.deepcopy(self.encoder)
89 | self.index_to_word = copy.deepcopy(self.decoder)
90 | self.word_to_index[''] = 0 # FIXME not elegant
91 | print(f"vocab size is {self.vocab_size()}")
92 |
93 | def vocab_size(self):
94 | return len(self.vocab)
95 |
96 | def bpe(self, token):
97 | if token in self.cache:
98 | return self.cache[token]
99 | word = tuple(token[:-1]) + ( token[-1] + '',)
100 | pairs = get_pairs(word)
101 |
102 | if not pairs:
103 | return token+''
104 |
105 | while True:
106 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
107 | if bigram not in self.bpe_ranks:
108 | break
109 | first, second = bigram
110 | new_word = []
111 | i = 0
112 | while i < len(word):
113 | try:
114 | j = word.index(first, i)
115 | new_word.extend(word[i:j])
116 | i = j
117 | except:
118 | new_word.extend(word[i:])
119 | break
120 |
121 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
122 | new_word.append(first+second)
123 | i += 2
124 | else:
125 | new_word.append(word[i])
126 | i += 1
127 | new_word = tuple(new_word)
128 | word = new_word
129 | if len(word) == 1:
130 | break
131 | else:
132 | pairs = get_pairs(word)
133 | word = ' '.join(word)
134 | self.cache[token] = word
135 | return word
136 |
137 | def encode(self, text):
138 | bpe_tokens = [self.encoder[""]]
139 | text = whitespace_clean(basic_clean(text)).lower()
140 | for token in re.findall(self.pat, text):
141 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
142 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
143 | bpe_tokens.append(self.encoder[""])
144 | return bpe_tokens
145 |
146 | def decode(self, tokens):
147 | text = ''.join([self.decoder[token] for token in tokens])
148 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
149 | return text
150 |
151 | def tokenize(self, text):
152 | tokens = []
153 | text = whitespace_clean(basic_clean(text)).lower()
154 | for token in re.findall(self.pat, text):
155 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
156 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
157 | return tokens
158 |
159 | def convert_tokens_to_ids(self, tokens):
160 | return [self.encoder[bpe_token] for bpe_token in tokens]
161 |
162 | def __call__(self, texts, return_tensors='pt', padding=True, truncation=True):
163 | """
164 | Returns the tokenized representation of given input string(s)
165 | Parameters
166 | ----------
167 | texts : Union[str, List[str]]
168 | An input string or a list of input strings to tokenize
169 | context_length : int
170 | The context length to use; all CLIP models use 77 as the context length
171 |
172 | remaining params are just to have same interface with huggingface tokenizer.
173 | They don't do much.
174 | Returns
175 | -------
176 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
177 | """
178 | context_length = 100 # NOTE 100 in VLN task, cause one token length is 97, one is 121
179 | if isinstance(texts, str):
180 | texts = [texts]
181 |
182 | sot_token = self.encoder[""]
183 | eot_token = self.encoder[""]
184 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
185 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
186 |
187 | for i, tokens in enumerate(all_tokens):
188 | if len(tokens) > context_length:
189 | # import ipdb;ipdb.set_trace()
190 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
191 | tokens = tokens[:context_length - 1]
192 | tokens.append(self.vocab[""]) # NOTE
193 | result[i, :len(tokens)] = torch.tensor(tokens)
194 |
195 | return result
196 |
197 | def encode_sentence(self, texts):
198 | # str -> numpy for only one sentence!!!!
199 | context_length = 100 # NOTE 100 in VLN task, cause one token length is 97, one is 121
200 | if isinstance(texts, str):
201 | texts = [texts]
202 |
203 | sot_token = self.encoder[""]
204 | eot_token = self.encoder[""]
205 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
206 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
207 |
208 | for i, tokens in enumerate(all_tokens):
209 | if len(tokens) > context_length:
210 | # import ipdb;ipdb.set_trace()
211 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
212 | tokens = tokens[:context_length]
213 | # tokens.append(self.vocab["<|endoftext|>"]) # NOTE no need to add [eos]
214 | result[i, :len(tokens)] = torch.tensor(tokens)
215 | result = result.squeeze(0) # [context_length]
216 |
217 | return np.array(result)
218 |
219 | def decode_sentence(self, tokens, length=None):
220 | # numpy -> str
221 | # text = ''.join([self.decoder[token] for token in tokens])
222 | # text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
223 | text = []
224 | if length is not None:
225 | tokens = tokens[:length]
226 | for ix in tokens:
227 | if ix == 0:
228 | break
229 | else:
230 | text.append(self.decoder[ix])
231 | text = ''.join([t for t in text])
232 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
233 | return text
234 |
235 | def shrink(self, inst):
236 | # numpy -> numpy
237 | if len(inst) == 0:
238 | return inst
239 | end = np.argmax(np.array(inst) == self.encoder[""])
240 | if len(inst) > 1 and inst[0] == self.encoder[""]:
241 | start = 1
242 | else:
243 | start = 0
244 | return inst[start: end]
245 |
246 | @staticmethod
247 | def split_sentence(sentence):
248 | ''' Break sentence into a list of words and punctuation '''
249 | toks = []
250 | for word in [s.strip().lower() for s in SimpleTokenizer.SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0]:
251 | # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
252 | if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
253 | toks += list(word)
254 | else:
255 | toks.append(word)
256 | return toks
257 |
--------------------------------------------------------------------------------
/pycocoevalcap/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | # import sys
5 |
6 | # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | from utils import Tokenizer, read_vocab
8 | # from llama import Tokenizer
9 |
10 | from tokenizer.ptbtokenizer import PTBTokenizer
11 | from bleu.bleu import Bleu
12 | from meteor.meteor import Meteor
13 | from rouge.rouge import Rouge
14 | from cider.cider import Cider
15 | from spice.spice import Spice
16 | # from wmd.wmd import WMD
17 | from clip_tokenizer.tokenization_clip import SimpleTokenizer
18 |
19 |
20 | TRAIN_VOCAB = '/data/user/kxh/instructllm/Matterport3DSimulator/tasks/R2R/data/train_vocab.txt'
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser('Speaker Evaluator', add_help=False)
25 | parser.add_argument('--ckpt_dir', default='../results_lana', type=str)
26 |
27 | args = parser.parse_args()
28 | return args
29 |
30 |
31 | def img_to_eval_imgs(scores, img_ids, method):
32 | img2eval = {}
33 |
34 | for img_id, score in zip(img_ids, scores):
35 | if not img_id in img2eval:
36 | img2eval[img_id] = {}
37 | img2eval[img_id]["image_id"] = img_id
38 | img2eval[img_id][method] = score
39 |
40 | return img2eval
41 |
42 |
43 | def eval_speaker(input_path):
44 | json_path = os.path.join(input_path, 'id2path_reverie_val_unseen.json')
45 | with open(json_path, 'r') as f:
46 | id2path = json.load(f)
47 |
48 | # tokenizer = Tokenizer('/root/mount/LLaMA-7B/tokenizer.model')
49 | # vocab = read_vocab(TRAIN_VOCAB)
50 | # tokenizer = Tokenizer(vocab=vocab, encoding_length=1000)
51 | # tokenizer = SimpleTokenizer()
52 |
53 | refs = {}
54 | candidates = {}
55 | for id, pair in id2path.items():
56 | gt_sentence_list = pair['gt']
57 | gt_list = []
58 | for sentence in gt_sentence_list:
59 | # gt_list.append(tokenizer.encode(sentence, bos=False, eos=False))
60 | # gt_list.append(' '.join(tokenizer.split_sentence(sentence)))
61 | gt_list.append(sentence)
62 | refs[id] = gt_list
63 |
64 | inference_sentence = pair['inference']
65 | # inference_list = tokenizer.encode(inference_sentence, bos=False, eos=False)
66 | # inference_list = [' '.join(tokenizer.split_sentence(inference_sentence))]
67 | inference_list = [inference_sentence]
68 | candidates[id] = inference_list
69 |
70 | # =================================================
71 | # Tokenization
72 | # =================================================
73 | print('tokenization...')
74 | tokenizer = PTBTokenizer()
75 | refs = tokenizer.tokenize(refs)
76 | candidates = tokenizer.tokenize(candidates)
77 |
78 | # =================================================
79 | # Set up scorers
80 | # =================================================
81 | print('setting up scorers...')
82 | scorers = [
83 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
84 | (Meteor(), "METEOR"),
85 | (Rouge(), "ROUGE_L"),
86 | (Cider(), "CIDEr"),
87 | (Spice(), "SPICE"),
88 | # (WMD(), "WMD"),
89 | ]
90 | eval_dict = {}
91 |
92 | # =================================================
93 | # Compute scores
94 | # =================================================
95 | for scorer, method in scorers:
96 | print(f'computing {scorer.method()} score...')
97 | score, scores = scorer.compute_score(refs, candidates)
98 | if type(method) == list:
99 | for sc, scs, m in zip(score, scores, method):
100 | eval_dict[m] = sc
101 | img2eval = img_to_eval_imgs(scs, list(id2path.keys()), m)
102 | print("%s: %0.3f" % (m, sc))
103 | else:
104 | eval_dict[method] = score
105 | img2eval = img_to_eval_imgs(scores, list(id2path.keys()), method)
106 | print("%s: %0.3f" % (method, score))
107 |
108 | evalImgs = list(img2eval.values())
109 | print('======================= Finished =======================')
110 | print(eval_dict)
111 |
112 |
113 | if __name__ == '__main__':
114 | args = parse_args()
115 | eval_speaker(args.ckpt_dir)
116 |
--------------------------------------------------------------------------------
/pycocoevalcap/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/pycocoevalcap/meteor/data/paraphrase-en.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/meteor/data/paraphrase-en.gz
--------------------------------------------------------------------------------
/pycocoevalcap/meteor/meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/meteor/meteor-1.5.jar
--------------------------------------------------------------------------------
/pycocoevalcap/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import subprocess
9 | import threading
10 |
11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
12 | METEOR_JAR = 'meteor-1.5.jar'
13 | # print METEOR_JAR
14 |
15 | class Meteor:
16 |
17 | def __init__(self):
18 | self.env = os.environ
19 | self.env['LC_ALL'] = 'en_US.UTF_8'
20 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
21 | '-', '-', '-stdio', '-l', 'en', '-norm']
22 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
23 | cwd=os.path.dirname(os.path.abspath(__file__)), \
24 | stdin=subprocess.PIPE, \
25 | stdout=subprocess.PIPE, \
26 | stderr=subprocess.PIPE,
27 | env=self.env, universal_newlines=True, bufsize=1)
28 | # Used to guarantee thread safety
29 | self.lock = threading.Lock()
30 |
31 | def compute_score(self, gts, res):
32 | assert(gts.keys() == res.keys())
33 | imgIds = sorted(list(gts.keys()))
34 | scores = []
35 |
36 | eval_line = 'EVAL'
37 | self.lock.acquire()
38 | for i in imgIds:
39 | assert(len(res[i]) == 1)
40 | stat = self._stat(res[i][0], gts[i])
41 | eval_line += ' ||| {}'.format(stat)
42 |
43 | # Send to METEOR
44 | self.meteor_p.stdin.write(eval_line + '\n')
45 |
46 | # Collect segment scores
47 | for i in range(len(imgIds)):
48 | score = float(self.meteor_p.stdout.readline().strip())
49 | scores.append(score)
50 |
51 | # Final score
52 | final_score = float(self.meteor_p.stdout.readline().strip())
53 | self.lock.release()
54 |
55 | return final_score, scores
56 |
57 | def method(self):
58 | return "METEOR"
59 |
60 | def _stat(self, hypothesis_str, reference_list):
61 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
62 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
63 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
64 | self.meteor_p.stdin.write(score_line+'\n')
65 | return self.meteor_p.stdout.readline().strip()
66 |
67 | def __del__(self):
68 | self.lock.acquire()
69 | self.meteor_p.stdin.close()
70 | self.meteor_p.kill()
71 | self.meteor_p.wait()
72 | self.lock.release()
73 |
--------------------------------------------------------------------------------
/pycocoevalcap/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/pycocoevalcap/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 |
12 | def my_lcs(string, sub):
13 | """
14 | Calculates longest common subsequence for a pair of tokenized strings
15 | :param string : list of str : tokens from a string split using whitespace
16 | :param sub : list of str : shorter string, also split using whitespace
17 | :returns: length (list of int): length of the longest common subsequence between the two strings
18 |
19 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
20 | """
21 | if(len(string)< len(sub)):
22 | sub, string = string, sub
23 |
24 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
25 |
26 | for j in range(1,len(sub)+1):
27 | for i in range(1,len(string)+1):
28 | if(string[i-1] == sub[j-1]):
29 | lengths[i][j] = lengths[i-1][j-1] + 1
30 | else:
31 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
32 |
33 | return lengths[len(string)][len(sub)]
34 |
35 | class Rouge():
36 | '''
37 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
38 |
39 | '''
40 | def __init__(self):
41 | # vrama91: updated the value below based on discussion with Hovey
42 | self.beta = 1.2
43 |
44 | def calc_score(self, candidate, refs):
45 | """
46 | Compute ROUGE-L score given one candidate and references for an image
47 | :param candidate: str : candidate sentence to be evaluated
48 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
49 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
50 | """
51 | assert(len(candidate)==1)
52 | assert(len(refs)>0)
53 | prec = []
54 | rec = []
55 |
56 | # split into tokens
57 | token_c = candidate[0].split(" ")
58 |
59 | for reference in refs:
60 | # split into tokens
61 | token_r = reference.split(" ")
62 | # compute the longest common subsequence
63 | lcs = my_lcs(token_r, token_c)
64 | prec.append(lcs/float(len(token_c)))
65 | rec.append(lcs/float(len(token_r)))
66 |
67 | prec_max = max(prec)
68 | rec_max = max(rec)
69 |
70 | if(prec_max!=0 and rec_max !=0):
71 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
72 | else:
73 | score = 0.0
74 | return score
75 |
76 | def compute_score(self, gts, res):
77 | """
78 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
79 | Invoked by evaluate_captions.py
80 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
81 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
82 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
83 | """
84 | assert(list(gts.keys()) == list(res.keys()))
85 | imgIds = list(gts.keys())
86 |
87 | score = []
88 | for id in imgIds:
89 | hypo = res[id]
90 | ref = gts[id]
91 |
92 | score.append(self.calc_score(hypo, ref))
93 |
94 | # Sanity check.
95 | assert(type(hypo) is list)
96 | assert(len(hypo) == 1)
97 | assert(type(ref) is list)
98 | assert(len(ref) > 0)
99 |
100 | average_score = np.mean(np.array(score))
101 | return average_score, np.array(score)
102 |
103 | def method(self):
104 | return "Rouge"
105 |
--------------------------------------------------------------------------------
/pycocoevalcap/spice/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/spice/__init__.py
--------------------------------------------------------------------------------
/pycocoevalcap/spice/spice-1.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/spice/spice-1.0.jar
--------------------------------------------------------------------------------
/pycocoevalcap/spice/spice.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import sys
4 | import subprocess
5 | import threading
6 | import json
7 | import numpy as np
8 | import ast
9 | import tempfile
10 |
11 | # Assumes spice.jar is in the same directory as spice.py. Change as needed.
12 | SPICE_JAR = 'spice-1.0.jar'
13 | TEMP_DIR = 'tmp'
14 | CACHE_DIR = 'cache'
15 |
16 | class Spice:
17 | """
18 | Main Class to compute the SPICE metric
19 | """
20 |
21 | def float_convert(self, obj):
22 | try:
23 | return float(obj)
24 | except:
25 | return np.nan
26 |
27 | def compute_score(self, gts, res):
28 | assert(sorted(gts.keys()) == sorted(res.keys()))
29 | imgIds = sorted(gts.keys())
30 |
31 | # Prepare temp input file for the SPICE scorer
32 | input_data = []
33 | for id in imgIds:
34 | hypo = res[id]
35 | ref = gts[id]
36 |
37 | # Sanity check.
38 | assert(type(hypo) is list)
39 | assert(len(hypo) == 1)
40 | assert(type(ref) is list)
41 | assert(len(ref) >= 1)
42 |
43 | input_data.append({
44 | "image_id" : id,
45 | "test" : hypo[0],
46 | "refs" : ref
47 | })
48 |
49 | cwd = os.path.dirname(os.path.abspath(__file__))
50 | temp_dir=os.path.join(cwd, TEMP_DIR)
51 | if not os.path.exists(temp_dir):
52 | os.makedirs(temp_dir)
53 | in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
54 | in_file.write(json.dumps(input_data, indent=2).encode('utf-8'))
55 | in_file.close()
56 |
57 | # Start job
58 | out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
59 | out_file.close()
60 | cache_dir=os.path.join(cwd, CACHE_DIR)
61 | if not os.path.exists(cache_dir):
62 | os.makedirs(cache_dir)
63 | spice_cmd = ['java', '-jar', '-Xmx64G', SPICE_JAR, in_file.name,
64 | '-cache', cache_dir,
65 | '-out', out_file.name,
66 | '-subset',
67 | '-silent'
68 | ]
69 | subprocess.check_call(spice_cmd,
70 | cwd=os.path.dirname(os.path.abspath(__file__)))
71 |
72 | # Read and process results
73 | with open(out_file.name) as data_file:
74 | results = json.load(data_file)
75 | os.remove(in_file.name)
76 | os.remove(out_file.name)
77 |
78 | imgId_to_scores = {}
79 | spice_scores = []
80 | for item in results:
81 | imgId_to_scores[item['image_id']] = item['scores']
82 | spice_scores.append(self.float_convert(item['scores']['All']['f']))
83 | average_score = np.mean(np.array(spice_scores))
84 | scores = []
85 | for image_id in imgIds:
86 | # Convert none to NaN before saving scores over subcategories
87 | score_set = {}
88 | for category,score_tuple in imgId_to_scores[image_id].items():
89 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()}
90 | scores.append(score_set)
91 | return average_score, scores
92 |
93 | def method(self):
94 | return "SPICE"
95 |
96 |
97 |
--------------------------------------------------------------------------------
/pycocoevalcap/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'hfang'
2 |
--------------------------------------------------------------------------------
/pycocoevalcap/tokenizer/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import sys
13 | import subprocess
14 | import tempfile
15 | import itertools
16 |
17 | # path to the stanford corenlp jar
18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
19 |
20 | # punctuations to be removed from the sentences
21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
23 |
24 | class PTBTokenizer:
25 | """Python wrapper of Stanford PTBTokenizer"""
26 |
27 | def tokenize(self, captions_for_image):
28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
29 | 'edu.stanford.nlp.process.PTBTokenizer', \
30 | '-preserveLines', '-lowerCase']
31 |
32 | # ======================================================
33 | # prepare data for PTB Tokenizer
34 | # ======================================================
35 | final_tokenized_captions_for_image = {}
36 | image_id = [k for k, v in list(captions_for_image.items()) for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in list(captions_for_image.items()) for c in v])
38 |
39 | # ======================================================
40 | # save sentences to temporary file
41 | # ======================================================
42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
44 | tmp_file.write(sentences.encode('utf-8'))
45 | tmp_file.close()
46 |
47 | # ======================================================
48 | # tokenize sentence
49 | # ======================================================
50 | cmd.append(os.path.basename(tmp_file.name))
51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
52 | stdout=subprocess.PIPE)
53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
54 | lines = token_lines.decode("utf-8").split('\n')
55 | # remove temp file
56 | os.remove(tmp_file.name)
57 |
58 | # ======================================================
59 | # create dictionary for tokenized captions
60 | # ======================================================
61 | for k, line in zip(image_id, lines):
62 | if not k in final_tokenized_captions_for_image:
63 | final_tokenized_captions_for_image[k] = []
64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
65 | if w not in PUNCTUATIONS])
66 | final_tokenized_captions_for_image[k].append(tokenized_caption)
67 |
68 | return final_tokenized_captions_for_image
69 |
--------------------------------------------------------------------------------
/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/refkxh/C-Instructor/55756e5fb3771f8dbbac0f63f075142a41906e74/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/pycocoevalcap/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import re
4 | import string
5 | import sys
6 | from collections import Counter, defaultdict
7 |
8 | import numpy as np
9 |
10 |
11 | # padding, unknown word, end of sentence
12 | base_vocab = ['', '', '']
13 | padding_idx = base_vocab.index('')
14 |
15 |
16 | class Tokenizer(object):
17 | ''' Class to tokenize and encode a sentence. '''
18 | SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') # Split on any non-alphanumeric character
19 |
20 | def __init__(self, vocab=None, encoding_length=20):
21 | self.encoding_length = encoding_length
22 | self.vocab = vocab
23 | self.word_to_index = {}
24 | self.index_to_word = {}
25 | if vocab:
26 | for i, word in enumerate(vocab):
27 | self.word_to_index[word] = i
28 | new_w2i = defaultdict(lambda: self.word_to_index[''])
29 | new_w2i.update(self.word_to_index)
30 | self.word_to_index = new_w2i
31 | for key, value in self.word_to_index.items():
32 | self.index_to_word[value] = key
33 | old = self.vocab_size()
34 | self.add_word('')
35 | assert self.vocab_size() == old+1
36 | print("OLD_VOCAB_SIZE", old)
37 | print("VOCAB_SIZE", self.vocab_size())
38 | print("VOACB", len(vocab))
39 |
40 | def finalize(self):
41 | """
42 | This is used for debug
43 | """
44 | self.word_to_index = dict(self.word_to_index) # To avoid using mis-typing tokens
45 |
46 | def add_word(self, word):
47 | assert word not in self.word_to_index
48 | self.word_to_index[word] = self.vocab_size() # vocab_size() is the
49 | self.index_to_word[self.vocab_size()] = word
50 |
51 | @staticmethod
52 | def split_sentence(sentence):
53 | ''' Break sentence into a list of words and punctuation '''
54 | toks = []
55 | for word in [s.strip().lower() for s in Tokenizer.SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0]:
56 | # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
57 | if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
58 | toks += list(word)
59 | else:
60 | toks.append(word)
61 | return toks
62 |
63 | def vocab_size(self):
64 | return len(self.index_to_word)
65 |
66 | def encode_sentence(self, sentence, max_length=None):
67 | if max_length is None:
68 | max_length = self.encoding_length
69 | if len(self.word_to_index) == 0:
70 | sys.exit('Tokenizer has no vocab')
71 |
72 | encoding = [self.word_to_index['']]
73 | for word in self.split_sentence(sentence):
74 | encoding.append(self.word_to_index[word]) # Default Dict
75 | encoding.append(self.word_to_index[''])
76 |
77 | if len(encoding) <= 2:
78 | return None
79 | #assert len(encoding) > 2
80 |
81 | if len(encoding) < max_length:
82 | encoding += [self.word_to_index['']] * (max_length-len(encoding)) # Padding
83 | elif len(encoding) > max_length:
84 | # Cut the length with EOS
85 | encoding[max_length - 1] = self.word_to_index['']
86 |
87 | return np.array(encoding[:max_length])
88 |
89 | def decode_sentence(self, encoding, length=None):
90 | sentence = []
91 | if length is not None:
92 | encoding = encoding[:length]
93 | for ix in encoding:
94 | if ix == self.word_to_index['']:
95 | break
96 | else:
97 | sentence.append(self.index_to_word[ix])
98 | return " ".join(sentence)
99 |
100 | def shrink(self, inst):
101 | """
102 | :param inst: The id inst
103 | :return: Remove the potential and
104 | If no return empty list
105 | """
106 | if len(inst) == 0:
107 | return inst
108 | # If no , return empty string
109 | end = np.argmax(np.array(inst) == self.word_to_index[''])
110 | if len(inst) > 1 and inst[0] == self.word_to_index['']:
111 | start = 1
112 | else:
113 | start = 0
114 | # print(inst, start, end)
115 | return inst[start: end]
116 |
117 |
118 | def load_datasets(splits):
119 | """
120 | :param splits: A list of split.
121 | if the split is "something@5000", it will use a random 5000 data from the data
122 | :return:
123 | """
124 | data = []
125 | old_state = random.getstate()
126 | for split in splits:
127 | # It only needs some part of the dataset?
128 | components = split.split("@")
129 | number = -1
130 | if len(components) > 1:
131 | split, number = components[0], int(components[1])
132 |
133 | # Load Json
134 | # if split in ['train', 'val_seen', 'val_unseen', 'test',
135 | # 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check
136 | if "/" not in split:
137 | with open(f'tasks/R2R/data/R2R_{split}.json') as f:
138 | # with open('tasks/R2R/data/R4R_%s_enc.json' % split) as f: # NOTE for r4r
139 | new_data = json.load(f)
140 | else:
141 | with open(split) as f:
142 | new_data = json.load(f)
143 |
144 | # Partition
145 | if number > 0:
146 | random.seed(0) # Make the data deterministic, additive
147 | random.shuffle(new_data)
148 | new_data = new_data[:number]
149 |
150 | # Join
151 | data += new_data
152 | random.setstate(old_state) # Recover the state of the random generator
153 | return data
154 |
155 |
156 | def build_vocab(splits=['train'], min_count=5, start_vocab=base_vocab):
157 | ''' Build a vocab, starting with base vocab containing a few useful tokens. '''
158 | count = Counter()
159 | t = Tokenizer()
160 | data = load_datasets(splits)
161 | for item in data:
162 | for instr in item['instructions']:
163 | count.update(t.split_sentence(instr))
164 | vocab = list(start_vocab)
165 | for word, num in count.most_common():
166 | if num >= min_count:
167 | vocab.append(word)
168 | else:
169 | break
170 | return vocab
171 |
172 |
173 | def write_vocab(vocab, path):
174 | print(f'Writing vocab of size {len(vocab)} to {path}')
175 | with open(path, 'w') as f:
176 | for word in vocab:
177 | f.write(f"{word}\n")
178 |
179 |
180 | def read_vocab(path):
181 | with open(path) as f:
182 | vocab = [word.strip() for word in f.readlines()]
183 | return vocab
184 |
--------------------------------------------------------------------------------
/reduce_checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 |
6 | input_dir = "results_lm_vis_final_rxr"
7 | input_file = os.path.join(input_dir, "checkpoint-7B.pth")
8 | output_file = os.path.join(input_dir, "checkpoint-7B-reduced.pth")
9 |
10 | checkpoint = torch.load(input_file, map_location="cpu")
11 | reduced_checkpoint = {}
12 |
13 | train_param_name = [
14 | "gate",
15 | "clip_proj",
16 | "clip_proj_norm",
17 | "clip_ob_proj",
18 | "clip_ob_proj_norm",
19 | "ob_ang_linear",
20 | "ob_ang_layer_norm",
21 | "visual_query",
22 | "visual_blocks",
23 | "visual_proj",
24 | "visual_proj_norm",
25 | "adapter_query",
26 | "ob_query",
27 | "action_query",
28 | "history_embeddings",
29 | "logits_temp",
30 | ]
31 |
32 | for key, value in checkpoint["model"].items():
33 | if key.startswith("llama.layers"):
34 | layer_num = int(key.split(".")[2])
35 | if layer_num >= 30:
36 | reduced_checkpoint[key] = value
37 | elif key.startswith("llama.norm"):
38 | reduced_checkpoint[key] = value
39 | else:
40 | for train_name in train_param_name:
41 | if train_name in key:
42 | reduced_checkpoint[key] = value
43 |
44 | print(f"Saved keys: {reduced_checkpoint.keys()}")
45 | checkpoint["model"] = reduced_checkpoint
46 | torch.save(checkpoint, output_file)
47 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 | torch==2.0.0+cu117
3 | torchvision==0.15.1+cu117
4 | fairscale
5 | sentencepiece
6 | Pillow
7 | opencv-python
8 | gradio
9 | tqdm
10 | git+https://github.com/csuhan/timm_0_3_2.git
11 | git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/util/bleu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Python implementation of BLEU and smooth-BLEU.
17 |
18 | This module provides a Python implementation of BLEU and smooth-BLEU.
19 | Smooth BLEU is computed following the method outlined in the paper:
20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
21 | evaluation metrics for machine translation. COLING 2004.
22 | """
23 |
24 | import collections
25 | import math
26 |
27 |
28 | def _get_ngrams(segment, max_order):
29 | """Extracts all n-grams upto a given maximum order from an input segment.
30 |
31 | Args:
32 | segment: text segment from which n-grams will be extracted.
33 | max_order: maximum length in tokens of the n-grams returned by this
34 | methods.
35 |
36 | Returns:
37 | The Counter containing all n-grams upto max_order in segment
38 | with a count of how many times each n-gram occurred.
39 | """
40 | ngram_counts = collections.Counter()
41 | for order in range(1, max_order + 1):
42 | for i in range(0, len(segment) - order + 1):
43 | ngram = tuple(segment[i:i+order])
44 | ngram_counts[ngram] += 1
45 | return ngram_counts
46 |
47 |
48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4,
49 | smooth=False):
50 | """Computes BLEU score of translated segments against one or more references.
51 |
52 | Args:
53 | reference_corpus: list of lists of references for each translation. Each
54 | reference should be tokenized into a list of tokens.
55 | translation_corpus: list of translations to score. Each translation
56 | should be tokenized into a list of tokens.
57 | max_order: Maximum n-gram order to use when computing BLEU score.
58 | smooth: Whether or not to apply Lin et al. 2004 smoothing.
59 |
60 | Returns:
61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
62 | precisions and brevity penalty.
63 | """
64 | matches_by_order = [0] * max_order
65 | possible_matches_by_order = [0] * max_order
66 | reference_length = 0
67 | translation_length = 0
68 | for (references, translation) in zip(reference_corpus, translation_corpus):
69 | reference_length += min(len(r) for r in references)
70 | translation_length += len(translation)
71 |
72 | merged_ref_ngram_counts = collections.Counter()
73 | for reference in references:
74 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
75 | translation_ngram_counts = _get_ngrams(translation, max_order)
76 | overlap = translation_ngram_counts & merged_ref_ngram_counts
77 | for ngram in overlap:
78 | matches_by_order[len(ngram)-1] += overlap[ngram]
79 | for order in range(1, max_order+1):
80 | possible_matches = len(translation) - order + 1
81 | if possible_matches > 0:
82 | possible_matches_by_order[order-1] += possible_matches
83 |
84 | precisions = [0] * max_order
85 | for i in range(0, max_order):
86 | if smooth:
87 | precisions[i] = ((matches_by_order[i] + 1.) /
88 | (possible_matches_by_order[i] + 1.))
89 | else:
90 | if possible_matches_by_order[i] > 0:
91 | precisions[i] = (float(matches_by_order[i]) /
92 | possible_matches_by_order[i])
93 | else:
94 | precisions[i] = 0.0
95 |
96 | if min(precisions) > 0:
97 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
98 | geo_mean = math.exp(p_log_sum)
99 | else:
100 | geo_mean = 0
101 |
102 | ratio = float(translation_length) / reference_length
103 |
104 | if ratio > 1.0:
105 | bp = 1.
106 | elif ratio == 0.:
107 | bp = 0.
108 | else:
109 | bp = math.exp(1 - 1. / ratio)
110 |
111 | bleu = geo_mean * bp
112 |
113 | return (bleu, precisions, bp, ratio, translation_length, reference_length)
114 |
--------------------------------------------------------------------------------
/util/extract_adapter_from_checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def save(full_model, path, model_type = 'BIAS'):
4 | if model_type == 'BIAS':
5 | keys = [
6 | f'visual_blocks.{i}.{key}.{suffix}'
7 | for i in range(8)
8 | for key in ['norm1', 'attn.qkv', 'attn.proj', 'norm2', 'mlp.fc1', 'mlp.fc2']
9 | for suffix in ['weight', 'bias']
10 | ] + [
11 | f'llama.layers.{i}.{key}'
12 | for i in range(32)
13 | for key in ['attention.gate', 'attention.wq.bias', 'attention.wo.bias', 'feed_forward.w1.bias', 'feed_forward.w2.bias', 'feed_forward.w3.bias', 'attention_norm.weight', 'ffn_norm.weight']
14 | ] + [
15 | f'{base_key}.{suffix}'
16 | for base_key in ['clip_proj_norm', 'visual_proj_norm', 'visual_proj', 'clip_proj']
17 | for suffix in ['weight', 'bias']
18 | ] + ['llama.norm.weight', 'visual_query.weight', 'adapter_query.weight']
19 |
20 |
21 | elif model_type == 'LORA':
22 | keys = [
23 | f'visual_blocks.{i}.{key}.{suffix}'
24 | for i in range(8)
25 | for key in [f'norm{j}' for j in range(1, 3)] + ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']
26 | for suffix in ['weight', 'bias']
27 | ] + [
28 | f'llama.layers.{i}.{key}'
29 | for i in range(32)
30 | for key in ['attention.gate', 'attention.wq.bias', 'attention.wo.bias', 'feed_forward.w1.bias', 'feed_forward.w2.bias', 'feed_forward.w3.bias', 'attention_norm.weight', 'ffn_norm.weight']
31 | + [f'attention.lora_wk_l{j}.weight' for j in range(1, 3)]
32 | + [f'attention.lora_wo_l{j}.weight' for j in range(1, 3)]
33 | + [f'feed_forward.lora_w{k}_l{j}.weight' for k in range(1, 4) for j in range(1, 3)]
34 | + [f'attention.lora_wq_l{j}.weight' for j in range(1, 3)]
35 | + [f'attention.lora_wv_l{j}.weight' for j in range(1, 3)]
36 | + ['attention.new_gate']
37 | ] + [
38 | f'{base_key}.{suffix}'
39 | for base_key in ['clip_proj_norm', 'visual_proj_norm', 'visual_proj', 'clip_proj']
40 | for suffix in ['weight', 'bias']
41 | ] + ['llama.norm.weight', 'visual_query.weight', 'adapter_query.weight']
42 |
43 | ## TODO: Add other model types
44 |
45 | full_model_state_dict = full_model.state_dict()
46 | small_weights = {key: full_model_state_dict[key] for key in keys}
47 | if model_type == 'BIAS':
48 | wrapped_small_weights = {'model': small_weights,'config': {'w_bias': True, 'w_lora': False, 'lora_rank': 16}}
49 | elif model_type == 'LORA':
50 | wrapped_small_weights = {'model': small_weights,'config': {'w_bias': True, 'w_lora': True, 'lora_rank': 16}}
51 | # Save the wrapped small weights
52 | torch.save(wrapped_small_weights, path)
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | def adjust_learning_rate(optimizer, epoch, args):
10 | """Decay the learning rate with half-cycle cosine after warmup"""
11 | if epoch < args.warmup_epochs:
12 | lr = args.lr * epoch / args.warmup_epochs
13 | else:
14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import builtins
13 | import datetime
14 | import os
15 | import time
16 | from collections import defaultdict, deque
17 | from pathlib import Path
18 | import urllib
19 | from tqdm import tqdm
20 |
21 | import torch
22 | import torch.utils.data
23 | import torch.distributed as dist
24 | from torch import inf
25 |
26 |
27 | class SmoothedValue(object):
28 | """Track a series of values and provide access to smoothed values over a
29 | window or the global series average.
30 | """
31 |
32 | def __init__(self, window_size=20, fmt=None):
33 | if fmt is None:
34 | fmt = "{median:.4f} ({global_avg:.4f})"
35 | self.deque = deque(maxlen=window_size)
36 | self.total = 0.0
37 | self.count = 0
38 | self.fmt = fmt
39 |
40 | def update(self, value, n=1):
41 | self.deque.append(value)
42 | self.count += n
43 | self.total += value * n
44 |
45 | def synchronize_between_processes(self):
46 | """
47 | Warning: does not synchronize the deque!
48 | """
49 | if not is_dist_avail_and_initialized():
50 | return
51 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
52 | dist.barrier()
53 | dist.all_reduce(t)
54 | t = t.tolist()
55 | self.count = int(t[0])
56 | self.total = t[1]
57 |
58 | @property
59 | def median(self):
60 | d = torch.tensor(list(self.deque))
61 | return d.median().item()
62 |
63 | @property
64 | def avg(self):
65 | d = torch.tensor(list(self.deque), dtype=torch.float32)
66 | return d.mean().item()
67 |
68 | @property
69 | def global_avg(self):
70 | return self.total / self.count
71 |
72 | @property
73 | def max(self):
74 | return max(self.deque)
75 |
76 | @property
77 | def value(self):
78 | return self.deque[-1]
79 |
80 | def __str__(self):
81 | return self.fmt.format(
82 | median=self.median,
83 | avg=self.avg,
84 | global_avg=self.global_avg,
85 | max=self.max,
86 | value=self.value)
87 |
88 |
89 | class MetricLogger(object):
90 | def __init__(self, delimiter="\t"):
91 | self.meters = defaultdict(SmoothedValue)
92 | self.delimiter = delimiter
93 |
94 | def update(self, **kwargs):
95 | for k, v in kwargs.items():
96 | if v is None:
97 | continue
98 | if isinstance(v, torch.Tensor):
99 | v = v.item()
100 | assert isinstance(v, (float, int))
101 | self.meters[k].update(v)
102 |
103 | def __getattr__(self, attr):
104 | if attr in self.meters:
105 | return self.meters[attr]
106 | if attr in self.__dict__:
107 | return self.__dict__[attr]
108 | raise AttributeError("'{}' object has no attribute '{}'".format(
109 | type(self).__name__, attr))
110 |
111 | def __str__(self):
112 | loss_str = []
113 | for name, meter in self.meters.items():
114 | loss_str.append(
115 | "{}: {}".format(name, str(meter))
116 | )
117 | return self.delimiter.join(loss_str)
118 |
119 | def synchronize_between_processes(self):
120 | for meter in self.meters.values():
121 | meter.synchronize_between_processes()
122 |
123 | def add_meter(self, name, meter):
124 | self.meters[name] = meter
125 |
126 | def log_every(self, iterable, print_freq, header=None):
127 | i = 0
128 | if not header:
129 | header = ''
130 | start_time = time.time()
131 | end = time.time()
132 | iter_time = SmoothedValue(fmt='{avg:.4f}')
133 | data_time = SmoothedValue(fmt='{avg:.4f}')
134 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
135 | log_msg = [
136 | header,
137 | '[{0' + space_fmt + '}/{1}]',
138 | 'eta: {eta}',
139 | '{meters}',
140 | 'time: {time}',
141 | 'data: {data}'
142 | ]
143 | if torch.cuda.is_available():
144 | log_msg.append('max mem: {memory:.0f}')
145 | log_msg = self.delimiter.join(log_msg)
146 | MB = 1024.0 * 1024.0
147 | for obj in iterable:
148 | data_time.update(time.time() - end)
149 | yield obj
150 | iter_time.update(time.time() - end)
151 | if i % print_freq == 0 or i == len(iterable) - 1:
152 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
153 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
154 | if torch.cuda.is_available():
155 | print(log_msg.format(
156 | i, len(iterable), eta=eta_string,
157 | meters=str(self),
158 | time=str(iter_time), data=str(data_time),
159 | memory=torch.cuda.max_memory_allocated() / MB))
160 | else:
161 | print(log_msg.format(
162 | i, len(iterable), eta=eta_string,
163 | meters=str(self),
164 | time=str(iter_time), data=str(data_time)))
165 | i += 1
166 | end = time.time()
167 | total_time = time.time() - start_time
168 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169 | print('{} Total time: {} ({:.4f} s / it)'.format(
170 | header, total_time_str, total_time / len(iterable)))
171 |
172 |
173 | def setup_for_distributed(is_master):
174 | """
175 | This function disables printing when not in master process
176 | """
177 | builtin_print = builtins.print
178 |
179 | def print(*args, **kwargs):
180 | force = kwargs.pop('force', False)
181 | force = force or (get_world_size() > 8)
182 | if is_master or force:
183 | now = datetime.datetime.now().time()
184 | builtin_print('[{}] '.format(now), end='') # print with time stamp
185 | builtin_print(*args, **kwargs)
186 |
187 | builtins.print = print
188 |
189 |
190 | def is_dist_avail_and_initialized():
191 | if not dist.is_available():
192 | return False
193 | if not dist.is_initialized():
194 | return False
195 | return True
196 |
197 |
198 | def get_world_size():
199 | if not is_dist_avail_and_initialized():
200 | return 1
201 | return dist.get_world_size()
202 |
203 |
204 | def get_rank():
205 | if not is_dist_avail_and_initialized():
206 | return 0
207 | return dist.get_rank()
208 |
209 |
210 | def is_main_process():
211 | return get_rank() == 0
212 |
213 |
214 | def save_on_master(*args, **kwargs):
215 | if is_main_process():
216 | torch.save(*args, **kwargs)
217 |
218 |
219 | def init_distributed_mode(args):
220 | if args.dist_on_itp:
221 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
222 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
223 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
224 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
225 | os.environ['LOCAL_RANK'] = str(args.gpu)
226 | os.environ['RANK'] = str(args.rank)
227 | os.environ['WORLD_SIZE'] = str(args.world_size)
228 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
229 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
230 | args.rank = int(os.environ["RANK"])
231 | args.world_size = int(os.environ['WORLD_SIZE'])
232 | args.gpu = int(os.environ['LOCAL_RANK'])
233 | elif 'SLURM_PROCID' in os.environ:
234 | args.rank = int(os.environ['SLURM_PROCID'])
235 | args.gpu = args.rank % torch.cuda.device_count()
236 | else:
237 | print('Not using distributed mode')
238 | setup_for_distributed(is_master=True) # hack
239 | args.distributed = False
240 | return
241 |
242 | args.distributed = True
243 |
244 | print("GPU::", args.gpu)
245 | torch.cuda.set_device(args.gpu)
246 | args.dist_backend = 'nccl'
247 | print('| distributed init (rank {}): {}, gpu {}'.format(
248 | args.rank, args.dist_url, args.gpu), flush=True)
249 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
250 | world_size=args.world_size, rank=args.rank)
251 | torch.distributed.barrier()
252 | setup_for_distributed(args.rank == 0)
253 |
254 |
255 | class NativeScalerWithGradNormCount:
256 | state_dict_key = "amp_scaler"
257 |
258 | def __init__(self):
259 | self._scaler = torch.cuda.amp.GradScaler()
260 |
261 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
262 | self._scaler.scale(loss).backward(create_graph=create_graph)
263 | if update_grad:
264 | if clip_grad is not None:
265 | assert parameters is not None
266 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
267 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
268 | else:
269 | self._scaler.unscale_(optimizer)
270 | norm = get_grad_norm_(parameters)
271 | self._scaler.step(optimizer)
272 | self._scaler.update()
273 | else:
274 | norm = None
275 | return norm
276 |
277 | def state_dict(self):
278 | return self._scaler.state_dict()
279 |
280 | def load_state_dict(self, state_dict):
281 | self._scaler.load_state_dict(state_dict)
282 |
283 |
284 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
285 | if isinstance(parameters, torch.Tensor):
286 | parameters = [parameters]
287 | parameters = [p for p in parameters if p.grad is not None]
288 | norm_type = float(norm_type)
289 | if len(parameters) == 0:
290 | return torch.tensor(0.)
291 | device = parameters[0].grad.device
292 | if norm_type == inf:
293 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
294 | else:
295 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
296 | return total_norm
297 |
298 |
299 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
300 | output_dir = Path(args.output_dir)
301 | epoch_name = str(epoch)
302 | if loss_scaler is not None:
303 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
304 | for checkpoint_path in checkpoint_paths:
305 | to_save = {
306 | 'model': model_without_ddp.state_dict(),
307 | 'optimizer': optimizer.state_dict(),
308 | 'epoch': epoch,
309 | 'scaler': loss_scaler.state_dict(),
310 | 'args': args,
311 | }
312 |
313 | save_on_master(to_save, checkpoint_path)
314 | else:
315 | client_state = {'epoch': epoch}
316 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
317 |
318 |
319 | def load_model(model_without_ddp, path):
320 | if path.startswith('https'):
321 | checkpoint = torch.hub.load_state_dict_from_url(
322 | path, map_location='cpu', check_hash=True)
323 | else:
324 | checkpoint = torch.load(path, map_location='cpu')
325 | new_checkpoint = {}
326 | for key, value in checkpoint['model'].items():
327 | key = key.replace("llma", "llama")
328 | new_checkpoint[key] = value
329 | print(model_without_ddp.load_state_dict(new_checkpoint, strict=False))
330 | print("Load checkpoint %s" % path)
331 |
332 |
333 | def all_reduce_mean(x):
334 | world_size = get_world_size()
335 | if world_size > 1:
336 | x_reduce = torch.tensor(x).cuda()
337 | dist.all_reduce(x_reduce)
338 | x_reduce /= world_size
339 | return x_reduce.item()
340 | else:
341 | return x
342 |
343 |
344 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
345 | decay = []
346 | no_decay = []
347 | for name, param in model.named_parameters():
348 | if not param.requires_grad:
349 | continue # frozen weights
350 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
351 | no_decay.append(param)
352 | else:
353 | decay.append(param)
354 | return [
355 | {'params': no_decay, 'weight_decay': 0.},
356 | {'params': decay, 'weight_decay': weight_decay}]
357 |
358 |
359 | class DistributedSubEpochSampler(torch.utils.data.Sampler):
360 |
361 | def __init__(self, dataset, num_replicas, rank, shuffle, split_epoch=1, seed=0):
362 | self.dataset = dataset
363 | self.num_replicas = num_replicas
364 | self.rank = rank
365 | self.shuffle = shuffle
366 | self.split_epoch = split_epoch
367 | self.seed = seed
368 |
369 | self.num_samples = len(dataset) // (num_replicas * split_epoch)
370 |
371 | def __len__(self):
372 | return self.num_samples
373 |
374 | def __iter__(self):
375 | if self.shuffle:
376 | # deterministically shuffle based on epoch and seed
377 | g = torch.Generator()
378 | g.manual_seed(self.seed + self.epoch // self.split_epoch)
379 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
380 | else:
381 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
382 |
383 | indices = indices[self.rank * self.split_epoch + self.epoch % self.split_epoch::self.num_replicas * self.split_epoch]
384 | assert len(indices) >= self.num_samples
385 | indices = indices[:self.num_samples]
386 |
387 | return iter(indices)
388 |
389 | def set_epoch(self, epoch):
390 | self.epoch = epoch
391 |
392 | def download(url: str, root: str):
393 | os.makedirs(root, exist_ok=True)
394 | filename = os.path.basename(url)
395 | download_target = os.path.join(root, filename)
396 |
397 | if os.path.exists(download_target) and not os.path.isfile(download_target):
398 | raise RuntimeError(f"{download_target} exists and is not a regular file")
399 |
400 | if os.path.isfile(download_target):
401 | return download_target
402 |
403 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
404 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
405 | while True:
406 | buffer = source.read(8192)
407 | if not buffer:
408 | break
409 | output.write(buffer)
410 | loop.update(len(buffer))
411 |
412 |
413 | return download_target
--------------------------------------------------------------------------------