├── .gitignore
├── LICENSE_llama2
├── OneLLM_Arxiv.pdf
├── README.md
├── config
└── llama2
│ ├── 7B.json
│ └── tokenizer.model
├── data
├── conversation_lib.py
├── data_utils.py
├── finetune_dataset.py
├── imu_utils.py
├── pretrain_dataset.py
└── video_utils.py
├── demos
├── cli.py
└── multi_turn_mm.py
├── docs
├── Data.md
└── Evaluation.md
├── engine_finetune.py
├── engine_pretrain.py
├── eval
├── audio_cap_clothov2.py
├── caption_eval.py
├── fmri_cap_nsd.py
├── image_bench_mmvet.py
├── image_cap_cococap.py
├── imu_cap_ego4d.py
├── point_cap_pointllm.py
└── video_qa_msvd.py
├── exps
├── image_text_pretrain_8gpu.sh
├── image_text_pretrain_slurm.sh
├── multimodal_text_finetune.sh
├── multimodal_text_pretrain_stage2.sh
└── multimodal_text_pretrain_stage3.sh
├── main_finetune.py
├── main_pretrain.py
├── model
├── LLM
│ ├── __init__.py
│ └── onellm.py
├── __init__.py
├── components.py
├── lib
│ ├── point_utils.py
│ └── pointnet2
│ │ ├── pointnet2_modules.py
│ │ ├── pointnet2_utils.py
│ │ ├── pytorch_utils.py
│ │ ├── setup.py
│ │ └── src
│ │ ├── ball_query.cpp
│ │ ├── ball_query_gpu.cu
│ │ ├── ball_query_gpu.h
│ │ ├── cuda_utils.h
│ │ ├── group_points.cpp
│ │ ├── group_points_gpu.cu
│ │ ├── group_points_gpu.h
│ │ ├── interpolate.cpp
│ │ ├── interpolate_gpu.cu
│ │ ├── interpolate_gpu.h
│ │ ├── pointnet2_api.cpp
│ │ ├── sampling.cpp
│ │ ├── sampling_gpu.cu
│ │ └── sampling_gpu.h
├── meta.py
└── tokenizer.py
├── requirements.txt
└── util
├── lr_sched.py
├── misc.py
└── pos_embed.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | *.egg-info
4 | dist
5 |
6 | output
7 | output_dir
8 | *.pth
9 | *.log
10 | weights
11 | datasets
12 | data_hub
13 | slurm*
14 | multimodal_llama2_7B
15 | weights
--------------------------------------------------------------------------------
/LICENSE_llama2:
--------------------------------------------------------------------------------
1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT
2 | Llama 2 Version Release Date: July 18, 2023
3 |
4 | "Agreement" means the terms and conditions for use, reproduction, distribution and
5 | modification of the Llama Materials set forth herein.
6 |
7 | "Documentation" means the specifications, manuals and documentation
8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
9 | libraries/llama-downloads/.
10 |
11 | "Licensee" or "you" means you, or your employer or any other person or entity (if
12 | you are entering into this Agreement on such person or entity's behalf), of the age
13 | required under applicable laws, rules or regulations to provide legal consent and that
14 | has legal authority to bind your employer or such other person or entity if you are
15 | entering in this Agreement on their behalf.
16 |
17 | "Llama 2" means the foundational large language models and software and
18 | algorithms, including machine-learning model code, trained model weights,
19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other
20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
21 | libraries/llama-downloads/.
22 |
23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and
24 | Documentation (and any portion thereof) made available under this Agreement.
25 |
26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta
28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland).
29 |
30 | By clicking "I Accept" below or by using or distributing any portion or element of the
31 | Llama Materials, you agree to be bound by this Agreement.
32 |
33 | 1. License Rights and Redistribution.
34 |
35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
36 | transferable and royalty-free limited license under Meta's intellectual property or
37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce,
38 | distribute, copy, create derivative works of, and make modifications to the Llama
39 | Materials.
40 |
41 | b. Redistribution and Use.
42 |
43 | i. If you distribute or make the Llama Materials, or any derivative works
44 | thereof, available to a third party, you shall provide a copy of this Agreement to such
45 | third party.
46 | ii. If you receive Llama Materials, or any derivative works thereof, from
47 | a Licensee as part of an integrated end user product, then Section 2 of this
48 | Agreement will not apply to you.
49 |
50 | iii. You must retain in all copies of the Llama Materials that you
51 | distribute the following attribution notice within a "Notice" text file distributed as a
52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved."
54 |
55 | iv. Your use of the Llama Materials must comply with applicable laws
56 | and regulations (including trade compliance laws and regulations) and adhere to the
57 | Acceptable Use Policy for the Llama Materials (available at
58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
59 | this Agreement.
60 |
61 | v. You will not use the Llama Materials or any output or results of the
62 | Llama Materials to improve any other large language model (excluding Llama 2 or
63 | derivative works thereof).
64 |
65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the
66 | monthly active users of the products or services made available by or for Licensee,
67 | or Licensee's affiliates, is greater than 700 million monthly active users in the
68 | preceding calendar month, you must request a license from Meta, which Meta may
69 | grant to you in its sole discretion, and you are not authorized to exercise any of the
70 | rights under this Agreement unless or until Meta otherwise expressly grants you
71 | such rights.
72 |
73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
82 |
83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
89 | ANY OF THE FOREGOING.
90 |
91 | 5. Intellectual Property.
92 |
93 | a. No trademark licenses are granted under this Agreement, and in
94 | connection with the Llama Materials, neither Meta nor Licensee may use any name
95 | or mark owned by or associated with the other or any of its affiliates, except as
96 | required for reasonable and customary use in describing and redistributing the
97 | Llama Materials.
98 |
99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or
100 | for Meta, with respect to any derivative works and modifications of the Llama
101 | Materials that are made by you, as between you and Meta, you are and will be the
102 | owner of such derivative works and modifications.
103 |
104 | c. If you institute litigation or other proceedings against Meta or any entity
105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
107 | constitutes infringement of intellectual property or other rights owned or licensable
108 | by you, then any licenses granted to you under this Agreement shall terminate as of
109 | the date such litigation or claim is filed or instituted. You will indemnify and hold
110 | harmless Meta from and against any claim by any third party arising out of or related
111 | to your use or distribution of the Llama Materials.
112 |
113 | 6. Term and Termination. The term of this Agreement will commence upon your
114 | acceptance of this Agreement or access to the Llama Materials and will continue in
115 | full force and effect until terminated in accordance with the terms and conditions
116 | herein. Meta may terminate this Agreement if you are in breach of any term or
117 | condition of this Agreement. Upon termination of this Agreement, you shall delete
118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
119 | termination of this Agreement.
120 |
121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and
122 | construed under the laws of the State of California without regard to choice of law
123 | principles, and the UN Convention on Contracts for the International Sale of Goods
124 | does not apply to this Agreement. The courts of California shall have exclusive
125 | jurisdiction of any dispute arising out of this Agreement.
126 |
127 |
--------------------------------------------------------------------------------
/OneLLM_Arxiv.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csuhan/OneLLM/8587a4768cf376fb41f7d586e21de5d1ab1ca365/OneLLM_Arxiv.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## OneLLM: One Framework to Align All Modalities with Language
2 |
3 | [[Project Page](https://onellm.csuhan.com)] [[Paper](https://arxiv.org/abs/2312.03700)] [[HF Demo🤗](https://huggingface.co/spaces/csuhan/OneLLM)] [[Modelscope Demo🤖](https://modelscope.cn/studios/csuhan/OneLLM)] [[Model🤗](https://huggingface.co/csuhan/OneLLM-7B)] [[Data](docs/Data.md)]
4 |
5 | ## News
6 |
7 | - **2024.02.27** OneLLM is accepted by **CVPR 2024**!🎉
8 | - **2023.12.01** Release model weights and inference code.
9 |
10 | ## Contents
11 |
12 | - [Install](#install)
13 | - [Models](#models)
14 | - [Demo](#demo)
15 | - [Data](#data)
16 | - [Evaluation](#evaluation)
17 | - [Training](#training)
18 |
19 | ### Install
20 |
21 | 1. Clone the repo into a local folder.
22 |
23 | ```bash
24 | git clone https://github.com/csuhan/OneLLM
25 |
26 | cd OneLLM
27 | ```
28 |
29 | 2. Install packages.
30 |
31 | ```bash
32 | conda create -n onellm python=3.9 -y
33 | conda activate onellm
34 |
35 | pip install -r requirements.txt
36 |
37 | # install pointnet
38 | cd model/lib/pointnet2
39 | python setup.py install
40 | ```
41 |
42 | 3. Install Apex. (Optional)
43 |
44 | ```bash
45 | git clone https://github.com/NVIDIA/apex
46 | cd apex
47 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
48 | ```
49 |
50 | ### Models
51 |
52 | We provide a preview model on the Hugging Face at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B).
53 |
54 | ### Demo
55 |
56 | **Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM).
57 |
58 | **Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally.
59 |
60 | ```bash
61 | python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth
62 | ```
63 |
64 | **CLI Demo:**
65 | ```bash
66 | python demos/cli.py --image_path ${IMAGE_PATH} --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth
67 | ```
68 |
69 | ### Data
70 |
71 | Please check [Data.md](docs/Data.md) for more detail.
72 |
73 | ### Evaluation
74 |
75 | Please check [Evaluation.md](docs/Evaluation.md) for more detail.
76 |
77 | ### Training
78 |
79 | #### Image-Text Pretraining
80 |
81 | **Single Node 8-GPU Training**: [exps/image_text_pretrain_8gpu.sh](exps/image_text_pretrain_8gpu.sh)
82 | Show More
83 |
84 | ```bash
85 | torchrun --nproc_per_node=8 main_pretrain.py \
86 | --epochs 1 --dataset image \
87 | --batch_size 40 --accum_iter 16 \
88 | --model_parallel_size 1 \
89 | --data_parallel sdp \
90 | --save_consolidated \
91 | --llama_type onellm \
92 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
93 | --llama_config config/llama2/7B.json \
94 | --tokenizer_path config/llama2/tokenizer.model \
95 | --auto_resume \
96 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
97 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \
98 | --save_freq 1000 \
99 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
100 | ```
101 |
102 |
103 | **Multi Nodes DDP Training**:
104 |
105 | Run N scripts on N nodes at the time, then we can launch a multi-node DDP training. Following is an example script for one node:
106 |
107 | ```
108 | MASTER_ADDR=IP_ADDRESS_OF_NODE_1
109 | NNODES=N
110 | MASTER_PORT=29500
111 | NPROC_PER_NODE=8
112 |
113 | RANK=0 or 1 or ... or N
114 |
115 | bash
116 | torchrun \
117 | --nnodes=$NNODES \
118 | --nproc_per_node=8 \
119 | --node_rank=$RANK \
120 | --master_port=$MASTER_PORT \
121 | --master_addr=$MASTER_ADDR \
122 | main_pretrain.py \
123 | --epochs 1 --dataset image \
124 | --batch_size 40 --accum_iter 16 \
125 | --model_parallel_size 1 \
126 | --data_parallel sdp \
127 | --save_consolidated \
128 | --llama_type onellm \
129 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
130 | --llama_config config/llama2/7B.json \
131 | --tokenizer_path config/llama2/tokenizer.model \
132 | --auto_resume \
133 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
134 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \
135 | --save_freq 1000 \
136 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
137 | ```
138 |
139 | **Multi Node SLURM Training**: [exps/image_text_pretrain_slurm.sh](exps/image_text_pretrain_slurm.sh)
140 | Show More
141 |
142 | ```bash
143 | #!/bin/bash
144 | #SBATCH --gres=gpu:8
145 | #SBATCH -n 16
146 | #SBATCH -N 2
147 | #SBATCH --cpus-per-task=16
148 |
149 | srun python -u main_pretrain.py \
150 | --epochs 1 --dataset image \
151 | --batch_size 40 --accum_iter 8 \
152 | --model_parallel_size 1 \
153 | --data_parallel sdp \
154 | --save_consolidated \
155 | --llama_type onellm \
156 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
157 | --llama_config config/llama2/7B.json \
158 | --tokenizer_path config/llama2/tokenizer.model \
159 | --auto_resume \
160 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
161 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \
162 | --save_freq 1000 \
163 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
164 | ```
165 |
166 |
167 | #### Multimodal-Text Pretraining
168 |
169 | **Stage II Pretraining**: Assume we have the pretrained `${IMAGE_TEXT_MODEL}`, run [exps/multimodal_text_pretrain_stage2.sh](exps/multimodal_text_pretrain_stage2.sh) for video-audio-point-text pretraining.
170 |
171 | **Stage III Pretraining**: Assume we have the pretrained `${STAGE2_MODEL}`, run [exps/multimodal_text_pretrain_stage3.sh](exps/multimodal_text_pretrain_stage3.sh) for depth-normal-imu-fmri-text pretraining.
172 |
173 | #### Instruction Tuning
174 |
175 | Assume we have the pretrained `${STAGE3_MODEL}`, run [exps/multimodal_text_finetune.sh](exps/multimodal_text_finetune.sh) for multimodal instruction tuning.
176 |
177 | ## Citation
178 |
179 | ```
180 | @InProceedings{han2023onellm,
181 | title={OneLLM: One Framework to Align All Modalities with Language},
182 | author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu},
183 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
184 | year={2024}
185 | }
186 | ```
187 |
188 | ## Acknowledgement
189 |
190 | [LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge)
191 |
192 | ## License
193 | This project is developed based on Llama 2, please refer to the [LLAMA 2 Community License](LICENSE_llama2).
194 |
--------------------------------------------------------------------------------
/config/llama2/7B.json:
--------------------------------------------------------------------------------
1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
2 |
--------------------------------------------------------------------------------
/config/llama2/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csuhan/OneLLM/8587a4768cf376fb41f7d586e21de5d1ab1ca365/config/llama2/tokenizer.model
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import torchaudio
4 | import torchvision.transforms.functional as F
5 | import torchvision.transforms as transforms
6 |
7 | try:
8 | from torchvision.transforms import InterpolationMode
9 |
10 | BICUBIC = InterpolationMode.BICUBIC
11 | except ImportError:
12 | BICUBIC = Image.BICUBIC
13 |
14 | T_random_resized_crop = transforms.Compose([
15 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
16 | antialias=None), # 3 is bicubic
17 | transforms.ToTensor(),
18 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
19 |
20 |
21 | # image transform
22 | transform_img_train = transforms.Compose([
23 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
24 | 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
25 | transforms.ToTensor(),
26 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
27 |
28 |
29 | class PairRandomResizedCrop(transforms.RandomResizedCrop):
30 | def forward(self, imgs):
31 | i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
32 | return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
33 |
34 |
35 | class PairToTensor(transforms.ToTensor):
36 | def __call__(self, pics):
37 | return [F.to_tensor(pic) for pic in pics]
38 |
39 |
40 | class PairNormalize(transforms.Normalize):
41 | def forward(self, tensors):
42 | return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
43 |
44 |
45 | transform_pairimg_train = transforms.Compose([
46 | PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
47 | 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic
48 | PairToTensor(),
49 | PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
50 |
51 |
52 | def pc_norm(pc):
53 | """ pc: NxC, return NxC """
54 | xyz = pc[:, :3]
55 | other_feature = pc[:, 3:]
56 |
57 | centroid = torch.mean(xyz, dim=0)
58 | xyz = xyz - centroid
59 | m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
60 | xyz = xyz / m
61 |
62 | pc = torch.cat((xyz, other_feature), dim=1)
63 | return pc
64 |
65 |
66 | def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
67 | waveform, sr = torchaudio.load(wav_name)
68 | # assert sr == 16000, 'input audio sampling rate must be 16kHz'
69 | if sr != 16000:
70 | trans = torchaudio.transforms.Resample(sr, 16000)
71 | waveform = trans(waveform)
72 |
73 | waveform = waveform - waveform.mean()
74 |
75 | fbank = torchaudio.compliance.kaldi.fbank(
76 | waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
77 | window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
78 |
79 | n_frames = fbank.shape[0]
80 |
81 | p = target_length - n_frames
82 | if p > 0:
83 | m = torch.nn.ZeroPad2d((0, 0, 0, p))
84 | fbank = m(fbank)
85 | elif p < 0:
86 | fbank = fbank[0:target_length, :]
87 |
88 | if aug:
89 | freqm = torchaudio.transforms.FrequencyMasking(48)
90 | timem = torchaudio.transforms.TimeMasking(192)
91 | fbank = torch.transpose(fbank, 0, 1)
92 | fbank = fbank.unsqueeze(0)
93 | fbank = freqm(fbank)
94 | fbank = timem(fbank)
95 | fbank = fbank.squeeze(0)
96 | fbank = torch.transpose(fbank, 0, 1)
97 |
98 | fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
99 | return fbank
--------------------------------------------------------------------------------
/data/imu_utils.py:
--------------------------------------------------------------------------------
1 | import string
2 | import numpy as np
3 | import matplotlib.animation as animation
4 | from matplotlib import pyplot as plt
5 | import json
6 | from collections import defaultdict
7 | from bisect import bisect_left
8 | import os
9 | import torch
10 | import torchaudio
11 | torchaudio.set_audio_backend("sox_io")
12 |
13 |
14 | def load_json(json_path: str):
15 | """
16 | Load a json file
17 | """
18 | with open(json_path, "r", encoding="utf-8") as f_name:
19 | data = json.load(f_name)
20 | return data
21 |
22 |
23 | def check_window_signal(info_t, w_s, w_e):
24 | length = w_e - w_s
25 | frame_offset = int(w_s * info_t.sample_rate)
26 | num_frames = int(length * info_t.sample_rate)
27 | if frame_offset + num_frames > int(info_t.num_frames):
28 | return False
29 | else:
30 | return True
31 |
32 |
33 | def index_narrations(ann_path):
34 | narration_raw = load_json(ann_path)
35 |
36 | narration_dict = defaultdict(list)
37 | summary_dict = defaultdict(list)
38 | avg_len = []
39 | for v_id, narr in narration_raw.items():
40 | narr_list = []
41 | summ_list = []
42 | if "narration_pass_1" in narr:
43 | narr_list += narr["narration_pass_1"]["narrations"]
44 | summ_list += narr["narration_pass_1"]["summaries"]
45 | if "narration_pass_2" in narr:
46 | narr_list += narr["narration_pass_2"]["narrations"]
47 | summ_list += narr["narration_pass_2"]["summaries"]
48 |
49 | if len(narr_list) > 0:
50 | narration_dict[v_id] = [
51 | (
52 | float(n_t["timestamp_sec"]),
53 | n_t["narration_text"],
54 | n_t["annotation_uid"],
55 | n_t["timestamp_frame"],
56 | )
57 | for n_t in narr_list
58 | ]
59 | avg_len.append(len(narration_dict[v_id]))
60 | else:
61 | narration_dict[v_id] = []
62 | if len(summ_list) > 0:
63 | summary_dict[v_id] = [
64 | (
65 | float(s_t["start_sec"]),
66 | float(s_t["end_sec"]),
67 | s_t["summary_text"],
68 | )
69 | for s_t in summ_list
70 | ]
71 | else:
72 | summary_dict[v_id] = []
73 | # print(f"Number of Videos with narration {len(narration_dict)}")
74 | # print(f"Avg. narration length {np.mean(avg_len)}")
75 | # print(f"Number of Videos with summaries {len(summary_dict)}")
76 | return narration_dict, summary_dict
77 |
78 |
79 | def get_signal_info(signal_fn: str):
80 | return torchaudio.info(signal_fn)
81 |
82 |
83 | def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
84 | """
85 | Given a signal track return the frames between video_start_sec and video_end_sec
86 | """
87 | info_t = get_signal_info(signal_fn)
88 |
89 | length = video_end_sec - video_start_sec
90 | aframes, _ = torchaudio.load(
91 | signal_fn,
92 | normalize=True,
93 | frame_offset=int(video_start_sec * info_t.sample_rate),
94 | num_frames=int(length * info_t.sample_rate),
95 | )
96 | return {"signal": aframes, "meta": info_t}
97 |
98 |
99 | def tosec(value):
100 | return value / 1000
101 |
102 |
103 | def toms(value):
104 | return value * 1000
105 |
106 |
107 | def delta(first_num: float, second_num: float):
108 | """Compute the absolute value of the difference of two numbers"""
109 | return abs(first_num - second_num)
110 |
111 |
112 | def padIMU(signal, duration_sec):
113 | """
114 | Pad the signal if necessary
115 | """
116 | expected_elements = round(duration_sec) * 200
117 |
118 | if signal.shape[0] > expected_elements:
119 | signal = signal[:expected_elements, :]
120 | elif signal.shape[0] < expected_elements:
121 | padding = expected_elements - signal.shape[0]
122 | padded_zeros = np.zeros((padding, 6))
123 | signal = np.concatenate([signal, padded_zeros], 0)
124 | # signal = signal[:expected_elements, :]
125 | return signal
126 |
127 |
128 | def resample(
129 | signals: np.ndarray,
130 | timestamps: np.ndarray,
131 | original_sample_rate: int,
132 | resample_rate: int,
133 | ):
134 | """
135 | Resamples data to new sample rate
136 | """
137 | signals = torch.as_tensor(signals)
138 | timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
139 | signals = torchaudio.functional.resample(
140 | waveform=signals.data.T,
141 | orig_freq=original_sample_rate,
142 | new_freq=resample_rate,
143 | ).T.numpy()
144 |
145 | nsamples = len(signals)
146 |
147 | period = 1 / resample_rate
148 |
149 | # timestamps are expected to be shape (N, 1)
150 | initital_seconds = timestamps[0] / 1e3
151 |
152 | ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
153 |
154 | timestamps = (ntimes * 1e3).squeeze().numpy()
155 | return signals, timestamps
156 |
157 |
158 | def resampleIMU(signal, timestamps):
159 | sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
160 | # resample all to 200hz
161 | if sampling_rate != 200:
162 | signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
163 | return signal, timestamps
164 |
165 |
166 | def get_imu_frames(
167 | imu_path,
168 | uid: str,
169 | video_start_sec: float,
170 | video_end_sec: float,
171 | ):
172 | """
173 | Given a IMU signal return the frames between video_start_sec and video_end_sec
174 | """
175 | signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
176 | signal = signal.transpose()
177 | timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
178 |
179 | if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
180 | return None
181 |
182 | start_id = bisect_left(timestamps, toms(video_start_sec))
183 | end_id = bisect_left(timestamps, toms(video_end_sec))
184 |
185 | # make sure the retrieved window interval are correct by a max of 1 sec margin
186 | if (
187 | delta(video_start_sec, tosec(timestamps[start_id])) > 4
188 | or delta(video_end_sec, tosec(timestamps[end_id])) > 4
189 | ):
190 | return None
191 |
192 | # get the window
193 | if start_id == end_id:
194 | start_id -= 1
195 | end_id += 1
196 | signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
197 |
198 | if len(signal) < 10 or len(timestamps) < 10:
199 | return None
200 | # resample the signal at 200hz if necessary
201 | signal, timestamps = resampleIMU(signal, timestamps)
202 |
203 | # pad the signal if necessary
204 | signal = padIMU(signal, video_end_sec - video_start_sec)
205 |
206 | sample_dict = {
207 | "timestamp": timestamps,
208 | "signal": torch.tensor(signal.T),
209 | "sampling_rate": 200,
210 | }
211 |
212 | return sample_dict
213 |
214 |
215 | def display_animation(frames, title, save_path_gif):
216 | fig, ax = plt.subplots()
217 | frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
218 | plt.title(title)
219 | ani = animation.ArtistAnimation(fig, frames)
220 | ani.save(save_path_gif, writer="imagemagick")
221 | plt.close()
222 |
223 |
224 | def display_animation_imu(frames, imu, title, save_path_gif):
225 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
226 | ax1.set_title(title)
227 | ax2.set_title("Acc.")
228 | ax3.set_title("Gyro.")
229 | frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
230 | ani = animation.ArtistAnimation(fig, frames)
231 |
232 | ax2.plot(imu[0].cpu().numpy(), color="red")
233 | ax2.plot(imu[1].cpu().numpy(), color="blue")
234 | ax2.plot(imu[2].cpu().numpy(), color="green")
235 | ax3.plot(imu[3].cpu().numpy(), color="red")
236 | ax3.plot(imu[4].cpu().numpy(), color="blue")
237 | ax3.plot(imu[5].cpu().numpy(), color="green")
238 | plt.tight_layout()
239 | ani.save(save_path_gif, writer="imagemagick")
240 | plt.close()
241 |
242 |
243 | def filter_narration(narration_text: str) -> bool:
244 | if "#c" in narration_text.lower():
245 | return True
246 | return False
247 |
248 |
249 | def clean_narration_text(narration_text: str) -> str:
250 | return (
251 | narration_text.replace("#C C ", "")
252 | .replace("#C", "")
253 | .replace("#unsure", "something")
254 | .strip()
255 | .strip(string.punctuation)
256 | .lower()[:128]
257 | )
258 |
--------------------------------------------------------------------------------
/data/pretrain_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, List
2 | try:
3 | from petrel_client.client import Client
4 | except:
5 | print("petrel_client is not installed.")
6 | import json
7 | import torch
8 | from io import BytesIO
9 | import random
10 | import multiprocessing as mp
11 | import copy
12 | from torch.utils.data import Dataset
13 | from pathlib import Path
14 |
15 | import pandas as pd
16 | from PIL import Image, ImageFile
17 | ImageFile.LOAD_TRUNCATED_IMAGES = True
18 |
19 | from model.tokenizer import Tokenizer
20 | import numpy as np
21 | import warnings
22 | import bisect
23 |
24 | from .data_utils import make_audio_features, pc_norm, transform_pairimg_train, transform_img_train
25 | from . import video_utils
26 | from .imu_utils import get_imu_frames
27 |
28 |
29 | DATASETS = dict(
30 | image=dict(
31 | train=(
32 | sorted(list(Path('datasets/Pretrain/image/laion400m_new').glob('*.csv')))[:1000]+
33 | sorted(list(Path('datasets/Pretrain/image/laion_coco').glob('*.csv')))[:1000]
34 | ),
35 | test= ('datasets/Pretrain/image/coco_caps_train2017.csv',),
36 | max_words=96,
37 | ),
38 | audio=dict(
39 | train=(
40 | sorted(list(Path('datasets/Pretrain/audio/wavcaps').glob('*.csv')))
41 | ),
42 | test=None,
43 | max_words=96,
44 | ),
45 | video=dict(
46 | train=(
47 | "datasets/Pretrain/video/webvid/results_2M_train_ceph.csv",),
48 | test=None,
49 | max_words=96
50 | ),
51 | point=dict(
52 | train=(
53 | "datasets/Pretrain/point/pointllm/cap3d_pointllm_train.csv",
54 | ),
55 | test=(
56 | "datasets/Pretrain/point/pointllm/cap3d_pointllm_test.csv",
57 | ),
58 | max_words=96,
59 | ),
60 | rgbd=dict(
61 | train=(
62 | 'datasets/Pretrain/image/cc3m.csv',),
63 | test=None,
64 | replace_list=['/cc3m/', '/cc3m_depth/'],
65 | max_words=96,
66 | ),
67 | rgbn=dict(
68 | train=(
69 | 'datasets/Pretrain/image/cc3m.csv',
70 | ),
71 | test=None,
72 | replace_list=['/cc3m/', '/cc3m_normal/'],
73 | max_words=96,
74 | ),
75 | fmri=dict(
76 | train=(
77 | "datasets/Pretrain/fmri/nsd/train_sub01.csv",
78 | "datasets/Pretrain/fmri/nsd/val_sub01.csv",
79 | ),
80 | test=(
81 | "datasets/Pretrain/fmri/nsd/test_sub01.csv",
82 | ),
83 | max_words=96
84 | ),
85 | imu=dict(
86 | train=(
87 | "datasets/Pretrain/imu/ego4d/window_idx_train.json",
88 | ),
89 | test=(
90 | "datasets/Pretrain/imu/ego4d/window_idx_val.json",
91 | ),
92 | imu_path="datasets/Pretrain/imu/ego4d/v2/processed_imu/",
93 | max_words=96,
94 | )
95 | )
96 |
97 |
98 | class PretrainDataset(Dataset):
99 | def __init__(self, dataset='image', partition='train', epochs=1, tokenizer_path=None, petrel_conf=None):
100 | self.dataset = dataset
101 | input_filenames = DATASETS[dataset][partition]
102 |
103 | self.petrel_conf = petrel_conf
104 | self.client = None
105 | self.partition = partition
106 | manager = mp.Manager()
107 | self.datas = manager.list()
108 | self.captions = manager.list()
109 | print('loading csv...')
110 | for input_filename in input_filenames:
111 | print(input_filename)
112 | if dataset != 'imu':
113 | chunk = pd.read_csv(input_filename, sep='\t', on_bad_lines='skip', lineterminator='\n')
114 | self.datas.extend(chunk['url'].tolist())
115 | self.captions.extend(chunk['caption'].tolist())
116 | else:
117 | self.datas = json.load(open(input_filename))
118 | self.imu_path = DATASETS[dataset]['imu_path']
119 |
120 | self.max_words = DATASETS[dataset]['max_words']
121 | self.tokenizer = Tokenizer(model_path=tokenizer_path)
122 |
123 | self.epochs = epochs
124 |
125 | def __len__(self):
126 | return int(len(self.datas) * self.epochs)
127 |
128 | def load_trans_image(self, image_path):
129 | image = Image.open(image_path).convert('RGB')
130 | image = transform_img_train(image)
131 | return image
132 |
133 | def load_trans_image_from_ceph(self, image_path):
134 | if self.client is None:
135 | self.client = Client(conf_path=self.petrel_conf)
136 | image = self.client.get(image_path)
137 |
138 | image = memoryview(image)
139 | image = Image.open(BytesIO(image)).convert('RGB')
140 | image = transform_img_train(image)
141 | return image
142 |
143 | def load_audio(self, audio_path):
144 | fbank = make_audio_features(audio_path, mel_bins=128, aug=True)
145 | fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
146 | return fbank
147 |
148 | def load_video(self, video_path):
149 | video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
150 | return video_feats[:, :, 0]
151 |
152 | def load_video_from_ceph(self, video_path):
153 | if self.client is None:
154 | self.client = Client(conf_path=self.petrel_conf)
155 | video = self.client.get(video_path)
156 | video = memoryview(video)
157 | video = BytesIO(video)
158 | video_feats = video_utils.load_and_transform_video_data(video, video_path, clip_duration=1, clips_per_video=5)
159 | return video_feats[:, :, 0]
160 |
161 | def load_point(self, point_path):
162 | point_feat = np.load(point_path)
163 | # [8196, 6]
164 | point_feat = torch.tensor(point_feat)
165 | point_feat = pc_norm(point_feat)
166 |
167 | return point_feat
168 |
169 | def load_rgbx(self, image_path):
170 | replace_list = DATASETS[self.dataset]['replace_list']
171 | x_image_path = image_path.replace(replace_list[0], replace_list[1])
172 | image = Image.open(image_path).convert('RGB')
173 | x_image = Image.open(x_image_path).convert('RGB')
174 | x_image = x_image.resize(image.size[-2:])
175 |
176 | image, x_image = transform_pairimg_train([image, x_image])
177 |
178 | # [2, 3, H, W]
179 | image = torch.stack([image, x_image], dim=0)
180 | return image
181 |
182 | def load_rgbx_from_ceph(self, image_path):
183 | if self.client is None:
184 | self.client = Client(conf_path=self.petrel_conf)
185 |
186 | replace_list = DATASETS[self.dataset]['replace_list']
187 | x_image_path = image_path.replace(replace_list[0], replace_list[1])
188 |
189 | image = self.client.get(image_path)
190 | image = memoryview(image)
191 | image = Image.open(BytesIO(image)).convert('RGB')
192 |
193 | x_image = self.client.get(x_image_path)
194 | x_image = memoryview(x_image)
195 | x_image = Image.open(BytesIO(x_image)).convert('RGB')
196 |
197 | x_image = x_image.resize(image.size[-2:])
198 |
199 | image, x_image = transform_pairimg_train([image, x_image])
200 |
201 | # [2, 3, H, W]
202 | image = torch.stack([image, x_image], dim=0)
203 | return image
204 |
205 | def load_fmri(self, fmri_path):
206 | data = np.load(fmri_path)
207 | data = data.mean(axis=0)
208 | data = torch.tensor(data[None])
209 | return data
210 |
211 | def load_imu(self, data_dict):
212 | uid = data_dict["video_uid"]
213 | w_s = data_dict["window_start"]
214 | w_e = data_dict["window_end"]
215 |
216 | imu_data = get_imu_frames(
217 | self.imu_path, uid,
218 | video_start_sec=w_s,
219 | video_end_sec=w_e,
220 | )
221 | if imu_data is None:
222 | raise ValueError
223 | return imu_data['signal']
224 |
225 | def __getitem__(self, index):
226 | index = index % len(self.datas)
227 | if self.dataset != 'imu':
228 | data_path, caption = self.datas[index], self.captions[index]
229 | else:
230 | data_dict = self.datas[index]
231 | caption = data_dict['text']
232 | data_path = data_dict["video_uid"]
233 |
234 | if isinstance(caption, list):
235 | caption = random.choice(caption)
236 | caption = str(caption)
237 |
238 | try:
239 | if self.dataset == 'image':
240 | data = self.load_trans_image_from_ceph(data_path)
241 | elif self.dataset == 'audio':
242 | data = self.load_audio(data_path)
243 | elif self.dataset == 'video':
244 | data = self.load_video_from_ceph(data_path)
245 | elif self.dataset == 'point':
246 | data = self.load_point(data_path)
247 | elif self.dataset in ['rgbn', 'rgbd']:
248 | data = self.load_rgbx(data_path)
249 | elif self.dataset == 'fmri':
250 | data = self.load_fmri(data_path)
251 | elif self.dataset == 'imu':
252 | data_dict = self.datas[index]
253 | data = self.load_imu(data_dict)
254 | except:
255 | print(data_path, 'Not Found')
256 | rand_idx = random.randint(0, len(self))
257 | return self.__getitem__(rand_idx)
258 |
259 | caption_tokens = torch.tensor(self.tokenizer.encode(caption, bos=True, eos=True), dtype=torch.int64)
260 | input_data = caption_tokens
261 |
262 | padding = self.max_words - input_data.shape[0]
263 | if padding > 0:
264 | input_data = torch.cat((input_data, torch.zeros(padding, dtype=torch.int64)))
265 | elif padding < 0:
266 | input_data = input_data[:self.max_words]
267 | labels = copy.deepcopy(input_data)
268 |
269 | if self.partition != 'train':
270 | return input_data, labels, data, data_path, self.dataset, caption
271 |
272 | return input_data, labels, data, data_path, self.dataset
273 |
274 | def __repr__(self):
275 | return f" None:
292 | super().__init__()
293 | self.datasets = list(datasets)
294 | assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
295 | # for d in self.datasets:
296 | # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
297 | self.cumulative_sizes = self.cumsum(self.datasets)
298 |
299 | def __len__(self):
300 | return self.cumulative_sizes[-1]
301 |
302 | def __getitem__(self, idx):
303 | if idx < 0:
304 | if -idx > len(self):
305 | raise ValueError("absolute value of index should not exceed dataset length")
306 | idx = len(self) + idx
307 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
308 | if dataset_idx == 0:
309 | sample_idx = idx
310 | else:
311 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
312 | return self.datasets[dataset_idx][sample_idx]
313 |
314 | @property
315 | def cummulative_sizes(self):
316 | warnings.warn("cummulative_sizes attribute is renamed to "
317 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
318 | return self.cumulative_sizes
319 |
320 | def get_indices(self, batch_size, world_size=1, rank_id=0):
321 | random.seed(0)
322 | real_batch_size = batch_size * world_size
323 | batch_train_indices = []
324 | num_batches = []
325 | for i in range(len(self.datasets)):
326 | # get train_indices
327 | start_idx = self.cumulative_sizes[i-1] if i>0 else 0
328 | end_idx = self.cumulative_sizes[i]
329 | train_indice = list(range(start_idx, end_idx))
330 | random.shuffle(train_indice)
331 | num_batch = int(len(self.datasets[i]) / real_batch_size)
332 | num_batches.append(num_batch)
333 | # get batch indices for each rank
334 | batch_train_indice = [
335 | train_indice[batch*real_batch_size:(batch+1)*real_batch_size][rank_id::world_size]
336 | for batch in range(num_batch)
337 | ]
338 | batch_train_indices.append(batch_train_indice)
339 | min_num_batch = min(num_batches)
340 |
341 | train_indices = []
342 | for batch in range(min_num_batch):
343 | for i in range(len(self.datasets)):
344 | train_indices.extend(batch_train_indices[i][batch])
345 |
346 | return train_indices
--------------------------------------------------------------------------------
/data/video_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from pytorchvideo import transforms as pv_transforms
5 | from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
6 | from pytorchvideo.data.encoded_video import EncodedVideo
7 | from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
8 | from torchvision import transforms
9 | from torchvision.transforms._transforms_video import NormalizeVideo
10 |
11 |
12 | def get_clip_timepoints(clip_sampler, duration):
13 | # Read out all clips in this video
14 | all_clips_timepoints = []
15 | is_last_clip = False
16 | end = 0.0
17 | while not is_last_clip:
18 | start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
19 | all_clips_timepoints.append((start, end))
20 | return all_clips_timepoints
21 |
22 |
23 |
24 | def crop_boxes(boxes, x_offset, y_offset):
25 | """
26 | Perform crop on the bounding boxes given the offsets.
27 | Args:
28 | boxes (ndarray or None): bounding boxes to perform crop. The dimension
29 | is `num boxes` x 4.
30 | x_offset (int): cropping offset in the x axis.
31 | y_offset (int): cropping offset in the y axis.
32 | Returns:
33 | cropped_boxes (ndarray or None): the cropped boxes with dimension of
34 | `num boxes` x 4.
35 | """
36 | cropped_boxes = boxes.copy()
37 | cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
38 | cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
39 |
40 | return cropped_boxes
41 |
42 |
43 | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
44 | """
45 | Perform uniform spatial sampling on the images and corresponding boxes.
46 | Args:
47 | images (tensor): images to perform uniform crop. The dimension is
48 | `num frames` x `channel` x `height` x `width`.
49 | size (int): size of height and weight to crop the images.
50 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
51 | is larger than height. Or 0, 1, or 2 for top, center, and bottom
52 | crop if height is larger than width.
53 | boxes (ndarray or None): optional. Corresponding boxes to images.
54 | Dimension is `num boxes` x 4.
55 | scale_size (int): optional. If not None, resize the images to scale_size before
56 | performing any crop.
57 | Returns:
58 | cropped (tensor): images with dimension of
59 | `num frames` x `channel` x `size` x `size`.
60 | cropped_boxes (ndarray or None): the cropped boxes with dimension of
61 | `num boxes` x 4.
62 | """
63 | assert spatial_idx in [0, 1, 2]
64 | ndim = len(images.shape)
65 | if ndim == 3:
66 | images = images.unsqueeze(0)
67 | height = images.shape[2]
68 | width = images.shape[3]
69 |
70 | if scale_size is not None:
71 | if width <= height:
72 | width, height = scale_size, int(height / width * scale_size)
73 | else:
74 | width, height = int(width / height * scale_size), scale_size
75 | images = torch.nn.functional.interpolate(
76 | images,
77 | size=(height, width),
78 | mode="bilinear",
79 | align_corners=False,
80 | )
81 |
82 | y_offset = int(math.ceil((height - size) / 2))
83 | x_offset = int(math.ceil((width - size) / 2))
84 |
85 | if height > width:
86 | if spatial_idx == 0:
87 | y_offset = 0
88 | elif spatial_idx == 2:
89 | y_offset = height - size
90 | else:
91 | if spatial_idx == 0:
92 | x_offset = 0
93 | elif spatial_idx == 2:
94 | x_offset = width - size
95 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
96 | cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
97 | if ndim == 3:
98 | cropped = cropped.squeeze(0)
99 | return cropped, cropped_boxes
100 |
101 |
102 | class SpatialCrop(nn.Module):
103 | """
104 | Convert the video into 3 smaller clips spatially. Must be used after the
105 | temporal crops to get spatial crops, and should be used with
106 | -2 in the spatial crop at the slowfast augmentation stage (so full
107 | frames are passed in here). Will return a larger list with the
108 | 3x spatial crops as well.
109 | """
110 |
111 | def __init__(self, crop_size: int = 224, num_crops: int = 3):
112 | super().__init__()
113 | self.crop_size = crop_size
114 | if num_crops == 3:
115 | self.crops_to_ext = [0, 1, 2]
116 | self.flipped_crops_to_ext = []
117 | elif num_crops == 1:
118 | self.crops_to_ext = [1]
119 | self.flipped_crops_to_ext = []
120 | else:
121 | raise NotImplementedError("Nothing else supported yet")
122 |
123 | def forward(self, videos):
124 | """
125 | Args:
126 | videos: A list of C, T, H, W videos.
127 | Returns:
128 | videos: A list with 3x the number of elements. Each video converted
129 | to C, T, H', W' by spatial cropping.
130 | """
131 | assert isinstance(videos, list), "Must be a list of videos after temporal crops"
132 | assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
133 | res = []
134 | for video in videos:
135 | for spatial_idx in self.crops_to_ext:
136 | res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
137 | if not self.flipped_crops_to_ext:
138 | continue
139 | flipped_video = transforms.functional.hflip(video)
140 | for spatial_idx in self.flipped_crops_to_ext:
141 | res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
142 | return res
143 |
144 |
145 | def load_and_transform_video_data(
146 | video_file,
147 | video_path,
148 | clip_duration=2,
149 | clips_per_video=5,
150 | sample_rate=16000,
151 | with_audio=False
152 | ):
153 | video_transform = transforms.Compose(
154 | [
155 | pv_transforms.ShortSideScale(224),
156 | NormalizeVideo(
157 | mean=(0.48145466, 0.4578275, 0.40821073),
158 | std=(0.26862954, 0.26130258, 0.27577711),
159 | ),
160 | ]
161 | )
162 |
163 | clip_sampler = ConstantClipsPerVideoSampler(
164 | clip_duration=clip_duration, clips_per_video=clips_per_video
165 | )
166 | frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
167 |
168 | if isinstance(video_file, str):
169 | video = EncodedVideo.from_path(
170 | video_file,
171 | decoder="decord",
172 | decode_audio=with_audio,
173 | # **{"sample_rate": sample_rate},
174 | )
175 | else:
176 | video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
177 |
178 | all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
179 |
180 | all_video = []
181 | for clip_timepoints in all_clips_timepoints:
182 | # Read the clip, get frames
183 | clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
184 | if clip is None:
185 | raise ValueError("No clip found")
186 | video_clip = frame_sampler(clip["video"])
187 | video_clip = video_clip / 255.0 # since this is float, need 0-1
188 |
189 | all_video.append(video_clip)
190 |
191 | all_video = [video_transform(clip) for clip in all_video]
192 | all_video = SpatialCrop(224, num_crops=3)(all_video)
193 |
194 | all_video = torch.stack(all_video, dim=0)
195 |
196 | if not with_audio:
197 | return all_video
198 | else:
199 | return all_video, clip['audio']
200 |
201 | if __name__ == '__main__':
202 | video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
203 | video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
204 | import pdb;pdb.set_trace()
205 |
--------------------------------------------------------------------------------
/demos/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
4 | import argparse
5 | import multiprocessing as mp
6 | import numpy as np
7 | import torch
8 | import torch.distributed as dist
9 | from fairscale.nn.model_parallel import initialize as fs_init
10 | from util.misc import setup_for_distributed
11 | from util.misc import default_tensor_type
12 | from model.meta import MetaModel
13 | from data.conversation_lib import conv_templates
14 | from PIL import Image
15 | import torchvision.transforms as transforms
16 |
17 |
18 | T_random_resized_crop = transforms.Compose([
19 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
20 | antialias=None), # 3 is bicubic
21 | transforms.ToTensor(),
22 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
23 |
24 |
25 | def model_worker(args: argparse.Namespace) -> None:
26 | rank = 0
27 | world_size = len(args.gpu_ids)
28 | gpu_id = args.gpu_ids[rank]
29 | dist.init_process_group(
30 | backend="nccl", rank=rank, world_size=world_size,
31 | init_method=f"tcp://{args.master_addr}:{args.master_port}",
32 | )
33 | print(f"| distributed init on worker {rank}/{world_size}. "
34 | f"using gpu: {gpu_id}")
35 | fs_init.initialize_model_parallel(world_size)
36 | torch.cuda.set_device(gpu_id)
37 |
38 | torch.manual_seed(1)
39 | np.random.seed(1)
40 |
41 | # set the print behavior.
42 | setup_for_distributed(rank == 0)
43 |
44 | target_dtype = {
45 | "bf16": torch.bfloat16,
46 | "fp16": torch.float16
47 | }[args.dtype]
48 | with default_tensor_type(dtype=target_dtype, device="cuda"):
49 | model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
50 | print("Loading pretrained weights ...")
51 | checkpoint = torch.load(args.pretrained_path, map_location='cpu')
52 | msg = model.load_state_dict(checkpoint, strict=False)
53 | print("load result:\n", msg)
54 | model.cuda()
55 | model.eval()
56 | print(f"Model = {str(model)}")
57 |
58 | print('Model is ready. Please input')
59 |
60 | conv = conv_templates["v1"].copy()
61 |
62 | image = Image.open(args.image_path).convert('RGB')
63 | image = T_random_resized_crop(image).unsqueeze(0).cuda().to(target_dtype)
64 | while True:
65 | try:
66 | inp = input(f"{conv.roles[0]}: ")
67 | except EOFError:
68 | inp = ""
69 | if not inp:
70 | print("exit...")
71 | break
72 |
73 | print(f"{conv.roles[1]}: ", end="")
74 |
75 | conv.append_message(conv.roles[0], inp)
76 | conv.append_message(conv.roles[1], None)
77 |
78 | with torch.cuda.amp.autocast(dtype=target_dtype):
79 | print(conv.get_prompt())
80 | response = model.generate([conv.get_prompt()], image, 256, temperature=0.1, top_p=0.75, modal=["image"])
81 | response = response[0]
82 | response = response[len(conv.get_prompt()):].split('###')[0]
83 | print(response)
84 | conv.messages[-1][-1] = response
85 |
86 |
87 | if __name__ == "__main__":
88 | parser = argparse.ArgumentParser("LLaMA2-Accessory Chat Demo")
89 | group = parser.add_mutually_exclusive_group()
90 | group.add_argument("--image_path", type=str, help="path to the input image")
91 | group.add_argument(
92 | "--gpu_ids", type=int, nargs="+",
93 | help="A list of space-separated gpu ids to run the model on. "
94 | "The model will span across GPUs in tensor-parallel mode."
95 | )
96 | parser.add_argument(
97 | "--tokenizer_path", type=str,
98 | default="config/llama2/tokenizer.model",
99 | help="Path to the tokenizer.model file provided along with the LLaMA "
100 | "model."
101 | )
102 | parser.add_argument(
103 | "--llama_type", default="onellm", type=str, metavar="MODEL",
104 | help="LLaMA model type."
105 | )
106 | parser.add_argument(
107 | "--llama_config", type=str, required=True,
108 | default="config/llama2/7B.json",
109 | help="Path to the llama model config json."
110 | )
111 | parser.add_argument(
112 | "--model_max_seq_len", type=int, default=2048,
113 | help="Max sequence length accepted by the pretrained model."
114 | )
115 | parser.add_argument(
116 | "--pretrained_path", type=str, required=True,
117 | help="Path to the llama model checkpoints. A list of checkpoints is "
118 | "supported and will be merged from left to right.")
119 | parser.add_argument(
120 | "--master_port", type=int, default=23862,
121 | help="A port used by the PyTorch distributed module to initialize."
122 | )
123 | parser.add_argument(
124 | "--master_addr", type=str, default="127.0.0.1",
125 | help="An address used by the PyTorch distributed module to initialize."
126 | )
127 | parser.add_argument(
128 | "--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
129 | help="The dtype used for model weights and inference."
130 | )
131 | args = parser.parse_args()
132 |
133 | # check and setup gpu_ids to use
134 | if args.gpu_ids is None:
135 | if args.n_gpus is None:
136 | args.n_gpus = 1
137 | assert args.n_gpus > 0, (
138 | "The demo currently must run on a positive number of GPUs."
139 | )
140 | args.gpu_ids = list(range(args.n_gpus))
141 |
142 | # using the default "fork" method messes up some imported libs (e.g.,
143 | # pandas)
144 | mp.set_start_method("spawn")
145 | model_worker(args)
146 |
--------------------------------------------------------------------------------
/demos/multi_turn_mm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
4 |
5 | import argparse
6 | import multiprocessing as mp
7 | import numpy as np
8 | from typing import List, Optional
9 |
10 | import torch
11 | import torch.distributed as dist
12 |
13 | from fairscale.nn.model_parallel import initialize as fs_init
14 |
15 | import gradio as gr
16 | from util.misc import setup_for_distributed
17 | from util.misc import default_tensor_type
18 | from model.meta import MetaModel
19 | from data.conversation_lib import conv_templates, SeparatorStyle
20 | from PIL import Image
21 | import torchvision.transforms as transforms
22 | from data.fintune_dataset import make_audio_features
23 | from data import video_utils
24 |
25 |
26 | T_random_resized_crop = transforms.Compose([
27 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
28 | antialias=None), # 3 is bicubic
29 | transforms.ToTensor(),
30 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
31 |
32 |
33 | def load_audio(audio_path):
34 | fbank = make_audio_features(audio_path, mel_bins=128)
35 | fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
36 | return fbank
37 |
38 | def load_video(video_path):
39 | video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
40 | return video_feats[:, :, 0]
41 |
42 |
43 | def model_worker(
44 | rank: int, args: argparse.Namespace, barrier: mp.Barrier,
45 | request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
46 | ) -> None:
47 | """
48 | The worker function that manipulates the GPU to run the inference.
49 | Exact n_gpu workers are started, with each one operating on a separate GPU.
50 |
51 | Args:
52 | rank (int): Distributed rank of the worker.
53 | args (argparse.Namespace): All command line arguments.
54 | barrier (multiprocessing.Barrier): A barrier used to delay the start
55 | of Web UI to be after the start of the model.
56 | """
57 |
58 | world_size = len(args.gpu_ids)
59 | gpu_id = args.gpu_ids[rank]
60 | dist.init_process_group(
61 | backend="nccl", rank=rank, world_size=world_size,
62 | init_method=f"tcp://{args.master_addr}:{args.master_port}",
63 | )
64 | print(f"| distributed init on worker {rank}/{world_size}. "
65 | f"using gpu: {gpu_id}")
66 | fs_init.initialize_model_parallel(world_size)
67 | torch.cuda.set_device(gpu_id)
68 |
69 | torch.manual_seed(1)
70 | np.random.seed(1)
71 |
72 | # set the print behavior.
73 | setup_for_distributed(rank == 0)
74 |
75 | target_dtype = {
76 | "bf16": torch.bfloat16,
77 | "fp16": torch.float16
78 | }[args.dtype]
79 | with default_tensor_type(dtype=target_dtype, device="cuda"):
80 | model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
81 | print("Loading pretrained weights ...")
82 | checkpoint = torch.load(args.pretrained_path, map_location='cpu')
83 | msg = model.load_state_dict(checkpoint, strict=False)
84 | print("load result:\n", msg)
85 | model.cuda()
86 | model.eval()
87 | print(f"Model = {str(model)}")
88 |
89 | barrier.wait()
90 |
91 | while True:
92 | img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
93 | if 'image' in modality and img_path is not None:
94 | image = Image.open(img_path).convert('RGB')
95 | inputs = T_random_resized_crop(image)
96 | elif 'video' in modality and video_path is not None:
97 | inputs = load_video(video_path)
98 | elif 'audio' in modality and audio_path is not None:
99 | inputs = load_audio(audio_path)
100 | else:
101 | inputs = None
102 |
103 | if inputs is not None:
104 | inputs = inputs[None].cuda().to(target_dtype)
105 |
106 | conv = conv_templates["v1"].copy()
107 | for user, bot in chatbot:
108 | conv.append_message(conv.roles[0], user)
109 | conv.append_message(conv.roles[1], bot)
110 |
111 | with torch.cuda.amp.autocast(dtype=target_dtype):
112 | print(conv.get_prompt())
113 | for stream_response in model.stream_generate(
114 | conv.get_prompt(), inputs,
115 | max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
116 | modal = modality
117 | ):
118 | conv_sep = (
119 | conv.sep
120 | if conv.sep_style == SeparatorStyle.SINGLE
121 | else conv.sep2
122 | )
123 | end_pos = stream_response["text"].find(conv_sep)
124 | if end_pos != -1:
125 | stream_response["text"] = (
126 | stream_response['text'][:end_pos].rstrip() + "\n"
127 | )
128 | stream_response["end_of_content"] = True
129 |
130 | # keep a few characters if not end_of_content to avoid sending
131 | # part of conv_sep before all of it is generated.
132 | if not stream_response["end_of_content"]:
133 | if len(stream_response["text"]) < len(conv_sep):
134 | continue
135 | stream_response["text"] = (
136 | stream_response["text"][:-len(conv_sep)]
137 | )
138 |
139 | if response_queue is not None:
140 | response_queue.put(stream_response)
141 |
142 | if stream_response["end_of_content"]:
143 | break
144 |
145 |
146 | def gradio_worker(
147 | request_queues: List[mp.Queue], response_queue: mp.Queue,
148 | args: argparse.Namespace, barrier: mp.Barrier,
149 | ) -> None:
150 | """
151 | The gradio worker is responsible for displaying the WebUI and relay the
152 | requests to model workers. It should be launched only once.
153 |
154 | Args:
155 | request_queues (List[mp.Queue]): A list of request queues (one for
156 | each model worker).
157 | args (argparse.Namespace): All command line arguments.
158 | barrier (multiprocessing.Barrier): A barrier used to delay the start
159 | of Web UI to be after the start of the model.
160 | """
161 |
162 | def show_user_input(msg, chatbot):
163 | return "", chatbot + [[msg, None]]
164 |
165 | def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
166 | for queue in request_queues:
167 | queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
168 | while True:
169 | content_piece = response_queue.get()
170 | chatbot[-1][1] = content_piece["text"]
171 | yield chatbot
172 | if content_piece["end_of_content"]:
173 | break
174 |
175 | def undo(chatbot):
176 | if len(chatbot) > 0:
177 | chatbot = chatbot[:-1]
178 | return chatbot
179 |
180 | def clear():
181 | chatbot = []
182 | msg = ""
183 | return chatbot, msg
184 |
185 | CSS ="""
186 | .contain { display: flex; flex-direction: column; }
187 | #component-0 { height: 100%; }
188 | #chatbot { flex-grow: 1; overflow: auto;}
189 | """
190 | with gr.Blocks(css=CSS) as demo:
191 | gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
192 | with gr.Row(equal_height=True):
193 | with gr.Column(scale=1):
194 | img_path = gr.Image(label='Image Input', type='filepath')
195 | video_path = gr.Video(label='Video Input')
196 | audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
197 | modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
198 |
199 | with gr.Column(scale=2):
200 | chatbot = gr.Chatbot(elem_id="chatbot")
201 | msg = gr.Textbox()
202 |
203 | with gr.Row():
204 | submit_button = gr.Button("Submit", variant="primary")
205 | undo_button = gr.Button("Undo")
206 | clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
207 | with gr.Row():
208 | max_gen_len = gr.Slider(
209 | minimum=1, maximum=args.model_max_seq_len // 2,
210 | value=args.model_max_seq_len // 2, interactive=True,
211 | label="Single-turn max response length",
212 | )
213 | gen_t = gr.Slider(
214 | minimum=0, maximum=1, value=0.1, interactive=True,
215 | label="Temperature",
216 | )
217 | top_p = gr.Slider(
218 | minimum=0, maximum=1, value=0.75, interactive=True,
219 | label="Top-p",
220 | )
221 | msg.submit(
222 | show_user_input, [msg, chatbot], [msg, chatbot],
223 | ).then(
224 | stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
225 | )
226 | submit_button.click(
227 | show_user_input, [msg, chatbot], [msg, chatbot],
228 | ).then(
229 | stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
230 | )
231 | undo_button.click(undo, chatbot, chatbot)
232 | # img_path.change(clear, [], [chatbot, msg])
233 | barrier.wait()
234 | demo.queue(api_open=True).launch(share=True, max_threads=1)
235 |
236 |
237 | if __name__ == "__main__":
238 | parser = argparse.ArgumentParser("Chat Demo")
239 | group = parser.add_mutually_exclusive_group()
240 | group.add_argument(
241 | "--gpu_ids", type=int, nargs="+",
242 | help="A list of space-separated gpu ids to run the model on. "
243 | "The model will span across GPUs in tensor-parallel mode."
244 | )
245 | parser.add_argument(
246 | "--tokenizer_path", type=str,
247 | default="config/llama2/tokenizer.model",
248 | help="Path to the tokenizer.model file provided along with the LLaMA "
249 | "model."
250 | )
251 | parser.add_argument(
252 | "--llama_type", default="onellm", type=str, metavar="MODEL",
253 | help="LLaMA model type."
254 | )
255 | parser.add_argument(
256 | "--llama_config", type=str, required=True,
257 | default="config/llama2/7B.json",
258 | help="Path to the llama model config json."
259 | )
260 | parser.add_argument(
261 | "--model_max_seq_len", type=int, default=2048,
262 | help="Max sequence length accepted by the pretrained model."
263 | )
264 | parser.add_argument(
265 | "--pretrained_path", type=str, required=True,
266 | help="Path to the llama model checkpoints. A list of checkpoints is "
267 | "supported and will be merged from left to right.")
268 | parser.add_argument(
269 | "--master_port", type=int, default=23862,
270 | help="A port used by the PyTorch distributed module to initialize."
271 | )
272 | parser.add_argument(
273 | "--master_addr", type=str, default="127.0.0.1",
274 | help="An address used by the PyTorch distributed module to initialize."
275 | )
276 | parser.add_argument(
277 | "--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
278 | help="The dtype used for model weights and inference."
279 | )
280 | args = parser.parse_args()
281 |
282 | # using the default "fork" method messes up some imported libs (e.g.,
283 | # pandas)
284 | mp.set_start_method("spawn")
285 |
286 | # setup the queues and start the model workers
287 | request_queues = []
288 | response_queue = mp.Queue()
289 | worker_processes = []
290 | barrier = mp.Barrier(len(args.gpu_ids) + 1)
291 | for rank, gpu_id in enumerate(args.gpu_ids):
292 | request_queue = mp.Queue()
293 | rank_response_queue = response_queue if rank == 0 else None
294 | process = mp.Process(
295 | target=model_worker,
296 | args=(rank, args, barrier, request_queue, rank_response_queue),
297 | )
298 | process.start()
299 | worker_processes.append(process)
300 | request_queues.append(request_queue)
301 |
302 | gradio_worker(request_queues, response_queue, args, barrier)
303 |
--------------------------------------------------------------------------------
/docs/Data.md:
--------------------------------------------------------------------------------
1 | ## Data
2 |
3 | ### Data Format
4 | Here we give an overview of data format. For details, please check the data loading code: [data/pretrain_dataset.py]() and [data/fintune_dataset.py]()
5 |
6 | #### Pretraining Data
7 | All the data except IMU are organized in `.csv` format. Each `.csv` has two columns: `caption` and `url`. `\t` is used as the delimiter. For example,
8 | ```
9 | caption url
10 | Woman receiving a foot massage at health spa Stock Photo cluster_p_ssd:s3://laion400m_mmg_ssd/29347/293477138.jpg
11 | Long injury list troubles Paul Hart as Portsmouth search for some Cup form cluster_p_ssd:s3://laion400m_mmg_ssd/43069/430692001.jpg
12 | ... ...
13 | ```
14 |
15 | #### Instruction Tuning Data
16 | All finetuning data are converted into multi-turn conversation format. The `.json` file contains a list of training samples, where each sample contains the following keys: `id`, `image` and `conversations`. For example,
17 | ```
18 | {'id': '000000033471', 'image': 'InstructionTuning/image/coco/train2017/000000033471.jpg', 'conversations': [{'from': 'human', 'value': 'What are the colors of the bus in the image?'}, {'from': 'gpt', 'value': 'The bus in the image is white and red.'}, {'from': 'human', 'value': 'What feature can be seen on the back of the bus?'}, {'from': 'gpt', 'value': 'The back of the bus features an advertisement.'}]}
19 | ```
20 |
21 |
22 | ### Download Links
23 |
24 | | Stage | Pretraining | | Instruction Tuning | |
25 | |----------|-------------|----------|--------------------|----------|
26 | | Modality | Dataset | Download | Dataset | Download |
27 | | Image | [LAION-400M](https://laion.ai/blog/laion-400-open-dataset) | [link](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md) | LLaVA-mix665K | [link](https://github.com/haotian-liu/LLaVA#visual-instruction-tuning) |
28 | | | LAION-COCO | [link](https://laion.ai/blog/laion-coco) | COCO Caption | [link](https://cocodataset.org/#download) |
29 | | Video | WebVid-2.5M | [link](https://github.com/m-bain/webvid) | [MSRVTT Caption](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) | [link](https://www.mediafire.com/folder/h14iarbs62e7p/shared) |
30 | | | | | MSRVTT-QA | [link](https://github.com/xudejing/video-question-answering) |
31 | | | | | [Video Conversation](https://github.com/joez17/ChatBridge/blob/main/custom_datasets/valor_data/DATASET.md#download-multis) | [link](https://drive.google.com/file/d/1C7k8flfITJ1GxMwFSvEmBFGyevDZl1ke/view?usp=drive_link) |
32 | | Audio | [WavCaps](https://github.com/XinhaoMei/WavCaps) | [link](https://huggingface.co/datasets/cvssp/WavCaps) | [AudioCaps](https://audiocaps.github.io/) | [link](https://github.com/cdjkim/audiocaps) |
33 | | | | | [Audio Conversation](https://github.com/joez17/ChatBridge/blob/main/custom_datasets/valor_data/DATASET.md#download-multis) | [link](https://drive.google.com/file/d/1C7k8flfITJ1GxMwFSvEmBFGyevDZl1ke/view?usp=drive_link) |
34 | | Point | [Cap3D](https://github.com/crockwell/Cap3D) | [link](https://huggingface.co/datasets/RunsenXu/PointLLM/tree/main) | [Point Conversation](https://github.com/OpenRobotLab/PointLLM) | [link](https://huggingface.co/datasets/RunsenXu/PointLLM) |
35 | | Depth | CC3M | [link](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) | LLaVA-150K | [link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) |
36 | | Normal | CC3M | [link](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) | LLaVA-150K | [link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) |
37 | | IMU | Ego4D | [link](https://ego4d-data.org/docs/data/imu/) | Ego4D | [link](https://ego4d-data.org/docs/data/imu/) |
38 | | fMRI | [NSD](https://naturalscenesdataset.org) | [link](https://huggingface.co/datasets/pscotti/naturalscenesdataset) | [NSD](https://naturalscenesdataset.org) | [link](https://huggingface.co/datasets/pscotti/naturalscenesdataset) |
39 |
40 | **Notes**
41 | - The depth/normal map are generated from [CC3M](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) and 50K random-subset of LLaVA-150K using a pretrained [DPT](https://github.com/EPFL-VILAB/omnidata/tree/main/omnidata_tools/torch#run-our-models-on-your-own-image).
42 | - The [IMU data](https://ego4d-data.org/docs/data/imu/) is preprocessed with [this script](https://github.com/facebookresearch/imu2clip/blob/main/dataset/ego4d/preprocessing_scripts/extract_imu.py).
43 |
44 |
45 | ### Instruction Tuning Data
46 |
47 | **Annotation Download:** Please download the annotation from [this link](https://huggingface.co/datasets/csuhan/OneLLM_InstructionTuning) and put them under `datasets/InstructionTuning`.
48 |
49 | Then download original datasets from the above table and put them under corresponding folders. The file structure should be:
50 |
51 | ```
52 | datasets
53 | └── InstructionTuning
54 | ├── audio
55 | │ ├── audioset2
56 | │ ├── audiocap_train.json
57 | │ ├── audiocap_val.json
58 | │ └── audio_conversation.json
59 | ├── depth_normal
60 | │ ├── depth
61 | │ ├── normal
62 | │ ├── llava_instruct_50k_depth.json
63 | │ └── llava_instruct_50k_normal.json
64 | ├── fmri
65 | │ ├── NSD
66 | │ └── fmri_fixed_train.json
67 | ├── image
68 | │ ├── coco
69 | │ ├── gqa
70 | │ ├── ocr_vqa
71 | │ ├── vg
72 | │ ├── cococap_train.json
73 | │ ├── llava_v1_5_mix665k_image.json
74 | │ └── llava_v1_5_mix665k_text.json
75 | ├── imu
76 | │ ├── ego4d
77 | │ └── imu_fixed_50k.json
78 | ├── point
79 | │ ├── pointllm/8192_npy
80 | │ └── pointllm_70k.json
81 | └── video
82 | ├── msr-vtt/MSR-VTT
83 | ├── msrvtt_cap_test.json
84 | ├── msrvtt_cap_trainval.json
85 | ├── msrvtt_vqa_test.json
86 | ├── msrvtt_vqa_train.json
87 | ├── msrvtt_vqa_val.json
88 | ├── video_complex_reasoning_10k.json
89 | ├── video_conversation_10k.json
90 | └── video_detail_10k.json
91 | ```
--------------------------------------------------------------------------------
/docs/Evaluation.md:
--------------------------------------------------------------------------------
1 | ## Evaluation
2 |
3 | **Annotation Download:** Download the annotations of evaluation datasets from: [csuhan/OneLLM_Eval](https://huggingface.co/datasets/csuhan/OneLLM_Eval), and put it under `datasets/Eval`.
4 |
5 | ### Image-Text Evaluation
6 |
7 | #### COCO Caption
8 |
9 | - Download [COCO2014 Val](http://images.cocodataset.org/zips/val2014.zip) and put it under `datasets/InstructionTuning/image/coco/val2014`
10 | - Fill `pretrained_path` in [eval/image_cap_cococap.py]() and run: `python eval/image_cap_cococap.py`
11 | - Install `https://github.com/salaniz/pycocoevalcap`
12 | - Evaluate with [eval/caption_eval.py]()
13 |
14 | #### MMVet
15 |
16 | - Download MMVet from [mm-vet.zip](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) and put it under `datasets/Eval/image/mm-vet`
17 | - Fill `pretrained_path` in [eval/image_bench_mmvet.py]() and run: `python eval/image_bench_mmvet.py`
18 | - Submit the result file to [Oneline Eval Server](https://huggingface.co/spaces/whyu/MM-Vet_Evaluator)
19 |
20 | ### Video-Text Evaluation
21 |
22 | #### MSVD QA
23 | - Download [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) video clips from [this link](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar) and put it under `datasets/Eval/video/MSVD/YouTubeClips`
24 | - Fill `pretrained_path` in [eval/video_qa_msvd.py]() and run: `python eval/video_qa_msvd.py`.
25 |
26 | ### Audio-Text Evaluation
27 |
28 | #### Clotho Caption
29 | - Download [Clothov2](https://zenodo.org/records/4783391) evaluation set from [this link](https://zenodo.org/records/4783391/files/clotho_audio_evaluation.7z?download=1) and put it under `datasets/Eval/audio/clothov2/evaluation`
30 | - Fill `pretrained_path` in [eval/audio_cap_clothov2.py]() and run: `python eval/audio_cap_clothov2.py`.
31 | - Evaluate with [eval/caption_eval.py]().
32 |
33 | ### Point-Text Evaluation
34 |
35 | #### PointLLM Caption
36 | - Download PointLLM data from [this link](https://huggingface.co/datasets/RunsenXu/PointLLM)
37 | - Fill `pretrained_path` in [eval/point_cap_pointllm.py]() and run: `python eval/point_cap_pointllm.py`.
38 | - Evaluate with [eval/caption_eval.py](). The annotation file is at [datasets/Eval/point/pointllm_test_cococap.json]()
39 |
40 | ### Depth/Normal-Text Evaluation
41 |
42 | TODO
43 |
44 | ### IMU-Text Evaluation
45 |
46 | #### Ego4D IMU Caption
47 |
48 | - Download Ego4D IMU data. Please refer to [docs/Data.md]().
49 | - Fill `IMU_PATH` and `pretrained_path` in [eval/imu_cap_ego4d.py]() and run: `python eval/imu_cap_ego4d.py`.
50 | - Evaluate with [eval/caption_eval.py]()
51 |
52 | ### fMRI-Text Evaluation
53 |
54 | #### NSD Caption
55 | - Download NSD data. Please refer to [docs/Data.md]().
56 | - Fill `pretrained_path` in [eval/fmri_cap_nsd.py]() and run: `python eval/fmri_cap_nsd.py`.
57 | - Evaluate with [eval/caption_eval.py]()
--------------------------------------------------------------------------------
/engine_finetune.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import contextlib
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | import util.misc as misc
9 | import util.lr_sched as lr_sched
10 |
11 | def train_one_epoch(model: torch.nn.Module,
12 | data_loader, optimizer: torch.optim.Optimizer,
13 | epoch: int, start_iter, loss_scaler,
14 | log_writer=None,
15 | args=None):
16 | model.train(True)
17 | metric_logger = misc.MetricLogger(delimiter=" ")
18 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19 | header = 'Epoch: [{}]'.format(epoch)
20 | print_freq = 10
21 |
22 | accum_iter = args.accum_iter
23 |
24 | optimizer.zero_grad()
25 |
26 | if log_writer is not None:
27 | print('log_dir: {}'.format(log_writer.log_dir))
28 | for data_iter_step, data_img in enumerate(
29 | metric_logger.log_every(data_loader, print_freq, header, start_iter), start=start_iter
30 | ):
31 | if len(data_img) == 4:
32 | examples, labels, image, modal = data_img
33 | elif len(data_img) == 3:
34 | examples, labels, modal = data_img
35 | image = None
36 | if data_iter_step % accum_iter == 0:
37 | # lr_sched.adjust_learning_rate(optimizer, data_iter_step, args)
38 | lr_sched.adjust_learning_rate_epoch(optimizer, data_iter_step / len(data_loader) + epoch, args)
39 | update_grad = (data_iter_step + 1) % accum_iter == 0
40 |
41 | autocast_ctx = {
42 | "bf16": torch.cuda.amp.autocast(dtype=torch.bfloat16),
43 | "fp16": torch.cuda.amp.autocast(dtype=torch.float16),
44 | "tf32": contextlib.nullcontext(),
45 | }[args.precision]
46 | backward_ctx = contextlib.nullcontext() if update_grad else model.no_sync()
47 |
48 | with autocast_ctx:
49 | i_loss = model(examples, labels, image, modal)
50 | i_loss_value = i_loss.item()
51 | if not math.isfinite(i_loss_value):
52 | print("[Rank {}] i_loss is {}, stopping training".format(dist.get_rank(), i_loss_value), force=True)
53 | # print(image_paths, force=True)
54 | sys.exit(1)
55 | loss_value = i_loss_value
56 | with backward_ctx:
57 | grad_norm = loss_scaler(
58 | i_loss / accum_iter, optimizer, model,
59 | parameters=model.parameters(),
60 | update_grad=update_grad,
61 | clip_grad=None if args.clip_grad <= 0 else args.clip_grad,
62 | )
63 | if update_grad:
64 | assert grad_norm is not None
65 | metric_logger.update(grad_norm=grad_norm)
66 |
67 | if update_grad:
68 | optimizer.zero_grad()
69 |
70 | torch.cuda.synchronize()
71 |
72 | metric_logger.update(loss=loss_value)
73 | metric_logger.update(iloss=i_loss_value)
74 |
75 | lr = optimizer.param_groups[0]["lr"]
76 | metric_logger.update(lr=lr)
77 |
78 | # save checkpoint
79 | if data_iter_step % 1000 == 0 and data_iter_step != 0:
80 | misc.save_model(
81 | output_dir=args.output_dir,
82 | args=args, epoch=epoch, iteration=data_iter_step, model=model, optimizer=optimizer,
83 | loss_scaler=loss_scaler, dataset_state=None)
84 |
85 | if update_grad:
86 | loss_value_reduce = misc.all_reduce_mean(loss_value)
87 | i_loss_value_reduce = misc.all_reduce_mean(i_loss_value)
88 | if update_grad:
89 | grad_norm_reduce = misc.all_reduce_mean(grad_norm)
90 |
91 | if log_writer is not None and update_grad:
92 | log_writer.add_scalar('train_loss', loss_value_reduce, data_iter_step)
93 | log_writer.add_scalar('i_train_loss', i_loss_value_reduce, data_iter_step)
94 | if update_grad:
95 | log_writer.add_scalar('grad_norm', grad_norm_reduce, data_iter_step)
96 | log_writer.add_scalar('lr', lr, data_iter_step)
97 |
98 | # gather the stats from all processes
99 | metric_logger.synchronize_between_processes()
100 | print("Averaged stats:", metric_logger)
101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
102 |
--------------------------------------------------------------------------------
/engine_pretrain.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import contextlib
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | import util.misc as misc
9 | import util.lr_sched as lr_sched
10 |
11 |
12 | def train_one_epoch(model: torch.nn.Module,
13 | data_loader, optimizer: torch.optim.Optimizer,
14 | device: torch.device, epoch: int, start_iter, loss_scaler,
15 | log_writer=None,
16 | args=None):
17 | model.train(True)
18 | metric_logger = misc.MetricLogger(delimiter=" ")
19 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
20 | header = 'Epoch: [{}]'.format(epoch)
21 | print_freq = 10
22 |
23 | accum_iter = args.accum_iter
24 |
25 | optimizer.zero_grad()
26 |
27 | if log_writer is not None:
28 | print('log_dir: {}'.format(log_writer.log_dir))
29 | for data_iter_step, data_img in enumerate(
30 | metric_logger.log_every(data_loader, print_freq, header, start_iter), start=start_iter
31 | ):
32 | examples, labels, image, image_paths, modal = data_img
33 | if data_iter_step % accum_iter == 0:
34 | lr_sched.adjust_learning_rate(optimizer, data_iter_step, args)
35 | update_grad = (data_iter_step + 1) % accum_iter == 0
36 |
37 | autocast_ctx = {
38 | "bf16": torch.cuda.amp.autocast(dtype=torch.bfloat16),
39 | "fp16": torch.cuda.amp.autocast(dtype=torch.float16),
40 | "tf32": contextlib.nullcontext(),
41 | }[args.precision]
42 | backward_ctx = contextlib.nullcontext() if update_grad else model.no_sync()
43 |
44 | with autocast_ctx:
45 | i_loss = model(examples, labels, image, modal)
46 | i_loss_value = i_loss.item()
47 | if not math.isfinite(i_loss_value):
48 | print("[Rank {}] i_loss is {}, stopping training".format(dist.get_rank(), i_loss_value), force=True)
49 | print(image_paths, force=True)
50 | sys.exit(1)
51 | loss_value = i_loss_value
52 | with backward_ctx:
53 | grad_norm = loss_scaler(
54 | i_loss / accum_iter, optimizer, model,
55 | parameters=model.parameters(),
56 | update_grad=update_grad,
57 | clip_grad=None if args.clip_grad <= 0 else args.clip_grad,
58 | )
59 | if update_grad:
60 | assert grad_norm is not None
61 | metric_logger.update(grad_norm=grad_norm)
62 |
63 | if update_grad:
64 | optimizer.zero_grad()
65 |
66 | torch.cuda.synchronize()
67 |
68 | metric_logger.update(loss=loss_value)
69 | metric_logger.update(iloss=i_loss_value)
70 |
71 | lr = optimizer.param_groups[0]["lr"]
72 | metric_logger.update(lr=lr)
73 |
74 | # save checkpoint
75 | if (data_iter_step % args.save_freq == 0 and data_iter_step != 0) or data_iter_step == len(data_loader)-1:
76 | misc.save_model(
77 | output_dir=args.output_dir,
78 | args=args, epoch=epoch, iteration=data_iter_step, model=model, optimizer=optimizer,
79 | loss_scaler=loss_scaler, dataset_state=None)
80 |
81 | if update_grad:
82 | loss_value_reduce = misc.all_reduce_mean(loss_value)
83 | i_loss_value_reduce = misc.all_reduce_mean(i_loss_value)
84 | if update_grad:
85 | grad_norm_reduce = misc.all_reduce_mean(grad_norm)
86 |
87 | if log_writer is not None and update_grad:
88 | log_writer.add_scalar('train_loss', loss_value_reduce, data_iter_step)
89 | log_writer.add_scalar('i_train_loss', i_loss_value_reduce, data_iter_step)
90 | if update_grad:
91 | log_writer.add_scalar('grad_norm', grad_norm_reduce, data_iter_step)
92 | log_writer.add_scalar('lr', lr, data_iter_step)
93 |
94 | # gather the stats from all processes
95 | metric_logger.synchronize_between_processes()
96 | print("Averaged stats:", metric_logger)
97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
98 |
--------------------------------------------------------------------------------
/eval/audio_cap_clothov2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import json
4 | import os
5 | import torch
6 | from model.meta import MetaModel
7 | import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | import multiprocessing as mp
10 | from fairscale.nn.model_parallel import initialize as fs_init
11 | from util.misc import default_tensor_type
12 | from util.misc import setup_for_distributed
13 | import numpy as np
14 | import torch.distributed as dist
15 | from data.conversation_lib import conv_templates
16 | from data.data_utils import make_audio_features
17 |
18 |
19 | def load_audio(audio_path):
20 | fbank = make_audio_features(audio_path, mel_bins=128)
21 | fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
22 | return fbank
23 |
24 |
25 | class AudioTextDataset(Dataset):
26 | def __init__(self) -> None:
27 | super().__init__()
28 | audio_dir = "datasets/Eval/audio/clothov2/evaluation/"
29 | self.audio_anns = json.load(open("datasets/Eval/audio/clothov2/eval_clothocap_ann.json"))
30 | self.audio_ids = [x['id'] for x in self.audio_anns['images']]
31 | self.audio_names = [x['file_name'] for x in self.audio_anns['images']]
32 | self.audio_files = [os.path.join(audio_dir, x) for x in self.audio_names]
33 |
34 | def __len__(self):
35 | return len(self.audio_files)
36 |
37 | def __getitem__(self, index):
38 | audio_file = self.audio_files[index]
39 | return load_audio(audio_file), self.audio_names[index], self.audio_ids[index]
40 |
41 | if __name__ == "__main__":
42 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
43 | answer_path = "eval/results/eval_clotho.json"
44 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
45 |
46 | mp.set_start_method("spawn")
47 | dist.init_process_group(
48 | backend="nccl", rank=0, world_size=1,
49 | init_method=f"tcp://127.0.0.1:23560")
50 | fs_init.initialize_model_parallel(1)
51 | torch.cuda.set_device(0)
52 | torch.manual_seed(1)
53 | np.random.seed(1)
54 |
55 | # set the print behavior.
56 | setup_for_distributed(True)
57 |
58 | target_dtype = {
59 | "bf16": torch.bfloat16,
60 | "fp16": torch.float16
61 | }['fp16']
62 | with default_tensor_type(dtype=target_dtype, device="cuda"):
63 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
64 |
65 | print("Loading pretrained weights ...")
66 | checkpoint = torch.load(pretrained_path, map_location='cpu')
67 | msg = model.load_state_dict(checkpoint, strict=False)
68 | print("load result:\n", msg)
69 | model.half().cuda()
70 | model.eval()
71 | print(f"Model = {str(model)}")
72 |
73 | def multi_modal_generate(images, inps, modal=['image']):
74 | images = images.cuda().to(target_dtype)
75 | prompts = []
76 | for inp in inps:
77 | conv = conv_templates["v1"].copy()
78 | conv.append_message(conv.roles[0], inp)
79 | conv.append_message(conv.roles[1], None)
80 | prompts.append(conv.get_prompt())
81 |
82 | with torch.cuda.amp.autocast(dtype=target_dtype):
83 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=modal)
84 | outputs = []
85 | for response, prompt in zip(responses, prompts):
86 | response = response[len(prompt):].split('###')[0]
87 | response = response.strip()
88 | outputs.append(response)
89 | return outputs
90 |
91 | dataset = AudioTextDataset()
92 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
93 |
94 | outputs = []
95 | for data in tqdm.tqdm(dataloader):
96 | audios, audio_names, audio_ids = data
97 | prompts = ['Provide a one-sentence caption for the provided audio.'] * len(audios)
98 | results = multi_modal_generate(audios, prompts, modal=['audio'])
99 |
100 | for audio_name, audio_id, result in zip(audio_names, audio_ids, results):
101 | outputs.append({
102 | 'image_id': audio_id.item(),
103 | 'caption': result.strip()
104 | })
105 |
106 | with open(answer_path, 'w') as f:
107 | json.dump(outputs, f)
--------------------------------------------------------------------------------
/eval/caption_eval.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | from pycocoevalcap.eval import COCOEvalCap
3 |
4 | # COCO Caption
5 | annotation_file = 'datasets/Eval/image/coco_cap/coco_karpathy_val_gt.json'
6 | results_file = 'eval/results/eval_cococap.json'
7 |
8 | # Nocaps Caption
9 | # annotation_file = 'datasets/Eval/image/nocaps/nocaps_val_4500_captions.json'
10 | # results_file = 'eval/results/eval_nocaps.json'
11 |
12 | # Clotho Caption
13 | # annotation_file = 'datasets/Eval/audio/clothov2/eval_cococap_ann.json'
14 | # results_file = 'eval/results/clotho_13B.json'
15 |
16 | # AVSD
17 | # annotation_file = 'datasets/Eval/video/AVSD/test_set4DSTC7-AVSD_cococap.json'
18 | # results_file = 'eval/results/eval_avsd.json'
19 |
20 | # VATEX
21 | # annotation_file = 'datasets/Eval/video/vatex/vatex_cococap.json'
22 | # results_file = 'eval/results/eval_vatex.json'
23 |
24 | # VALOR32K
25 | # annotation_file = 'datasets/Eval/video/valor32k/test_ann_cococap.json'
26 | # results_file = 'eval/results/eval_videocap_valor.json'
27 |
28 | # fMRI Caption
29 | # annotation_file = "datasets/Eval/fmri/fmri_eval_cococap.json"
30 | # results_file = "eval/results/fmricap.json"
31 |
32 | # PointLLM Caption
33 | # annotation_file = "datasets/Eval/point/pointllm/pointllm_test_cococap.json"
34 | # results_file = "eval/results/eval_pointllm_cap.json"
35 |
36 | # IMU Caption
37 | # annotation_file = 'datasets/Eval/imu/imu_2000_cococap.json'
38 | # results_file = 'eval/results/imucap.json'
39 |
40 | # create coco object and coco_result object
41 | coco = COCO(annotation_file)
42 | coco_result = coco.loadRes(results_file)
43 |
44 | # create coco_eval object by taking coco and coco_result
45 | coco_eval = COCOEvalCap(coco, coco_result)
46 |
47 | # evaluate on a subset of images by setting
48 | # coco_eval.params['image_id'] = coco_result.getImgIds()
49 | # please remove this line when evaluating the full validation set
50 | coco_eval.params['image_id'] = coco_result.getImgIds()
51 |
52 | # evaluate results
53 | # SPICE will take a few minutes the first time, but speeds up due to caching
54 | coco_eval.evaluate()
55 |
56 | # print output evaluation scores
57 | for metric, score in coco_eval.eval.items():
58 | print(f'{metric}: {score:.3f}')
59 |
--------------------------------------------------------------------------------
/eval/fmri_cap_nsd.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import json
4 | import os
5 | import torch
6 | from model.meta import MetaModel
7 | import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | import multiprocessing as mp
10 | from fairscale.nn.model_parallel import initialize as fs_init
11 | from util.misc import default_tensor_type
12 | from util.misc import setup_for_distributed
13 | import numpy as np
14 | import torch.distributed as dist
15 | from data.conversation_lib import conv_templates
16 |
17 |
18 | def load_fmri(fmri_path):
19 | data = np.load(fmri_path)
20 | data = data.mean(axis=0)
21 | data = torch.tensor(data[None])
22 | return data
23 |
24 | class CaptionDataset(Dataset):
25 | def __init__(self) -> None:
26 | super().__init__()
27 | self.fmri_anns = json.load(open("datasets/Eval/fmri/fmri_eval_cococap.json"))
28 | self.fmri_ids = [x['id'] for x in self.fmri_anns['images']]
29 | self.fmri_names = [x['file_name'] for x in self.fmri_anns['images']]
30 | self.fmri_files = self.fmri_names
31 |
32 | def __len__(self):
33 | return len(self.fmri_files)
34 |
35 | def __getitem__(self, index):
36 | fmri_file = self.fmri_files[index]
37 | return load_fmri(fmri_file), self.fmri_names[index], self.fmri_ids[index]
38 |
39 |
40 | if __name__ == "__main__":
41 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
42 | answer_path = "eval/results/eval_fmricap.json"
43 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
44 |
45 | mp.set_start_method("spawn")
46 | dist.init_process_group(
47 | backend="nccl", rank=0, world_size=1,
48 | init_method=f"tcp://127.0.0.1:23560")
49 | fs_init.initialize_model_parallel(1)
50 | torch.cuda.set_device(0)
51 | torch.manual_seed(1)
52 | np.random.seed(1)
53 | # set the print behavior.
54 | setup_for_distributed(True)
55 |
56 | target_dtype = {
57 | "bf16": torch.bfloat16,
58 | "fp16": torch.float16
59 | }['fp16']
60 | with default_tensor_type(dtype=target_dtype, device="cuda"):
61 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
62 |
63 | print("Loading pretrained weights ...")
64 | checkpoint = torch.load(pretrained_path, map_location='cpu')
65 | msg = model.load_state_dict(checkpoint, strict=False)
66 | print("load result:\n", msg)
67 | model.half().cuda()
68 | model.eval()
69 | print(f"Model = {str(model)}")
70 |
71 | def multi_modal_generate(images, inps, modal=['image']):
72 | images = images.cuda().to(target_dtype)
73 |
74 | prompts = []
75 | for inp in inps:
76 | conv = conv_templates["v1"].copy()
77 | conv.append_message(conv.roles[0], inp)
78 | conv.append_message(conv.roles[1], None)
79 | prompts.append(conv.get_prompt())
80 |
81 | with torch.cuda.amp.autocast(dtype=target_dtype):
82 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=modal)
83 | outputs = []
84 | for response, prompt in zip(responses, prompts):
85 | response = response[len(prompt):].split('###')[0]
86 | response = response.strip()
87 | outputs.append(response)
88 | return outputs
89 |
90 | dataset = CaptionDataset()
91 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
92 |
93 | outputs = []
94 | for data in tqdm.tqdm(dataloader):
95 | fmris, fmri_names, fmri_ids = data
96 | prompts = ['Provide a one-sentence caption for the provided fMRI data.'] * len(fmris)
97 |
98 | results = multi_modal_generate(fmris, prompts, modal=['fmri'])
99 |
100 | for fmri_name, fmri_id, result in zip(fmri_names, fmri_ids, results):
101 | outputs.append({
102 | 'image_id': fmri_id.item(),
103 | 'caption': result.strip()
104 | })
105 | print(fmri_name, fmri_id, result.strip())
106 | print('='*10)
107 |
108 | with open(answer_path, 'w') as f:
109 | json.dump(outputs, f)
110 |
--------------------------------------------------------------------------------
/eval/image_bench_mmvet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import os
4 | import json
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | import torch
9 | import torch.distributed as dist
10 | import multiprocessing as mp
11 | from fairscale.nn.model_parallel import initialize as fs_init
12 | from util.misc import default_tensor_type
13 | from util.misc import setup_for_distributed
14 | import torchvision.transforms as transforms
15 | from model.meta import MetaModel
16 | from data.conversation_lib import conv_templates
17 |
18 |
19 | T_resized_center_crop = transforms.Compose([
20 | transforms.Resize(
21 | 336, interpolation=transforms.InterpolationMode.BICUBIC
22 | ),
23 | transforms.CenterCrop(336),
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
26 |
27 |
28 | if __name__ == "__main__":
29 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
30 | answer_path = "eval/results/eval_mmvet.json"
31 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
32 |
33 | mp.set_start_method("spawn")
34 | dist.init_process_group(
35 | backend="nccl", rank=0, world_size=1,
36 | init_method=f"tcp://127.0.0.1:23563")
37 | fs_init.initialize_model_parallel(1)
38 | torch.cuda.set_device(0)
39 | torch.manual_seed(1)
40 | np.random.seed(1)
41 | # set the print behavior.
42 | setup_for_distributed(True)
43 |
44 | target_dtype = {
45 | "bf16": torch.bfloat16,
46 | "fp16": torch.float16
47 | }['fp16']
48 | with default_tensor_type(dtype=target_dtype, device="cuda"):
49 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
50 |
51 | print("Loading pretrained weights ...")
52 | checkpoint = torch.load(pretrained_path, map_location='cpu')
53 | msg = model.load_state_dict(checkpoint, strict=False)
54 | print("load result:\n", msg)
55 | model.half().cuda()
56 | model.eval()
57 | print(f"Model = {str(model)}")
58 |
59 | def multi_modal_generate(img_path, inp):
60 |
61 | conv = conv_templates["v1"].copy()
62 | if img_path is not None:
63 | image = Image.open(img_path).convert('RGB')
64 | image = T_resized_center_crop(image).unsqueeze(0).cuda().to(target_dtype)
65 | else:
66 | image = None
67 |
68 | conv.append_message(conv.roles[0], inp)
69 | conv.append_message(conv.roles[1], None)
70 |
71 | with torch.cuda.amp.autocast(dtype=target_dtype):
72 | response = model.generate([conv.get_prompt()], image, 256, temperature=0.1, top_p=0.75, modal=['image'])
73 | response = response[0]
74 | response = response[len(conv.get_prompt()):].split('###')[0]
75 | conv.messages[-1][-1] = response
76 | return response.strip()
77 |
78 | result = {}
79 | batch_size = 1
80 | print("Starting...")
81 | datas = json.load(open('datasets/Eval/image/mm-vet/mm-vet.json'))
82 | predictions = {}
83 | with torch.no_grad():
84 | for image_name, data in tqdm(datas.items()):
85 | image_path = os.path.join('datasets/Eval/image/mm-vet/images', data['imagename'])
86 | pred = multi_modal_generate(image_path, data['question'])
87 | predictions[image_name]=pred
88 |
89 | with open(answer_path, 'w') as f:
90 | json.dump(predictions, f)
--------------------------------------------------------------------------------
/eval/image_cap_cococap.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import os
4 | import json
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | import torch
9 | import torch.distributed as dist
10 | from torch.utils.data import Dataset, DataLoader
11 | import multiprocessing as mp
12 | from fairscale.nn.model_parallel import initialize as fs_init
13 | from util.misc import default_tensor_type
14 | from util.misc import setup_for_distributed
15 | import torchvision.transforms as transforms
16 | from model.meta import MetaModel
17 | from data.conversation_lib import conv_templates
18 |
19 |
20 | T_resized_center_crop = transforms.Compose([
21 | transforms.Resize(
22 | 224, interpolation=transforms.InterpolationMode.BICUBIC
23 | ),
24 | transforms.CenterCrop(224),
25 | transforms.ToTensor(),
26 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
27 |
28 |
29 | class CocoCapDataset(Dataset):
30 | def __init__(self) -> None:
31 | super().__init__()
32 | self.datas = json.load(open('datasets/Eval/image/coco_cap/coco_karpathy_val.json'))
33 |
34 | def __len__(self):
35 | return len(self.datas)
36 |
37 | def __getitem__(self, index):
38 | data = self.datas[index]
39 | image_path = os.path.join("datasets/InstructionTuning/image/coco/", data['image'])
40 | image = Image.open(image_path).convert('RGB')
41 | image = T_resized_center_crop(image)
42 | image_id = int(data['image'].split('_')[-1].split('.')[0])
43 | question = 'Provide a one-sentence caption for the provided image.'
44 | return image, question, image_id
45 |
46 |
47 | if __name__ == "__main__":
48 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
49 | answer_path = "eval/results/eval_cococap.json"
50 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
51 |
52 | mp.set_start_method("spawn")
53 | dist.init_process_group(
54 | backend="nccl", rank=0, world_size=1,
55 | init_method=f"tcp://127.0.0.1:23560")
56 | fs_init.initialize_model_parallel(1)
57 | torch.cuda.set_device(0)
58 | torch.manual_seed(1)
59 | np.random.seed(1)
60 | # set the print behavior.
61 | setup_for_distributed(True)
62 |
63 | target_dtype = {
64 | "bf16": torch.bfloat16,
65 | "fp16": torch.float16
66 | }['fp16']
67 | with default_tensor_type(dtype=target_dtype, device="cuda"):
68 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
69 |
70 | print("Loading pretrained weights ...")
71 | checkpoint = torch.load(pretrained_path, map_location='cpu')
72 | msg = model.load_state_dict(checkpoint, strict=False)
73 | print("load result:\n", msg)
74 | model.half().cuda()
75 | model.eval()
76 | print(f"Model = {str(model)}")
77 |
78 | def multi_modal_generate(images, inps):
79 | images = images.cuda().to(target_dtype)
80 |
81 | prompts = []
82 | for inp in inps:
83 | conv = conv_templates["v1"].copy()
84 | conv.append_message(conv.roles[0], inp)
85 | conv.append_message(conv.roles[1], None)
86 | prompts.append(conv.get_prompt())
87 |
88 | with torch.cuda.amp.autocast(dtype=target_dtype):
89 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=['image'])
90 | outputs = []
91 | for response, prompt in zip(responses, prompts):
92 | response = response[len(prompt):].split('###')[0]
93 | response = response.strip()
94 | outputs.append(response)
95 | return outputs
96 |
97 | print("Starting...")
98 | dataset = CocoCapDataset()
99 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
100 |
101 | predictions = []
102 | with torch.no_grad():
103 | for data in tqdm(dataloader):
104 | images, questions, image_ids = data
105 | preds = multi_modal_generate(images, questions)
106 | for question, pred, image_id in zip(questions, preds, image_ids):
107 | predictions.append({'image_id': image_id.item(), 'caption': pred})
108 |
109 | with open(answer_path, 'w') as f:
110 | json.dump(predictions, f)
--------------------------------------------------------------------------------
/eval/imu_cap_ego4d.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import json
4 | import os
5 | import torch
6 | from model.meta import MetaModel
7 | import tqdm
8 | from torch.utils.data import Dataset, DataLoader
9 | import multiprocessing as mp
10 | from fairscale.nn.model_parallel import initialize as fs_init
11 | from util.misc import default_tensor_type
12 | from util.misc import setup_for_distributed
13 | import numpy as np
14 | import torch.distributed as dist
15 | from data.conversation_lib import conv_templates
16 | from data.imu_utils import get_imu_frames
17 |
18 | IMU_PATH="FILL/IMU/PATH/HERE"
19 |
20 | def load_imu(data_dict):
21 | uid = data_dict["video_uid"]
22 | w_s = data_dict["window_start"]
23 | w_e = data_dict["window_end"]
24 |
25 | imu_data = get_imu_frames(
26 | IMU_PATH, uid,
27 | video_start_sec=w_s,
28 | video_end_sec=w_e,
29 | )
30 | if imu_data is None:
31 | raise ValueError
32 | return imu_data['signal']
33 |
34 |
35 | class CaptionDataset(Dataset):
36 | def __init__(self) -> None:
37 | super().__init__()
38 | self.imu_anns = json.load(open("datasets/Eval/imu/Ego4D/imu_2000_cococap.json"))
39 | self.imu_ids = [x['id'] for x in self.imu_anns['images']]
40 | self.imu_names = [x['file_name'] for x in self.imu_anns['images']]
41 | self.imu_files = self.imu_names
42 |
43 | def __len__(self):
44 | return len(self.imu_files)
45 |
46 | def __getitem__(self, index):
47 | return load_imu(self.imu_anns['images'][index]), self.imu_names[index], self.imu_ids[index]
48 |
49 | if __name__ == "__main__":
50 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
51 | answer_path = "eval/results/eval_imucap.json"
52 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
53 |
54 | mp.set_start_method("spawn")
55 | dist.init_process_group(
56 | backend="nccl", rank=0, world_size=1,
57 | init_method=f"tcp://127.0.0.1:23560")
58 | fs_init.initialize_model_parallel(1)
59 | torch.cuda.set_device(0)
60 | torch.manual_seed(1)
61 | np.random.seed(1)
62 | # set the print behavior.
63 | setup_for_distributed(True)
64 |
65 | target_dtype = {
66 | "bf16": torch.bfloat16,
67 | "fp16": torch.float16
68 | }['fp16']
69 | with default_tensor_type(dtype=target_dtype, device="cuda"):
70 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
71 |
72 | print("Loading pretrained weights ...")
73 | checkpoint = torch.load(pretrained_path, map_location='cpu')
74 | msg = model.load_state_dict(checkpoint, strict=False)
75 | print("load result:\n", msg)
76 | model.half().cuda()
77 | model.eval()
78 | print(f"Model = {str(model)}")
79 |
80 | def multi_modal_generate(images, inps, modal=['image']):
81 | images = images.cuda().to(target_dtype)
82 |
83 | prompts = []
84 | for inp in inps:
85 | conv = conv_templates["v1"].copy()
86 | conv.append_message(conv.roles[0], inp)
87 | conv.append_message(conv.roles[1], None)
88 | prompts.append(conv.get_prompt())
89 |
90 | with torch.cuda.amp.autocast(dtype=target_dtype):
91 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=modal)
92 | outputs = []
93 | for response, prompt in zip(responses, prompts):
94 | response = response[len(prompt):].split('###')[0]
95 | response = response.strip()
96 | outputs.append(response)
97 | return outputs
98 |
99 | dataset = CaptionDataset()
100 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
101 |
102 | outputs = []
103 | for data in tqdm.tqdm(dataloader):
104 | imus, imu_names, imu_ids = data
105 | prompts = ['Describe the scene.'] * len(imus)
106 |
107 | results = multi_modal_generate(imus, prompts, modal=['imu'])
108 |
109 | for imu_name, imu_id, result in zip(imu_names, imu_ids, results):
110 | outputs.append({
111 | 'image_id': imu_id.item(),
112 | 'caption': result.strip()
113 | })
114 | print(imu_name, imu_id, result.strip())
115 | print('='*10)
116 |
117 | with open(answer_path, 'w') as f:
118 | json.dump(outputs, f)
--------------------------------------------------------------------------------
/eval/point_cap_pointllm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import torch
7 | import torch.distributed as dist
8 | from torch.utils.data import Dataset, DataLoader
9 | import multiprocessing as mp
10 | from fairscale.nn.model_parallel import initialize as fs_init
11 | from util.misc import default_tensor_type
12 | from util.misc import setup_for_distributed
13 | import numpy as np
14 | from model.meta import MetaModel
15 | from data.conversation_lib import conv_templates
16 | from data.data_utils import pc_norm
17 |
18 |
19 | class CaptionDataset(Dataset):
20 | def __init__(self) -> None:
21 | super().__init__()
22 | self.anns = json.load(open('datasets/Eval/point/pointllm_test.json'))
23 | self.ids = list(self.anns.keys())
24 |
25 | def __len__(self):
26 | return len(self.anns)
27 |
28 | def __getitem__(self, index):
29 | id = self.ids[index]
30 | caption = self.anns[id]
31 |
32 | file_path = f'datasets/Eval/point/pointllm/8192_npy/{id}_8192.npy'
33 |
34 | point_feat = np.load(file_path)
35 | point_feat = torch.tensor(point_feat)
36 | point_feat = pc_norm(point_feat)
37 |
38 | question = 'What is this?'
39 | answer = caption
40 | return point_feat, question, id, answer
41 |
42 |
43 | if __name__ == "__main__":
44 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
45 | answer_path = "eval/results/eval_pointllm_cap.json"
46 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
47 |
48 | mp.set_start_method("spawn")
49 | dist.init_process_group(
50 | backend="nccl", rank=0, world_size=1,
51 | init_method=f"tcp://127.0.0.1:23581")
52 | fs_init.initialize_model_parallel(1)
53 | torch.cuda.set_device(0)
54 | torch.manual_seed(1)
55 | np.random.seed(1)
56 | # set the print behavior.
57 | setup_for_distributed(True)
58 |
59 | target_dtype = {
60 | "bf16": torch.bfloat16,
61 | "fp16": torch.float16
62 | }['fp16']
63 | with default_tensor_type(dtype=target_dtype, device="cuda"):
64 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
65 |
66 | print("Loading pretrained weights ...")
67 | checkpoint = torch.load(pretrained_path, map_location='cpu')
68 | msg = model.load_state_dict(checkpoint, strict=False)
69 | print("load result:\n", msg)
70 | model.half().cuda()
71 | model.eval()
72 | print(f"Model = {str(model)}")
73 |
74 | def multi_modal_generate(images, inps):
75 | images = images.cuda().to(target_dtype)
76 | prompts = []
77 | for inp in inps:
78 | conv = conv_templates["v1"].copy()
79 | conv.append_message(conv.roles[0], inp)
80 | conv.append_message(conv.roles[1], None)
81 | prompts.append(conv.get_prompt())
82 | with torch.cuda.amp.autocast(dtype=target_dtype):
83 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=['point'])
84 | outputs = []
85 | for response, prompt in zip(responses, prompts):
86 | response = response[len(prompt):].split('###')[0]
87 | response = response.strip()
88 | outputs.append(response)
89 | return outputs
90 |
91 | print("Starting...")
92 | dataset = CaptionDataset()
93 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
94 |
95 | predictions = []
96 | with torch.no_grad():
97 | for data in tqdm(dataloader):
98 | images, questions, ids, answers = data
99 | preds = multi_modal_generate(images, questions)
100 |
101 | for question, pred, id, answer in zip(questions, preds, ids, answers):
102 | predictions.append({
103 | 'object_id': id,
104 | 'model_output': pred,
105 | 'ground_truth': answer
106 | })
107 |
108 | with open(answer_path, 'w') as f:
109 | json.dump(predictions, f)
110 |
--------------------------------------------------------------------------------
/eval/video_qa_msvd.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import torch
7 | import torch.distributed as dist
8 | from torch.utils.data import Dataset, DataLoader
9 | import multiprocessing as mp
10 | from fairscale.nn.model_parallel import initialize as fs_init
11 | from util.misc import default_tensor_type
12 | from util.misc import setup_for_distributed
13 | import numpy as np
14 | from model.meta import MetaModel
15 | from data.conversation_lib import conv_templates
16 | from data import video_utils
17 |
18 |
19 | def load_video(video_path):
20 | video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
21 | return video_feats[:, :, 0]
22 |
23 |
24 | class CaptionDataset(Dataset):
25 | def __init__(self) -> None:
26 | super().__init__()
27 | self.datas = json.load(open('datasets/Eval/video/MSVD/MSVD-QA/test_qa.json'))
28 | map_ids =[x.strip().split(' ') for x in open('datasets/Eval/video/MSVD/MSVD-QA/youtube_mapping.txt').readlines()]
29 | self.id_to_video_ids = {x[1]:x[0] for x in map_ids}
30 |
31 | def __len__(self):
32 | return len(self.datas)
33 |
34 | def __getitem__(self, index):
35 | data = self.datas[index]
36 | video_id = 'vid'+str(data['video_id'])
37 | video_name = self.id_to_video_ids[video_id] + '.avi'
38 | image_path = os.path.join("datasets/Eval/video/MSVD/YouTubeClips", video_name)
39 | image = load_video(image_path)
40 | question_id = data['id']
41 | question = data['question'] + '\nAnswer the question using a single word or phrase.'
42 | answer = data['answer']
43 | return image, question, question_id, answer
44 |
45 |
46 | if __name__ == "__main__":
47 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth"
48 | answer_path = "eval/results/eval_msvd.json"
49 | os.makedirs(os.path.dirname(answer_path), exist_ok=True)
50 |
51 | mp.set_start_method("spawn")
52 | dist.init_process_group(
53 | backend="nccl", rank=0, world_size=1,
54 | init_method=f"tcp://127.0.0.1:23563")
55 | fs_init.initialize_model_parallel(1)
56 | torch.cuda.set_device(0)
57 | torch.manual_seed(1)
58 | np.random.seed(1)
59 | # set the print behavior.
60 | setup_for_distributed(True)
61 |
62 | target_dtype = {
63 | "bf16": torch.bfloat16,
64 | "fp16": torch.float16
65 | }['fp16']
66 | with default_tensor_type(dtype=target_dtype, device="cuda"):
67 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model")
68 |
69 | print("Loading pretrained weights ...")
70 | checkpoint = torch.load(pretrained_path, map_location='cpu')
71 | msg = model.load_state_dict(checkpoint, strict=False)
72 | print("load result:\n", msg)
73 | model.half().cuda()
74 | model.eval()
75 | print(f"Model = {str(model)}")
76 |
77 | def multi_modal_generate(images, inps):
78 | images = images.cuda().to(target_dtype)
79 |
80 | prompts = []
81 | for inp in inps:
82 | conv = conv_templates["v1"].copy()
83 | conv.append_message(conv.roles[0], inp)
84 | conv.append_message(conv.roles[1], None)
85 | prompts.append(conv.get_prompt())
86 |
87 | with torch.cuda.amp.autocast(dtype=target_dtype):
88 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=['video'])
89 | outputs = []
90 | for response, prompt in zip(responses, prompts):
91 | response = response[len(prompt):].split('###')[0]
92 | response = response.strip()
93 | outputs.append(response)
94 | return outputs
95 |
96 | result = {}
97 | print("Starting...")
98 | dataset = CaptionDataset()
99 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
100 | predictions = []
101 | correct = 0
102 | with torch.no_grad():
103 | for data in tqdm(dataloader):
104 | images, questions, question_ids, answers = data
105 | preds = multi_modal_generate(images, questions)
106 | for question, pred, question_id, answer in zip(questions, preds, question_ids, answers):
107 | predictions.append({'question_id': question_id.item(), 'answer': pred, 'gt_answer': answer})
108 | pred = pred.strip().lower()
109 | answer = answer.strip().lower()
110 | if (pred in answer) or (answer in pred):
111 | correct += 1
112 |
113 | acc = float(correct) / len(dataset)
114 | print('Accuracy:', acc)
115 |
116 | with open(answer_path, 'w') as f:
117 | json.dump(predictions, f)
--------------------------------------------------------------------------------
/exps/image_text_pretrain_8gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | LLAMA_7B_PATH=""
4 | OUTPUT_DIR=""
5 |
6 | torchrun --nproc_per_node=8 main_pretrain.py \
7 | --epochs 1 --dataset image \
8 | --batch_size 40 --accum_iter 16 \
9 | --model_parallel_size 1 \
10 | --data_parallel sdp \
11 | --save_consolidated \
12 | --llama_type onellm \
13 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
14 | --llama_config config/llama2/7B.json \
15 | --tokenizer_path config/llama2/tokenizer.model \
16 | --auto_resume \
17 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
18 | --warmup_iters 2000 --lr_decay_iters 400000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \
19 | --save_freq 1000 \
20 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
--------------------------------------------------------------------------------
/exps/image_text_pretrain_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p {Partition Name}
3 | #SBATCH --gres=gpu:8
4 | #SBATCH -n 16
5 | #SBATCH -N 2
6 | #SBATCH --cpus-per-task=16
7 |
8 | srun python -u main_pretrain.py \
9 | --epochs 1 --dataset image \
10 | --batch_size 40 --accum_iter 8 \
11 | --model_parallel_size 1 \
12 | --data_parallel sdp \
13 | --save_consolidated \
14 | --llama_type onellm \
15 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
16 | --llama_config config/llama2/7B.json \
17 | --tokenizer_path config/llama2/tokenizer.model \
18 | --auto_resume \
19 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
20 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \
21 | --save_freq 1000 \
22 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
--------------------------------------------------------------------------------
/exps/multimodal_text_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | STAGE3_MODEL=""
4 | OUTPUT_DIR=""
5 |
6 | torchrun --nproc_per_node=8 main_finetune.py \
7 | --epochs 1 --warmup_epochs 0.05 \
8 | --datasets image audio video point rgbd rgbn imu fmri \
9 | --max_words 2048 --batch_size 4 --accum_iter 4 \
10 | --model_parallel_size 1 \
11 | --data_parallel sdp \
12 | --checkpointing --save_consolidated \
13 | --llama_type onellm \
14 | --init_from ${STAGE3_MODEL} \
15 | --auto_resume \
16 | --weight_decay 0.0 --output_dir ${OUTPUT_DIR} \
17 | --lr 2e-5 --min_lr 0.0 --clip_grad 2 \
18 | --save_interval 1 \
19 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
20 |
--------------------------------------------------------------------------------
/exps/multimodal_text_pretrain_stage2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | LLAMA_7B_PATH=""
4 | OUTPUT_DIR=""
5 | IMAGE_TEXT_MODEL=""
6 |
7 | torchrun --nproc_per_node=8 main_pretrain.py \
8 | --epochs 1 --dataset image audio point video \
9 | --batch_size 40 --accum_iter 16 \
10 | --model_parallel_size 1 \
11 | --data_parallel sdp \
12 | --save_consolidated \
13 | --llama_type onellm \
14 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
15 | --llama_config config/llama2/7B.json \
16 | --tokenizer_path config/llama2/tokenizer.model \
17 | --init_from ${IMAGE_TEXT_MODEL} \
18 | --init_from_image \
19 | --auto_resume \
20 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
21 | --warmup_iters 2000 --lr_decay_iters 400000 --lr 1e-5 --min_lr 5e-6 --clip_grad 2 \
22 | --save_freq 1000 \
23 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
--------------------------------------------------------------------------------
/exps/multimodal_text_pretrain_stage3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | LLAMA_7B_PATH=""
4 | OUTPUT_DIR=""
5 | STAGE2_MODEL=""
6 |
7 | torchrun --nproc_per_node=8 main_pretrain.py \
8 | --epochs 1 --dataset image audio point video rgbd rgbn fmri imu \
9 | --batch_size 40 --accum_iter 16 \
10 | --model_parallel_size 1 \
11 | --data_parallel sdp \
12 | --save_consolidated \
13 | --llama_type onellm \
14 | --llama_ckpt_dir ${LLAMA_7B_PATH} \
15 | --llama_config config/llama2/7B.json \
16 | --tokenizer_path config/llama2/tokenizer.model \
17 | --init_from ${STAGE2_MODEL} \
18 | --init_from_image \
19 | --auto_resume \
20 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \
21 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 1e-5 --min_lr 5e-6 --clip_grad 2 \
22 | --save_freq 1000 \
23 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log
--------------------------------------------------------------------------------
/main_finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import numpy as np
5 | import os
6 | import time
7 | from pathlib import Path
8 | import functools
9 | import multiprocessing
10 |
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | from torch.utils.tensorboard import SummaryWriter
14 | from torch.distributed.fsdp import (
15 | FullyShardedDataParallel as FSDP,
16 | MixedPrecision,
17 | ShardingStrategy,
18 | )
19 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
20 | checkpoint_wrapper,
21 | CheckpointImpl,
22 | apply_activation_checkpointing,
23 | )
24 | from torch.distributed.fsdp.wrap import (
25 | transformer_auto_wrap_policy,
26 | )
27 |
28 | from fairscale.nn.model_parallel import initialize as fs_init
29 |
30 | try:
31 | from apex.optimizers import FusedAdam as AdamW
32 | except ImportError:
33 | warnings.warn("cannot import FusedAdam from apex, use torch AdamW instead")
34 | from torch.optim import AdamW
35 |
36 | import util.misc as misc
37 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
38 | from model.meta import MetaModel
39 | from engine_finetune import train_one_epoch
40 |
41 | import warnings
42 | warnings.filterwarnings("ignore")
43 |
44 | from data.finetune_dataset import FinetuneDialogDataset, FinetuneDistSampler
45 |
46 |
47 | def get_args_parser():
48 | parser = argparse.ArgumentParser('OneLLM Finetuning', add_help=False)
49 | parser.add_argument('--datasets', type=str, default='image', nargs='+')
50 | parser.add_argument('--epochs', default=1, type=int)
51 | parser.add_argument('--batch_size', default=64, type=int,
52 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
53 | parser.add_argument('--accum_iter', default=4, type=int,
54 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
55 |
56 | # Model parameters
57 | parser.add_argument('--llama_type', default='llama', type=str, metavar='MODEL',
58 | help='Name of model to train')
59 | parser.add_argument("--llama_ckpt_dir", type=str, default="")
60 | parser.add_argument("--llama_config", type=str, default="config/llama2/7B.json")
61 | parser.add_argument("--tokenizer_path", type=str, default="config/llama2/tokenizer.model")
62 |
63 | # Optimizer parameters
64 | parser.add_argument('--weight_decay', type=float, default=0.02,
65 | help='weight decay (default: 0.05)')
66 |
67 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
68 | help='learning rate (absolute lr)')
69 | parser.add_argument('--min_lr', type=float, default=0.0001, metavar='LR',
70 | help='lower lr bound for cyclic schedulers that hit 0')
71 |
72 | parser.add_argument('--warmup_epochs', type=float, default=1.0, metavar='N',
73 | help='epoch to warmup LR')
74 |
75 | parser.add_argument('--clip_grad', type=int, default=-1,
76 | help='grad clipping norm')
77 |
78 | parser.add_argument('--output_dir', default='./output_dir',
79 | help='path where to save, empty for no saving')
80 | parser.add_argument('--log_dir', default='./output_dir',
81 | help='path where to tensorboard log')
82 | parser.add_argument('--device', default='cuda',
83 | help='device to use for training / testing')
84 | parser.add_argument('--seed', default=0, type=int)
85 | parser.add_argument('--resume', default='',
86 | help='resume from checkpoint')
87 | parser.add_argument('--auto_resume', action='store_true')
88 | parser.add_argument('--init_from', default='',
89 | help='init from checkpoint')
90 | parser.add_argument('--init_from_image', action='store_true')
91 |
92 | parser.add_argument('--num_workers', default=5, type=int)
93 | parser.add_argument('--pin_mem', action='store_true',
94 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
95 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
96 | parser.set_defaults(pin_mem=True)
97 |
98 | # distributed training parameters
99 | parser.add_argument('--world_size', default=1, type=int,
100 | help='number of distributed processes')
101 | parser.add_argument('--local_rank', default=-1, type=int)
102 | parser.add_argument('--dist_on_itp', action='store_true')
103 | parser.add_argument('--dist_url', default='env://',
104 | help='url used to set up distributed training')
105 |
106 | parser.add_argument('--model_parallel_size', type=int, default=1)
107 | parser.add_argument('--data_parallel', type=str, choices=['ddp', 'sdp', 'fsdp'], default='sdp')
108 | parser.add_argument('--precision', type=str, choices=['fp16', 'bf16', 'tf32'], default='bf16')
109 | parser.add_argument('--save_interval', type=int, default=5000)
110 | parser.add_argument('--save_consolidated', action="store_true",
111 | help="save consolidated model weights along with regular checkpoints "
112 | "used to resume training. useful for convenient deployment but "
113 | "will occupy some additional disk space.")
114 | parser.add_argument("--checkpointing", action="store_true")
115 |
116 | parser.add_argument('--max_words', type=int, default=2048)
117 | parser.add_argument('--image_words', type=int, default=30)
118 |
119 | return parser
120 |
121 |
122 | def main(args):
123 | multiprocessing.set_start_method("spawn")
124 | misc.init_distributed_mode(args)
125 | fs_init.initialize_model_parallel(args.model_parallel_size)
126 | if args.precision == "tf32":
127 | torch.backends.cuda.matmul.allow_tf32 = True
128 | torch.backends.cudnn.allow_tf32 = True
129 |
130 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
131 | print("{}".format(args).replace(', ', ',\n'))
132 |
133 | device = torch.device(args.device)
134 |
135 | # fix the seed for reproducibility
136 | seed = args.seed + misc.get_rank()
137 | torch.manual_seed(seed)
138 | np.random.seed(seed)
139 |
140 | # cudnn.benchmark = True
141 |
142 | global_rank = misc.get_rank()
143 | mp_rank = fs_init.get_model_parallel_rank()
144 | mp_world_size = fs_init.get_model_parallel_world_size()
145 | dp_rank = fs_init.get_data_parallel_rank()
146 | dp_world_size = fs_init.get_data_parallel_world_size()
147 | dp_group = fs_init.get_data_parallel_group()
148 |
149 | dataset_train = FinetuneDialogDataset(args.datasets, max_words=args.max_words, image_words=args.image_words, tokenizer_path=args.tokenizer_path)
150 |
151 | if global_rank == 0 and args.log_dir is not None:
152 | os.makedirs(args.log_dir, exist_ok=True)
153 | log_writer = SummaryWriter(log_dir=args.log_dir)
154 | else:
155 | log_writer = None
156 |
157 | # define the model
158 | model = MetaModel(args.llama_type, args.llama_config, args.llama_ckpt_dir, args.tokenizer_path)
159 | model.to(device)
160 | print("Model = %s" % str(model))
161 | if args.init_from:
162 | print("Init checkpoint from %s" % args.init_from)
163 | checkpoint = torch.load(os.path.join(args.init_from, f"consolidated.{mp_rank:02d}-of-{mp_world_size:02d}.pth"), map_location='cpu')
164 | msg = model.load_state_dict(checkpoint, strict=False)
165 | print(msg)
166 |
167 | mixed_precision_dtype = {
168 | "fp16": torch.float16,
169 | "bf16": torch.bfloat16,
170 | "tf32": torch.float32,
171 | }[args.precision]
172 | TransformerBlock = type(model.llma.layers[0])
173 | model = FSDP(
174 | model,
175 | process_group=fs_init.get_data_parallel_group(),
176 | auto_wrap_policy=functools.partial(
177 | transformer_auto_wrap_policy,
178 | transformer_layer_cls=[TransformerBlock],
179 | ),
180 | limit_all_gathers=True,
181 | use_orig_params=True,
182 | sync_module_states=True,
183 | mixed_precision=MixedPrecision(
184 | param_dtype=mixed_precision_dtype,
185 | reduce_dtype=mixed_precision_dtype,
186 | buffer_dtype=mixed_precision_dtype,
187 | ),
188 | sharding_strategy={
189 | "sdp": ShardingStrategy.SHARD_GRAD_OP,
190 | "ddp": ShardingStrategy.NO_SHARD,
191 | "fsdp": ShardingStrategy.FULL_SHARD,
192 | }[args.data_parallel],
193 | ignored_parameters=[param for param in model.parameters() if not param.requires_grad],
194 | )
195 |
196 | if args.checkpointing:
197 | print("apply gradient checkpointing")
198 | non_reentrant_wrapper = functools.partial(
199 | checkpoint_wrapper,
200 | offload_to_cpu=False,
201 | checkpoint_impl=CheckpointImpl.NO_REENTRANT,
202 | )
203 | check_fn = lambda submodule: isinstance(submodule, TransformerBlock)
204 | apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
205 |
206 | eff_batch_size = args.batch_size * args.accum_iter * fs_init.get_data_parallel_world_size()
207 | print("effective batch size: %d" % eff_batch_size)
208 |
209 | # following timm: set wd as 0 for bias and norm layers
210 | #param_groups = misc.add_weight_decay(model, args.weight_decay)
211 | param_groups = {
212 | "decay": {"params": [], "weight_decay": args.weight_decay, "lr": args.lr},
213 | "no_decay": {"params": [], "weight_decay": 0., "lr": args.lr},
214 | "scratch_decay": {"params": [], "weight_decay": args.weight_decay, "lr": args.lr},
215 | "scratch_no_decay": {"params": [], "weight_decay": 0., "lr": args.lr},
216 | }
217 | print("Making parameter groups ...")
218 | for name, param in model.named_parameters():
219 | if not param.requires_grad:
220 | continue
221 | no_decay = name.endswith(".bias") or name.endswith("norm.weight")
222 | scratch = "llma.resample_layers" in name or "llma.resample_tokens" in name
223 | group_name = ("scratch_" if scratch else "") + ("no_decay" if no_decay else "decay")
224 | print(f"{name}: in group {group_name}")
225 | param_groups[group_name]["params"].append(param)
226 | optimizer = AdamW(
227 | [param_groups[key] for key in ["decay", "no_decay", "scratch_decay", "scratch_no_decay"]],
228 | betas=(0.9, 0.95),
229 | )
230 | print(optimizer)
231 | loss_scaler = NativeScaler(args)
232 |
233 | start_epoch = 0
234 | start_iter = 0
235 | if args.resume or args.auto_resume:
236 | start_epoch, start_iter = misc.load_model(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler)
237 |
238 | sampler_train = FinetuneDistSampler(
239 | dataset_train, num_replicas=dp_world_size, rank=dp_rank, shuffle=True, batch_size=args.batch_size,
240 | acc_grad=args.accum_iter
241 | )
242 | data_loader_train = torch.utils.data.DataLoader(
243 | dataset_train,
244 | batch_size=args.batch_size,
245 | num_workers=args.num_workers,
246 | pin_memory=args.pin_mem,
247 | sampler=sampler_train,
248 | drop_last=True
249 | )
250 |
251 | print(f"Start training for {args.epochs} epochs")
252 | start_time = time.time()
253 | for epoch in range(start_epoch, args.epochs):
254 | if args.distributed:
255 | data_loader_train.sampler.set_epoch(epoch, start_iter)
256 |
257 | train_stats = train_one_epoch(
258 | model, data_loader_train,
259 | optimizer, epoch, start_iter, loss_scaler,
260 | log_writer=log_writer,
261 | args=args
262 | )
263 |
264 | if args.output_dir and (epoch % args.save_interval == 0 or epoch + 1 == args.epochs):
265 | misc.save_model(
266 | output_dir=args.output_dir,
267 | args=args, epoch=epoch, iteration=0, model=model, optimizer=optimizer,
268 | loss_scaler=loss_scaler, dataset_state=None,
269 | )
270 |
271 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
272 | 'epoch': epoch,
273 | **{f'val_{k}': v for k, v in train_stats.items()}}
274 |
275 | if args.output_dir and misc.is_main_process():
276 | if log_writer is not None:
277 | log_writer.flush()
278 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
279 | f.write(json.dumps(log_stats) + "\n")
280 |
281 | start_iter = 0
282 |
283 | total_time = time.time() - start_time
284 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
285 | print('Training time {}'.format(total_time_str))
286 |
287 |
288 | if __name__ == '__main__':
289 | args = get_args_parser()
290 | args = args.parse_args()
291 | if args.output_dir:
292 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
293 | main(args)
294 |
--------------------------------------------------------------------------------
/model/LLM/__init__.py:
--------------------------------------------------------------------------------
1 | from . import onellm
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csuhan/OneLLM/8587a4768cf376fb41f7d586e21de5d1ab1ca365/model/__init__.py
--------------------------------------------------------------------------------
/model/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import torch
3 | import torch.nn as nn
4 |
5 | try:
6 | from apex.normalization import FusedRMSNorm as RMSNorm
7 | except ImportError:
8 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
9 |
10 | class RMSNorm(torch.nn.Module):
11 | def __init__(self, dim: int, eps: float = 1e-6):
12 | """
13 | Initialize the RMSNorm normalization layer.
14 |
15 | Args:
16 | dim (int): The dimension of the input tensor.
17 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
18 |
19 | Attributes:
20 | eps (float): A small value added to the denominator for numerical stability.
21 | weight (nn.Parameter): Learnable scaling parameter.
22 |
23 | """
24 | super().__init__()
25 | self.eps = eps
26 | self.weight = nn.Parameter(torch.ones(dim))
27 |
28 | def _norm(self, x):
29 | """
30 | Apply the RMSNorm normalization to the input tensor.
31 |
32 | Args:
33 | x (torch.Tensor): The input tensor.
34 |
35 | Returns:
36 | torch.Tensor: The normalized tensor.
37 |
38 | """
39 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
40 |
41 | def forward(self, x):
42 | """
43 | Forward pass through the RMSNorm layer.
44 |
45 | Args:
46 | x (torch.Tensor): The input tensor.
47 |
48 | Returns:
49 | torch.Tensor: The output tensor after applying RMSNorm.
50 |
51 | """
52 | output = self._norm(x.float()).type_as(x)
53 | return output * self.weight
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/model/lib/point_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | import pointnet2_cuda
5 |
6 | class KNN(nn.Module):
7 | def __init__(self, neighbors, transpose_mode=True):
8 | super(KNN, self).__init__()
9 | self.neighbors = neighbors
10 |
11 | @torch.no_grad()
12 | def forward(self, support, query):
13 | """
14 | Args:
15 | support ([tensor]): [B, N, C]
16 | query ([tensor]): [B, M, C]
17 | Returns:
18 | [int]: neighbor idx. [B, M, K]
19 | """
20 | dist = torch.cdist(support, query)
21 | k_dist = dist.topk(k=self.neighbors, dim=1, largest=False)
22 | return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int()
23 |
24 |
25 | class GroupingOperation(Function):
26 |
27 | @staticmethod
28 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
29 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
30 | """
31 | :param ctx:
32 | :param features: (B, C, N) tensor of features to group
33 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
34 | :return:
35 | output: (B, C, npoint, nsample) tensor
36 | """
37 | assert features.is_contiguous()
38 | assert idx.is_contiguous()
39 |
40 | B, nfeatures, nsample = idx.size()
41 | _, C, N = features.size()
42 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device)
43 |
44 | pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
45 |
46 | ctx.for_backwards = (idx, N)
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_out: torch.Tensor):
51 | """
52 | :param ctx:
53 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
54 | :return:
55 | grad_features: (B, C, N) gradient of the features
56 | """
57 | idx, N = ctx.for_backwards
58 |
59 | B, C, npoint, nsample = grad_out.size()
60 | grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True)
61 | grad_out_data = grad_out.data.contiguous()
62 | pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
63 | return grad_features, None
64 |
65 | grouping_operation = GroupingOperation.apply
66 |
67 |
68 | class KNNGroup(nn.Module):
69 | def __init__(self, nsample: int,
70 | relative_xyz=True,
71 | normalize_dp=False,
72 | return_only_idx=False,
73 | **kwargs
74 | ):
75 | """[summary]
76 |
77 | Args:
78 | nsample (int): maximum number of features to gather in the ball
79 | use_xyz (bool, optional): concate xyz. Defaults to True.
80 | ret_grouped_xyz (bool, optional): [description]. Defaults to False.
81 | normalize_dp (bool, optional): [description]. Defaults to False.
82 | """
83 | super().__init__()
84 | self.nsample = nsample
85 | self.knn = KNN(nsample, transpose_mode=True)
86 | self.relative_xyz = relative_xyz
87 | self.normalize_dp = normalize_dp
88 | self.return_only_idx = return_only_idx
89 |
90 | def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None):
91 | """
92 | :param query_xyz: (B, N, 3) xyz coordinates of the features
93 | :param support_xyz: (B, npoint, 3) centroids
94 | :param features: (B, C, N) descriptors of the features
95 | :return:
96 | new_features: (B, 3 + C, npoint, nsample)
97 | """
98 | _, idx = self.knn(support_xyz, query_xyz)
99 | if self.return_only_idx:
100 | return idx
101 | idx = idx.int()
102 | xyz_trans = support_xyz.transpose(1, 2).contiguous()
103 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
104 | if self.relative_xyz:
105 | grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position
106 | if self.normalize_dp:
107 | grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1)
108 | if features is not None:
109 | grouped_features = grouping_operation(features, idx)
110 | return grouped_xyz, grouped_features
111 | else:
112 | return grouped_xyz, None
113 |
114 |
115 | class FurthestPointSampling(Function):
116 | @staticmethod
117 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
118 | """
119 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
120 | minimum distance
121 | :param ctx:
122 | :param xyz: (B, N, 3) where N > npoint
123 | :param npoint: int, number of features in the sampled set
124 | :return:
125 | output: (B, npoint) tensor containing the set (idx)
126 | """
127 | assert xyz.is_contiguous()
128 |
129 | B, N, _ = xyz.size()
130 | # output = torch.cuda.IntTensor(B, npoint, device=xyz.device)
131 | # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10)
132 | output = torch.cuda.IntTensor(B, npoint)
133 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
134 |
135 | pointnet2_cuda.furthest_point_sampling_wrapper(
136 | B, N, npoint, xyz, temp, output)
137 | return output
138 |
139 | @staticmethod
140 | def backward(xyz, a=None):
141 | return None, None
142 |
143 | furthest_point_sample = FurthestPointSampling.apply
144 |
145 |
146 | class PointPatchEmbed(nn.Module):
147 |
148 | def __init__(self,
149 | sample_ratio=0.0625,
150 | sample_number=1024,
151 | group_size=32,
152 | in_channels=6,
153 | channels=1024,
154 | kernel_size=1,
155 | stride=1,
156 | normalize_dp=False,
157 | relative_xyz=True,
158 | ):
159 | super().__init__()
160 | self.sample_ratio = sample_ratio
161 | self.sample_number = sample_number
162 | self.group_size = group_size
163 |
164 | self.sample_fn = furthest_point_sample
165 | self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp)
166 |
167 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride)
168 |
169 |
170 | def forward(self, x):
171 | # coordinates
172 | p = x[:, :, 3:].contiguous()
173 |
174 | B, N, _ = p.shape[:3]
175 | # idx = self.sample_fn(p, int(N * self.sample_ratio)).long()
176 | idx = self.sample_fn(p, self.sample_number).long()
177 | center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
178 | # query neighbors.
179 | _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32]
180 |
181 | # [B, 6, 1024] -> [B, channels, 1024, 1]
182 | fj = self.conv1(fj).max(dim=-1, keepdim=True)[0]
183 |
184 | return fj
185 |
186 |
187 | if __name__ == '__main__':
188 | model = PointPatchEmbed(channels=256).cuda()
189 | input = torch.rand(4, 16384, 6).cuda()
190 | ou = model(input)
191 | import pdb;pdb.set_trace()
--------------------------------------------------------------------------------
/model/lib/pointnet2/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from . import pointnet2_utils
6 | from . import pytorch_utils as pt_utils
7 | from typing import List
8 |
9 |
10 | class _PointnetSAModuleBase(nn.Module):
11 |
12 | def __init__(self):
13 | super().__init__()
14 | self.npoint = None
15 | self.groupers = None
16 | self.mlps = None
17 | self.pool_method = 'max_pool'
18 |
19 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
20 | """
21 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
22 | :param features: (B, N, C) tensor of the descriptors of the the features
23 | :param new_xyz:
24 | :return:
25 | new_xyz: (B, npoint, 3) tensor of the new features' xyz
26 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
27 | """
28 | new_features_list = []
29 |
30 | xyz_flipped = xyz.transpose(1, 2).contiguous()
31 | if new_xyz is None:
32 | new_xyz = pointnet2_utils.gather_operation(
33 | xyz_flipped,
34 | pointnet2_utils.furthest_point_sample(xyz, self.npoint)
35 | ).transpose(1, 2).contiguous() if self.npoint is not None else None
36 |
37 | for i in range(len(self.groupers)):
38 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
39 |
40 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
41 | if self.pool_method == 'max_pool':
42 | new_features = F.max_pool2d(
43 | new_features, kernel_size=[1, new_features.size(3)]
44 | ) # (B, mlp[-1], npoint, 1)
45 | elif self.pool_method == 'avg_pool':
46 | new_features = F.avg_pool2d(
47 | new_features, kernel_size=[1, new_features.size(3)]
48 | ) # (B, mlp[-1], npoint, 1)
49 | else:
50 | raise NotImplementedError
51 |
52 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
53 | new_features_list.append(new_features)
54 |
55 | return new_xyz, torch.cat(new_features_list, dim=1)
56 |
57 |
58 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
59 | """Pointnet set abstraction layer with multiscale grouping"""
60 |
61 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
62 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
63 | """
64 | :param npoint: int
65 | :param radii: list of float, list of radii to group with
66 | :param nsamples: list of int, number of samples in each ball query
67 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
68 | :param bn: whether to use batchnorm
69 | :param use_xyz:
70 | :param pool_method: max_pool / avg_pool
71 | :param instance_norm: whether to use instance_norm
72 | """
73 | super().__init__()
74 |
75 | assert len(radii) == len(nsamples) == len(mlps)
76 |
77 | self.npoint = npoint
78 | self.groupers = nn.ModuleList()
79 | self.mlps = nn.ModuleList()
80 | for i in range(len(radii)):
81 | radius = radii[i]
82 | nsample = nsamples[i]
83 | self.groupers.append(
84 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
85 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
86 | )
87 | mlp_spec = mlps[i]
88 | if use_xyz:
89 | mlp_spec[0] += 3
90 |
91 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
92 | self.pool_method = pool_method
93 |
94 |
95 | class PointnetSAModule(PointnetSAModuleMSG):
96 | """Pointnet set abstraction layer"""
97 |
98 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
99 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
100 | """
101 | :param mlp: list of int, spec of the pointnet before the global max_pool
102 | :param npoint: int, number of features
103 | :param radius: float, radius of ball
104 | :param nsample: int, number of samples in the ball query
105 | :param bn: whether to use batchnorm
106 | :param use_xyz:
107 | :param pool_method: max_pool / avg_pool
108 | :param instance_norm: whether to use instance_norm
109 | """
110 | super().__init__(
111 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
112 | pool_method=pool_method, instance_norm=instance_norm
113 | )
114 |
115 |
116 | class PointnetFPModule(nn.Module):
117 | r"""Propigates the features of one set to another"""
118 |
119 | def __init__(self, *, mlp: List[int], bn: bool = True):
120 | """
121 | :param mlp: list of int
122 | :param bn: whether to use batchnorm
123 | """
124 | super().__init__()
125 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
126 |
127 | def forward(
128 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
129 | ) -> torch.Tensor:
130 | """
131 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
132 | :param known: (B, m, 3) tensor of the xyz positions of the known features
133 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
134 | :param known_feats: (B, C2, m) tensor of features to be propigated
135 | :return:
136 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features
137 | """
138 | if known is not None:
139 | dist, idx = pointnet2_utils.three_nn(unknown, known)
140 | dist_recip = 1.0 / (dist + 1e-8)
141 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
142 | weight = dist_recip / norm
143 |
144 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
145 | else:
146 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
147 |
148 | if unknow_feats is not None:
149 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
150 | else:
151 | new_features = interpolated_feats
152 |
153 | new_features = new_features.unsqueeze(-1)
154 | new_features = self.mlp(new_features)
155 |
156 | return new_features.squeeze(-1)
157 |
158 |
159 | if __name__ == "__main__":
160 | pass
161 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch.autograd import Function
4 | import torch.nn as nn
5 | from typing import Tuple
6 |
7 | import pointnet2_cuda as pointnet2
8 |
9 |
10 | class FurthestPointSampling(Function):
11 | @staticmethod
12 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
13 | """
14 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
15 | minimum distance
16 | :param ctx:
17 | :param xyz: (B, N, 3) where N > npoint
18 | :param npoint: int, number of features in the sampled set
19 | :return:
20 | output: (B, npoint) tensor containing the set
21 | """
22 | assert xyz.is_contiguous()
23 |
24 | B, N, _ = xyz.size()
25 | output = torch.cuda.IntTensor(B, npoint)
26 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
27 |
28 | pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
29 | return output
30 |
31 | @staticmethod
32 | def backward(xyz, a=None):
33 | return None, None
34 |
35 |
36 | furthest_point_sample = FurthestPointSampling.apply
37 |
38 |
39 | class GatherOperation(Function):
40 |
41 | @staticmethod
42 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
43 | """
44 | :param ctx:
45 | :param features: (B, C, N)
46 | :param idx: (B, npoint) index tensor of the features to gather
47 | :return:
48 | output: (B, C, npoint)
49 | """
50 | assert features.is_contiguous()
51 | assert idx.is_contiguous()
52 |
53 | B, npoint = idx.size()
54 | _, C, N = features.size()
55 | output = torch.cuda.FloatTensor(B, C, npoint)
56 |
57 | pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
58 |
59 | ctx.for_backwards = (idx, C, N)
60 | return output
61 |
62 | @staticmethod
63 | def backward(ctx, grad_out):
64 | idx, C, N = ctx.for_backwards
65 | B, npoint = idx.size()
66 |
67 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
68 | grad_out_data = grad_out.data.contiguous()
69 | pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
70 | return grad_features, None
71 |
72 |
73 | gather_operation = GatherOperation.apply
74 |
75 |
76 | class ThreeNN(Function):
77 |
78 | @staticmethod
79 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
80 | """
81 | Find the three nearest neighbors of unknown in known
82 | :param ctx:
83 | :param unknown: (B, N, 3)
84 | :param known: (B, M, 3)
85 | :return:
86 | dist: (B, N, 3) l2 distance to the three nearest neighbors
87 | idx: (B, N, 3) index of 3 nearest neighbors
88 | """
89 | assert unknown.is_contiguous()
90 | assert known.is_contiguous()
91 |
92 | B, N, _ = unknown.size()
93 | m = known.size(1)
94 | dist2 = torch.cuda.FloatTensor(B, N, 3)
95 | idx = torch.cuda.IntTensor(B, N, 3)
96 |
97 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
98 | return torch.sqrt(dist2), idx
99 |
100 | @staticmethod
101 | def backward(ctx, a=None, b=None):
102 | return None, None
103 |
104 |
105 | three_nn = ThreeNN.apply
106 |
107 |
108 | class ThreeInterpolate(Function):
109 |
110 | @staticmethod
111 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
112 | """
113 | Performs weight linear interpolation on 3 features
114 | :param ctx:
115 | :param features: (B, C, M) Features descriptors to be interpolated from
116 | :param idx: (B, n, 3) three nearest neighbors of the target features in features
117 | :param weight: (B, n, 3) weights
118 | :return:
119 | output: (B, C, N) tensor of the interpolated features
120 | """
121 | assert features.is_contiguous()
122 | assert idx.is_contiguous()
123 | assert weight.is_contiguous()
124 |
125 | B, c, m = features.size()
126 | n = idx.size(1)
127 | ctx.three_interpolate_for_backward = (idx, weight, m)
128 | output = torch.cuda.FloatTensor(B, c, n)
129 |
130 | pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
131 | return output
132 |
133 | @staticmethod
134 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135 | """
136 | :param ctx:
137 | :param grad_out: (B, C, N) tensor with gradients of outputs
138 | :return:
139 | grad_features: (B, C, M) tensor with gradients of features
140 | None:
141 | None:
142 | """
143 | idx, weight, m = ctx.three_interpolate_for_backward
144 | B, c, n = grad_out.size()
145 |
146 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
147 | grad_out_data = grad_out.data.contiguous()
148 |
149 | pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
150 | return grad_features, None, None
151 |
152 |
153 | three_interpolate = ThreeInterpolate.apply
154 |
155 |
156 | class GroupingOperation(Function):
157 |
158 | @staticmethod
159 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
160 | """
161 | :param ctx:
162 | :param features: (B, C, N) tensor of features to group
163 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
164 | :return:
165 | output: (B, C, npoint, nsample) tensor
166 | """
167 | assert features.is_contiguous()
168 | assert idx.is_contiguous()
169 |
170 | B, nfeatures, nsample = idx.size()
171 | _, C, N = features.size()
172 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
173 |
174 | pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
175 |
176 | ctx.for_backwards = (idx, N)
177 | return output
178 |
179 | @staticmethod
180 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
181 | """
182 | :param ctx:
183 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
184 | :return:
185 | grad_features: (B, C, N) gradient of the features
186 | """
187 | idx, N = ctx.for_backwards
188 |
189 | B, C, npoint, nsample = grad_out.size()
190 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
191 |
192 | grad_out_data = grad_out.data.contiguous()
193 | pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
194 | return grad_features, None
195 |
196 |
197 | grouping_operation = GroupingOperation.apply
198 |
199 |
200 | class BallQuery(Function):
201 |
202 | @staticmethod
203 | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
204 | """
205 | :param ctx:
206 | :param radius: float, radius of the balls
207 | :param nsample: int, maximum number of features in the balls
208 | :param xyz: (B, N, 3) xyz coordinates of the features
209 | :param new_xyz: (B, npoint, 3) centers of the ball query
210 | :return:
211 | idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
212 | """
213 | assert new_xyz.is_contiguous()
214 | assert xyz.is_contiguous()
215 |
216 | B, N, _ = xyz.size()
217 | npoint = new_xyz.size(1)
218 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
219 |
220 | pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
221 | return idx
222 |
223 | @staticmethod
224 | def backward(ctx, a=None):
225 | return None, None, None, None
226 |
227 |
228 | ball_query = BallQuery.apply
229 |
230 |
231 | class QueryAndGroup(nn.Module):
232 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
233 | """
234 | :param radius: float, radius of ball
235 | :param nsample: int, maximum number of features to gather in the ball
236 | :param use_xyz:
237 | """
238 | super().__init__()
239 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
240 |
241 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
242 | """
243 | :param xyz: (B, N, 3) xyz coordinates of the features
244 | :param new_xyz: (B, npoint, 3) centroids
245 | :param features: (B, C, N) descriptors of the features
246 | :return:
247 | new_features: (B, 3 + C, npoint, nsample)
248 | """
249 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
250 | xyz_trans = xyz.transpose(1, 2).contiguous()
251 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
252 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
253 |
254 | if features is not None:
255 | grouped_features = grouping_operation(features, idx)
256 | if self.use_xyz:
257 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
258 | else:
259 | new_features = grouped_features
260 | else:
261 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
262 | new_features = grouped_xyz
263 |
264 | return new_features
265 |
266 |
267 | class GroupAll(nn.Module):
268 | def __init__(self, use_xyz: bool = True):
269 | super().__init__()
270 | self.use_xyz = use_xyz
271 |
272 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
273 | """
274 | :param xyz: (B, N, 3) xyz coordinates of the features
275 | :param new_xyz: ignored
276 | :param features: (B, C, N) descriptors of the features
277 | :return:
278 | new_features: (B, C + 3, 1, N)
279 | """
280 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
281 | if features is not None:
282 | grouped_features = features.unsqueeze(2)
283 | if self.use_xyz:
284 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
285 | else:
286 | new_features = grouped_features
287 | else:
288 | new_features = grouped_xyz
289 |
290 | return new_features
291 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import List, Tuple
3 |
4 |
5 | class SharedMLP(nn.Sequential):
6 |
7 | def __init__(
8 | self,
9 | args: List[int],
10 | *,
11 | bn: bool = False,
12 | activation=nn.ReLU(inplace=True),
13 | preact: bool = False,
14 | first: bool = False,
15 | name: str = "",
16 | instance_norm: bool = False,
17 | ):
18 | super().__init__()
19 |
20 | for i in range(len(args) - 1):
21 | self.add_module(
22 | name + 'layer{}'.format(i),
23 | Conv2d(
24 | args[i],
25 | args[i + 1],
26 | bn=(not first or not preact or (i != 0)) and bn,
27 | activation=activation
28 | if (not first or not preact or (i != 0)) else None,
29 | preact=preact,
30 | instance_norm=instance_norm
31 | )
32 | )
33 |
34 |
35 | class _ConvBase(nn.Sequential):
36 |
37 | def __init__(
38 | self,
39 | in_size,
40 | out_size,
41 | kernel_size,
42 | stride,
43 | padding,
44 | activation,
45 | bn,
46 | init,
47 | conv=None,
48 | batch_norm=None,
49 | bias=True,
50 | preact=False,
51 | name="",
52 | instance_norm=False,
53 | instance_norm_func=None
54 | ):
55 | super().__init__()
56 |
57 | bias = bias and (not bn)
58 | conv_unit = conv(
59 | in_size,
60 | out_size,
61 | kernel_size=kernel_size,
62 | stride=stride,
63 | padding=padding,
64 | bias=bias
65 | )
66 | init(conv_unit.weight)
67 | if bias:
68 | nn.init.constant_(conv_unit.bias, 0)
69 |
70 | if bn:
71 | if not preact:
72 | bn_unit = batch_norm(out_size)
73 | else:
74 | bn_unit = batch_norm(in_size)
75 | if instance_norm:
76 | if not preact:
77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
78 | else:
79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
80 |
81 | if preact:
82 | if bn:
83 | self.add_module(name + 'bn', bn_unit)
84 |
85 | if activation is not None:
86 | self.add_module(name + 'activation', activation)
87 |
88 | if not bn and instance_norm:
89 | self.add_module(name + 'in', in_unit)
90 |
91 | self.add_module(name + 'conv', conv_unit)
92 |
93 | if not preact:
94 | if bn:
95 | self.add_module(name + 'bn', bn_unit)
96 |
97 | if activation is not None:
98 | self.add_module(name + 'activation', activation)
99 |
100 | if not bn and instance_norm:
101 | self.add_module(name + 'in', in_unit)
102 |
103 |
104 | class _BNBase(nn.Sequential):
105 |
106 | def __init__(self, in_size, batch_norm=None, name=""):
107 | super().__init__()
108 | self.add_module(name + "bn", batch_norm(in_size))
109 |
110 | nn.init.constant_(self[0].weight, 1.0)
111 | nn.init.constant_(self[0].bias, 0)
112 |
113 |
114 | class BatchNorm1d(_BNBase):
115 |
116 | def __init__(self, in_size: int, *, name: str = ""):
117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
118 |
119 |
120 | class BatchNorm2d(_BNBase):
121 |
122 | def __init__(self, in_size: int, name: str = ""):
123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
124 |
125 |
126 | class Conv1d(_ConvBase):
127 |
128 | def __init__(
129 | self,
130 | in_size: int,
131 | out_size: int,
132 | *,
133 | kernel_size: int = 1,
134 | stride: int = 1,
135 | padding: int = 0,
136 | activation=nn.ReLU(inplace=True),
137 | bn: bool = False,
138 | init=nn.init.kaiming_normal_,
139 | bias: bool = True,
140 | preact: bool = False,
141 | name: str = "",
142 | instance_norm=False
143 | ):
144 | super().__init__(
145 | in_size,
146 | out_size,
147 | kernel_size,
148 | stride,
149 | padding,
150 | activation,
151 | bn,
152 | init,
153 | conv=nn.Conv1d,
154 | batch_norm=BatchNorm1d,
155 | bias=bias,
156 | preact=preact,
157 | name=name,
158 | instance_norm=instance_norm,
159 | instance_norm_func=nn.InstanceNorm1d
160 | )
161 |
162 |
163 | class Conv2d(_ConvBase):
164 |
165 | def __init__(
166 | self,
167 | in_size: int,
168 | out_size: int,
169 | *,
170 | kernel_size: Tuple[int, int] = (1, 1),
171 | stride: Tuple[int, int] = (1, 1),
172 | padding: Tuple[int, int] = (0, 0),
173 | activation=nn.ReLU(inplace=True),
174 | bn: bool = False,
175 | init=nn.init.kaiming_normal_,
176 | bias: bool = True,
177 | preact: bool = False,
178 | name: str = "",
179 | instance_norm=False
180 | ):
181 | super().__init__(
182 | in_size,
183 | out_size,
184 | kernel_size,
185 | stride,
186 | padding,
187 | activation,
188 | bn,
189 | init,
190 | conv=nn.Conv2d,
191 | batch_norm=BatchNorm2d,
192 | bias=bias,
193 | preact=preact,
194 | name=name,
195 | instance_norm=instance_norm,
196 | instance_norm_func=nn.InstanceNorm2d
197 | )
198 |
199 |
200 | class FC(nn.Sequential):
201 |
202 | def __init__(
203 | self,
204 | in_size: int,
205 | out_size: int,
206 | *,
207 | activation=nn.ReLU(inplace=True),
208 | bn: bool = False,
209 | init=None,
210 | preact: bool = False,
211 | name: str = ""
212 | ):
213 | super().__init__()
214 |
215 | fc = nn.Linear(in_size, out_size, bias=not bn)
216 | if init is not None:
217 | init(fc.weight)
218 | if not bn:
219 | nn.init.constant(fc.bias, 0)
220 |
221 | if preact:
222 | if bn:
223 | self.add_module(name + 'bn', BatchNorm1d(in_size))
224 |
225 | if activation is not None:
226 | self.add_module(name + 'activation', activation)
227 |
228 | self.add_module(name + 'fc', fc)
229 |
230 | if not preact:
231 | if bn:
232 | self.add_module(name + 'bn', BatchNorm1d(out_size))
233 |
234 | if activation is not None:
235 | self.add_module(name + 'activation', activation)
236 |
237 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='pointnet2',
6 | ext_modules=[
7 | CUDAExtension('pointnet2_cuda', [
8 | 'src/pointnet2_api.cpp',
9 |
10 | 'src/ball_query.cpp',
11 | 'src/ball_query_gpu.cu',
12 | 'src/group_points.cpp',
13 | 'src/group_points_gpu.cu',
14 | 'src/interpolate.cpp',
15 | 'src/interpolate_gpu.cu',
16 | 'src/sampling.cpp',
17 | 'src/sampling_gpu.cu',
18 | ],
19 | extra_compile_args={'cxx': ['-g'],
20 | 'nvcc': ['-O2']})
21 | ],
22 | cmdclass={'build_ext': BuildExtension}
23 | )
24 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include "ball_query_gpu.h"
8 |
9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
11 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
12 |
13 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
14 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
15 | CHECK_INPUT(new_xyz_tensor);
16 | CHECK_INPUT(xyz_tensor);
17 | const float *new_xyz = new_xyz_tensor.data();
18 | const float *xyz = xyz_tensor.data();
19 | int *idx = idx_tensor.data();
20 |
21 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
22 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
23 | return 1;
24 | }
25 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "ball_query_gpu.h"
6 | #include "cuda_utils.h"
7 |
8 |
9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
11 | // new_xyz: (B, M, 3)
12 | // xyz: (B, N, 3)
13 | // output:
14 | // idx: (B, M, nsample)
15 | int bs_idx = blockIdx.y;
16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
17 | if (bs_idx >= b || pt_idx >= m) return;
18 |
19 | new_xyz += bs_idx * m * 3 + pt_idx * 3;
20 | xyz += bs_idx * n * 3;
21 | idx += bs_idx * m * nsample + pt_idx * nsample;
22 |
23 | float radius2 = radius * radius;
24 | float new_x = new_xyz[0];
25 | float new_y = new_xyz[1];
26 | float new_z = new_xyz[2];
27 |
28 | int cnt = 0;
29 | for (int k = 0; k < n; ++k) {
30 | float x = xyz[k * 3 + 0];
31 | float y = xyz[k * 3 + 1];
32 | float z = xyz[k * 3 + 2];
33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
34 | if (d2 < radius2){
35 | if (cnt == 0){
36 | for (int l = 0; l < nsample; ++l) {
37 | idx[l] = k;
38 | }
39 | }
40 | idx[cnt] = k;
41 | ++cnt;
42 | if (cnt >= nsample) break;
43 | }
44 | }
45 | }
46 |
47 |
48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
50 | // new_xyz: (B, M, 3)
51 | // xyz: (B, N, 3)
52 | // output:
53 | // idx: (B, M, nsample)
54 |
55 | cudaError_t err;
56 |
57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
58 | dim3 threads(THREADS_PER_BLOCK);
59 |
60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
61 | // cudaDeviceSynchronize(); // for using printf in kernel function
62 | err = cudaGetLastError();
63 | if (cudaSuccess != err) {
64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
65 | exit(-1);
66 | }
67 | }
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/ball_query_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _BALL_QUERY_GPU_H
2 | #define _BALL_QUERY_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
11 |
12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
14 |
15 | #endif
16 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 |
6 | #define TOTAL_THREADS 1024
7 | #define THREADS_PER_BLOCK 256
8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
9 |
10 | inline int opt_n_threads(int work_size) {
11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
12 |
13 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
14 | }
15 | #endif
16 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "group_points_gpu.h"
6 | #include
7 | #include
8 |
9 |
10 |
11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
13 |
14 | float *grad_points = grad_points_tensor.data();
15 | const int *idx = idx_tensor.data();
16 | const float *grad_out = grad_out_tensor.data();
17 |
18 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
19 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
20 | return 1;
21 | }
22 |
23 |
24 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
25 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
26 |
27 | const float *points = points_tensor.data();
28 | const int *idx = idx_tensor.data();
29 | float *out = out_tensor.data();
30 |
31 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
32 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
33 | return 1;
34 | }
35 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 | #include "group_points_gpu.h"
6 |
7 |
8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
10 | // grad_out: (B, C, npoints, nsample)
11 | // idx: (B, npoints, nsample)
12 | // output:
13 | // grad_points: (B, C, N)
14 | int bs_idx = blockIdx.z;
15 | int c_idx = blockIdx.y;
16 | int index = blockIdx.x * blockDim.x + threadIdx.x;
17 | int pt_idx = index / nsample;
18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
19 |
20 | int sample_idx = index % nsample;
21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
23 |
24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
25 | }
26 |
27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
29 | // grad_out: (B, C, npoints, nsample)
30 | // idx: (B, npoints, nsample)
31 | // output:
32 | // grad_points: (B, C, N)
33 | cudaError_t err;
34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
35 | dim3 threads(THREADS_PER_BLOCK);
36 |
37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
38 |
39 | err = cudaGetLastError();
40 | if (cudaSuccess != err) {
41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42 | exit(-1);
43 | }
44 | }
45 |
46 |
47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
49 | // points: (B, C, N)
50 | // idx: (B, npoints, nsample)
51 | // output:
52 | // out: (B, C, npoints, nsample)
53 | int bs_idx = blockIdx.z;
54 | int c_idx = blockIdx.y;
55 | int index = blockIdx.x * blockDim.x + threadIdx.x;
56 | int pt_idx = index / nsample;
57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
58 |
59 | int sample_idx = index % nsample;
60 |
61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0];
63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
64 |
65 | out[out_idx] = points[in_idx];
66 | }
67 |
68 |
69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
70 | const float *points, const int *idx, float *out, cudaStream_t stream) {
71 | // points: (B, C, N)
72 | // idx: (B, npoints, nsample)
73 | // output:
74 | // out: (B, C, npoints, nsample)
75 | cudaError_t err;
76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
77 | dim3 threads(THREADS_PER_BLOCK);
78 |
79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out);
80 | // cudaDeviceSynchronize(); // for using printf in kernel function
81 | err = cudaGetLastError();
82 | if (cudaSuccess != err) {
83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
84 | exit(-1);
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/group_points_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _GROUP_POINTS_GPU_H
2 | #define _GROUP_POINTS_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 |
10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
12 |
13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
14 | const float *points, const int *idx, float *out, cudaStream_t stream);
15 |
16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
18 |
19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
21 |
22 | #endif
23 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include "interpolate_gpu.h"
9 | #include
10 | #include
11 |
12 |
13 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
14 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
15 | const float *unknown = unknown_tensor.data();
16 | const float *known = known_tensor.data();
17 | float *dist2 = dist2_tensor.data();
18 | int *idx = idx_tensor.data();
19 |
20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
21 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
22 | }
23 |
24 |
25 | void three_interpolate_wrapper_fast(int b, int c, int m, int n,
26 | at::Tensor points_tensor,
27 | at::Tensor idx_tensor,
28 | at::Tensor weight_tensor,
29 | at::Tensor out_tensor) {
30 |
31 | const float *points = points_tensor.data();
32 | const float *weight = weight_tensor.data();
33 | float *out = out_tensor.data();
34 | const int *idx = idx_tensor.data();
35 |
36 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
37 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
38 | }
39 |
40 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
41 | at::Tensor grad_out_tensor,
42 | at::Tensor idx_tensor,
43 | at::Tensor weight_tensor,
44 | at::Tensor grad_points_tensor) {
45 |
46 | const float *grad_out = grad_out_tensor.data();
47 | const float *weight = weight_tensor.data();
48 | float *grad_points = grad_points_tensor.data();
49 | const int *idx = idx_tensor.data();
50 |
51 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
52 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
53 | }
54 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 | #include "interpolate_gpu.h"
7 |
8 |
9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
11 | // unknown: (B, N, 3)
12 | // known: (B, M, 3)
13 | // output:
14 | // dist2: (B, N, 3)
15 | // idx: (B, N, 3)
16 |
17 | int bs_idx = blockIdx.y;
18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
19 | if (bs_idx >= b || pt_idx >= n) return;
20 |
21 | unknown += bs_idx * n * 3 + pt_idx * 3;
22 | known += bs_idx * m * 3;
23 | dist2 += bs_idx * n * 3 + pt_idx * 3;
24 | idx += bs_idx * n * 3 + pt_idx * 3;
25 |
26 | float ux = unknown[0];
27 | float uy = unknown[1];
28 | float uz = unknown[2];
29 |
30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
31 | int besti1 = 0, besti2 = 0, besti3 = 0;
32 | for (int k = 0; k < m; ++k) {
33 | float x = known[k * 3 + 0];
34 | float y = known[k * 3 + 1];
35 | float z = known[k * 3 + 2];
36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
37 | if (d < best1) {
38 | best3 = best2; besti3 = besti2;
39 | best2 = best1; besti2 = besti1;
40 | best1 = d; besti1 = k;
41 | }
42 | else if (d < best2) {
43 | best3 = best2; besti3 = besti2;
44 | best2 = d; besti2 = k;
45 | }
46 | else if (d < best3) {
47 | best3 = d; besti3 = k;
48 | }
49 | }
50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
52 | }
53 |
54 |
55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
56 | const float *known, float *dist2, int *idx, cudaStream_t stream) {
57 | // unknown: (B, N, 3)
58 | // known: (B, M, 3)
59 | // output:
60 | // dist2: (B, N, 3)
61 | // idx: (B, N, 3)
62 |
63 | cudaError_t err;
64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
65 | dim3 threads(THREADS_PER_BLOCK);
66 |
67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx);
68 |
69 | err = cudaGetLastError();
70 | if (cudaSuccess != err) {
71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
72 | exit(-1);
73 | }
74 | }
75 |
76 |
77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
79 | // points: (B, C, M)
80 | // idx: (B, N, 3)
81 | // weight: (B, N, 3)
82 | // output:
83 | // out: (B, C, N)
84 |
85 | int bs_idx = blockIdx.z;
86 | int c_idx = blockIdx.y;
87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
88 |
89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
90 |
91 | weight += bs_idx * n * 3 + pt_idx * 3;
92 | points += bs_idx * c * m + c_idx * m;
93 | idx += bs_idx * n * 3 + pt_idx * 3;
94 | out += bs_idx * c * n + c_idx * n;
95 |
96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
97 | }
98 |
99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
101 | // points: (B, C, M)
102 | // idx: (B, N, 3)
103 | // weight: (B, N, 3)
104 | // output:
105 | // out: (B, C, N)
106 |
107 | cudaError_t err;
108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
109 | dim3 threads(THREADS_PER_BLOCK);
110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out);
111 |
112 | err = cudaGetLastError();
113 | if (cudaSuccess != err) {
114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
115 | exit(-1);
116 | }
117 | }
118 |
119 |
120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
122 | // grad_out: (B, C, N)
123 | // weight: (B, N, 3)
124 | // output:
125 | // grad_points: (B, C, M)
126 |
127 | int bs_idx = blockIdx.z;
128 | int c_idx = blockIdx.y;
129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
130 |
131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
132 |
133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx;
134 | weight += bs_idx * n * 3 + pt_idx * 3;
135 | grad_points += bs_idx * c * m + c_idx * m;
136 | idx += bs_idx * n * 3 + pt_idx * 3;
137 |
138 |
139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
142 | }
143 |
144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
146 | // grad_out: (B, C, N)
147 | // weight: (B, N, 3)
148 | // output:
149 | // grad_points: (B, C, M)
150 |
151 | cudaError_t err;
152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
153 | dim3 threads(THREADS_PER_BLOCK);
154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points);
155 |
156 | err = cudaGetLastError();
157 | if (cudaSuccess != err) {
158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
159 | exit(-1);
160 | }
161 | }
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/interpolate_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _INTERPOLATE_GPU_H
2 | #define _INTERPOLATE_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 |
10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
12 |
13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
14 | const float *known, float *dist2, int *idx, cudaStream_t stream);
15 |
16 |
17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
19 |
20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
22 |
23 |
24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
26 |
27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
29 |
30 | #endif
31 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/pointnet2_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "ball_query_gpu.h"
5 | #include "group_points_gpu.h"
6 | #include "sampling_gpu.h"
7 | #include "interpolate_gpu.h"
8 |
9 |
10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
12 |
13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
15 |
16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
18 |
19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
20 |
21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
24 | }
25 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include "sampling_gpu.h"
7 |
8 |
9 |
10 | int gather_points_wrapper_fast(int b, int c, int n, int npoints,
11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
12 | const float *points = points_tensor.data();
13 | const int *idx = idx_tensor.data();
14 | float *out = out_tensor.data();
15 |
16 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
17 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
18 | return 1;
19 | }
20 |
21 |
22 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
23 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
24 |
25 | const float *grad_out = grad_out_tensor.data();
26 | const int *idx = idx_tensor.data();
27 | float *grad_points = grad_points_tensor.data();
28 |
29 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
30 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
31 | return 1;
32 | }
33 |
34 |
35 | int furthest_point_sampling_wrapper(int b, int n, int m,
36 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
37 |
38 | const float *points = points_tensor.data();
39 | float *temp = temp_tensor.data();
40 | int *idx = idx_tensor.data();
41 |
42 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
43 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
44 | return 1;
45 | }
46 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 | #include "sampling_gpu.h"
6 |
7 |
8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m,
9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
10 | // points: (B, C, N)
11 | // idx: (B, M)
12 | // output:
13 | // out: (B, C, M)
14 |
15 | int bs_idx = blockIdx.z;
16 | int c_idx = blockIdx.y;
17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
19 |
20 | out += bs_idx * c * m + c_idx * m + pt_idx;
21 | idx += bs_idx * m + pt_idx;
22 | points += bs_idx * c * n + c_idx * n;
23 | out[0] = points[idx[0]];
24 | }
25 |
26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
27 | const float *points, const int *idx, float *out, cudaStream_t stream) {
28 | // points: (B, C, N)
29 | // idx: (B, npoints)
30 | // output:
31 | // out: (B, C, npoints)
32 |
33 | cudaError_t err;
34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
35 | dim3 threads(THREADS_PER_BLOCK);
36 |
37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out);
38 |
39 | err = cudaGetLastError();
40 | if (cudaSuccess != err) {
41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42 | exit(-1);
43 | }
44 | }
45 |
46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
47 | const int *__restrict__ idx, float *__restrict__ grad_points) {
48 | // grad_out: (B, C, M)
49 | // idx: (B, M)
50 | // output:
51 | // grad_points: (B, C, N)
52 |
53 | int bs_idx = blockIdx.z;
54 | int c_idx = blockIdx.y;
55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
57 |
58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx;
59 | idx += bs_idx * m + pt_idx;
60 | grad_points += bs_idx * c * n + c_idx * n;
61 |
62 | atomicAdd(grad_points + idx[0], grad_out[0]);
63 | }
64 |
65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
67 | // grad_out: (B, C, npoints)
68 | // idx: (B, npoints)
69 | // output:
70 | // grad_points: (B, C, N)
71 |
72 | cudaError_t err;
73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
74 | dim3 threads(THREADS_PER_BLOCK);
75 |
76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points);
77 |
78 | err = cudaGetLastError();
79 | if (cudaSuccess != err) {
80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
81 | exit(-1);
82 | }
83 | }
84 |
85 |
86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
87 | const float v1 = dists[idx1], v2 = dists[idx2];
88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
89 | dists[idx1] = max(v1, v2);
90 | dists_i[idx1] = v2 > v1 ? i2 : i1;
91 | }
92 |
93 | template
94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m,
95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
96 | // dataset: (B, N, 3)
97 | // tmp: (B, N)
98 | // output:
99 | // idx: (B, M)
100 |
101 | if (m <= 0) return;
102 | __shared__ float dists[block_size];
103 | __shared__ int dists_i[block_size];
104 |
105 | int batch_index = blockIdx.x;
106 | dataset += batch_index * n * 3;
107 | temp += batch_index * n;
108 | idxs += batch_index * m;
109 |
110 | int tid = threadIdx.x;
111 | const int stride = block_size;
112 |
113 | int old = 0;
114 | if (threadIdx.x == 0)
115 | idxs[0] = old;
116 |
117 | __syncthreads();
118 | for (int j = 1; j < m; j++) {
119 | int besti = 0;
120 | float best = -1;
121 | float x1 = dataset[old * 3 + 0];
122 | float y1 = dataset[old * 3 + 1];
123 | float z1 = dataset[old * 3 + 2];
124 | for (int k = tid; k < n; k += stride) {
125 | float x2, y2, z2;
126 | x2 = dataset[k * 3 + 0];
127 | y2 = dataset[k * 3 + 1];
128 | z2 = dataset[k * 3 + 2];
129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
130 | // if (mag <= 1e-3)
131 | // continue;
132 |
133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
134 | float d2 = min(d, temp[k]);
135 | temp[k] = d2;
136 | besti = d2 > best ? k : besti;
137 | best = d2 > best ? d2 : best;
138 | }
139 | dists[tid] = best;
140 | dists_i[tid] = besti;
141 | __syncthreads();
142 |
143 | if (block_size >= 1024) {
144 | if (tid < 512) {
145 | __update(dists, dists_i, tid, tid + 512);
146 | }
147 | __syncthreads();
148 | }
149 |
150 | if (block_size >= 512) {
151 | if (tid < 256) {
152 | __update(dists, dists_i, tid, tid + 256);
153 | }
154 | __syncthreads();
155 | }
156 | if (block_size >= 256) {
157 | if (tid < 128) {
158 | __update(dists, dists_i, tid, tid + 128);
159 | }
160 | __syncthreads();
161 | }
162 | if (block_size >= 128) {
163 | if (tid < 64) {
164 | __update(dists, dists_i, tid, tid + 64);
165 | }
166 | __syncthreads();
167 | }
168 | if (block_size >= 64) {
169 | if (tid < 32) {
170 | __update(dists, dists_i, tid, tid + 32);
171 | }
172 | __syncthreads();
173 | }
174 | if (block_size >= 32) {
175 | if (tid < 16) {
176 | __update(dists, dists_i, tid, tid + 16);
177 | }
178 | __syncthreads();
179 | }
180 | if (block_size >= 16) {
181 | if (tid < 8) {
182 | __update(dists, dists_i, tid, tid + 8);
183 | }
184 | __syncthreads();
185 | }
186 | if (block_size >= 8) {
187 | if (tid < 4) {
188 | __update(dists, dists_i, tid, tid + 4);
189 | }
190 | __syncthreads();
191 | }
192 | if (block_size >= 4) {
193 | if (tid < 2) {
194 | __update(dists, dists_i, tid, tid + 2);
195 | }
196 | __syncthreads();
197 | }
198 | if (block_size >= 2) {
199 | if (tid < 1) {
200 | __update(dists, dists_i, tid, tid + 1);
201 | }
202 | __syncthreads();
203 | }
204 |
205 | old = dists_i[0];
206 | if (tid == 0)
207 | idxs[j] = old;
208 | }
209 | }
210 |
211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m,
212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
213 | // dataset: (B, N, 3)
214 | // tmp: (B, N)
215 | // output:
216 | // idx: (B, M)
217 |
218 | cudaError_t err;
219 | unsigned int n_threads = opt_n_threads(n);
220 |
221 | switch (n_threads) {
222 | case 1024:
223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break;
224 | case 512:
225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break;
226 | case 256:
227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break;
228 | case 128:
229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break;
230 | case 64:
231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break;
232 | case 32:
233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break;
234 | case 16:
235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break;
236 | case 8:
237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break;
238 | case 4:
239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break;
240 | case 2:
241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break;
242 | case 1:
243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break;
244 | default:
245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs);
246 | }
247 |
248 | err = cudaGetLastError();
249 | if (cudaSuccess != err) {
250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
251 | exit(-1);
252 | }
253 | }
254 |
--------------------------------------------------------------------------------
/model/lib/pointnet2/src/sampling_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_GPU_H
2 | #define _SAMPLING_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints,
10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
11 |
12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
13 | const float *points, const int *idx, float *out, cudaStream_t stream);
14 |
15 |
16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
18 |
19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
21 |
22 |
23 | int furthest_point_sampling_wrapper(int b, int n, int m,
24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
25 |
26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m,
27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream);
28 |
29 | #endif
30 |
--------------------------------------------------------------------------------
/model/meta.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | import torch.nn as nn
4 | import json
5 | import os
6 | from .tokenizer import Tokenizer
7 | from . import LLM
8 |
9 | from fairscale.nn.model_parallel import initialize as fs_init
10 |
11 |
12 | class MetaModel(nn.Module):
13 |
14 | def __init__(self, llama_type, llama_config, llama_ckpt_dir=None, tokenizer_path=None):
15 | super().__init__()
16 |
17 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
18 |
19 | ModelArgs = LLM.__dict__[llama_type].ModelArgs
20 | Transformer = LLM.__dict__[llama_type].Transformer
21 |
22 | with open(llama_config, "r") as f:
23 | params = json.loads(f.read())
24 | model_args: ModelArgs = ModelArgs(
25 | max_seq_len=2048, max_batch_size=32, **params
26 | )
27 | self.tokenizer = Tokenizer(model_path=tokenizer_path)
28 | model_args.vocab_size = self.tokenizer.n_words
29 |
30 | model = Transformer(model_args)
31 | mp_rank = fs_init.get_model_parallel_rank()
32 | if llama_ckpt_dir is not None:
33 | ckpt_path = os.path.join(llama_ckpt_dir, f"consolidated.{mp_rank:02d}.pth")
34 | if os.path.exists(ckpt_path):
35 | checkpoint = torch.load(ckpt_path, map_location="cpu")
36 | msg = model.load_state_dict(checkpoint, strict=False)
37 | print(msg)
38 | else:
39 | print(f'Checkpoint not found at {ckpt_path}')
40 | self.llma = model
41 | for name, param in self.named_parameters():
42 | if param.requires_grad:
43 | print(f"Trainable param: {name}, {param.shape}, {param.dtype}")
44 | count = sum(p.numel() for p in self.parameters() if p.requires_grad)
45 | print(f"Parameter count : {count}")
46 |
47 | def forward(self, examples, labels, image=None, modal='image'):
48 | output = self.llma(examples, image=image, modal=modal)
49 | output = output[:, :-1, :]
50 | labels = labels[:, 1:]
51 |
52 | if labels.sum() == 0:
53 | c_loss = output.mean() * 0
54 | else:
55 | c_loss = self.criterion(output.reshape(-1, 32000), labels.flatten())
56 |
57 | return c_loss
58 |
59 | def generate(
60 | self,
61 | prompts: List[str],
62 | images,
63 | max_gen_len: int,
64 | temperature: float = 0.8,
65 | top_p: float = 0.95,
66 | modal = ['image'],
67 | ) -> List[str]:
68 | bsz = len(prompts)
69 | params = self.llma.params
70 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
71 |
72 | prompt_tokens = [self.tokenizer.encode(
73 | x, bos=True, eos=False) for x in prompts]
74 |
75 | min_prompt_size = min([len(t) for t in prompt_tokens])
76 | max_prompt_size = max([len(t) for t in prompt_tokens])
77 |
78 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
79 |
80 | tokens = torch.full(
81 | (bsz, total_len), self.tokenizer.pad_id).cuda().long()
82 | for k, t in enumerate(prompt_tokens):
83 | tokens[k, : len(t)] = torch.tensor(t).long()
84 | input_text_mask = tokens != self.tokenizer.pad_id
85 | start_pos = min_prompt_size
86 | prev_pos = 0
87 | for cur_pos in range(start_pos, total_len):
88 | logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal=modal)
89 | if temperature > 0:
90 | probs = torch.softmax(logits / temperature, dim=-1)
91 | next_token = self.sample_top_p(probs, top_p)
92 | else:
93 | next_token = torch.argmax(logits, dim=-1)
94 | next_token = next_token.reshape(-1)
95 | # only replace token if prompt has already been generated
96 | next_token = torch.where(
97 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
98 | )
99 | tokens[:, cur_pos] = next_token
100 | prev_pos = cur_pos
101 |
102 | decoded = []
103 | for i, t in enumerate(tokens.tolist()):
104 | # cut to max gen len
105 | t = t[: len(prompt_tokens[i]) + max_gen_len]
106 | # cut to eos tok if any
107 | try:
108 | t = t[: t.index(self.tokenizer.eos_id)]
109 | except ValueError:
110 | pass
111 | decoded.append(self.tokenizer.decode(t))
112 | return decoded
113 |
114 | @torch.inference_mode()
115 | def stream_generate(
116 | self,
117 | prompt: str,
118 | images,
119 | max_gen_len: int,
120 | temperature: float = 0.8,
121 | top_p: float = 0.95,
122 | modal = ['image'],
123 | ):
124 | params = self.llma.params
125 |
126 | prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
127 | # truncate from the left. leave some space for generation.
128 | max_seq_len = params.max_seq_len
129 | if images is not None:
130 | max_seq_len -= self.llma.image_words
131 |
132 | max_prompt_size = max_seq_len - max_gen_len
133 | prompt_tokens = prompt_tokens[-max_prompt_size:]
134 |
135 | prompt_size = len(prompt_tokens)
136 |
137 | total_len = min(max_seq_len, max_gen_len + prompt_size)
138 |
139 | tokens = torch.full([total_len], 0).cuda().long()
140 |
141 | tokens[:len(prompt_tokens)] = torch.tensor(prompt_tokens).long()
142 | start_pos = prompt_size
143 | prev_pos = 0
144 | generate_until = start_pos
145 | for cur_pos in range(start_pos, total_len):
146 | logits = self.llma.forward_inference(tokens[None, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal = modal)
147 | if temperature > 0:
148 | probs = torch.softmax(logits / temperature, dim=-1)
149 | next_token = self.sample_top_p(probs, top_p)
150 | else:
151 | next_token = torch.argmax(logits, dim=-1)
152 | next_token = next_token.item()
153 |
154 | if next_token == self.tokenizer.eos_id:
155 | break
156 |
157 | tokens[cur_pos] = next_token
158 | prev_pos = cur_pos
159 | generate_until = cur_pos + 1
160 | yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": False}
161 |
162 | yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": True}
163 |
164 | def sample_top_p(self, probs, p):
165 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
166 | probs_sum = torch.cumsum(probs_sort, dim=-1)
167 | mask = probs_sum - probs_sort > p
168 | probs_sort[mask] = 0.0
169 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
170 | next_token = torch.multinomial(probs_sort, num_samples=1)
171 | next_token = torch.gather(probs_idx, -1, next_token)
172 | return next_token
173 |
174 | def get_image_words(self):
175 | return self.llma.image_words
--------------------------------------------------------------------------------
/model/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from sentencepiece import SentencePieceProcessor
5 | from logging import getLogger
6 | from typing import List
7 | import os
8 |
9 |
10 | logger = getLogger()
11 |
12 |
13 | class Tokenizer:
14 | def __init__(self, model_path: str):
15 | # reload tokenizer
16 | assert os.path.isfile(model_path), model_path
17 | self.sp_model = SentencePieceProcessor(model_file=model_path)
18 | logger.info(f"Reloaded SentencePiece model from {model_path}")
19 |
20 | # BOS / EOS token IDs
21 | self.n_words: int = self.sp_model.vocab_size()
22 | self.bos_id: int = self.sp_model.bos_id()
23 | self.eos_id: int = self.sp_model.eos_id()
24 | self.pad_id: int = self.sp_model.pad_id()
25 | logger.info(
26 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
27 | )
28 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
29 |
30 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
31 | assert type(s) is str
32 | t = self.sp_model.encode(s)
33 | if bos:
34 | t = [self.bos_id] + t
35 | if eos:
36 | t = t + [self.eos_id]
37 | return t
38 |
39 | def decode(self, t: List[int]) -> str:
40 | return self.sp_model.decode(t)
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 | torch==2.0.0+cu117
3 | packaging
4 | fairscale
5 | sentencepiece
6 | Pillow
7 | huggingface_hub
8 | open_clip_torch==2.23.0
9 | decord
10 | pytorchvideo==0.1.5
11 | torchaudio
12 | matplotlib
13 | flash-attn
14 | gradio
15 | pandas
16 |
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | def adjust_learning_rate(optimizer, it, args):
10 | """Decay the learning rate with half-cycle cosine after warmup"""
11 | if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps
12 | lr = args.lr * it / args.warmup_iters
13 | elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate
14 | lr = args.min_lr
15 | else: # 3) in between, use cosine decay down to min learning rate
16 | decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters)
17 | assert 0 <= decay_ratio <= 1
18 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
19 | lr = args.min_lr + (args.lr - args.min_lr) * coeff
20 |
21 | for param_group in optimizer.param_groups:
22 | if "lr_scale" in param_group:
23 | param_group["lr"] = lr * param_group["lr_scale"]
24 | else:
25 | param_group["lr"] = lr
26 | return lr
27 |
28 |
29 | def adjust_learning_rate_epoch(optimizer, epoch, args):
30 | """Decay the learning rate with half-cycle cosine after warmup"""
31 | if epoch < args.warmup_epochs:
32 | lr = args.lr * epoch / args.warmup_epochs
33 | else:
34 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
35 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
36 | for param_group in optimizer.param_groups:
37 | if "lr_scale" in param_group:
38 | param_group["lr"] = lr * param_group["lr_scale"]
39 | else:
40 | param_group["lr"] = lr
41 | return lr
42 |
43 |
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 | import torch
12 |
13 | # --------------------------------------------------------
14 | # 2D sine-cosine position embedding
15 | # References:
16 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
17 | # MoCo v3: https://github.com/facebookresearch/moco-v3
18 | # --------------------------------------------------------
19 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
20 | """
21 | grid_size: int of the grid height and width
22 | return:
23 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
24 | """
25 | grid_h = np.arange(grid_size, dtype=np.float32)
26 | grid_w = np.arange(grid_size, dtype=np.float32)
27 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
28 | grid = np.stack(grid, axis=0)
29 |
30 | grid = grid.reshape([2, 1, grid_size, grid_size])
31 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32 | if cls_token:
33 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
34 | return pos_embed
35 |
36 |
37 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38 | assert embed_dim % 2 == 0
39 |
40 | # use half of dimensions to encode grid_h
41 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
42 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43 |
44 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
45 | return emb
46 |
47 |
48 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
49 | """
50 | embed_dim: output dimension for each position
51 | pos: a list of positions to be encoded: size (M,)
52 | out: (M, D)
53 | """
54 | assert embed_dim % 2 == 0
55 | omega = np.arange(embed_dim // 2, dtype=np.float)
56 | omega /= embed_dim / 2.
57 | omega = 1. / 10000**omega # (D/2,)
58 |
59 | pos = pos.reshape(-1) # (M,)
60 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
61 |
62 | emb_sin = np.sin(out) # (M, D/2)
63 | emb_cos = np.cos(out) # (M, D/2)
64 |
65 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
66 | return emb
67 |
68 |
69 | # --------------------------------------------------------
70 | # Interpolate position embeddings for high-resolution
71 | # References:
72 | # DeiT: https://github.com/facebookresearch/deit
73 | # --------------------------------------------------------
74 | def interpolate_pos_embed(model, checkpoint_model):
75 | if 'pos_embed' in checkpoint_model:
76 | pos_embed_checkpoint = checkpoint_model['pos_embed']
77 | embedding_size = pos_embed_checkpoint.shape[-1]
78 | num_patches = model.patch_embed.num_patches
79 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
80 | # height (== width) for the checkpoint position embedding
81 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
82 | # height (== width) for the new position embedding
83 | new_size = int(num_patches ** 0.5)
84 | # class_token and dist_token are kept unchanged
85 | if orig_size != new_size:
86 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
87 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
88 | # only the position tokens are interpolated
89 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
90 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
91 | pos_tokens = torch.nn.functional.interpolate(
92 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
93 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
94 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
95 | checkpoint_model['pos_embed'] = new_pos_embed
96 |
97 |
98 | def interpolate_pos_embed_online(
99 | pos_embed, orig_size, new_size, num_extra_tokens: int
100 | ):
101 | # [257, 1024]
102 | extra_tokens = pos_embed[:num_extra_tokens]
103 | pos_tokens = pos_embed[num_extra_tokens:]
104 | embedding_size = pos_tokens.shape[1]
105 | pos_tokens = pos_tokens.reshape(
106 | -1, orig_size[0], orig_size[1], embedding_size
107 | ).permute(0, 3, 1, 2)
108 | pos_tokens = torch.nn.functional.interpolate(
109 | pos_tokens, size=new_size, mode="bicubic", align_corners=False,
110 | )
111 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size)
112 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
113 | return new_pos_embed
--------------------------------------------------------------------------------