├── .gitignore ├── LICENSE_llama2 ├── OneLLM_Arxiv.pdf ├── README.md ├── config └── llama2 │ ├── 7B.json │ └── tokenizer.model ├── data ├── conversation_lib.py ├── data_utils.py ├── finetune_dataset.py ├── imu_utils.py ├── pretrain_dataset.py └── video_utils.py ├── demos ├── cli.py └── multi_turn_mm.py ├── docs ├── Data.md └── Evaluation.md ├── engine_finetune.py ├── engine_pretrain.py ├── eval ├── audio_cap_clothov2.py ├── caption_eval.py ├── fmri_cap_nsd.py ├── image_bench_mmvet.py ├── image_cap_cococap.py ├── imu_cap_ego4d.py ├── point_cap_pointllm.py └── video_qa_msvd.py ├── exps ├── image_text_pretrain_8gpu.sh ├── image_text_pretrain_slurm.sh ├── multimodal_text_finetune.sh ├── multimodal_text_pretrain_stage2.sh └── multimodal_text_pretrain_stage3.sh ├── main_finetune.py ├── main_pretrain.py ├── model ├── LLM │ ├── __init__.py │ └── onellm.py ├── __init__.py ├── components.py ├── lib │ ├── point_utils.py │ └── pointnet2 │ │ ├── pointnet2_modules.py │ │ ├── pointnet2_utils.py │ │ ├── pytorch_utils.py │ │ ├── setup.py │ │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── ball_query_gpu.h │ │ ├── cuda_utils.h │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── group_points_gpu.h │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── interpolate_gpu.h │ │ ├── pointnet2_api.cpp │ │ ├── sampling.cpp │ │ ├── sampling_gpu.cu │ │ └── sampling_gpu.h ├── meta.py └── tokenizer.py ├── requirements.txt └── util ├── lr_sched.py ├── misc.py └── pos_embed.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.egg-info 4 | dist 5 | 6 | output 7 | output_dir 8 | *.pth 9 | *.log 10 | weights 11 | datasets 12 | data_hub 13 | slurm* 14 | multimodal_llama2_7B 15 | weights -------------------------------------------------------------------------------- /LICENSE_llama2: -------------------------------------------------------------------------------- 1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT 2 | Llama 2 Version Release Date: July 18, 2023 3 | 4 | "Agreement" means the terms and conditions for use, reproduction, distribution and 5 | modification of the Llama Materials set forth herein. 6 | 7 | "Documentation" means the specifications, manuals and documentation 8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- 9 | libraries/llama-downloads/. 10 | 11 | "Licensee" or "you" means you, or your employer or any other person or entity (if 12 | you are entering into this Agreement on such person or entity's behalf), of the age 13 | required under applicable laws, rules or regulations to provide legal consent and that 14 | has legal authority to bind your employer or such other person or entity if you are 15 | entering in this Agreement on their behalf. 16 | 17 | "Llama 2" means the foundational large language models and software and 18 | algorithms, including machine-learning model code, trained model weights, 19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other 20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- 21 | libraries/llama-downloads/. 22 | 23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and 24 | Documentation (and any portion thereof) made available under this Agreement. 25 | 26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you 27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta 28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland). 29 | 30 | By clicking "I Accept" below or by using or distributing any portion or element of the 31 | Llama Materials, you agree to be bound by this Agreement. 32 | 33 | 1. License Rights and Redistribution. 34 | 35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non- 36 | transferable and royalty-free limited license under Meta's intellectual property or 37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce, 38 | distribute, copy, create derivative works of, and make modifications to the Llama 39 | Materials. 40 | 41 | b. Redistribution and Use. 42 | 43 | i. If you distribute or make the Llama Materials, or any derivative works 44 | thereof, available to a third party, you shall provide a copy of this Agreement to such 45 | third party. 46 | ii. If you receive Llama Materials, or any derivative works thereof, from 47 | a Licensee as part of an integrated end user product, then Section 2 of this 48 | Agreement will not apply to you. 49 | 50 | iii. You must retain in all copies of the Llama Materials that you 51 | distribute the following attribution notice within a "Notice" text file distributed as a 52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, 53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved." 54 | 55 | iv. Your use of the Llama Materials must comply with applicable laws 56 | and regulations (including trade compliance laws and regulations) and adhere to the 57 | Acceptable Use Policy for the Llama Materials (available at 58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into 59 | this Agreement. 60 | 61 | v. You will not use the Llama Materials or any output or results of the 62 | Llama Materials to improve any other large language model (excluding Llama 2 or 63 | derivative works thereof). 64 | 65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the 66 | monthly active users of the products or services made available by or for Licensee, 67 | or Licensee's affiliates, is greater than 700 million monthly active users in the 68 | preceding calendar month, you must request a license from Meta, which Meta may 69 | grant to you in its sole discretion, and you are not authorized to exercise any of the 70 | rights under this Agreement unless or until Meta otherwise expressly grants you 71 | such rights. 72 | 73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE 74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE 75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY 77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR 78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE 79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING 80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR 81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. 82 | 83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE 84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, 85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS 86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, 87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN 88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF 89 | ANY OF THE FOREGOING. 90 | 91 | 5. Intellectual Property. 92 | 93 | a. No trademark licenses are granted under this Agreement, and in 94 | connection with the Llama Materials, neither Meta nor Licensee may use any name 95 | or mark owned by or associated with the other or any of its affiliates, except as 96 | required for reasonable and customary use in describing and redistributing the 97 | Llama Materials. 98 | 99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or 100 | for Meta, with respect to any derivative works and modifications of the Llama 101 | Materials that are made by you, as between you and Meta, you are and will be the 102 | owner of such derivative works and modifications. 103 | 104 | c. If you institute litigation or other proceedings against Meta or any entity 105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama 106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing, 107 | constitutes infringement of intellectual property or other rights owned or licensable 108 | by you, then any licenses granted to you under this Agreement shall terminate as of 109 | the date such litigation or claim is filed or instituted. You will indemnify and hold 110 | harmless Meta from and against any claim by any third party arising out of or related 111 | to your use or distribution of the Llama Materials. 112 | 113 | 6. Term and Termination. The term of this Agreement will commence upon your 114 | acceptance of this Agreement or access to the Llama Materials and will continue in 115 | full force and effect until terminated in accordance with the terms and conditions 116 | herein. Meta may terminate this Agreement if you are in breach of any term or 117 | condition of this Agreement. Upon termination of this Agreement, you shall delete 118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the 119 | termination of this Agreement. 120 | 121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and 122 | construed under the laws of the State of California without regard to choice of law 123 | principles, and the UN Convention on Contracts for the International Sale of Goods 124 | does not apply to this Agreement. The courts of California shall have exclusive 125 | jurisdiction of any dispute arising out of this Agreement. 126 | 127 | -------------------------------------------------------------------------------- /OneLLM_Arxiv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuhan/OneLLM/8587a4768cf376fb41f7d586e21de5d1ab1ca365/OneLLM_Arxiv.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## OneLLM: One Framework to Align All Modalities with Language 2 | 3 | [[Project Page](https://onellm.csuhan.com)] [[Paper](https://arxiv.org/abs/2312.03700)] [[HF Demo🤗](https://huggingface.co/spaces/csuhan/OneLLM)] [[Modelscope Demo🤖](https://modelscope.cn/studios/csuhan/OneLLM)] [[Model🤗](https://huggingface.co/csuhan/OneLLM-7B)] [[Data](docs/Data.md)] 4 | 5 | ## News 6 | 7 | - **2024.02.27** OneLLM is accepted by **CVPR 2024**!🎉 8 | - **2023.12.01** Release model weights and inference code. 9 | 10 | ## Contents 11 | 12 | - [Install](#install) 13 | - [Models](#models) 14 | - [Demo](#demo) 15 | - [Data](#data) 16 | - [Evaluation](#evaluation) 17 | - [Training](#training) 18 | 19 | ### Install 20 | 21 | 1. Clone the repo into a local folder. 22 | 23 | ```bash 24 | git clone https://github.com/csuhan/OneLLM 25 | 26 | cd OneLLM 27 | ``` 28 | 29 | 2. Install packages. 30 | 31 | ```bash 32 | conda create -n onellm python=3.9 -y 33 | conda activate onellm 34 | 35 | pip install -r requirements.txt 36 | 37 | # install pointnet 38 | cd model/lib/pointnet2 39 | python setup.py install 40 | ``` 41 | 42 | 3. Install Apex. (Optional) 43 | 44 | ```bash 45 | git clone https://github.com/NVIDIA/apex 46 | cd apex 47 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 48 | ``` 49 | 50 | ### Models 51 | 52 | We provide a preview model on the Hugging Face at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B). 53 | 54 | ### Demo 55 | 56 | **Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM). 57 | 58 | **Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally. 59 | 60 | ```bash 61 | python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth 62 | ``` 63 | 64 | **CLI Demo:** 65 | ```bash 66 | python demos/cli.py --image_path ${IMAGE_PATH} --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth 67 | ``` 68 | 69 | ### Data 70 | 71 | Please check [Data.md](docs/Data.md) for more detail. 72 | 73 | ### Evaluation 74 | 75 | Please check [Evaluation.md](docs/Evaluation.md) for more detail. 76 | 77 | ### Training 78 | 79 | #### Image-Text Pretraining 80 | 81 | **Single Node 8-GPU Training**: [exps/image_text_pretrain_8gpu.sh](exps/image_text_pretrain_8gpu.sh) 82 |
Show More 83 | 84 | ```bash 85 | torchrun --nproc_per_node=8 main_pretrain.py \ 86 | --epochs 1 --dataset image \ 87 | --batch_size 40 --accum_iter 16 \ 88 | --model_parallel_size 1 \ 89 | --data_parallel sdp \ 90 | --save_consolidated \ 91 | --llama_type onellm \ 92 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 93 | --llama_config config/llama2/7B.json \ 94 | --tokenizer_path config/llama2/tokenizer.model \ 95 | --auto_resume \ 96 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 97 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \ 98 | --save_freq 1000 \ 99 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log 100 | ``` 101 |
102 | 103 | **Multi Nodes DDP Training**: 104 | 105 | Run N scripts on N nodes at the time, then we can launch a multi-node DDP training. Following is an example script for one node: 106 | 107 | ``` 108 | MASTER_ADDR=IP_ADDRESS_OF_NODE_1 109 | NNODES=N 110 | MASTER_PORT=29500 111 | NPROC_PER_NODE=8 112 | 113 | RANK=0 or 1 or ... or N 114 | 115 | bash 116 | torchrun \ 117 | --nnodes=$NNODES \ 118 | --nproc_per_node=8 \ 119 | --node_rank=$RANK \ 120 | --master_port=$MASTER_PORT \ 121 | --master_addr=$MASTER_ADDR \ 122 | main_pretrain.py \ 123 | --epochs 1 --dataset image \ 124 | --batch_size 40 --accum_iter 16 \ 125 | --model_parallel_size 1 \ 126 | --data_parallel sdp \ 127 | --save_consolidated \ 128 | --llama_type onellm \ 129 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 130 | --llama_config config/llama2/7B.json \ 131 | --tokenizer_path config/llama2/tokenizer.model \ 132 | --auto_resume \ 133 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 134 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \ 135 | --save_freq 1000 \ 136 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log 137 | ``` 138 | 139 | **Multi Node SLURM Training**: [exps/image_text_pretrain_slurm.sh](exps/image_text_pretrain_slurm.sh) 140 |
Show More 141 | 142 | ```bash 143 | #!/bin/bash 144 | #SBATCH --gres=gpu:8 145 | #SBATCH -n 16 146 | #SBATCH -N 2 147 | #SBATCH --cpus-per-task=16 148 | 149 | srun python -u main_pretrain.py \ 150 | --epochs 1 --dataset image \ 151 | --batch_size 40 --accum_iter 8 \ 152 | --model_parallel_size 1 \ 153 | --data_parallel sdp \ 154 | --save_consolidated \ 155 | --llama_type onellm \ 156 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 157 | --llama_config config/llama2/7B.json \ 158 | --tokenizer_path config/llama2/tokenizer.model \ 159 | --auto_resume \ 160 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 161 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \ 162 | --save_freq 1000 \ 163 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log 164 | ``` 165 |
166 | 167 | #### Multimodal-Text Pretraining 168 | 169 | **Stage II Pretraining**: Assume we have the pretrained `${IMAGE_TEXT_MODEL}`, run [exps/multimodal_text_pretrain_stage2.sh](exps/multimodal_text_pretrain_stage2.sh) for video-audio-point-text pretraining. 170 | 171 | **Stage III Pretraining**: Assume we have the pretrained `${STAGE2_MODEL}`, run [exps/multimodal_text_pretrain_stage3.sh](exps/multimodal_text_pretrain_stage3.sh) for depth-normal-imu-fmri-text pretraining. 172 | 173 | #### Instruction Tuning 174 | 175 | Assume we have the pretrained `${STAGE3_MODEL}`, run [exps/multimodal_text_finetune.sh](exps/multimodal_text_finetune.sh) for multimodal instruction tuning. 176 | 177 | ## Citation 178 | 179 | ``` 180 | @InProceedings{han2023onellm, 181 | title={OneLLM: One Framework to Align All Modalities with Language}, 182 | author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu}, 183 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 184 | year={2024} 185 | } 186 | ``` 187 | 188 | ## Acknowledgement 189 | 190 | [LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge) 191 | 192 | ## License 193 | This project is developed based on Llama 2, please refer to the [LLAMA 2 Community License](LICENSE_llama2). 194 | -------------------------------------------------------------------------------- /config/llama2/7B.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1} 2 | -------------------------------------------------------------------------------- /config/llama2/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuhan/OneLLM/8587a4768cf376fb41f7d586e21de5d1ab1ca365/config/llama2/tokenizer.model -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import torchaudio 4 | import torchvision.transforms.functional as F 5 | import torchvision.transforms as transforms 6 | 7 | try: 8 | from torchvision.transforms import InterpolationMode 9 | 10 | BICUBIC = InterpolationMode.BICUBIC 11 | except ImportError: 12 | BICUBIC = Image.BICUBIC 13 | 14 | T_random_resized_crop = transforms.Compose([ 15 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC, 16 | antialias=None), # 3 is bicubic 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 19 | 20 | 21 | # image transform 22 | transform_img_train = transforms.Compose([ 23 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=( 24 | 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 27 | 28 | 29 | class PairRandomResizedCrop(transforms.RandomResizedCrop): 30 | def forward(self, imgs): 31 | i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) 32 | return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs] 33 | 34 | 35 | class PairToTensor(transforms.ToTensor): 36 | def __call__(self, pics): 37 | return [F.to_tensor(pic) for pic in pics] 38 | 39 | 40 | class PairNormalize(transforms.Normalize): 41 | def forward(self, tensors): 42 | return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors] 43 | 44 | 45 | transform_pairimg_train = transforms.Compose([ 46 | PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=( 47 | 0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic 48 | PairToTensor(), 49 | PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 50 | 51 | 52 | def pc_norm(pc): 53 | """ pc: NxC, return NxC """ 54 | xyz = pc[:, :3] 55 | other_feature = pc[:, 3:] 56 | 57 | centroid = torch.mean(xyz, dim=0) 58 | xyz = xyz - centroid 59 | m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1))) 60 | xyz = xyz / m 61 | 62 | pc = torch.cat((xyz, other_feature), dim=1) 63 | return pc 64 | 65 | 66 | def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False): 67 | waveform, sr = torchaudio.load(wav_name) 68 | # assert sr == 16000, 'input audio sampling rate must be 16kHz' 69 | if sr != 16000: 70 | trans = torchaudio.transforms.Resample(sr, 16000) 71 | waveform = trans(waveform) 72 | 73 | waveform = waveform - waveform.mean() 74 | 75 | fbank = torchaudio.compliance.kaldi.fbank( 76 | waveform, htk_compat=True, sample_frequency=16000, use_energy=False, 77 | window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10) 78 | 79 | n_frames = fbank.shape[0] 80 | 81 | p = target_length - n_frames 82 | if p > 0: 83 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 84 | fbank = m(fbank) 85 | elif p < 0: 86 | fbank = fbank[0:target_length, :] 87 | 88 | if aug: 89 | freqm = torchaudio.transforms.FrequencyMasking(48) 90 | timem = torchaudio.transforms.TimeMasking(192) 91 | fbank = torch.transpose(fbank, 0, 1) 92 | fbank = fbank.unsqueeze(0) 93 | fbank = freqm(fbank) 94 | fbank = timem(fbank) 95 | fbank = fbank.squeeze(0) 96 | fbank = torch.transpose(fbank, 0, 1) 97 | 98 | fbank = (fbank - (-4.2677393)) / (4.5689974 * 2) 99 | return fbank -------------------------------------------------------------------------------- /data/imu_utils.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | import matplotlib.animation as animation 4 | from matplotlib import pyplot as plt 5 | import json 6 | from collections import defaultdict 7 | from bisect import bisect_left 8 | import os 9 | import torch 10 | import torchaudio 11 | torchaudio.set_audio_backend("sox_io") 12 | 13 | 14 | def load_json(json_path: str): 15 | """ 16 | Load a json file 17 | """ 18 | with open(json_path, "r", encoding="utf-8") as f_name: 19 | data = json.load(f_name) 20 | return data 21 | 22 | 23 | def check_window_signal(info_t, w_s, w_e): 24 | length = w_e - w_s 25 | frame_offset = int(w_s * info_t.sample_rate) 26 | num_frames = int(length * info_t.sample_rate) 27 | if frame_offset + num_frames > int(info_t.num_frames): 28 | return False 29 | else: 30 | return True 31 | 32 | 33 | def index_narrations(ann_path): 34 | narration_raw = load_json(ann_path) 35 | 36 | narration_dict = defaultdict(list) 37 | summary_dict = defaultdict(list) 38 | avg_len = [] 39 | for v_id, narr in narration_raw.items(): 40 | narr_list = [] 41 | summ_list = [] 42 | if "narration_pass_1" in narr: 43 | narr_list += narr["narration_pass_1"]["narrations"] 44 | summ_list += narr["narration_pass_1"]["summaries"] 45 | if "narration_pass_2" in narr: 46 | narr_list += narr["narration_pass_2"]["narrations"] 47 | summ_list += narr["narration_pass_2"]["summaries"] 48 | 49 | if len(narr_list) > 0: 50 | narration_dict[v_id] = [ 51 | ( 52 | float(n_t["timestamp_sec"]), 53 | n_t["narration_text"], 54 | n_t["annotation_uid"], 55 | n_t["timestamp_frame"], 56 | ) 57 | for n_t in narr_list 58 | ] 59 | avg_len.append(len(narration_dict[v_id])) 60 | else: 61 | narration_dict[v_id] = [] 62 | if len(summ_list) > 0: 63 | summary_dict[v_id] = [ 64 | ( 65 | float(s_t["start_sec"]), 66 | float(s_t["end_sec"]), 67 | s_t["summary_text"], 68 | ) 69 | for s_t in summ_list 70 | ] 71 | else: 72 | summary_dict[v_id] = [] 73 | # print(f"Number of Videos with narration {len(narration_dict)}") 74 | # print(f"Avg. narration length {np.mean(avg_len)}") 75 | # print(f"Number of Videos with summaries {len(summary_dict)}") 76 | return narration_dict, summary_dict 77 | 78 | 79 | def get_signal_info(signal_fn: str): 80 | return torchaudio.info(signal_fn) 81 | 82 | 83 | def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float): 84 | """ 85 | Given a signal track return the frames between video_start_sec and video_end_sec 86 | """ 87 | info_t = get_signal_info(signal_fn) 88 | 89 | length = video_end_sec - video_start_sec 90 | aframes, _ = torchaudio.load( 91 | signal_fn, 92 | normalize=True, 93 | frame_offset=int(video_start_sec * info_t.sample_rate), 94 | num_frames=int(length * info_t.sample_rate), 95 | ) 96 | return {"signal": aframes, "meta": info_t} 97 | 98 | 99 | def tosec(value): 100 | return value / 1000 101 | 102 | 103 | def toms(value): 104 | return value * 1000 105 | 106 | 107 | def delta(first_num: float, second_num: float): 108 | """Compute the absolute value of the difference of two numbers""" 109 | return abs(first_num - second_num) 110 | 111 | 112 | def padIMU(signal, duration_sec): 113 | """ 114 | Pad the signal if necessary 115 | """ 116 | expected_elements = round(duration_sec) * 200 117 | 118 | if signal.shape[0] > expected_elements: 119 | signal = signal[:expected_elements, :] 120 | elif signal.shape[0] < expected_elements: 121 | padding = expected_elements - signal.shape[0] 122 | padded_zeros = np.zeros((padding, 6)) 123 | signal = np.concatenate([signal, padded_zeros], 0) 124 | # signal = signal[:expected_elements, :] 125 | return signal 126 | 127 | 128 | def resample( 129 | signals: np.ndarray, 130 | timestamps: np.ndarray, 131 | original_sample_rate: int, 132 | resample_rate: int, 133 | ): 134 | """ 135 | Resamples data to new sample rate 136 | """ 137 | signals = torch.as_tensor(signals) 138 | timestamps = torch.from_numpy(timestamps).unsqueeze(-1) 139 | signals = torchaudio.functional.resample( 140 | waveform=signals.data.T, 141 | orig_freq=original_sample_rate, 142 | new_freq=resample_rate, 143 | ).T.numpy() 144 | 145 | nsamples = len(signals) 146 | 147 | period = 1 / resample_rate 148 | 149 | # timestamps are expected to be shape (N, 1) 150 | initital_seconds = timestamps[0] / 1e3 151 | 152 | ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds 153 | 154 | timestamps = (ntimes * 1e3).squeeze().numpy() 155 | return signals, timestamps 156 | 157 | 158 | def resampleIMU(signal, timestamps): 159 | sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps))))) 160 | # resample all to 200hz 161 | if sampling_rate != 200: 162 | signal, timestamps = resample(signal, timestamps, sampling_rate, 200) 163 | return signal, timestamps 164 | 165 | 166 | def get_imu_frames( 167 | imu_path, 168 | uid: str, 169 | video_start_sec: float, 170 | video_end_sec: float, 171 | ): 172 | """ 173 | Given a IMU signal return the frames between video_start_sec and video_end_sec 174 | """ 175 | signal = np.load(os.path.join(imu_path, f"{uid}.npy")) 176 | signal = signal.transpose() 177 | timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy")) 178 | 179 | if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]: 180 | return None 181 | 182 | start_id = bisect_left(timestamps, toms(video_start_sec)) 183 | end_id = bisect_left(timestamps, toms(video_end_sec)) 184 | 185 | # make sure the retrieved window interval are correct by a max of 1 sec margin 186 | if ( 187 | delta(video_start_sec, tosec(timestamps[start_id])) > 4 188 | or delta(video_end_sec, tosec(timestamps[end_id])) > 4 189 | ): 190 | return None 191 | 192 | # get the window 193 | if start_id == end_id: 194 | start_id -= 1 195 | end_id += 1 196 | signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id] 197 | 198 | if len(signal) < 10 or len(timestamps) < 10: 199 | return None 200 | # resample the signal at 200hz if necessary 201 | signal, timestamps = resampleIMU(signal, timestamps) 202 | 203 | # pad the signal if necessary 204 | signal = padIMU(signal, video_end_sec - video_start_sec) 205 | 206 | sample_dict = { 207 | "timestamp": timestamps, 208 | "signal": torch.tensor(signal.T), 209 | "sampling_rate": 200, 210 | } 211 | 212 | return sample_dict 213 | 214 | 215 | def display_animation(frames, title, save_path_gif): 216 | fig, ax = plt.subplots() 217 | frames = [[ax.imshow(frames[i])] for i in range(len(frames))] 218 | plt.title(title) 219 | ani = animation.ArtistAnimation(fig, frames) 220 | ani.save(save_path_gif, writer="imagemagick") 221 | plt.close() 222 | 223 | 224 | def display_animation_imu(frames, imu, title, save_path_gif): 225 | fig, (ax1, ax2, ax3) = plt.subplots(3, 1) 226 | ax1.set_title(title) 227 | ax2.set_title("Acc.") 228 | ax3.set_title("Gyro.") 229 | frames = [[ax1.imshow(frames[i])] for i in range(len(frames))] 230 | ani = animation.ArtistAnimation(fig, frames) 231 | 232 | ax2.plot(imu[0].cpu().numpy(), color="red") 233 | ax2.plot(imu[1].cpu().numpy(), color="blue") 234 | ax2.plot(imu[2].cpu().numpy(), color="green") 235 | ax3.plot(imu[3].cpu().numpy(), color="red") 236 | ax3.plot(imu[4].cpu().numpy(), color="blue") 237 | ax3.plot(imu[5].cpu().numpy(), color="green") 238 | plt.tight_layout() 239 | ani.save(save_path_gif, writer="imagemagick") 240 | plt.close() 241 | 242 | 243 | def filter_narration(narration_text: str) -> bool: 244 | if "#c" in narration_text.lower(): 245 | return True 246 | return False 247 | 248 | 249 | def clean_narration_text(narration_text: str) -> str: 250 | return ( 251 | narration_text.replace("#C C ", "") 252 | .replace("#C", "") 253 | .replace("#unsure", "something") 254 | .strip() 255 | .strip(string.punctuation) 256 | .lower()[:128] 257 | ) 258 | -------------------------------------------------------------------------------- /data/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List 2 | try: 3 | from petrel_client.client import Client 4 | except: 5 | print("petrel_client is not installed.") 6 | import json 7 | import torch 8 | from io import BytesIO 9 | import random 10 | import multiprocessing as mp 11 | import copy 12 | from torch.utils.data import Dataset 13 | from pathlib import Path 14 | 15 | import pandas as pd 16 | from PIL import Image, ImageFile 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | from model.tokenizer import Tokenizer 20 | import numpy as np 21 | import warnings 22 | import bisect 23 | 24 | from .data_utils import make_audio_features, pc_norm, transform_pairimg_train, transform_img_train 25 | from . import video_utils 26 | from .imu_utils import get_imu_frames 27 | 28 | 29 | DATASETS = dict( 30 | image=dict( 31 | train=( 32 | sorted(list(Path('datasets/Pretrain/image/laion400m_new').glob('*.csv')))[:1000]+ 33 | sorted(list(Path('datasets/Pretrain/image/laion_coco').glob('*.csv')))[:1000] 34 | ), 35 | test= ('datasets/Pretrain/image/coco_caps_train2017.csv',), 36 | max_words=96, 37 | ), 38 | audio=dict( 39 | train=( 40 | sorted(list(Path('datasets/Pretrain/audio/wavcaps').glob('*.csv'))) 41 | ), 42 | test=None, 43 | max_words=96, 44 | ), 45 | video=dict( 46 | train=( 47 | "datasets/Pretrain/video/webvid/results_2M_train_ceph.csv",), 48 | test=None, 49 | max_words=96 50 | ), 51 | point=dict( 52 | train=( 53 | "datasets/Pretrain/point/pointllm/cap3d_pointllm_train.csv", 54 | ), 55 | test=( 56 | "datasets/Pretrain/point/pointllm/cap3d_pointllm_test.csv", 57 | ), 58 | max_words=96, 59 | ), 60 | rgbd=dict( 61 | train=( 62 | 'datasets/Pretrain/image/cc3m.csv',), 63 | test=None, 64 | replace_list=['/cc3m/', '/cc3m_depth/'], 65 | max_words=96, 66 | ), 67 | rgbn=dict( 68 | train=( 69 | 'datasets/Pretrain/image/cc3m.csv', 70 | ), 71 | test=None, 72 | replace_list=['/cc3m/', '/cc3m_normal/'], 73 | max_words=96, 74 | ), 75 | fmri=dict( 76 | train=( 77 | "datasets/Pretrain/fmri/nsd/train_sub01.csv", 78 | "datasets/Pretrain/fmri/nsd/val_sub01.csv", 79 | ), 80 | test=( 81 | "datasets/Pretrain/fmri/nsd/test_sub01.csv", 82 | ), 83 | max_words=96 84 | ), 85 | imu=dict( 86 | train=( 87 | "datasets/Pretrain/imu/ego4d/window_idx_train.json", 88 | ), 89 | test=( 90 | "datasets/Pretrain/imu/ego4d/window_idx_val.json", 91 | ), 92 | imu_path="datasets/Pretrain/imu/ego4d/v2/processed_imu/", 93 | max_words=96, 94 | ) 95 | ) 96 | 97 | 98 | class PretrainDataset(Dataset): 99 | def __init__(self, dataset='image', partition='train', epochs=1, tokenizer_path=None, petrel_conf=None): 100 | self.dataset = dataset 101 | input_filenames = DATASETS[dataset][partition] 102 | 103 | self.petrel_conf = petrel_conf 104 | self.client = None 105 | self.partition = partition 106 | manager = mp.Manager() 107 | self.datas = manager.list() 108 | self.captions = manager.list() 109 | print('loading csv...') 110 | for input_filename in input_filenames: 111 | print(input_filename) 112 | if dataset != 'imu': 113 | chunk = pd.read_csv(input_filename, sep='\t', on_bad_lines='skip', lineterminator='\n') 114 | self.datas.extend(chunk['url'].tolist()) 115 | self.captions.extend(chunk['caption'].tolist()) 116 | else: 117 | self.datas = json.load(open(input_filename)) 118 | self.imu_path = DATASETS[dataset]['imu_path'] 119 | 120 | self.max_words = DATASETS[dataset]['max_words'] 121 | self.tokenizer = Tokenizer(model_path=tokenizer_path) 122 | 123 | self.epochs = epochs 124 | 125 | def __len__(self): 126 | return int(len(self.datas) * self.epochs) 127 | 128 | def load_trans_image(self, image_path): 129 | image = Image.open(image_path).convert('RGB') 130 | image = transform_img_train(image) 131 | return image 132 | 133 | def load_trans_image_from_ceph(self, image_path): 134 | if self.client is None: 135 | self.client = Client(conf_path=self.petrel_conf) 136 | image = self.client.get(image_path) 137 | 138 | image = memoryview(image) 139 | image = Image.open(BytesIO(image)).convert('RGB') 140 | image = transform_img_train(image) 141 | return image 142 | 143 | def load_audio(self, audio_path): 144 | fbank = make_audio_features(audio_path, mel_bins=128, aug=True) 145 | fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024] 146 | return fbank 147 | 148 | def load_video(self, video_path): 149 | video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5) 150 | return video_feats[:, :, 0] 151 | 152 | def load_video_from_ceph(self, video_path): 153 | if self.client is None: 154 | self.client = Client(conf_path=self.petrel_conf) 155 | video = self.client.get(video_path) 156 | video = memoryview(video) 157 | video = BytesIO(video) 158 | video_feats = video_utils.load_and_transform_video_data(video, video_path, clip_duration=1, clips_per_video=5) 159 | return video_feats[:, :, 0] 160 | 161 | def load_point(self, point_path): 162 | point_feat = np.load(point_path) 163 | # [8196, 6] 164 | point_feat = torch.tensor(point_feat) 165 | point_feat = pc_norm(point_feat) 166 | 167 | return point_feat 168 | 169 | def load_rgbx(self, image_path): 170 | replace_list = DATASETS[self.dataset]['replace_list'] 171 | x_image_path = image_path.replace(replace_list[0], replace_list[1]) 172 | image = Image.open(image_path).convert('RGB') 173 | x_image = Image.open(x_image_path).convert('RGB') 174 | x_image = x_image.resize(image.size[-2:]) 175 | 176 | image, x_image = transform_pairimg_train([image, x_image]) 177 | 178 | # [2, 3, H, W] 179 | image = torch.stack([image, x_image], dim=0) 180 | return image 181 | 182 | def load_rgbx_from_ceph(self, image_path): 183 | if self.client is None: 184 | self.client = Client(conf_path=self.petrel_conf) 185 | 186 | replace_list = DATASETS[self.dataset]['replace_list'] 187 | x_image_path = image_path.replace(replace_list[0], replace_list[1]) 188 | 189 | image = self.client.get(image_path) 190 | image = memoryview(image) 191 | image = Image.open(BytesIO(image)).convert('RGB') 192 | 193 | x_image = self.client.get(x_image_path) 194 | x_image = memoryview(x_image) 195 | x_image = Image.open(BytesIO(x_image)).convert('RGB') 196 | 197 | x_image = x_image.resize(image.size[-2:]) 198 | 199 | image, x_image = transform_pairimg_train([image, x_image]) 200 | 201 | # [2, 3, H, W] 202 | image = torch.stack([image, x_image], dim=0) 203 | return image 204 | 205 | def load_fmri(self, fmri_path): 206 | data = np.load(fmri_path) 207 | data = data.mean(axis=0) 208 | data = torch.tensor(data[None]) 209 | return data 210 | 211 | def load_imu(self, data_dict): 212 | uid = data_dict["video_uid"] 213 | w_s = data_dict["window_start"] 214 | w_e = data_dict["window_end"] 215 | 216 | imu_data = get_imu_frames( 217 | self.imu_path, uid, 218 | video_start_sec=w_s, 219 | video_end_sec=w_e, 220 | ) 221 | if imu_data is None: 222 | raise ValueError 223 | return imu_data['signal'] 224 | 225 | def __getitem__(self, index): 226 | index = index % len(self.datas) 227 | if self.dataset != 'imu': 228 | data_path, caption = self.datas[index], self.captions[index] 229 | else: 230 | data_dict = self.datas[index] 231 | caption = data_dict['text'] 232 | data_path = data_dict["video_uid"] 233 | 234 | if isinstance(caption, list): 235 | caption = random.choice(caption) 236 | caption = str(caption) 237 | 238 | try: 239 | if self.dataset == 'image': 240 | data = self.load_trans_image_from_ceph(data_path) 241 | elif self.dataset == 'audio': 242 | data = self.load_audio(data_path) 243 | elif self.dataset == 'video': 244 | data = self.load_video_from_ceph(data_path) 245 | elif self.dataset == 'point': 246 | data = self.load_point(data_path) 247 | elif self.dataset in ['rgbn', 'rgbd']: 248 | data = self.load_rgbx(data_path) 249 | elif self.dataset == 'fmri': 250 | data = self.load_fmri(data_path) 251 | elif self.dataset == 'imu': 252 | data_dict = self.datas[index] 253 | data = self.load_imu(data_dict) 254 | except: 255 | print(data_path, 'Not Found') 256 | rand_idx = random.randint(0, len(self)) 257 | return self.__getitem__(rand_idx) 258 | 259 | caption_tokens = torch.tensor(self.tokenizer.encode(caption, bos=True, eos=True), dtype=torch.int64) 260 | input_data = caption_tokens 261 | 262 | padding = self.max_words - input_data.shape[0] 263 | if padding > 0: 264 | input_data = torch.cat((input_data, torch.zeros(padding, dtype=torch.int64))) 265 | elif padding < 0: 266 | input_data = input_data[:self.max_words] 267 | labels = copy.deepcopy(input_data) 268 | 269 | if self.partition != 'train': 270 | return input_data, labels, data, data_path, self.dataset, caption 271 | 272 | return input_data, labels, data, data_path, self.dataset 273 | 274 | def __repr__(self): 275 | return f" None: 292 | super().__init__() 293 | self.datasets = list(datasets) 294 | assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] 295 | # for d in self.datasets: 296 | # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" 297 | self.cumulative_sizes = self.cumsum(self.datasets) 298 | 299 | def __len__(self): 300 | return self.cumulative_sizes[-1] 301 | 302 | def __getitem__(self, idx): 303 | if idx < 0: 304 | if -idx > len(self): 305 | raise ValueError("absolute value of index should not exceed dataset length") 306 | idx = len(self) + idx 307 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 308 | if dataset_idx == 0: 309 | sample_idx = idx 310 | else: 311 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 312 | return self.datasets[dataset_idx][sample_idx] 313 | 314 | @property 315 | def cummulative_sizes(self): 316 | warnings.warn("cummulative_sizes attribute is renamed to " 317 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 318 | return self.cumulative_sizes 319 | 320 | def get_indices(self, batch_size, world_size=1, rank_id=0): 321 | random.seed(0) 322 | real_batch_size = batch_size * world_size 323 | batch_train_indices = [] 324 | num_batches = [] 325 | for i in range(len(self.datasets)): 326 | # get train_indices 327 | start_idx = self.cumulative_sizes[i-1] if i>0 else 0 328 | end_idx = self.cumulative_sizes[i] 329 | train_indice = list(range(start_idx, end_idx)) 330 | random.shuffle(train_indice) 331 | num_batch = int(len(self.datasets[i]) / real_batch_size) 332 | num_batches.append(num_batch) 333 | # get batch indices for each rank 334 | batch_train_indice = [ 335 | train_indice[batch*real_batch_size:(batch+1)*real_batch_size][rank_id::world_size] 336 | for batch in range(num_batch) 337 | ] 338 | batch_train_indices.append(batch_train_indice) 339 | min_num_batch = min(num_batches) 340 | 341 | train_indices = [] 342 | for batch in range(min_num_batch): 343 | for i in range(len(self.datasets)): 344 | train_indices.extend(batch_train_indices[i][batch]) 345 | 346 | return train_indices -------------------------------------------------------------------------------- /data/video_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from pytorchvideo import transforms as pv_transforms 5 | from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler 6 | from pytorchvideo.data.encoded_video import EncodedVideo 7 | from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord 8 | from torchvision import transforms 9 | from torchvision.transforms._transforms_video import NormalizeVideo 10 | 11 | 12 | def get_clip_timepoints(clip_sampler, duration): 13 | # Read out all clips in this video 14 | all_clips_timepoints = [] 15 | is_last_clip = False 16 | end = 0.0 17 | while not is_last_clip: 18 | start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) 19 | all_clips_timepoints.append((start, end)) 20 | return all_clips_timepoints 21 | 22 | 23 | 24 | def crop_boxes(boxes, x_offset, y_offset): 25 | """ 26 | Perform crop on the bounding boxes given the offsets. 27 | Args: 28 | boxes (ndarray or None): bounding boxes to perform crop. The dimension 29 | is `num boxes` x 4. 30 | x_offset (int): cropping offset in the x axis. 31 | y_offset (int): cropping offset in the y axis. 32 | Returns: 33 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 34 | `num boxes` x 4. 35 | """ 36 | cropped_boxes = boxes.copy() 37 | cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset 38 | cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset 39 | 40 | return cropped_boxes 41 | 42 | 43 | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): 44 | """ 45 | Perform uniform spatial sampling on the images and corresponding boxes. 46 | Args: 47 | images (tensor): images to perform uniform crop. The dimension is 48 | `num frames` x `channel` x `height` x `width`. 49 | size (int): size of height and weight to crop the images. 50 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width 51 | is larger than height. Or 0, 1, or 2 for top, center, and bottom 52 | crop if height is larger than width. 53 | boxes (ndarray or None): optional. Corresponding boxes to images. 54 | Dimension is `num boxes` x 4. 55 | scale_size (int): optional. If not None, resize the images to scale_size before 56 | performing any crop. 57 | Returns: 58 | cropped (tensor): images with dimension of 59 | `num frames` x `channel` x `size` x `size`. 60 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 61 | `num boxes` x 4. 62 | """ 63 | assert spatial_idx in [0, 1, 2] 64 | ndim = len(images.shape) 65 | if ndim == 3: 66 | images = images.unsqueeze(0) 67 | height = images.shape[2] 68 | width = images.shape[3] 69 | 70 | if scale_size is not None: 71 | if width <= height: 72 | width, height = scale_size, int(height / width * scale_size) 73 | else: 74 | width, height = int(width / height * scale_size), scale_size 75 | images = torch.nn.functional.interpolate( 76 | images, 77 | size=(height, width), 78 | mode="bilinear", 79 | align_corners=False, 80 | ) 81 | 82 | y_offset = int(math.ceil((height - size) / 2)) 83 | x_offset = int(math.ceil((width - size) / 2)) 84 | 85 | if height > width: 86 | if spatial_idx == 0: 87 | y_offset = 0 88 | elif spatial_idx == 2: 89 | y_offset = height - size 90 | else: 91 | if spatial_idx == 0: 92 | x_offset = 0 93 | elif spatial_idx == 2: 94 | x_offset = width - size 95 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] 96 | cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None 97 | if ndim == 3: 98 | cropped = cropped.squeeze(0) 99 | return cropped, cropped_boxes 100 | 101 | 102 | class SpatialCrop(nn.Module): 103 | """ 104 | Convert the video into 3 smaller clips spatially. Must be used after the 105 | temporal crops to get spatial crops, and should be used with 106 | -2 in the spatial crop at the slowfast augmentation stage (so full 107 | frames are passed in here). Will return a larger list with the 108 | 3x spatial crops as well. 109 | """ 110 | 111 | def __init__(self, crop_size: int = 224, num_crops: int = 3): 112 | super().__init__() 113 | self.crop_size = crop_size 114 | if num_crops == 3: 115 | self.crops_to_ext = [0, 1, 2] 116 | self.flipped_crops_to_ext = [] 117 | elif num_crops == 1: 118 | self.crops_to_ext = [1] 119 | self.flipped_crops_to_ext = [] 120 | else: 121 | raise NotImplementedError("Nothing else supported yet") 122 | 123 | def forward(self, videos): 124 | """ 125 | Args: 126 | videos: A list of C, T, H, W videos. 127 | Returns: 128 | videos: A list with 3x the number of elements. Each video converted 129 | to C, T, H', W' by spatial cropping. 130 | """ 131 | assert isinstance(videos, list), "Must be a list of videos after temporal crops" 132 | assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" 133 | res = [] 134 | for video in videos: 135 | for spatial_idx in self.crops_to_ext: 136 | res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) 137 | if not self.flipped_crops_to_ext: 138 | continue 139 | flipped_video = transforms.functional.hflip(video) 140 | for spatial_idx in self.flipped_crops_to_ext: 141 | res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) 142 | return res 143 | 144 | 145 | def load_and_transform_video_data( 146 | video_file, 147 | video_path, 148 | clip_duration=2, 149 | clips_per_video=5, 150 | sample_rate=16000, 151 | with_audio=False 152 | ): 153 | video_transform = transforms.Compose( 154 | [ 155 | pv_transforms.ShortSideScale(224), 156 | NormalizeVideo( 157 | mean=(0.48145466, 0.4578275, 0.40821073), 158 | std=(0.26862954, 0.26130258, 0.27577711), 159 | ), 160 | ] 161 | ) 162 | 163 | clip_sampler = ConstantClipsPerVideoSampler( 164 | clip_duration=clip_duration, clips_per_video=clips_per_video 165 | ) 166 | frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) 167 | 168 | if isinstance(video_file, str): 169 | video = EncodedVideo.from_path( 170 | video_file, 171 | decoder="decord", 172 | decode_audio=with_audio, 173 | # **{"sample_rate": sample_rate}, 174 | ) 175 | else: 176 | video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate) 177 | 178 | all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) 179 | 180 | all_video = [] 181 | for clip_timepoints in all_clips_timepoints: 182 | # Read the clip, get frames 183 | clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) 184 | if clip is None: 185 | raise ValueError("No clip found") 186 | video_clip = frame_sampler(clip["video"]) 187 | video_clip = video_clip / 255.0 # since this is float, need 0-1 188 | 189 | all_video.append(video_clip) 190 | 191 | all_video = [video_transform(clip) for clip in all_video] 192 | all_video = SpatialCrop(224, num_crops=3)(all_video) 193 | 194 | all_video = torch.stack(all_video, dim=0) 195 | 196 | if not with_audio: 197 | return all_video 198 | else: 199 | return all_video, clip['audio'] 200 | 201 | if __name__ == '__main__': 202 | video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4" 203 | video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True) 204 | import pdb;pdb.set_trace() 205 | -------------------------------------------------------------------------------- /demos/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0]) 4 | import argparse 5 | import multiprocessing as mp 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | from fairscale.nn.model_parallel import initialize as fs_init 10 | from util.misc import setup_for_distributed 11 | from util.misc import default_tensor_type 12 | from model.meta import MetaModel 13 | from data.conversation_lib import conv_templates 14 | from PIL import Image 15 | import torchvision.transforms as transforms 16 | 17 | 18 | T_random_resized_crop = transforms.Compose([ 19 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3, 20 | antialias=None), # 3 is bicubic 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 23 | 24 | 25 | def model_worker(args: argparse.Namespace) -> None: 26 | rank = 0 27 | world_size = len(args.gpu_ids) 28 | gpu_id = args.gpu_ids[rank] 29 | dist.init_process_group( 30 | backend="nccl", rank=rank, world_size=world_size, 31 | init_method=f"tcp://{args.master_addr}:{args.master_port}", 32 | ) 33 | print(f"| distributed init on worker {rank}/{world_size}. " 34 | f"using gpu: {gpu_id}") 35 | fs_init.initialize_model_parallel(world_size) 36 | torch.cuda.set_device(gpu_id) 37 | 38 | torch.manual_seed(1) 39 | np.random.seed(1) 40 | 41 | # set the print behavior. 42 | setup_for_distributed(rank == 0) 43 | 44 | target_dtype = { 45 | "bf16": torch.bfloat16, 46 | "fp16": torch.float16 47 | }[args.dtype] 48 | with default_tensor_type(dtype=target_dtype, device="cuda"): 49 | model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path) 50 | print("Loading pretrained weights ...") 51 | checkpoint = torch.load(args.pretrained_path, map_location='cpu') 52 | msg = model.load_state_dict(checkpoint, strict=False) 53 | print("load result:\n", msg) 54 | model.cuda() 55 | model.eval() 56 | print(f"Model = {str(model)}") 57 | 58 | print('Model is ready. Please input') 59 | 60 | conv = conv_templates["v1"].copy() 61 | 62 | image = Image.open(args.image_path).convert('RGB') 63 | image = T_random_resized_crop(image).unsqueeze(0).cuda().to(target_dtype) 64 | while True: 65 | try: 66 | inp = input(f"{conv.roles[0]}: ") 67 | except EOFError: 68 | inp = "" 69 | if not inp: 70 | print("exit...") 71 | break 72 | 73 | print(f"{conv.roles[1]}: ", end="") 74 | 75 | conv.append_message(conv.roles[0], inp) 76 | conv.append_message(conv.roles[1], None) 77 | 78 | with torch.cuda.amp.autocast(dtype=target_dtype): 79 | print(conv.get_prompt()) 80 | response = model.generate([conv.get_prompt()], image, 256, temperature=0.1, top_p=0.75, modal=["image"]) 81 | response = response[0] 82 | response = response[len(conv.get_prompt()):].split('###')[0] 83 | print(response) 84 | conv.messages[-1][-1] = response 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser("LLaMA2-Accessory Chat Demo") 89 | group = parser.add_mutually_exclusive_group() 90 | group.add_argument("--image_path", type=str, help="path to the input image") 91 | group.add_argument( 92 | "--gpu_ids", type=int, nargs="+", 93 | help="A list of space-separated gpu ids to run the model on. " 94 | "The model will span across GPUs in tensor-parallel mode." 95 | ) 96 | parser.add_argument( 97 | "--tokenizer_path", type=str, 98 | default="config/llama2/tokenizer.model", 99 | help="Path to the tokenizer.model file provided along with the LLaMA " 100 | "model." 101 | ) 102 | parser.add_argument( 103 | "--llama_type", default="onellm", type=str, metavar="MODEL", 104 | help="LLaMA model type." 105 | ) 106 | parser.add_argument( 107 | "--llama_config", type=str, required=True, 108 | default="config/llama2/7B.json", 109 | help="Path to the llama model config json." 110 | ) 111 | parser.add_argument( 112 | "--model_max_seq_len", type=int, default=2048, 113 | help="Max sequence length accepted by the pretrained model." 114 | ) 115 | parser.add_argument( 116 | "--pretrained_path", type=str, required=True, 117 | help="Path to the llama model checkpoints. A list of checkpoints is " 118 | "supported and will be merged from left to right.") 119 | parser.add_argument( 120 | "--master_port", type=int, default=23862, 121 | help="A port used by the PyTorch distributed module to initialize." 122 | ) 123 | parser.add_argument( 124 | "--master_addr", type=str, default="127.0.0.1", 125 | help="An address used by the PyTorch distributed module to initialize." 126 | ) 127 | parser.add_argument( 128 | "--dtype", type=str, choices=["fp16", "bf16"], default="fp16", 129 | help="The dtype used for model weights and inference." 130 | ) 131 | args = parser.parse_args() 132 | 133 | # check and setup gpu_ids to use 134 | if args.gpu_ids is None: 135 | if args.n_gpus is None: 136 | args.n_gpus = 1 137 | assert args.n_gpus > 0, ( 138 | "The demo currently must run on a positive number of GPUs." 139 | ) 140 | args.gpu_ids = list(range(args.n_gpus)) 141 | 142 | # using the default "fork" method messes up some imported libs (e.g., 143 | # pandas) 144 | mp.set_start_method("spawn") 145 | model_worker(args) 146 | -------------------------------------------------------------------------------- /demos/multi_turn_mm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0]) 4 | 5 | import argparse 6 | import multiprocessing as mp 7 | import numpy as np 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from fairscale.nn.model_parallel import initialize as fs_init 14 | 15 | import gradio as gr 16 | from util.misc import setup_for_distributed 17 | from util.misc import default_tensor_type 18 | from model.meta import MetaModel 19 | from data.conversation_lib import conv_templates, SeparatorStyle 20 | from PIL import Image 21 | import torchvision.transforms as transforms 22 | from data.fintune_dataset import make_audio_features 23 | from data import video_utils 24 | 25 | 26 | T_random_resized_crop = transforms.Compose([ 27 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3, 28 | antialias=None), # 3 is bicubic 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 31 | 32 | 33 | def load_audio(audio_path): 34 | fbank = make_audio_features(audio_path, mel_bins=128) 35 | fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024] 36 | return fbank 37 | 38 | def load_video(video_path): 39 | video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5) 40 | return video_feats[:, :, 0] 41 | 42 | 43 | def model_worker( 44 | rank: int, args: argparse.Namespace, barrier: mp.Barrier, 45 | request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None, 46 | ) -> None: 47 | """ 48 | The worker function that manipulates the GPU to run the inference. 49 | Exact n_gpu workers are started, with each one operating on a separate GPU. 50 | 51 | Args: 52 | rank (int): Distributed rank of the worker. 53 | args (argparse.Namespace): All command line arguments. 54 | barrier (multiprocessing.Barrier): A barrier used to delay the start 55 | of Web UI to be after the start of the model. 56 | """ 57 | 58 | world_size = len(args.gpu_ids) 59 | gpu_id = args.gpu_ids[rank] 60 | dist.init_process_group( 61 | backend="nccl", rank=rank, world_size=world_size, 62 | init_method=f"tcp://{args.master_addr}:{args.master_port}", 63 | ) 64 | print(f"| distributed init on worker {rank}/{world_size}. " 65 | f"using gpu: {gpu_id}") 66 | fs_init.initialize_model_parallel(world_size) 67 | torch.cuda.set_device(gpu_id) 68 | 69 | torch.manual_seed(1) 70 | np.random.seed(1) 71 | 72 | # set the print behavior. 73 | setup_for_distributed(rank == 0) 74 | 75 | target_dtype = { 76 | "bf16": torch.bfloat16, 77 | "fp16": torch.float16 78 | }[args.dtype] 79 | with default_tensor_type(dtype=target_dtype, device="cuda"): 80 | model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path) 81 | print("Loading pretrained weights ...") 82 | checkpoint = torch.load(args.pretrained_path, map_location='cpu') 83 | msg = model.load_state_dict(checkpoint, strict=False) 84 | print("load result:\n", msg) 85 | model.cuda() 86 | model.eval() 87 | print(f"Model = {str(model)}") 88 | 89 | barrier.wait() 90 | 91 | while True: 92 | img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get() 93 | if 'image' in modality and img_path is not None: 94 | image = Image.open(img_path).convert('RGB') 95 | inputs = T_random_resized_crop(image) 96 | elif 'video' in modality and video_path is not None: 97 | inputs = load_video(video_path) 98 | elif 'audio' in modality and audio_path is not None: 99 | inputs = load_audio(audio_path) 100 | else: 101 | inputs = None 102 | 103 | if inputs is not None: 104 | inputs = inputs[None].cuda().to(target_dtype) 105 | 106 | conv = conv_templates["v1"].copy() 107 | for user, bot in chatbot: 108 | conv.append_message(conv.roles[0], user) 109 | conv.append_message(conv.roles[1], bot) 110 | 111 | with torch.cuda.amp.autocast(dtype=target_dtype): 112 | print(conv.get_prompt()) 113 | for stream_response in model.stream_generate( 114 | conv.get_prompt(), inputs, 115 | max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, 116 | modal = modality 117 | ): 118 | conv_sep = ( 119 | conv.sep 120 | if conv.sep_style == SeparatorStyle.SINGLE 121 | else conv.sep2 122 | ) 123 | end_pos = stream_response["text"].find(conv_sep) 124 | if end_pos != -1: 125 | stream_response["text"] = ( 126 | stream_response['text'][:end_pos].rstrip() + "\n" 127 | ) 128 | stream_response["end_of_content"] = True 129 | 130 | # keep a few characters if not end_of_content to avoid sending 131 | # part of conv_sep before all of it is generated. 132 | if not stream_response["end_of_content"]: 133 | if len(stream_response["text"]) < len(conv_sep): 134 | continue 135 | stream_response["text"] = ( 136 | stream_response["text"][:-len(conv_sep)] 137 | ) 138 | 139 | if response_queue is not None: 140 | response_queue.put(stream_response) 141 | 142 | if stream_response["end_of_content"]: 143 | break 144 | 145 | 146 | def gradio_worker( 147 | request_queues: List[mp.Queue], response_queue: mp.Queue, 148 | args: argparse.Namespace, barrier: mp.Barrier, 149 | ) -> None: 150 | """ 151 | The gradio worker is responsible for displaying the WebUI and relay the 152 | requests to model workers. It should be launched only once. 153 | 154 | Args: 155 | request_queues (List[mp.Queue]): A list of request queues (one for 156 | each model worker). 157 | args (argparse.Namespace): All command line arguments. 158 | barrier (multiprocessing.Barrier): A barrier used to delay the start 159 | of Web UI to be after the start of the model. 160 | """ 161 | 162 | def show_user_input(msg, chatbot): 163 | return "", chatbot + [[msg, None]] 164 | 165 | def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality): 166 | for queue in request_queues: 167 | queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality)) 168 | while True: 169 | content_piece = response_queue.get() 170 | chatbot[-1][1] = content_piece["text"] 171 | yield chatbot 172 | if content_piece["end_of_content"]: 173 | break 174 | 175 | def undo(chatbot): 176 | if len(chatbot) > 0: 177 | chatbot = chatbot[:-1] 178 | return chatbot 179 | 180 | def clear(): 181 | chatbot = [] 182 | msg = "" 183 | return chatbot, msg 184 | 185 | CSS =""" 186 | .contain { display: flex; flex-direction: column; } 187 | #component-0 { height: 100%; } 188 | #chatbot { flex-grow: 1; overflow: auto;} 189 | """ 190 | with gr.Blocks(css=CSS) as demo: 191 | gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language") 192 | with gr.Row(equal_height=True): 193 | with gr.Column(scale=1): 194 | img_path = gr.Image(label='Image Input', type='filepath') 195 | video_path = gr.Video(label='Video Input') 196 | audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload']) 197 | modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities') 198 | 199 | with gr.Column(scale=2): 200 | chatbot = gr.Chatbot(elem_id="chatbot") 201 | msg = gr.Textbox() 202 | 203 | with gr.Row(): 204 | submit_button = gr.Button("Submit", variant="primary") 205 | undo_button = gr.Button("Undo") 206 | clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality]) 207 | with gr.Row(): 208 | max_gen_len = gr.Slider( 209 | minimum=1, maximum=args.model_max_seq_len // 2, 210 | value=args.model_max_seq_len // 2, interactive=True, 211 | label="Single-turn max response length", 212 | ) 213 | gen_t = gr.Slider( 214 | minimum=0, maximum=1, value=0.1, interactive=True, 215 | label="Temperature", 216 | ) 217 | top_p = gr.Slider( 218 | minimum=0, maximum=1, value=0.75, interactive=True, 219 | label="Top-p", 220 | ) 221 | msg.submit( 222 | show_user_input, [msg, chatbot], [msg, chatbot], 223 | ).then( 224 | stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, 225 | ) 226 | submit_button.click( 227 | show_user_input, [msg, chatbot], [msg, chatbot], 228 | ).then( 229 | stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, 230 | ) 231 | undo_button.click(undo, chatbot, chatbot) 232 | # img_path.change(clear, [], [chatbot, msg]) 233 | barrier.wait() 234 | demo.queue(api_open=True).launch(share=True, max_threads=1) 235 | 236 | 237 | if __name__ == "__main__": 238 | parser = argparse.ArgumentParser("Chat Demo") 239 | group = parser.add_mutually_exclusive_group() 240 | group.add_argument( 241 | "--gpu_ids", type=int, nargs="+", 242 | help="A list of space-separated gpu ids to run the model on. " 243 | "The model will span across GPUs in tensor-parallel mode." 244 | ) 245 | parser.add_argument( 246 | "--tokenizer_path", type=str, 247 | default="config/llama2/tokenizer.model", 248 | help="Path to the tokenizer.model file provided along with the LLaMA " 249 | "model." 250 | ) 251 | parser.add_argument( 252 | "--llama_type", default="onellm", type=str, metavar="MODEL", 253 | help="LLaMA model type." 254 | ) 255 | parser.add_argument( 256 | "--llama_config", type=str, required=True, 257 | default="config/llama2/7B.json", 258 | help="Path to the llama model config json." 259 | ) 260 | parser.add_argument( 261 | "--model_max_seq_len", type=int, default=2048, 262 | help="Max sequence length accepted by the pretrained model." 263 | ) 264 | parser.add_argument( 265 | "--pretrained_path", type=str, required=True, 266 | help="Path to the llama model checkpoints. A list of checkpoints is " 267 | "supported and will be merged from left to right.") 268 | parser.add_argument( 269 | "--master_port", type=int, default=23862, 270 | help="A port used by the PyTorch distributed module to initialize." 271 | ) 272 | parser.add_argument( 273 | "--master_addr", type=str, default="127.0.0.1", 274 | help="An address used by the PyTorch distributed module to initialize." 275 | ) 276 | parser.add_argument( 277 | "--dtype", type=str, choices=["fp16", "bf16"], default="fp16", 278 | help="The dtype used for model weights and inference." 279 | ) 280 | args = parser.parse_args() 281 | 282 | # using the default "fork" method messes up some imported libs (e.g., 283 | # pandas) 284 | mp.set_start_method("spawn") 285 | 286 | # setup the queues and start the model workers 287 | request_queues = [] 288 | response_queue = mp.Queue() 289 | worker_processes = [] 290 | barrier = mp.Barrier(len(args.gpu_ids) + 1) 291 | for rank, gpu_id in enumerate(args.gpu_ids): 292 | request_queue = mp.Queue() 293 | rank_response_queue = response_queue if rank == 0 else None 294 | process = mp.Process( 295 | target=model_worker, 296 | args=(rank, args, barrier, request_queue, rank_response_queue), 297 | ) 298 | process.start() 299 | worker_processes.append(process) 300 | request_queues.append(request_queue) 301 | 302 | gradio_worker(request_queues, response_queue, args, barrier) 303 | -------------------------------------------------------------------------------- /docs/Data.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | ### Data Format 4 | Here we give an overview of data format. For details, please check the data loading code: [data/pretrain_dataset.py]() and [data/fintune_dataset.py]() 5 | 6 | #### Pretraining Data 7 | All the data except IMU are organized in `.csv` format. Each `.csv` has two columns: `caption` and `url`. `\t` is used as the delimiter. For example, 8 | ``` 9 | caption url 10 | Woman receiving a foot massage at health spa Stock Photo cluster_p_ssd:s3://laion400m_mmg_ssd/29347/293477138.jpg 11 | Long injury list troubles Paul Hart as Portsmouth search for some Cup form cluster_p_ssd:s3://laion400m_mmg_ssd/43069/430692001.jpg 12 | ... ... 13 | ``` 14 | 15 | #### Instruction Tuning Data 16 | All finetuning data are converted into multi-turn conversation format. The `.json` file contains a list of training samples, where each sample contains the following keys: `id`, `image` and `conversations`. For example, 17 | ``` 18 | {'id': '000000033471', 'image': 'InstructionTuning/image/coco/train2017/000000033471.jpg', 'conversations': [{'from': 'human', 'value': 'What are the colors of the bus in the image?'}, {'from': 'gpt', 'value': 'The bus in the image is white and red.'}, {'from': 'human', 'value': 'What feature can be seen on the back of the bus?'}, {'from': 'gpt', 'value': 'The back of the bus features an advertisement.'}]} 19 | ``` 20 | 21 | 22 | ### Download Links 23 | 24 | | Stage | Pretraining | | Instruction Tuning | | 25 | |----------|-------------|----------|--------------------|----------| 26 | | Modality | Dataset | Download | Dataset | Download | 27 | | Image | [LAION-400M](https://laion.ai/blog/laion-400-open-dataset) | [link](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md) | LLaVA-mix665K | [link](https://github.com/haotian-liu/LLaVA#visual-instruction-tuning) | 28 | | | LAION-COCO | [link](https://laion.ai/blog/laion-coco) | COCO Caption | [link](https://cocodataset.org/#download) | 29 | | Video | WebVid-2.5M | [link](https://github.com/m-bain/webvid) | [MSRVTT Caption](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) | [link](https://www.mediafire.com/folder/h14iarbs62e7p/shared) | 30 | | | | | MSRVTT-QA | [link](https://github.com/xudejing/video-question-answering) | 31 | | | | | [Video Conversation](https://github.com/joez17/ChatBridge/blob/main/custom_datasets/valor_data/DATASET.md#download-multis) | [link](https://drive.google.com/file/d/1C7k8flfITJ1GxMwFSvEmBFGyevDZl1ke/view?usp=drive_link) | 32 | | Audio | [WavCaps](https://github.com/XinhaoMei/WavCaps) | [link](https://huggingface.co/datasets/cvssp/WavCaps) | [AudioCaps](https://audiocaps.github.io/) | [link](https://github.com/cdjkim/audiocaps) | 33 | | | | | [Audio Conversation](https://github.com/joez17/ChatBridge/blob/main/custom_datasets/valor_data/DATASET.md#download-multis) | [link](https://drive.google.com/file/d/1C7k8flfITJ1GxMwFSvEmBFGyevDZl1ke/view?usp=drive_link) | 34 | | Point | [Cap3D](https://github.com/crockwell/Cap3D) | [link](https://huggingface.co/datasets/RunsenXu/PointLLM/tree/main) | [Point Conversation](https://github.com/OpenRobotLab/PointLLM) | [link](https://huggingface.co/datasets/RunsenXu/PointLLM) | 35 | | Depth | CC3M | [link](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) | LLaVA-150K | [link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) | 36 | | Normal | CC3M | [link](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) | LLaVA-150K | [link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) | 37 | | IMU | Ego4D | [link](https://ego4d-data.org/docs/data/imu/) | Ego4D | [link](https://ego4d-data.org/docs/data/imu/) | 38 | | fMRI | [NSD](https://naturalscenesdataset.org) | [link](https://huggingface.co/datasets/pscotti/naturalscenesdataset) | [NSD](https://naturalscenesdataset.org) | [link](https://huggingface.co/datasets/pscotti/naturalscenesdataset) | 39 | 40 | **Notes** 41 | - The depth/normal map are generated from [CC3M](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md) and 50K random-subset of LLaVA-150K using a pretrained [DPT](https://github.com/EPFL-VILAB/omnidata/tree/main/omnidata_tools/torch#run-our-models-on-your-own-image). 42 | - The [IMU data](https://ego4d-data.org/docs/data/imu/) is preprocessed with [this script](https://github.com/facebookresearch/imu2clip/blob/main/dataset/ego4d/preprocessing_scripts/extract_imu.py). 43 | 44 | 45 | ### Instruction Tuning Data 46 | 47 | **Annotation Download:** Please download the annotation from [this link](https://huggingface.co/datasets/csuhan/OneLLM_InstructionTuning) and put them under `datasets/InstructionTuning`. 48 | 49 | Then download original datasets from the above table and put them under corresponding folders. The file structure should be: 50 | 51 | ``` 52 | datasets 53 | └── InstructionTuning 54 | ├── audio 55 | │ ├── audioset2 56 | │ ├── audiocap_train.json 57 | │ ├── audiocap_val.json 58 | │ └── audio_conversation.json 59 | ├── depth_normal 60 | │ ├── depth 61 | │ ├── normal 62 | │ ├── llava_instruct_50k_depth.json 63 | │ └── llava_instruct_50k_normal.json 64 | ├── fmri 65 | │ ├── NSD 66 | │ └── fmri_fixed_train.json 67 | ├── image 68 | │ ├── coco 69 | │ ├── gqa 70 | │ ├── ocr_vqa 71 | │ ├── vg 72 | │ ├── cococap_train.json 73 | │ ├── llava_v1_5_mix665k_image.json 74 | │ └── llava_v1_5_mix665k_text.json 75 | ├── imu 76 | │ ├── ego4d 77 | │ └── imu_fixed_50k.json 78 | ├── point 79 | │ ├── pointllm/8192_npy 80 | │ └── pointllm_70k.json 81 | └── video 82 | ├── msr-vtt/MSR-VTT 83 | ├── msrvtt_cap_test.json 84 | ├── msrvtt_cap_trainval.json 85 | ├── msrvtt_vqa_test.json 86 | ├── msrvtt_vqa_train.json 87 | ├── msrvtt_vqa_val.json 88 | ├── video_complex_reasoning_10k.json 89 | ├── video_conversation_10k.json 90 | └── video_detail_10k.json 91 | ``` -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | ## Evaluation 2 | 3 | **Annotation Download:** Download the annotations of evaluation datasets from: [csuhan/OneLLM_Eval](https://huggingface.co/datasets/csuhan/OneLLM_Eval), and put it under `datasets/Eval`. 4 | 5 | ### Image-Text Evaluation 6 | 7 | #### COCO Caption 8 | 9 | - Download [COCO2014 Val](http://images.cocodataset.org/zips/val2014.zip) and put it under `datasets/InstructionTuning/image/coco/val2014` 10 | - Fill `pretrained_path` in [eval/image_cap_cococap.py]() and run: `python eval/image_cap_cococap.py` 11 | - Install `https://github.com/salaniz/pycocoevalcap` 12 | - Evaluate with [eval/caption_eval.py]() 13 | 14 | #### MMVet 15 | 16 | - Download MMVet from [mm-vet.zip](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) and put it under `datasets/Eval/image/mm-vet` 17 | - Fill `pretrained_path` in [eval/image_bench_mmvet.py]() and run: `python eval/image_bench_mmvet.py` 18 | - Submit the result file to [Oneline Eval Server](https://huggingface.co/spaces/whyu/MM-Vet_Evaluator) 19 | 20 | ### Video-Text Evaluation 21 | 22 | #### MSVD QA 23 | - Download [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) video clips from [this link](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar) and put it under `datasets/Eval/video/MSVD/YouTubeClips` 24 | - Fill `pretrained_path` in [eval/video_qa_msvd.py]() and run: `python eval/video_qa_msvd.py`. 25 | 26 | ### Audio-Text Evaluation 27 | 28 | #### Clotho Caption 29 | - Download [Clothov2](https://zenodo.org/records/4783391) evaluation set from [this link](https://zenodo.org/records/4783391/files/clotho_audio_evaluation.7z?download=1) and put it under `datasets/Eval/audio/clothov2/evaluation` 30 | - Fill `pretrained_path` in [eval/audio_cap_clothov2.py]() and run: `python eval/audio_cap_clothov2.py`. 31 | - Evaluate with [eval/caption_eval.py](). 32 | 33 | ### Point-Text Evaluation 34 | 35 | #### PointLLM Caption 36 | - Download PointLLM data from [this link](https://huggingface.co/datasets/RunsenXu/PointLLM) 37 | - Fill `pretrained_path` in [eval/point_cap_pointllm.py]() and run: `python eval/point_cap_pointllm.py`. 38 | - Evaluate with [eval/caption_eval.py](). The annotation file is at [datasets/Eval/point/pointllm_test_cococap.json]() 39 | 40 | ### Depth/Normal-Text Evaluation 41 | 42 | TODO 43 | 44 | ### IMU-Text Evaluation 45 | 46 | #### Ego4D IMU Caption 47 | 48 | - Download Ego4D IMU data. Please refer to [docs/Data.md](). 49 | - Fill `IMU_PATH` and `pretrained_path` in [eval/imu_cap_ego4d.py]() and run: `python eval/imu_cap_ego4d.py`. 50 | - Evaluate with [eval/caption_eval.py]() 51 | 52 | ### fMRI-Text Evaluation 53 | 54 | #### NSD Caption 55 | - Download NSD data. Please refer to [docs/Data.md](). 56 | - Fill `pretrained_path` in [eval/fmri_cap_nsd.py]() and run: `python eval/fmri_cap_nsd.py`. 57 | - Evaluate with [eval/caption_eval.py]() -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import contextlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import util.misc as misc 9 | import util.lr_sched as lr_sched 10 | 11 | def train_one_epoch(model: torch.nn.Module, 12 | data_loader, optimizer: torch.optim.Optimizer, 13 | epoch: int, start_iter, loss_scaler, 14 | log_writer=None, 15 | args=None): 16 | model.train(True) 17 | metric_logger = misc.MetricLogger(delimiter=" ") 18 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 19 | header = 'Epoch: [{}]'.format(epoch) 20 | print_freq = 10 21 | 22 | accum_iter = args.accum_iter 23 | 24 | optimizer.zero_grad() 25 | 26 | if log_writer is not None: 27 | print('log_dir: {}'.format(log_writer.log_dir)) 28 | for data_iter_step, data_img in enumerate( 29 | metric_logger.log_every(data_loader, print_freq, header, start_iter), start=start_iter 30 | ): 31 | if len(data_img) == 4: 32 | examples, labels, image, modal = data_img 33 | elif len(data_img) == 3: 34 | examples, labels, modal = data_img 35 | image = None 36 | if data_iter_step % accum_iter == 0: 37 | # lr_sched.adjust_learning_rate(optimizer, data_iter_step, args) 38 | lr_sched.adjust_learning_rate_epoch(optimizer, data_iter_step / len(data_loader) + epoch, args) 39 | update_grad = (data_iter_step + 1) % accum_iter == 0 40 | 41 | autocast_ctx = { 42 | "bf16": torch.cuda.amp.autocast(dtype=torch.bfloat16), 43 | "fp16": torch.cuda.amp.autocast(dtype=torch.float16), 44 | "tf32": contextlib.nullcontext(), 45 | }[args.precision] 46 | backward_ctx = contextlib.nullcontext() if update_grad else model.no_sync() 47 | 48 | with autocast_ctx: 49 | i_loss = model(examples, labels, image, modal) 50 | i_loss_value = i_loss.item() 51 | if not math.isfinite(i_loss_value): 52 | print("[Rank {}] i_loss is {}, stopping training".format(dist.get_rank(), i_loss_value), force=True) 53 | # print(image_paths, force=True) 54 | sys.exit(1) 55 | loss_value = i_loss_value 56 | with backward_ctx: 57 | grad_norm = loss_scaler( 58 | i_loss / accum_iter, optimizer, model, 59 | parameters=model.parameters(), 60 | update_grad=update_grad, 61 | clip_grad=None if args.clip_grad <= 0 else args.clip_grad, 62 | ) 63 | if update_grad: 64 | assert grad_norm is not None 65 | metric_logger.update(grad_norm=grad_norm) 66 | 67 | if update_grad: 68 | optimizer.zero_grad() 69 | 70 | torch.cuda.synchronize() 71 | 72 | metric_logger.update(loss=loss_value) 73 | metric_logger.update(iloss=i_loss_value) 74 | 75 | lr = optimizer.param_groups[0]["lr"] 76 | metric_logger.update(lr=lr) 77 | 78 | # save checkpoint 79 | if data_iter_step % 1000 == 0 and data_iter_step != 0: 80 | misc.save_model( 81 | output_dir=args.output_dir, 82 | args=args, epoch=epoch, iteration=data_iter_step, model=model, optimizer=optimizer, 83 | loss_scaler=loss_scaler, dataset_state=None) 84 | 85 | if update_grad: 86 | loss_value_reduce = misc.all_reduce_mean(loss_value) 87 | i_loss_value_reduce = misc.all_reduce_mean(i_loss_value) 88 | if update_grad: 89 | grad_norm_reduce = misc.all_reduce_mean(grad_norm) 90 | 91 | if log_writer is not None and update_grad: 92 | log_writer.add_scalar('train_loss', loss_value_reduce, data_iter_step) 93 | log_writer.add_scalar('i_train_loss', i_loss_value_reduce, data_iter_step) 94 | if update_grad: 95 | log_writer.add_scalar('grad_norm', grad_norm_reduce, data_iter_step) 96 | log_writer.add_scalar('lr', lr, data_iter_step) 97 | 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import contextlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import util.misc as misc 9 | import util.lr_sched as lr_sched 10 | 11 | 12 | def train_one_epoch(model: torch.nn.Module, 13 | data_loader, optimizer: torch.optim.Optimizer, 14 | device: torch.device, epoch: int, start_iter, loss_scaler, 15 | log_writer=None, 16 | args=None): 17 | model.train(True) 18 | metric_logger = misc.MetricLogger(delimiter=" ") 19 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 20 | header = 'Epoch: [{}]'.format(epoch) 21 | print_freq = 10 22 | 23 | accum_iter = args.accum_iter 24 | 25 | optimizer.zero_grad() 26 | 27 | if log_writer is not None: 28 | print('log_dir: {}'.format(log_writer.log_dir)) 29 | for data_iter_step, data_img in enumerate( 30 | metric_logger.log_every(data_loader, print_freq, header, start_iter), start=start_iter 31 | ): 32 | examples, labels, image, image_paths, modal = data_img 33 | if data_iter_step % accum_iter == 0: 34 | lr_sched.adjust_learning_rate(optimizer, data_iter_step, args) 35 | update_grad = (data_iter_step + 1) % accum_iter == 0 36 | 37 | autocast_ctx = { 38 | "bf16": torch.cuda.amp.autocast(dtype=torch.bfloat16), 39 | "fp16": torch.cuda.amp.autocast(dtype=torch.float16), 40 | "tf32": contextlib.nullcontext(), 41 | }[args.precision] 42 | backward_ctx = contextlib.nullcontext() if update_grad else model.no_sync() 43 | 44 | with autocast_ctx: 45 | i_loss = model(examples, labels, image, modal) 46 | i_loss_value = i_loss.item() 47 | if not math.isfinite(i_loss_value): 48 | print("[Rank {}] i_loss is {}, stopping training".format(dist.get_rank(), i_loss_value), force=True) 49 | print(image_paths, force=True) 50 | sys.exit(1) 51 | loss_value = i_loss_value 52 | with backward_ctx: 53 | grad_norm = loss_scaler( 54 | i_loss / accum_iter, optimizer, model, 55 | parameters=model.parameters(), 56 | update_grad=update_grad, 57 | clip_grad=None if args.clip_grad <= 0 else args.clip_grad, 58 | ) 59 | if update_grad: 60 | assert grad_norm is not None 61 | metric_logger.update(grad_norm=grad_norm) 62 | 63 | if update_grad: 64 | optimizer.zero_grad() 65 | 66 | torch.cuda.synchronize() 67 | 68 | metric_logger.update(loss=loss_value) 69 | metric_logger.update(iloss=i_loss_value) 70 | 71 | lr = optimizer.param_groups[0]["lr"] 72 | metric_logger.update(lr=lr) 73 | 74 | # save checkpoint 75 | if (data_iter_step % args.save_freq == 0 and data_iter_step != 0) or data_iter_step == len(data_loader)-1: 76 | misc.save_model( 77 | output_dir=args.output_dir, 78 | args=args, epoch=epoch, iteration=data_iter_step, model=model, optimizer=optimizer, 79 | loss_scaler=loss_scaler, dataset_state=None) 80 | 81 | if update_grad: 82 | loss_value_reduce = misc.all_reduce_mean(loss_value) 83 | i_loss_value_reduce = misc.all_reduce_mean(i_loss_value) 84 | if update_grad: 85 | grad_norm_reduce = misc.all_reduce_mean(grad_norm) 86 | 87 | if log_writer is not None and update_grad: 88 | log_writer.add_scalar('train_loss', loss_value_reduce, data_iter_step) 89 | log_writer.add_scalar('i_train_loss', i_loss_value_reduce, data_iter_step) 90 | if update_grad: 91 | log_writer.add_scalar('grad_norm', grad_norm_reduce, data_iter_step) 92 | log_writer.add_scalar('lr', lr, data_iter_step) 93 | 94 | # gather the stats from all processes 95 | metric_logger.synchronize_between_processes() 96 | print("Averaged stats:", metric_logger) 97 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 98 | -------------------------------------------------------------------------------- /eval/audio_cap_clothov2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import json 4 | import os 5 | import torch 6 | from model.meta import MetaModel 7 | import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | import multiprocessing as mp 10 | from fairscale.nn.model_parallel import initialize as fs_init 11 | from util.misc import default_tensor_type 12 | from util.misc import setup_for_distributed 13 | import numpy as np 14 | import torch.distributed as dist 15 | from data.conversation_lib import conv_templates 16 | from data.data_utils import make_audio_features 17 | 18 | 19 | def load_audio(audio_path): 20 | fbank = make_audio_features(audio_path, mel_bins=128) 21 | fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024] 22 | return fbank 23 | 24 | 25 | class AudioTextDataset(Dataset): 26 | def __init__(self) -> None: 27 | super().__init__() 28 | audio_dir = "datasets/Eval/audio/clothov2/evaluation/" 29 | self.audio_anns = json.load(open("datasets/Eval/audio/clothov2/eval_clothocap_ann.json")) 30 | self.audio_ids = [x['id'] for x in self.audio_anns['images']] 31 | self.audio_names = [x['file_name'] for x in self.audio_anns['images']] 32 | self.audio_files = [os.path.join(audio_dir, x) for x in self.audio_names] 33 | 34 | def __len__(self): 35 | return len(self.audio_files) 36 | 37 | def __getitem__(self, index): 38 | audio_file = self.audio_files[index] 39 | return load_audio(audio_file), self.audio_names[index], self.audio_ids[index] 40 | 41 | if __name__ == "__main__": 42 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 43 | answer_path = "eval/results/eval_clotho.json" 44 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 45 | 46 | mp.set_start_method("spawn") 47 | dist.init_process_group( 48 | backend="nccl", rank=0, world_size=1, 49 | init_method=f"tcp://127.0.0.1:23560") 50 | fs_init.initialize_model_parallel(1) 51 | torch.cuda.set_device(0) 52 | torch.manual_seed(1) 53 | np.random.seed(1) 54 | 55 | # set the print behavior. 56 | setup_for_distributed(True) 57 | 58 | target_dtype = { 59 | "bf16": torch.bfloat16, 60 | "fp16": torch.float16 61 | }['fp16'] 62 | with default_tensor_type(dtype=target_dtype, device="cuda"): 63 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 64 | 65 | print("Loading pretrained weights ...") 66 | checkpoint = torch.load(pretrained_path, map_location='cpu') 67 | msg = model.load_state_dict(checkpoint, strict=False) 68 | print("load result:\n", msg) 69 | model.half().cuda() 70 | model.eval() 71 | print(f"Model = {str(model)}") 72 | 73 | def multi_modal_generate(images, inps, modal=['image']): 74 | images = images.cuda().to(target_dtype) 75 | prompts = [] 76 | for inp in inps: 77 | conv = conv_templates["v1"].copy() 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompts.append(conv.get_prompt()) 81 | 82 | with torch.cuda.amp.autocast(dtype=target_dtype): 83 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=modal) 84 | outputs = [] 85 | for response, prompt in zip(responses, prompts): 86 | response = response[len(prompt):].split('###')[0] 87 | response = response.strip() 88 | outputs.append(response) 89 | return outputs 90 | 91 | dataset = AudioTextDataset() 92 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) 93 | 94 | outputs = [] 95 | for data in tqdm.tqdm(dataloader): 96 | audios, audio_names, audio_ids = data 97 | prompts = ['Provide a one-sentence caption for the provided audio.'] * len(audios) 98 | results = multi_modal_generate(audios, prompts, modal=['audio']) 99 | 100 | for audio_name, audio_id, result in zip(audio_names, audio_ids, results): 101 | outputs.append({ 102 | 'image_id': audio_id.item(), 103 | 'caption': result.strip() 104 | }) 105 | 106 | with open(answer_path, 'w') as f: 107 | json.dump(outputs, f) -------------------------------------------------------------------------------- /eval/caption_eval.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from pycocoevalcap.eval import COCOEvalCap 3 | 4 | # COCO Caption 5 | annotation_file = 'datasets/Eval/image/coco_cap/coco_karpathy_val_gt.json' 6 | results_file = 'eval/results/eval_cococap.json' 7 | 8 | # Nocaps Caption 9 | # annotation_file = 'datasets/Eval/image/nocaps/nocaps_val_4500_captions.json' 10 | # results_file = 'eval/results/eval_nocaps.json' 11 | 12 | # Clotho Caption 13 | # annotation_file = 'datasets/Eval/audio/clothov2/eval_cococap_ann.json' 14 | # results_file = 'eval/results/clotho_13B.json' 15 | 16 | # AVSD 17 | # annotation_file = 'datasets/Eval/video/AVSD/test_set4DSTC7-AVSD_cococap.json' 18 | # results_file = 'eval/results/eval_avsd.json' 19 | 20 | # VATEX 21 | # annotation_file = 'datasets/Eval/video/vatex/vatex_cococap.json' 22 | # results_file = 'eval/results/eval_vatex.json' 23 | 24 | # VALOR32K 25 | # annotation_file = 'datasets/Eval/video/valor32k/test_ann_cococap.json' 26 | # results_file = 'eval/results/eval_videocap_valor.json' 27 | 28 | # fMRI Caption 29 | # annotation_file = "datasets/Eval/fmri/fmri_eval_cococap.json" 30 | # results_file = "eval/results/fmricap.json" 31 | 32 | # PointLLM Caption 33 | # annotation_file = "datasets/Eval/point/pointllm/pointllm_test_cococap.json" 34 | # results_file = "eval/results/eval_pointllm_cap.json" 35 | 36 | # IMU Caption 37 | # annotation_file = 'datasets/Eval/imu/imu_2000_cococap.json' 38 | # results_file = 'eval/results/imucap.json' 39 | 40 | # create coco object and coco_result object 41 | coco = COCO(annotation_file) 42 | coco_result = coco.loadRes(results_file) 43 | 44 | # create coco_eval object by taking coco and coco_result 45 | coco_eval = COCOEvalCap(coco, coco_result) 46 | 47 | # evaluate on a subset of images by setting 48 | # coco_eval.params['image_id'] = coco_result.getImgIds() 49 | # please remove this line when evaluating the full validation set 50 | coco_eval.params['image_id'] = coco_result.getImgIds() 51 | 52 | # evaluate results 53 | # SPICE will take a few minutes the first time, but speeds up due to caching 54 | coco_eval.evaluate() 55 | 56 | # print output evaluation scores 57 | for metric, score in coco_eval.eval.items(): 58 | print(f'{metric}: {score:.3f}') 59 | -------------------------------------------------------------------------------- /eval/fmri_cap_nsd.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import json 4 | import os 5 | import torch 6 | from model.meta import MetaModel 7 | import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | import multiprocessing as mp 10 | from fairscale.nn.model_parallel import initialize as fs_init 11 | from util.misc import default_tensor_type 12 | from util.misc import setup_for_distributed 13 | import numpy as np 14 | import torch.distributed as dist 15 | from data.conversation_lib import conv_templates 16 | 17 | 18 | def load_fmri(fmri_path): 19 | data = np.load(fmri_path) 20 | data = data.mean(axis=0) 21 | data = torch.tensor(data[None]) 22 | return data 23 | 24 | class CaptionDataset(Dataset): 25 | def __init__(self) -> None: 26 | super().__init__() 27 | self.fmri_anns = json.load(open("datasets/Eval/fmri/fmri_eval_cococap.json")) 28 | self.fmri_ids = [x['id'] for x in self.fmri_anns['images']] 29 | self.fmri_names = [x['file_name'] for x in self.fmri_anns['images']] 30 | self.fmri_files = self.fmri_names 31 | 32 | def __len__(self): 33 | return len(self.fmri_files) 34 | 35 | def __getitem__(self, index): 36 | fmri_file = self.fmri_files[index] 37 | return load_fmri(fmri_file), self.fmri_names[index], self.fmri_ids[index] 38 | 39 | 40 | if __name__ == "__main__": 41 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 42 | answer_path = "eval/results/eval_fmricap.json" 43 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 44 | 45 | mp.set_start_method("spawn") 46 | dist.init_process_group( 47 | backend="nccl", rank=0, world_size=1, 48 | init_method=f"tcp://127.0.0.1:23560") 49 | fs_init.initialize_model_parallel(1) 50 | torch.cuda.set_device(0) 51 | torch.manual_seed(1) 52 | np.random.seed(1) 53 | # set the print behavior. 54 | setup_for_distributed(True) 55 | 56 | target_dtype = { 57 | "bf16": torch.bfloat16, 58 | "fp16": torch.float16 59 | }['fp16'] 60 | with default_tensor_type(dtype=target_dtype, device="cuda"): 61 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 62 | 63 | print("Loading pretrained weights ...") 64 | checkpoint = torch.load(pretrained_path, map_location='cpu') 65 | msg = model.load_state_dict(checkpoint, strict=False) 66 | print("load result:\n", msg) 67 | model.half().cuda() 68 | model.eval() 69 | print(f"Model = {str(model)}") 70 | 71 | def multi_modal_generate(images, inps, modal=['image']): 72 | images = images.cuda().to(target_dtype) 73 | 74 | prompts = [] 75 | for inp in inps: 76 | conv = conv_templates["v1"].copy() 77 | conv.append_message(conv.roles[0], inp) 78 | conv.append_message(conv.roles[1], None) 79 | prompts.append(conv.get_prompt()) 80 | 81 | with torch.cuda.amp.autocast(dtype=target_dtype): 82 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=modal) 83 | outputs = [] 84 | for response, prompt in zip(responses, prompts): 85 | response = response[len(prompt):].split('###')[0] 86 | response = response.strip() 87 | outputs.append(response) 88 | return outputs 89 | 90 | dataset = CaptionDataset() 91 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) 92 | 93 | outputs = [] 94 | for data in tqdm.tqdm(dataloader): 95 | fmris, fmri_names, fmri_ids = data 96 | prompts = ['Provide a one-sentence caption for the provided fMRI data.'] * len(fmris) 97 | 98 | results = multi_modal_generate(fmris, prompts, modal=['fmri']) 99 | 100 | for fmri_name, fmri_id, result in zip(fmri_names, fmri_ids, results): 101 | outputs.append({ 102 | 'image_id': fmri_id.item(), 103 | 'caption': result.strip() 104 | }) 105 | print(fmri_name, fmri_id, result.strip()) 106 | print('='*10) 107 | 108 | with open(answer_path, 'w') as f: 109 | json.dump(outputs, f) 110 | -------------------------------------------------------------------------------- /eval/image_bench_mmvet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import json 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import torch 9 | import torch.distributed as dist 10 | import multiprocessing as mp 11 | from fairscale.nn.model_parallel import initialize as fs_init 12 | from util.misc import default_tensor_type 13 | from util.misc import setup_for_distributed 14 | import torchvision.transforms as transforms 15 | from model.meta import MetaModel 16 | from data.conversation_lib import conv_templates 17 | 18 | 19 | T_resized_center_crop = transforms.Compose([ 20 | transforms.Resize( 21 | 336, interpolation=transforms.InterpolationMode.BICUBIC 22 | ), 23 | transforms.CenterCrop(336), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 26 | 27 | 28 | if __name__ == "__main__": 29 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 30 | answer_path = "eval/results/eval_mmvet.json" 31 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 32 | 33 | mp.set_start_method("spawn") 34 | dist.init_process_group( 35 | backend="nccl", rank=0, world_size=1, 36 | init_method=f"tcp://127.0.0.1:23563") 37 | fs_init.initialize_model_parallel(1) 38 | torch.cuda.set_device(0) 39 | torch.manual_seed(1) 40 | np.random.seed(1) 41 | # set the print behavior. 42 | setup_for_distributed(True) 43 | 44 | target_dtype = { 45 | "bf16": torch.bfloat16, 46 | "fp16": torch.float16 47 | }['fp16'] 48 | with default_tensor_type(dtype=target_dtype, device="cuda"): 49 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 50 | 51 | print("Loading pretrained weights ...") 52 | checkpoint = torch.load(pretrained_path, map_location='cpu') 53 | msg = model.load_state_dict(checkpoint, strict=False) 54 | print("load result:\n", msg) 55 | model.half().cuda() 56 | model.eval() 57 | print(f"Model = {str(model)}") 58 | 59 | def multi_modal_generate(img_path, inp): 60 | 61 | conv = conv_templates["v1"].copy() 62 | if img_path is not None: 63 | image = Image.open(img_path).convert('RGB') 64 | image = T_resized_center_crop(image).unsqueeze(0).cuda().to(target_dtype) 65 | else: 66 | image = None 67 | 68 | conv.append_message(conv.roles[0], inp) 69 | conv.append_message(conv.roles[1], None) 70 | 71 | with torch.cuda.amp.autocast(dtype=target_dtype): 72 | response = model.generate([conv.get_prompt()], image, 256, temperature=0.1, top_p=0.75, modal=['image']) 73 | response = response[0] 74 | response = response[len(conv.get_prompt()):].split('###')[0] 75 | conv.messages[-1][-1] = response 76 | return response.strip() 77 | 78 | result = {} 79 | batch_size = 1 80 | print("Starting...") 81 | datas = json.load(open('datasets/Eval/image/mm-vet/mm-vet.json')) 82 | predictions = {} 83 | with torch.no_grad(): 84 | for image_name, data in tqdm(datas.items()): 85 | image_path = os.path.join('datasets/Eval/image/mm-vet/images', data['imagename']) 86 | pred = multi_modal_generate(image_path, data['question']) 87 | predictions[image_name]=pred 88 | 89 | with open(answer_path, 'w') as f: 90 | json.dump(predictions, f) -------------------------------------------------------------------------------- /eval/image_cap_cococap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import json 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.data import Dataset, DataLoader 11 | import multiprocessing as mp 12 | from fairscale.nn.model_parallel import initialize as fs_init 13 | from util.misc import default_tensor_type 14 | from util.misc import setup_for_distributed 15 | import torchvision.transforms as transforms 16 | from model.meta import MetaModel 17 | from data.conversation_lib import conv_templates 18 | 19 | 20 | T_resized_center_crop = transforms.Compose([ 21 | transforms.Resize( 22 | 224, interpolation=transforms.InterpolationMode.BICUBIC 23 | ), 24 | transforms.CenterCrop(224), 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) 27 | 28 | 29 | class CocoCapDataset(Dataset): 30 | def __init__(self) -> None: 31 | super().__init__() 32 | self.datas = json.load(open('datasets/Eval/image/coco_cap/coco_karpathy_val.json')) 33 | 34 | def __len__(self): 35 | return len(self.datas) 36 | 37 | def __getitem__(self, index): 38 | data = self.datas[index] 39 | image_path = os.path.join("datasets/InstructionTuning/image/coco/", data['image']) 40 | image = Image.open(image_path).convert('RGB') 41 | image = T_resized_center_crop(image) 42 | image_id = int(data['image'].split('_')[-1].split('.')[0]) 43 | question = 'Provide a one-sentence caption for the provided image.' 44 | return image, question, image_id 45 | 46 | 47 | if __name__ == "__main__": 48 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 49 | answer_path = "eval/results/eval_cococap.json" 50 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 51 | 52 | mp.set_start_method("spawn") 53 | dist.init_process_group( 54 | backend="nccl", rank=0, world_size=1, 55 | init_method=f"tcp://127.0.0.1:23560") 56 | fs_init.initialize_model_parallel(1) 57 | torch.cuda.set_device(0) 58 | torch.manual_seed(1) 59 | np.random.seed(1) 60 | # set the print behavior. 61 | setup_for_distributed(True) 62 | 63 | target_dtype = { 64 | "bf16": torch.bfloat16, 65 | "fp16": torch.float16 66 | }['fp16'] 67 | with default_tensor_type(dtype=target_dtype, device="cuda"): 68 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 69 | 70 | print("Loading pretrained weights ...") 71 | checkpoint = torch.load(pretrained_path, map_location='cpu') 72 | msg = model.load_state_dict(checkpoint, strict=False) 73 | print("load result:\n", msg) 74 | model.half().cuda() 75 | model.eval() 76 | print(f"Model = {str(model)}") 77 | 78 | def multi_modal_generate(images, inps): 79 | images = images.cuda().to(target_dtype) 80 | 81 | prompts = [] 82 | for inp in inps: 83 | conv = conv_templates["v1"].copy() 84 | conv.append_message(conv.roles[0], inp) 85 | conv.append_message(conv.roles[1], None) 86 | prompts.append(conv.get_prompt()) 87 | 88 | with torch.cuda.amp.autocast(dtype=target_dtype): 89 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=['image']) 90 | outputs = [] 91 | for response, prompt in zip(responses, prompts): 92 | response = response[len(prompt):].split('###')[0] 93 | response = response.strip() 94 | outputs.append(response) 95 | return outputs 96 | 97 | print("Starting...") 98 | dataset = CocoCapDataset() 99 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) 100 | 101 | predictions = [] 102 | with torch.no_grad(): 103 | for data in tqdm(dataloader): 104 | images, questions, image_ids = data 105 | preds = multi_modal_generate(images, questions) 106 | for question, pred, image_id in zip(questions, preds, image_ids): 107 | predictions.append({'image_id': image_id.item(), 'caption': pred}) 108 | 109 | with open(answer_path, 'w') as f: 110 | json.dump(predictions, f) -------------------------------------------------------------------------------- /eval/imu_cap_ego4d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import json 4 | import os 5 | import torch 6 | from model.meta import MetaModel 7 | import tqdm 8 | from torch.utils.data import Dataset, DataLoader 9 | import multiprocessing as mp 10 | from fairscale.nn.model_parallel import initialize as fs_init 11 | from util.misc import default_tensor_type 12 | from util.misc import setup_for_distributed 13 | import numpy as np 14 | import torch.distributed as dist 15 | from data.conversation_lib import conv_templates 16 | from data.imu_utils import get_imu_frames 17 | 18 | IMU_PATH="FILL/IMU/PATH/HERE" 19 | 20 | def load_imu(data_dict): 21 | uid = data_dict["video_uid"] 22 | w_s = data_dict["window_start"] 23 | w_e = data_dict["window_end"] 24 | 25 | imu_data = get_imu_frames( 26 | IMU_PATH, uid, 27 | video_start_sec=w_s, 28 | video_end_sec=w_e, 29 | ) 30 | if imu_data is None: 31 | raise ValueError 32 | return imu_data['signal'] 33 | 34 | 35 | class CaptionDataset(Dataset): 36 | def __init__(self) -> None: 37 | super().__init__() 38 | self.imu_anns = json.load(open("datasets/Eval/imu/Ego4D/imu_2000_cococap.json")) 39 | self.imu_ids = [x['id'] for x in self.imu_anns['images']] 40 | self.imu_names = [x['file_name'] for x in self.imu_anns['images']] 41 | self.imu_files = self.imu_names 42 | 43 | def __len__(self): 44 | return len(self.imu_files) 45 | 46 | def __getitem__(self, index): 47 | return load_imu(self.imu_anns['images'][index]), self.imu_names[index], self.imu_ids[index] 48 | 49 | if __name__ == "__main__": 50 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 51 | answer_path = "eval/results/eval_imucap.json" 52 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 53 | 54 | mp.set_start_method("spawn") 55 | dist.init_process_group( 56 | backend="nccl", rank=0, world_size=1, 57 | init_method=f"tcp://127.0.0.1:23560") 58 | fs_init.initialize_model_parallel(1) 59 | torch.cuda.set_device(0) 60 | torch.manual_seed(1) 61 | np.random.seed(1) 62 | # set the print behavior. 63 | setup_for_distributed(True) 64 | 65 | target_dtype = { 66 | "bf16": torch.bfloat16, 67 | "fp16": torch.float16 68 | }['fp16'] 69 | with default_tensor_type(dtype=target_dtype, device="cuda"): 70 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 71 | 72 | print("Loading pretrained weights ...") 73 | checkpoint = torch.load(pretrained_path, map_location='cpu') 74 | msg = model.load_state_dict(checkpoint, strict=False) 75 | print("load result:\n", msg) 76 | model.half().cuda() 77 | model.eval() 78 | print(f"Model = {str(model)}") 79 | 80 | def multi_modal_generate(images, inps, modal=['image']): 81 | images = images.cuda().to(target_dtype) 82 | 83 | prompts = [] 84 | for inp in inps: 85 | conv = conv_templates["v1"].copy() 86 | conv.append_message(conv.roles[0], inp) 87 | conv.append_message(conv.roles[1], None) 88 | prompts.append(conv.get_prompt()) 89 | 90 | with torch.cuda.amp.autocast(dtype=target_dtype): 91 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=modal) 92 | outputs = [] 93 | for response, prompt in zip(responses, prompts): 94 | response = response[len(prompt):].split('###')[0] 95 | response = response.strip() 96 | outputs.append(response) 97 | return outputs 98 | 99 | dataset = CaptionDataset() 100 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) 101 | 102 | outputs = [] 103 | for data in tqdm.tqdm(dataloader): 104 | imus, imu_names, imu_ids = data 105 | prompts = ['Describe the scene.'] * len(imus) 106 | 107 | results = multi_modal_generate(imus, prompts, modal=['imu']) 108 | 109 | for imu_name, imu_id, result in zip(imu_names, imu_ids, results): 110 | outputs.append({ 111 | 'image_id': imu_id.item(), 112 | 'caption': result.strip() 113 | }) 114 | print(imu_name, imu_id, result.strip()) 115 | print('='*10) 116 | 117 | with open(answer_path, 'w') as f: 118 | json.dump(outputs, f) -------------------------------------------------------------------------------- /eval/point_cap_pointllm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import torch 7 | import torch.distributed as dist 8 | from torch.utils.data import Dataset, DataLoader 9 | import multiprocessing as mp 10 | from fairscale.nn.model_parallel import initialize as fs_init 11 | from util.misc import default_tensor_type 12 | from util.misc import setup_for_distributed 13 | import numpy as np 14 | from model.meta import MetaModel 15 | from data.conversation_lib import conv_templates 16 | from data.data_utils import pc_norm 17 | 18 | 19 | class CaptionDataset(Dataset): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | self.anns = json.load(open('datasets/Eval/point/pointllm_test.json')) 23 | self.ids = list(self.anns.keys()) 24 | 25 | def __len__(self): 26 | return len(self.anns) 27 | 28 | def __getitem__(self, index): 29 | id = self.ids[index] 30 | caption = self.anns[id] 31 | 32 | file_path = f'datasets/Eval/point/pointllm/8192_npy/{id}_8192.npy' 33 | 34 | point_feat = np.load(file_path) 35 | point_feat = torch.tensor(point_feat) 36 | point_feat = pc_norm(point_feat) 37 | 38 | question = 'What is this?' 39 | answer = caption 40 | return point_feat, question, id, answer 41 | 42 | 43 | if __name__ == "__main__": 44 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 45 | answer_path = "eval/results/eval_pointllm_cap.json" 46 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 47 | 48 | mp.set_start_method("spawn") 49 | dist.init_process_group( 50 | backend="nccl", rank=0, world_size=1, 51 | init_method=f"tcp://127.0.0.1:23581") 52 | fs_init.initialize_model_parallel(1) 53 | torch.cuda.set_device(0) 54 | torch.manual_seed(1) 55 | np.random.seed(1) 56 | # set the print behavior. 57 | setup_for_distributed(True) 58 | 59 | target_dtype = { 60 | "bf16": torch.bfloat16, 61 | "fp16": torch.float16 62 | }['fp16'] 63 | with default_tensor_type(dtype=target_dtype, device="cuda"): 64 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 65 | 66 | print("Loading pretrained weights ...") 67 | checkpoint = torch.load(pretrained_path, map_location='cpu') 68 | msg = model.load_state_dict(checkpoint, strict=False) 69 | print("load result:\n", msg) 70 | model.half().cuda() 71 | model.eval() 72 | print(f"Model = {str(model)}") 73 | 74 | def multi_modal_generate(images, inps): 75 | images = images.cuda().to(target_dtype) 76 | prompts = [] 77 | for inp in inps: 78 | conv = conv_templates["v1"].copy() 79 | conv.append_message(conv.roles[0], inp) 80 | conv.append_message(conv.roles[1], None) 81 | prompts.append(conv.get_prompt()) 82 | with torch.cuda.amp.autocast(dtype=target_dtype): 83 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=['point']) 84 | outputs = [] 85 | for response, prompt in zip(responses, prompts): 86 | response = response[len(prompt):].split('###')[0] 87 | response = response.strip() 88 | outputs.append(response) 89 | return outputs 90 | 91 | print("Starting...") 92 | dataset = CaptionDataset() 93 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) 94 | 95 | predictions = [] 96 | with torch.no_grad(): 97 | for data in tqdm(dataloader): 98 | images, questions, ids, answers = data 99 | preds = multi_modal_generate(images, questions) 100 | 101 | for question, pred, id, answer in zip(questions, preds, ids, answers): 102 | predictions.append({ 103 | 'object_id': id, 104 | 'model_output': pred, 105 | 'ground_truth': answer 106 | }) 107 | 108 | with open(answer_path, 'w') as f: 109 | json.dump(predictions, f) 110 | -------------------------------------------------------------------------------- /eval/video_qa_msvd.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import torch 7 | import torch.distributed as dist 8 | from torch.utils.data import Dataset, DataLoader 9 | import multiprocessing as mp 10 | from fairscale.nn.model_parallel import initialize as fs_init 11 | from util.misc import default_tensor_type 12 | from util.misc import setup_for_distributed 13 | import numpy as np 14 | from model.meta import MetaModel 15 | from data.conversation_lib import conv_templates 16 | from data import video_utils 17 | 18 | 19 | def load_video(video_path): 20 | video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5) 21 | return video_feats[:, :, 0] 22 | 23 | 24 | class CaptionDataset(Dataset): 25 | def __init__(self) -> None: 26 | super().__init__() 27 | self.datas = json.load(open('datasets/Eval/video/MSVD/MSVD-QA/test_qa.json')) 28 | map_ids =[x.strip().split(' ') for x in open('datasets/Eval/video/MSVD/MSVD-QA/youtube_mapping.txt').readlines()] 29 | self.id_to_video_ids = {x[1]:x[0] for x in map_ids} 30 | 31 | def __len__(self): 32 | return len(self.datas) 33 | 34 | def __getitem__(self, index): 35 | data = self.datas[index] 36 | video_id = 'vid'+str(data['video_id']) 37 | video_name = self.id_to_video_ids[video_id] + '.avi' 38 | image_path = os.path.join("datasets/Eval/video/MSVD/YouTubeClips", video_name) 39 | image = load_video(image_path) 40 | question_id = data['id'] 41 | question = data['question'] + '\nAnswer the question using a single word or phrase.' 42 | answer = data['answer'] 43 | return image, question, question_id, answer 44 | 45 | 46 | if __name__ == "__main__": 47 | pretrained_path = "path/to/pretrained/ckpt/consolidated.00-of-01.pth" 48 | answer_path = "eval/results/eval_msvd.json" 49 | os.makedirs(os.path.dirname(answer_path), exist_ok=True) 50 | 51 | mp.set_start_method("spawn") 52 | dist.init_process_group( 53 | backend="nccl", rank=0, world_size=1, 54 | init_method=f"tcp://127.0.0.1:23563") 55 | fs_init.initialize_model_parallel(1) 56 | torch.cuda.set_device(0) 57 | torch.manual_seed(1) 58 | np.random.seed(1) 59 | # set the print behavior. 60 | setup_for_distributed(True) 61 | 62 | target_dtype = { 63 | "bf16": torch.bfloat16, 64 | "fp16": torch.float16 65 | }['fp16'] 66 | with default_tensor_type(dtype=target_dtype, device="cuda"): 67 | model = MetaModel("onellm", "config/llama2/7B.json", None, "config/llama2/tokenizer.model") 68 | 69 | print("Loading pretrained weights ...") 70 | checkpoint = torch.load(pretrained_path, map_location='cpu') 71 | msg = model.load_state_dict(checkpoint, strict=False) 72 | print("load result:\n", msg) 73 | model.half().cuda() 74 | model.eval() 75 | print(f"Model = {str(model)}") 76 | 77 | def multi_modal_generate(images, inps): 78 | images = images.cuda().to(target_dtype) 79 | 80 | prompts = [] 81 | for inp in inps: 82 | conv = conv_templates["v1"].copy() 83 | conv.append_message(conv.roles[0], inp) 84 | conv.append_message(conv.roles[1], None) 85 | prompts.append(conv.get_prompt()) 86 | 87 | with torch.cuda.amp.autocast(dtype=target_dtype): 88 | responses = model.generate(prompts, images, 128, temperature=0.1, top_p=0.75, modal=['video']) 89 | outputs = [] 90 | for response, prompt in zip(responses, prompts): 91 | response = response[len(prompt):].split('###')[0] 92 | response = response.strip() 93 | outputs.append(response) 94 | return outputs 95 | 96 | result = {} 97 | print("Starting...") 98 | dataset = CaptionDataset() 99 | dataloader = DataLoader(dataset, batch_size=32, shuffle=False, drop_last=False) 100 | predictions = [] 101 | correct = 0 102 | with torch.no_grad(): 103 | for data in tqdm(dataloader): 104 | images, questions, question_ids, answers = data 105 | preds = multi_modal_generate(images, questions) 106 | for question, pred, question_id, answer in zip(questions, preds, question_ids, answers): 107 | predictions.append({'question_id': question_id.item(), 'answer': pred, 'gt_answer': answer}) 108 | pred = pred.strip().lower() 109 | answer = answer.strip().lower() 110 | if (pred in answer) or (answer in pred): 111 | correct += 1 112 | 113 | acc = float(correct) / len(dataset) 114 | print('Accuracy:', acc) 115 | 116 | with open(answer_path, 'w') as f: 117 | json.dump(predictions, f) -------------------------------------------------------------------------------- /exps/image_text_pretrain_8gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LLAMA_7B_PATH="" 4 | OUTPUT_DIR="" 5 | 6 | torchrun --nproc_per_node=8 main_pretrain.py \ 7 | --epochs 1 --dataset image \ 8 | --batch_size 40 --accum_iter 16 \ 9 | --model_parallel_size 1 \ 10 | --data_parallel sdp \ 11 | --save_consolidated \ 12 | --llama_type onellm \ 13 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 14 | --llama_config config/llama2/7B.json \ 15 | --tokenizer_path config/llama2/tokenizer.model \ 16 | --auto_resume \ 17 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 18 | --warmup_iters 2000 --lr_decay_iters 400000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \ 19 | --save_freq 1000 \ 20 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log -------------------------------------------------------------------------------- /exps/image_text_pretrain_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p {Partition Name} 3 | #SBATCH --gres=gpu:8 4 | #SBATCH -n 16 5 | #SBATCH -N 2 6 | #SBATCH --cpus-per-task=16 7 | 8 | srun python -u main_pretrain.py \ 9 | --epochs 1 --dataset image \ 10 | --batch_size 40 --accum_iter 8 \ 11 | --model_parallel_size 1 \ 12 | --data_parallel sdp \ 13 | --save_consolidated \ 14 | --llama_type onellm \ 15 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 16 | --llama_config config/llama2/7B.json \ 17 | --tokenizer_path config/llama2/tokenizer.model \ 18 | --auto_resume \ 19 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 20 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 5e-5 --min_lr 5e-6 --clip_grad 2 \ 21 | --save_freq 1000 \ 22 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log -------------------------------------------------------------------------------- /exps/multimodal_text_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | STAGE3_MODEL="" 4 | OUTPUT_DIR="" 5 | 6 | torchrun --nproc_per_node=8 main_finetune.py \ 7 | --epochs 1 --warmup_epochs 0.05 \ 8 | --datasets image audio video point rgbd rgbn imu fmri \ 9 | --max_words 2048 --batch_size 4 --accum_iter 4 \ 10 | --model_parallel_size 1 \ 11 | --data_parallel sdp \ 12 | --checkpointing --save_consolidated \ 13 | --llama_type onellm \ 14 | --init_from ${STAGE3_MODEL} \ 15 | --auto_resume \ 16 | --weight_decay 0.0 --output_dir ${OUTPUT_DIR} \ 17 | --lr 2e-5 --min_lr 0.0 --clip_grad 2 \ 18 | --save_interval 1 \ 19 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log 20 | -------------------------------------------------------------------------------- /exps/multimodal_text_pretrain_stage2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LLAMA_7B_PATH="" 4 | OUTPUT_DIR="" 5 | IMAGE_TEXT_MODEL="" 6 | 7 | torchrun --nproc_per_node=8 main_pretrain.py \ 8 | --epochs 1 --dataset image audio point video \ 9 | --batch_size 40 --accum_iter 16 \ 10 | --model_parallel_size 1 \ 11 | --data_parallel sdp \ 12 | --save_consolidated \ 13 | --llama_type onellm \ 14 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 15 | --llama_config config/llama2/7B.json \ 16 | --tokenizer_path config/llama2/tokenizer.model \ 17 | --init_from ${IMAGE_TEXT_MODEL} \ 18 | --init_from_image \ 19 | --auto_resume \ 20 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 21 | --warmup_iters 2000 --lr_decay_iters 400000 --lr 1e-5 --min_lr 5e-6 --clip_grad 2 \ 22 | --save_freq 1000 \ 23 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log -------------------------------------------------------------------------------- /exps/multimodal_text_pretrain_stage3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LLAMA_7B_PATH="" 4 | OUTPUT_DIR="" 5 | STAGE2_MODEL="" 6 | 7 | torchrun --nproc_per_node=8 main_pretrain.py \ 8 | --epochs 1 --dataset image audio point video rgbd rgbn fmri imu \ 9 | --batch_size 40 --accum_iter 16 \ 10 | --model_parallel_size 1 \ 11 | --data_parallel sdp \ 12 | --save_consolidated \ 13 | --llama_type onellm \ 14 | --llama_ckpt_dir ${LLAMA_7B_PATH} \ 15 | --llama_config config/llama2/7B.json \ 16 | --tokenizer_path config/llama2/tokenizer.model \ 17 | --init_from ${STAGE2_MODEL} \ 18 | --init_from_image \ 19 | --auto_resume \ 20 | --weight_decay 0.1 --output_dir ${OUTPUT_DIR} \ 21 | --warmup_iters 2000 --lr_decay_iters 200000 --lr 1e-5 --min_lr 5e-6 --clip_grad 2 \ 22 | --save_freq 1000 \ 23 | 2>&1 | tee -a ${OUTPUT_DIR}/output.log -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | import functools 9 | import multiprocessing 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.distributed.fsdp import ( 15 | FullyShardedDataParallel as FSDP, 16 | MixedPrecision, 17 | ShardingStrategy, 18 | ) 19 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 20 | checkpoint_wrapper, 21 | CheckpointImpl, 22 | apply_activation_checkpointing, 23 | ) 24 | from torch.distributed.fsdp.wrap import ( 25 | transformer_auto_wrap_policy, 26 | ) 27 | 28 | from fairscale.nn.model_parallel import initialize as fs_init 29 | 30 | try: 31 | from apex.optimizers import FusedAdam as AdamW 32 | except ImportError: 33 | warnings.warn("cannot import FusedAdam from apex, use torch AdamW instead") 34 | from torch.optim import AdamW 35 | 36 | import util.misc as misc 37 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 38 | from model.meta import MetaModel 39 | from engine_finetune import train_one_epoch 40 | 41 | import warnings 42 | warnings.filterwarnings("ignore") 43 | 44 | from data.finetune_dataset import FinetuneDialogDataset, FinetuneDistSampler 45 | 46 | 47 | def get_args_parser(): 48 | parser = argparse.ArgumentParser('OneLLM Finetuning', add_help=False) 49 | parser.add_argument('--datasets', type=str, default='image', nargs='+') 50 | parser.add_argument('--epochs', default=1, type=int) 51 | parser.add_argument('--batch_size', default=64, type=int, 52 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 53 | parser.add_argument('--accum_iter', default=4, type=int, 54 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 55 | 56 | # Model parameters 57 | parser.add_argument('--llama_type', default='llama', type=str, metavar='MODEL', 58 | help='Name of model to train') 59 | parser.add_argument("--llama_ckpt_dir", type=str, default="") 60 | parser.add_argument("--llama_config", type=str, default="config/llama2/7B.json") 61 | parser.add_argument("--tokenizer_path", type=str, default="config/llama2/tokenizer.model") 62 | 63 | # Optimizer parameters 64 | parser.add_argument('--weight_decay', type=float, default=0.02, 65 | help='weight decay (default: 0.05)') 66 | 67 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 68 | help='learning rate (absolute lr)') 69 | parser.add_argument('--min_lr', type=float, default=0.0001, metavar='LR', 70 | help='lower lr bound for cyclic schedulers that hit 0') 71 | 72 | parser.add_argument('--warmup_epochs', type=float, default=1.0, metavar='N', 73 | help='epoch to warmup LR') 74 | 75 | parser.add_argument('--clip_grad', type=int, default=-1, 76 | help='grad clipping norm') 77 | 78 | parser.add_argument('--output_dir', default='./output_dir', 79 | help='path where to save, empty for no saving') 80 | parser.add_argument('--log_dir', default='./output_dir', 81 | help='path where to tensorboard log') 82 | parser.add_argument('--device', default='cuda', 83 | help='device to use for training / testing') 84 | parser.add_argument('--seed', default=0, type=int) 85 | parser.add_argument('--resume', default='', 86 | help='resume from checkpoint') 87 | parser.add_argument('--auto_resume', action='store_true') 88 | parser.add_argument('--init_from', default='', 89 | help='init from checkpoint') 90 | parser.add_argument('--init_from_image', action='store_true') 91 | 92 | parser.add_argument('--num_workers', default=5, type=int) 93 | parser.add_argument('--pin_mem', action='store_true', 94 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 95 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 96 | parser.set_defaults(pin_mem=True) 97 | 98 | # distributed training parameters 99 | parser.add_argument('--world_size', default=1, type=int, 100 | help='number of distributed processes') 101 | parser.add_argument('--local_rank', default=-1, type=int) 102 | parser.add_argument('--dist_on_itp', action='store_true') 103 | parser.add_argument('--dist_url', default='env://', 104 | help='url used to set up distributed training') 105 | 106 | parser.add_argument('--model_parallel_size', type=int, default=1) 107 | parser.add_argument('--data_parallel', type=str, choices=['ddp', 'sdp', 'fsdp'], default='sdp') 108 | parser.add_argument('--precision', type=str, choices=['fp16', 'bf16', 'tf32'], default='bf16') 109 | parser.add_argument('--save_interval', type=int, default=5000) 110 | parser.add_argument('--save_consolidated', action="store_true", 111 | help="save consolidated model weights along with regular checkpoints " 112 | "used to resume training. useful for convenient deployment but " 113 | "will occupy some additional disk space.") 114 | parser.add_argument("--checkpointing", action="store_true") 115 | 116 | parser.add_argument('--max_words', type=int, default=2048) 117 | parser.add_argument('--image_words', type=int, default=30) 118 | 119 | return parser 120 | 121 | 122 | def main(args): 123 | multiprocessing.set_start_method("spawn") 124 | misc.init_distributed_mode(args) 125 | fs_init.initialize_model_parallel(args.model_parallel_size) 126 | if args.precision == "tf32": 127 | torch.backends.cuda.matmul.allow_tf32 = True 128 | torch.backends.cudnn.allow_tf32 = True 129 | 130 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 131 | print("{}".format(args).replace(', ', ',\n')) 132 | 133 | device = torch.device(args.device) 134 | 135 | # fix the seed for reproducibility 136 | seed = args.seed + misc.get_rank() 137 | torch.manual_seed(seed) 138 | np.random.seed(seed) 139 | 140 | # cudnn.benchmark = True 141 | 142 | global_rank = misc.get_rank() 143 | mp_rank = fs_init.get_model_parallel_rank() 144 | mp_world_size = fs_init.get_model_parallel_world_size() 145 | dp_rank = fs_init.get_data_parallel_rank() 146 | dp_world_size = fs_init.get_data_parallel_world_size() 147 | dp_group = fs_init.get_data_parallel_group() 148 | 149 | dataset_train = FinetuneDialogDataset(args.datasets, max_words=args.max_words, image_words=args.image_words, tokenizer_path=args.tokenizer_path) 150 | 151 | if global_rank == 0 and args.log_dir is not None: 152 | os.makedirs(args.log_dir, exist_ok=True) 153 | log_writer = SummaryWriter(log_dir=args.log_dir) 154 | else: 155 | log_writer = None 156 | 157 | # define the model 158 | model = MetaModel(args.llama_type, args.llama_config, args.llama_ckpt_dir, args.tokenizer_path) 159 | model.to(device) 160 | print("Model = %s" % str(model)) 161 | if args.init_from: 162 | print("Init checkpoint from %s" % args.init_from) 163 | checkpoint = torch.load(os.path.join(args.init_from, f"consolidated.{mp_rank:02d}-of-{mp_world_size:02d}.pth"), map_location='cpu') 164 | msg = model.load_state_dict(checkpoint, strict=False) 165 | print(msg) 166 | 167 | mixed_precision_dtype = { 168 | "fp16": torch.float16, 169 | "bf16": torch.bfloat16, 170 | "tf32": torch.float32, 171 | }[args.precision] 172 | TransformerBlock = type(model.llma.layers[0]) 173 | model = FSDP( 174 | model, 175 | process_group=fs_init.get_data_parallel_group(), 176 | auto_wrap_policy=functools.partial( 177 | transformer_auto_wrap_policy, 178 | transformer_layer_cls=[TransformerBlock], 179 | ), 180 | limit_all_gathers=True, 181 | use_orig_params=True, 182 | sync_module_states=True, 183 | mixed_precision=MixedPrecision( 184 | param_dtype=mixed_precision_dtype, 185 | reduce_dtype=mixed_precision_dtype, 186 | buffer_dtype=mixed_precision_dtype, 187 | ), 188 | sharding_strategy={ 189 | "sdp": ShardingStrategy.SHARD_GRAD_OP, 190 | "ddp": ShardingStrategy.NO_SHARD, 191 | "fsdp": ShardingStrategy.FULL_SHARD, 192 | }[args.data_parallel], 193 | ignored_parameters=[param for param in model.parameters() if not param.requires_grad], 194 | ) 195 | 196 | if args.checkpointing: 197 | print("apply gradient checkpointing") 198 | non_reentrant_wrapper = functools.partial( 199 | checkpoint_wrapper, 200 | offload_to_cpu=False, 201 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 202 | ) 203 | check_fn = lambda submodule: isinstance(submodule, TransformerBlock) 204 | apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) 205 | 206 | eff_batch_size = args.batch_size * args.accum_iter * fs_init.get_data_parallel_world_size() 207 | print("effective batch size: %d" % eff_batch_size) 208 | 209 | # following timm: set wd as 0 for bias and norm layers 210 | #param_groups = misc.add_weight_decay(model, args.weight_decay) 211 | param_groups = { 212 | "decay": {"params": [], "weight_decay": args.weight_decay, "lr": args.lr}, 213 | "no_decay": {"params": [], "weight_decay": 0., "lr": args.lr}, 214 | "scratch_decay": {"params": [], "weight_decay": args.weight_decay, "lr": args.lr}, 215 | "scratch_no_decay": {"params": [], "weight_decay": 0., "lr": args.lr}, 216 | } 217 | print("Making parameter groups ...") 218 | for name, param in model.named_parameters(): 219 | if not param.requires_grad: 220 | continue 221 | no_decay = name.endswith(".bias") or name.endswith("norm.weight") 222 | scratch = "llma.resample_layers" in name or "llma.resample_tokens" in name 223 | group_name = ("scratch_" if scratch else "") + ("no_decay" if no_decay else "decay") 224 | print(f"{name}: in group {group_name}") 225 | param_groups[group_name]["params"].append(param) 226 | optimizer = AdamW( 227 | [param_groups[key] for key in ["decay", "no_decay", "scratch_decay", "scratch_no_decay"]], 228 | betas=(0.9, 0.95), 229 | ) 230 | print(optimizer) 231 | loss_scaler = NativeScaler(args) 232 | 233 | start_epoch = 0 234 | start_iter = 0 235 | if args.resume or args.auto_resume: 236 | start_epoch, start_iter = misc.load_model(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler) 237 | 238 | sampler_train = FinetuneDistSampler( 239 | dataset_train, num_replicas=dp_world_size, rank=dp_rank, shuffle=True, batch_size=args.batch_size, 240 | acc_grad=args.accum_iter 241 | ) 242 | data_loader_train = torch.utils.data.DataLoader( 243 | dataset_train, 244 | batch_size=args.batch_size, 245 | num_workers=args.num_workers, 246 | pin_memory=args.pin_mem, 247 | sampler=sampler_train, 248 | drop_last=True 249 | ) 250 | 251 | print(f"Start training for {args.epochs} epochs") 252 | start_time = time.time() 253 | for epoch in range(start_epoch, args.epochs): 254 | if args.distributed: 255 | data_loader_train.sampler.set_epoch(epoch, start_iter) 256 | 257 | train_stats = train_one_epoch( 258 | model, data_loader_train, 259 | optimizer, epoch, start_iter, loss_scaler, 260 | log_writer=log_writer, 261 | args=args 262 | ) 263 | 264 | if args.output_dir and (epoch % args.save_interval == 0 or epoch + 1 == args.epochs): 265 | misc.save_model( 266 | output_dir=args.output_dir, 267 | args=args, epoch=epoch, iteration=0, model=model, optimizer=optimizer, 268 | loss_scaler=loss_scaler, dataset_state=None, 269 | ) 270 | 271 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 272 | 'epoch': epoch, 273 | **{f'val_{k}': v for k, v in train_stats.items()}} 274 | 275 | if args.output_dir and misc.is_main_process(): 276 | if log_writer is not None: 277 | log_writer.flush() 278 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 279 | f.write(json.dumps(log_stats) + "\n") 280 | 281 | start_iter = 0 282 | 283 | total_time = time.time() - start_time 284 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 285 | print('Training time {}'.format(total_time_str)) 286 | 287 | 288 | if __name__ == '__main__': 289 | args = get_args_parser() 290 | args = args.parse_args() 291 | if args.output_dir: 292 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 293 | main(args) 294 | -------------------------------------------------------------------------------- /model/LLM/__init__.py: -------------------------------------------------------------------------------- 1 | from . import onellm -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuhan/OneLLM/8587a4768cf376fb41f7d586e21de5d1ab1ca365/model/__init__.py -------------------------------------------------------------------------------- /model/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | 5 | try: 6 | from apex.normalization import FusedRMSNorm as RMSNorm 7 | except ImportError: 8 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 9 | 10 | class RMSNorm(torch.nn.Module): 11 | def __init__(self, dim: int, eps: float = 1e-6): 12 | """ 13 | Initialize the RMSNorm normalization layer. 14 | 15 | Args: 16 | dim (int): The dimension of the input tensor. 17 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 18 | 19 | Attributes: 20 | eps (float): A small value added to the denominator for numerical stability. 21 | weight (nn.Parameter): Learnable scaling parameter. 22 | 23 | """ 24 | super().__init__() 25 | self.eps = eps 26 | self.weight = nn.Parameter(torch.ones(dim)) 27 | 28 | def _norm(self, x): 29 | """ 30 | Apply the RMSNorm normalization to the input tensor. 31 | 32 | Args: 33 | x (torch.Tensor): The input tensor. 34 | 35 | Returns: 36 | torch.Tensor: The normalized tensor. 37 | 38 | """ 39 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 40 | 41 | def forward(self, x): 42 | """ 43 | Forward pass through the RMSNorm layer. 44 | 45 | Args: 46 | x (torch.Tensor): The input tensor. 47 | 48 | Returns: 49 | torch.Tensor: The output tensor after applying RMSNorm. 50 | 51 | """ 52 | output = self._norm(x.float()).type_as(x) 53 | return output * self.weight 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /model/lib/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import pointnet2_cuda 5 | 6 | class KNN(nn.Module): 7 | def __init__(self, neighbors, transpose_mode=True): 8 | super(KNN, self).__init__() 9 | self.neighbors = neighbors 10 | 11 | @torch.no_grad() 12 | def forward(self, support, query): 13 | """ 14 | Args: 15 | support ([tensor]): [B, N, C] 16 | query ([tensor]): [B, M, C] 17 | Returns: 18 | [int]: neighbor idx. [B, M, K] 19 | """ 20 | dist = torch.cdist(support, query) 21 | k_dist = dist.topk(k=self.neighbors, dim=1, largest=False) 22 | return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int() 23 | 24 | 25 | class GroupingOperation(Function): 26 | 27 | @staticmethod 28 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 29 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 30 | """ 31 | :param ctx: 32 | :param features: (B, C, N) tensor of features to group 33 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with 34 | :return: 35 | output: (B, C, npoint, nsample) tensor 36 | """ 37 | assert features.is_contiguous() 38 | assert idx.is_contiguous() 39 | 40 | B, nfeatures, nsample = idx.size() 41 | _, C, N = features.size() 42 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device) 43 | 44 | pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) 45 | 46 | ctx.for_backwards = (idx, N) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_out: torch.Tensor): 51 | """ 52 | :param ctx: 53 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward 54 | :return: 55 | grad_features: (B, C, N) gradient of the features 56 | """ 57 | idx, N = ctx.for_backwards 58 | 59 | B, C, npoint, nsample = grad_out.size() 60 | grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True) 61 | grad_out_data = grad_out.data.contiguous() 62 | pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) 63 | return grad_features, None 64 | 65 | grouping_operation = GroupingOperation.apply 66 | 67 | 68 | class KNNGroup(nn.Module): 69 | def __init__(self, nsample: int, 70 | relative_xyz=True, 71 | normalize_dp=False, 72 | return_only_idx=False, 73 | **kwargs 74 | ): 75 | """[summary] 76 | 77 | Args: 78 | nsample (int): maximum number of features to gather in the ball 79 | use_xyz (bool, optional): concate xyz. Defaults to True. 80 | ret_grouped_xyz (bool, optional): [description]. Defaults to False. 81 | normalize_dp (bool, optional): [description]. Defaults to False. 82 | """ 83 | super().__init__() 84 | self.nsample = nsample 85 | self.knn = KNN(nsample, transpose_mode=True) 86 | self.relative_xyz = relative_xyz 87 | self.normalize_dp = normalize_dp 88 | self.return_only_idx = return_only_idx 89 | 90 | def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None): 91 | """ 92 | :param query_xyz: (B, N, 3) xyz coordinates of the features 93 | :param support_xyz: (B, npoint, 3) centroids 94 | :param features: (B, C, N) descriptors of the features 95 | :return: 96 | new_features: (B, 3 + C, npoint, nsample) 97 | """ 98 | _, idx = self.knn(support_xyz, query_xyz) 99 | if self.return_only_idx: 100 | return idx 101 | idx = idx.int() 102 | xyz_trans = support_xyz.transpose(1, 2).contiguous() 103 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 104 | if self.relative_xyz: 105 | grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1) # relative position 106 | if self.normalize_dp: 107 | grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1) 108 | if features is not None: 109 | grouped_features = grouping_operation(features, idx) 110 | return grouped_xyz, grouped_features 111 | else: 112 | return grouped_xyz, None 113 | 114 | 115 | class FurthestPointSampling(Function): 116 | @staticmethod 117 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 118 | """ 119 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 120 | minimum distance 121 | :param ctx: 122 | :param xyz: (B, N, 3) where N > npoint 123 | :param npoint: int, number of features in the sampled set 124 | :return: 125 | output: (B, npoint) tensor containing the set (idx) 126 | """ 127 | assert xyz.is_contiguous() 128 | 129 | B, N, _ = xyz.size() 130 | # output = torch.cuda.IntTensor(B, npoint, device=xyz.device) 131 | # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10) 132 | output = torch.cuda.IntTensor(B, npoint) 133 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 134 | 135 | pointnet2_cuda.furthest_point_sampling_wrapper( 136 | B, N, npoint, xyz, temp, output) 137 | return output 138 | 139 | @staticmethod 140 | def backward(xyz, a=None): 141 | return None, None 142 | 143 | furthest_point_sample = FurthestPointSampling.apply 144 | 145 | 146 | class PointPatchEmbed(nn.Module): 147 | 148 | def __init__(self, 149 | sample_ratio=0.0625, 150 | sample_number=1024, 151 | group_size=32, 152 | in_channels=6, 153 | channels=1024, 154 | kernel_size=1, 155 | stride=1, 156 | normalize_dp=False, 157 | relative_xyz=True, 158 | ): 159 | super().__init__() 160 | self.sample_ratio = sample_ratio 161 | self.sample_number = sample_number 162 | self.group_size = group_size 163 | 164 | self.sample_fn = furthest_point_sample 165 | self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp) 166 | 167 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride) 168 | 169 | 170 | def forward(self, x): 171 | # coordinates 172 | p = x[:, :, 3:].contiguous() 173 | 174 | B, N, _ = p.shape[:3] 175 | # idx = self.sample_fn(p, int(N * self.sample_ratio)).long() 176 | idx = self.sample_fn(p, self.sample_number).long() 177 | center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3)) 178 | # query neighbors. 179 | _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32] 180 | 181 | # [B, 6, 1024] -> [B, channels, 1024, 1] 182 | fj = self.conv1(fj).max(dim=-1, keepdim=True)[0] 183 | 184 | return fj 185 | 186 | 187 | if __name__ == '__main__': 188 | model = PointPatchEmbed(channels=256).cuda() 189 | input = torch.rand(4, 16384, 6).cuda() 190 | ou = model(input) 191 | import pdb;pdb.set_trace() -------------------------------------------------------------------------------- /model/lib/pointnet2/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from . import pointnet2_utils 6 | from . import pytorch_utils as pt_utils 7 | from typing import List 8 | 9 | 10 | class _PointnetSAModuleBase(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.npoint = None 15 | self.groupers = None 16 | self.mlps = None 17 | self.pool_method = 'max_pool' 18 | 19 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): 20 | """ 21 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 22 | :param features: (B, N, C) tensor of the descriptors of the the features 23 | :param new_xyz: 24 | :return: 25 | new_xyz: (B, npoint, 3) tensor of the new features' xyz 26 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 27 | """ 28 | new_features_list = [] 29 | 30 | xyz_flipped = xyz.transpose(1, 2).contiguous() 31 | if new_xyz is None: 32 | new_xyz = pointnet2_utils.gather_operation( 33 | xyz_flipped, 34 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 35 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 36 | 37 | for i in range(len(self.groupers)): 38 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) 39 | 40 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 41 | if self.pool_method == 'max_pool': 42 | new_features = F.max_pool2d( 43 | new_features, kernel_size=[1, new_features.size(3)] 44 | ) # (B, mlp[-1], npoint, 1) 45 | elif self.pool_method == 'avg_pool': 46 | new_features = F.avg_pool2d( 47 | new_features, kernel_size=[1, new_features.size(3)] 48 | ) # (B, mlp[-1], npoint, 1) 49 | else: 50 | raise NotImplementedError 51 | 52 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 53 | new_features_list.append(new_features) 54 | 55 | return new_xyz, torch.cat(new_features_list, dim=1) 56 | 57 | 58 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 59 | """Pointnet set abstraction layer with multiscale grouping""" 60 | 61 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, 62 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 63 | """ 64 | :param npoint: int 65 | :param radii: list of float, list of radii to group with 66 | :param nsamples: list of int, number of samples in each ball query 67 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale 68 | :param bn: whether to use batchnorm 69 | :param use_xyz: 70 | :param pool_method: max_pool / avg_pool 71 | :param instance_norm: whether to use instance_norm 72 | """ 73 | super().__init__() 74 | 75 | assert len(radii) == len(nsamples) == len(mlps) 76 | 77 | self.npoint = npoint 78 | self.groupers = nn.ModuleList() 79 | self.mlps = nn.ModuleList() 80 | for i in range(len(radii)): 81 | radius = radii[i] 82 | nsample = nsamples[i] 83 | self.groupers.append( 84 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 85 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 86 | ) 87 | mlp_spec = mlps[i] 88 | if use_xyz: 89 | mlp_spec[0] += 3 90 | 91 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) 92 | self.pool_method = pool_method 93 | 94 | 95 | class PointnetSAModule(PointnetSAModuleMSG): 96 | """Pointnet set abstraction layer""" 97 | 98 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, 99 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 100 | """ 101 | :param mlp: list of int, spec of the pointnet before the global max_pool 102 | :param npoint: int, number of features 103 | :param radius: float, radius of ball 104 | :param nsample: int, number of samples in the ball query 105 | :param bn: whether to use batchnorm 106 | :param use_xyz: 107 | :param pool_method: max_pool / avg_pool 108 | :param instance_norm: whether to use instance_norm 109 | """ 110 | super().__init__( 111 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, 112 | pool_method=pool_method, instance_norm=instance_norm 113 | ) 114 | 115 | 116 | class PointnetFPModule(nn.Module): 117 | r"""Propigates the features of one set to another""" 118 | 119 | def __init__(self, *, mlp: List[int], bn: bool = True): 120 | """ 121 | :param mlp: list of int 122 | :param bn: whether to use batchnorm 123 | """ 124 | super().__init__() 125 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 126 | 127 | def forward( 128 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor 129 | ) -> torch.Tensor: 130 | """ 131 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features 132 | :param known: (B, m, 3) tensor of the xyz positions of the known features 133 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to 134 | :param known_feats: (B, C2, m) tensor of features to be propigated 135 | :return: 136 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features 137 | """ 138 | if known is not None: 139 | dist, idx = pointnet2_utils.three_nn(unknown, known) 140 | dist_recip = 1.0 / (dist + 1e-8) 141 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 142 | weight = dist_recip / norm 143 | 144 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) 145 | else: 146 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) 147 | 148 | if unknow_feats is not None: 149 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) 150 | else: 151 | new_features = interpolated_feats 152 | 153 | new_features = new_features.unsqueeze(-1) 154 | new_features = self.mlp(new_features) 155 | 156 | return new_features.squeeze(-1) 157 | 158 | 159 | if __name__ == "__main__": 160 | pass 161 | -------------------------------------------------------------------------------- /model/lib/pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.autograd import Function 4 | import torch.nn as nn 5 | from typing import Tuple 6 | 7 | import pointnet2_cuda as pointnet2 8 | 9 | 10 | class FurthestPointSampling(Function): 11 | @staticmethod 12 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 13 | """ 14 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 15 | minimum distance 16 | :param ctx: 17 | :param xyz: (B, N, 3) where N > npoint 18 | :param npoint: int, number of features in the sampled set 19 | :return: 20 | output: (B, npoint) tensor containing the set 21 | """ 22 | assert xyz.is_contiguous() 23 | 24 | B, N, _ = xyz.size() 25 | output = torch.cuda.IntTensor(B, npoint) 26 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 27 | 28 | pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 29 | return output 30 | 31 | @staticmethod 32 | def backward(xyz, a=None): 33 | return None, None 34 | 35 | 36 | furthest_point_sample = FurthestPointSampling.apply 37 | 38 | 39 | class GatherOperation(Function): 40 | 41 | @staticmethod 42 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 43 | """ 44 | :param ctx: 45 | :param features: (B, C, N) 46 | :param idx: (B, npoint) index tensor of the features to gather 47 | :return: 48 | output: (B, C, npoint) 49 | """ 50 | assert features.is_contiguous() 51 | assert idx.is_contiguous() 52 | 53 | B, npoint = idx.size() 54 | _, C, N = features.size() 55 | output = torch.cuda.FloatTensor(B, C, npoint) 56 | 57 | pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) 58 | 59 | ctx.for_backwards = (idx, C, N) 60 | return output 61 | 62 | @staticmethod 63 | def backward(ctx, grad_out): 64 | idx, C, N = ctx.for_backwards 65 | B, npoint = idx.size() 66 | 67 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 68 | grad_out_data = grad_out.data.contiguous() 69 | pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 70 | return grad_features, None 71 | 72 | 73 | gather_operation = GatherOperation.apply 74 | 75 | 76 | class ThreeNN(Function): 77 | 78 | @staticmethod 79 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Find the three nearest neighbors of unknown in known 82 | :param ctx: 83 | :param unknown: (B, N, 3) 84 | :param known: (B, M, 3) 85 | :return: 86 | dist: (B, N, 3) l2 distance to the three nearest neighbors 87 | idx: (B, N, 3) index of 3 nearest neighbors 88 | """ 89 | assert unknown.is_contiguous() 90 | assert known.is_contiguous() 91 | 92 | B, N, _ = unknown.size() 93 | m = known.size(1) 94 | dist2 = torch.cuda.FloatTensor(B, N, 3) 95 | idx = torch.cuda.IntTensor(B, N, 3) 96 | 97 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 98 | return torch.sqrt(dist2), idx 99 | 100 | @staticmethod 101 | def backward(ctx, a=None, b=None): 102 | return None, None 103 | 104 | 105 | three_nn = ThreeNN.apply 106 | 107 | 108 | class ThreeInterpolate(Function): 109 | 110 | @staticmethod 111 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 112 | """ 113 | Performs weight linear interpolation on 3 features 114 | :param ctx: 115 | :param features: (B, C, M) Features descriptors to be interpolated from 116 | :param idx: (B, n, 3) three nearest neighbors of the target features in features 117 | :param weight: (B, n, 3) weights 118 | :return: 119 | output: (B, C, N) tensor of the interpolated features 120 | """ 121 | assert features.is_contiguous() 122 | assert idx.is_contiguous() 123 | assert weight.is_contiguous() 124 | 125 | B, c, m = features.size() 126 | n = idx.size(1) 127 | ctx.three_interpolate_for_backward = (idx, weight, m) 128 | output = torch.cuda.FloatTensor(B, c, n) 129 | 130 | pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) 131 | return output 132 | 133 | @staticmethod 134 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 135 | """ 136 | :param ctx: 137 | :param grad_out: (B, C, N) tensor with gradients of outputs 138 | :return: 139 | grad_features: (B, C, M) tensor with gradients of features 140 | None: 141 | None: 142 | """ 143 | idx, weight, m = ctx.three_interpolate_for_backward 144 | B, c, n = grad_out.size() 145 | 146 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 147 | grad_out_data = grad_out.data.contiguous() 148 | 149 | pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) 150 | return grad_features, None, None 151 | 152 | 153 | three_interpolate = ThreeInterpolate.apply 154 | 155 | 156 | class GroupingOperation(Function): 157 | 158 | @staticmethod 159 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 160 | """ 161 | :param ctx: 162 | :param features: (B, C, N) tensor of features to group 163 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with 164 | :return: 165 | output: (B, C, npoint, nsample) tensor 166 | """ 167 | assert features.is_contiguous() 168 | assert idx.is_contiguous() 169 | 170 | B, nfeatures, nsample = idx.size() 171 | _, C, N = features.size() 172 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) 173 | 174 | pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) 175 | 176 | ctx.for_backwards = (idx, N) 177 | return output 178 | 179 | @staticmethod 180 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 181 | """ 182 | :param ctx: 183 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward 184 | :return: 185 | grad_features: (B, C, N) gradient of the features 186 | """ 187 | idx, N = ctx.for_backwards 188 | 189 | B, C, npoint, nsample = grad_out.size() 190 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 191 | 192 | grad_out_data = grad_out.data.contiguous() 193 | pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) 194 | return grad_features, None 195 | 196 | 197 | grouping_operation = GroupingOperation.apply 198 | 199 | 200 | class BallQuery(Function): 201 | 202 | @staticmethod 203 | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: 204 | """ 205 | :param ctx: 206 | :param radius: float, radius of the balls 207 | :param nsample: int, maximum number of features in the balls 208 | :param xyz: (B, N, 3) xyz coordinates of the features 209 | :param new_xyz: (B, npoint, 3) centers of the ball query 210 | :return: 211 | idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls 212 | """ 213 | assert new_xyz.is_contiguous() 214 | assert xyz.is_contiguous() 215 | 216 | B, N, _ = xyz.size() 217 | npoint = new_xyz.size(1) 218 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() 219 | 220 | pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) 221 | return idx 222 | 223 | @staticmethod 224 | def backward(ctx, a=None): 225 | return None, None, None, None 226 | 227 | 228 | ball_query = BallQuery.apply 229 | 230 | 231 | class QueryAndGroup(nn.Module): 232 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True): 233 | """ 234 | :param radius: float, radius of ball 235 | :param nsample: int, maximum number of features to gather in the ball 236 | :param use_xyz: 237 | """ 238 | super().__init__() 239 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 240 | 241 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: 242 | """ 243 | :param xyz: (B, N, 3) xyz coordinates of the features 244 | :param new_xyz: (B, npoint, 3) centroids 245 | :param features: (B, C, N) descriptors of the features 246 | :return: 247 | new_features: (B, 3 + C, npoint, nsample) 248 | """ 249 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 250 | xyz_trans = xyz.transpose(1, 2).contiguous() 251 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 252 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 253 | 254 | if features is not None: 255 | grouped_features = grouping_operation(features, idx) 256 | if self.use_xyz: 257 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) 258 | else: 259 | new_features = grouped_features 260 | else: 261 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" 262 | new_features = grouped_xyz 263 | 264 | return new_features 265 | 266 | 267 | class GroupAll(nn.Module): 268 | def __init__(self, use_xyz: bool = True): 269 | super().__init__() 270 | self.use_xyz = use_xyz 271 | 272 | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): 273 | """ 274 | :param xyz: (B, N, 3) xyz coordinates of the features 275 | :param new_xyz: ignored 276 | :param features: (B, C, N) descriptors of the features 277 | :return: 278 | new_features: (B, C + 3, 1, N) 279 | """ 280 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 281 | if features is not None: 282 | grouped_features = features.unsqueeze(2) 283 | if self.use_xyz: 284 | new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) 285 | else: 286 | new_features = grouped_features 287 | else: 288 | new_features = grouped_xyz 289 | 290 | return new_features 291 | -------------------------------------------------------------------------------- /model/lib/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, 17 | ): 18 | super().__init__() 19 | 20 | for i in range(len(args) - 1): 21 | self.add_module( 22 | name + 'layer{}'.format(i), 23 | Conv2d( 24 | args[i], 25 | args[i + 1], 26 | bn=(not first or not preact or (i != 0)) and bn, 27 | activation=activation 28 | if (not first or not preact or (i != 0)) else None, 29 | preact=preact, 30 | instance_norm=instance_norm 31 | ) 32 | ) 33 | 34 | 35 | class _ConvBase(nn.Sequential): 36 | 37 | def __init__( 38 | self, 39 | in_size, 40 | out_size, 41 | kernel_size, 42 | stride, 43 | padding, 44 | activation, 45 | bn, 46 | init, 47 | conv=None, 48 | batch_norm=None, 49 | bias=True, 50 | preact=False, 51 | name="", 52 | instance_norm=False, 53 | instance_norm_func=None 54 | ): 55 | super().__init__() 56 | 57 | bias = bias and (not bn) 58 | conv_unit = conv( 59 | in_size, 60 | out_size, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | bias=bias 65 | ) 66 | init(conv_unit.weight) 67 | if bias: 68 | nn.init.constant_(conv_unit.bias, 0) 69 | 70 | if bn: 71 | if not preact: 72 | bn_unit = batch_norm(out_size) 73 | else: 74 | bn_unit = batch_norm(in_size) 75 | if instance_norm: 76 | if not preact: 77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 78 | else: 79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 80 | 81 | if preact: 82 | if bn: 83 | self.add_module(name + 'bn', bn_unit) 84 | 85 | if activation is not None: 86 | self.add_module(name + 'activation', activation) 87 | 88 | if not bn and instance_norm: 89 | self.add_module(name + 'in', in_unit) 90 | 91 | self.add_module(name + 'conv', conv_unit) 92 | 93 | if not preact: 94 | if bn: 95 | self.add_module(name + 'bn', bn_unit) 96 | 97 | if activation is not None: 98 | self.add_module(name + 'activation', activation) 99 | 100 | if not bn and instance_norm: 101 | self.add_module(name + 'in', in_unit) 102 | 103 | 104 | class _BNBase(nn.Sequential): 105 | 106 | def __init__(self, in_size, batch_norm=None, name=""): 107 | super().__init__() 108 | self.add_module(name + "bn", batch_norm(in_size)) 109 | 110 | nn.init.constant_(self[0].weight, 1.0) 111 | nn.init.constant_(self[0].bias, 0) 112 | 113 | 114 | class BatchNorm1d(_BNBase): 115 | 116 | def __init__(self, in_size: int, *, name: str = ""): 117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 118 | 119 | 120 | class BatchNorm2d(_BNBase): 121 | 122 | def __init__(self, in_size: int, name: str = ""): 123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 124 | 125 | 126 | class Conv1d(_ConvBase): 127 | 128 | def __init__( 129 | self, 130 | in_size: int, 131 | out_size: int, 132 | *, 133 | kernel_size: int = 1, 134 | stride: int = 1, 135 | padding: int = 0, 136 | activation=nn.ReLU(inplace=True), 137 | bn: bool = False, 138 | init=nn.init.kaiming_normal_, 139 | bias: bool = True, 140 | preact: bool = False, 141 | name: str = "", 142 | instance_norm=False 143 | ): 144 | super().__init__( 145 | in_size, 146 | out_size, 147 | kernel_size, 148 | stride, 149 | padding, 150 | activation, 151 | bn, 152 | init, 153 | conv=nn.Conv1d, 154 | batch_norm=BatchNorm1d, 155 | bias=bias, 156 | preact=preact, 157 | name=name, 158 | instance_norm=instance_norm, 159 | instance_norm_func=nn.InstanceNorm1d 160 | ) 161 | 162 | 163 | class Conv2d(_ConvBase): 164 | 165 | def __init__( 166 | self, 167 | in_size: int, 168 | out_size: int, 169 | *, 170 | kernel_size: Tuple[int, int] = (1, 1), 171 | stride: Tuple[int, int] = (1, 1), 172 | padding: Tuple[int, int] = (0, 0), 173 | activation=nn.ReLU(inplace=True), 174 | bn: bool = False, 175 | init=nn.init.kaiming_normal_, 176 | bias: bool = True, 177 | preact: bool = False, 178 | name: str = "", 179 | instance_norm=False 180 | ): 181 | super().__init__( 182 | in_size, 183 | out_size, 184 | kernel_size, 185 | stride, 186 | padding, 187 | activation, 188 | bn, 189 | init, 190 | conv=nn.Conv2d, 191 | batch_norm=BatchNorm2d, 192 | bias=bias, 193 | preact=preact, 194 | name=name, 195 | instance_norm=instance_norm, 196 | instance_norm_func=nn.InstanceNorm2d 197 | ) 198 | 199 | 200 | class FC(nn.Sequential): 201 | 202 | def __init__( 203 | self, 204 | in_size: int, 205 | out_size: int, 206 | *, 207 | activation=nn.ReLU(inplace=True), 208 | bn: bool = False, 209 | init=None, 210 | preact: bool = False, 211 | name: str = "" 212 | ): 213 | super().__init__() 214 | 215 | fc = nn.Linear(in_size, out_size, bias=not bn) 216 | if init is not None: 217 | init(fc.weight) 218 | if not bn: 219 | nn.init.constant(fc.bias, 0) 220 | 221 | if preact: 222 | if bn: 223 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 224 | 225 | if activation is not None: 226 | self.add_module(name + 'activation', activation) 227 | 228 | self.add_module(name + 'fc', fc) 229 | 230 | if not preact: 231 | if bn: 232 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 233 | 234 | if activation is not None: 235 | self.add_module(name + 'activation', activation) 236 | 237 | -------------------------------------------------------------------------------- /model/lib/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='pointnet2', 6 | ext_modules=[ 7 | CUDAExtension('pointnet2_cuda', [ 8 | 'src/pointnet2_api.cpp', 9 | 10 | 'src/ball_query.cpp', 11 | 'src/ball_query_gpu.cu', 12 | 'src/group_points.cpp', 13 | 'src/group_points_gpu.cu', 14 | 'src/interpolate.cpp', 15 | 'src/interpolate_gpu.cu', 16 | 'src/sampling.cpp', 17 | 'src/sampling_gpu.cu', 18 | ], 19 | extra_compile_args={'cxx': ['-g'], 20 | 'nvcc': ['-O2']}) 21 | ], 22 | cmdclass={'build_ext': BuildExtension} 23 | ) 24 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "ball_query_gpu.h" 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 12 | 13 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 14 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 15 | CHECK_INPUT(new_xyz_tensor); 16 | CHECK_INPUT(xyz_tensor); 17 | const float *new_xyz = new_xyz_tensor.data(); 18 | const float *xyz = xyz_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | 21 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 22 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 23 | return 1; 24 | } 25 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /model/lib/pointnet2/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "group_points_gpu.h" 6 | #include 7 | #include 8 | 9 | 10 | 11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 13 | 14 | float *grad_points = grad_points_tensor.data(); 15 | const int *idx = idx_tensor.data(); 16 | const float *grad_out = grad_out_tensor.data(); 17 | 18 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 19 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 20 | return 1; 21 | } 22 | 23 | 24 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 25 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 26 | 27 | const float *points = points_tensor.data(); 28 | const int *idx = idx_tensor.data(); 29 | float *out = out_tensor.data(); 30 | 31 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 32 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 33 | return 1; 34 | } 35 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "interpolate_gpu.h" 9 | #include 10 | #include 11 | 12 | 13 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 14 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 15 | const float *unknown = unknown_tensor.data(); 16 | const float *known = known_tensor.data(); 17 | float *dist2 = dist2_tensor.data(); 18 | int *idx = idx_tensor.data(); 19 | 20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 21 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 22 | } 23 | 24 | 25 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 26 | at::Tensor points_tensor, 27 | at::Tensor idx_tensor, 28 | at::Tensor weight_tensor, 29 | at::Tensor out_tensor) { 30 | 31 | const float *points = points_tensor.data(); 32 | const float *weight = weight_tensor.data(); 33 | float *out = out_tensor.data(); 34 | const int *idx = idx_tensor.data(); 35 | 36 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 37 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 38 | } 39 | 40 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 41 | at::Tensor grad_out_tensor, 42 | at::Tensor idx_tensor, 43 | at::Tensor weight_tensor, 44 | at::Tensor grad_points_tensor) { 45 | 46 | const float *grad_out = grad_out_tensor.data(); 47 | const float *weight = weight_tensor.data(); 48 | float *grad_points = grad_points_tensor.data(); 49 | const int *idx = idx_tensor.data(); 50 | 51 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 52 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 53 | } 54 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, 3) 15 | // idx: (B, N, 3) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * 3 + pt_idx * 3; 24 | idx += bs_idx * n * 3 + pt_idx * 3; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 37 | if (d < best1) { 38 | best3 = best2; besti3 = besti2; 39 | best2 = best1; besti2 = besti1; 40 | best1 = d; besti1 = k; 41 | } 42 | else if (d < best2) { 43 | best3 = best2; besti3 = besti2; 44 | best2 = d; besti2 = k; 45 | } 46 | else if (d < best3) { 47 | best3 = d; besti3 = k; 48 | } 49 | } 50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 52 | } 53 | 54 | 55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 56 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 57 | // unknown: (B, N, 3) 58 | // known: (B, M, 3) 59 | // output: 60 | // dist2: (B, N, 3) 61 | // idx: (B, N, 3) 62 | 63 | cudaError_t err; 64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 65 | dim3 threads(THREADS_PER_BLOCK); 66 | 67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 68 | 69 | err = cudaGetLastError(); 70 | if (cudaSuccess != err) { 71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 72 | exit(-1); 73 | } 74 | } 75 | 76 | 77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 79 | // points: (B, C, M) 80 | // idx: (B, N, 3) 81 | // weight: (B, N, 3) 82 | // output: 83 | // out: (B, C, N) 84 | 85 | int bs_idx = blockIdx.z; 86 | int c_idx = blockIdx.y; 87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 90 | 91 | weight += bs_idx * n * 3 + pt_idx * 3; 92 | points += bs_idx * c * m + c_idx * m; 93 | idx += bs_idx * n * 3 + pt_idx * 3; 94 | out += bs_idx * c * n + c_idx * n; 95 | 96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 97 | } 98 | 99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 101 | // points: (B, C, M) 102 | // idx: (B, N, 3) 103 | // weight: (B, N, 3) 104 | // output: 105 | // out: (B, C, N) 106 | 107 | cudaError_t err; 108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 109 | dim3 threads(THREADS_PER_BLOCK); 110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 111 | 112 | err = cudaGetLastError(); 113 | if (cudaSuccess != err) { 114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 115 | exit(-1); 116 | } 117 | } 118 | 119 | 120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 122 | // grad_out: (B, C, N) 123 | // weight: (B, N, 3) 124 | // output: 125 | // grad_points: (B, C, M) 126 | 127 | int bs_idx = blockIdx.z; 128 | int c_idx = blockIdx.y; 129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 132 | 133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 134 | weight += bs_idx * n * 3 + pt_idx * 3; 135 | grad_points += bs_idx * c * m + c_idx * m; 136 | idx += bs_idx * n * 3 + pt_idx * 3; 137 | 138 | 139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 142 | } 143 | 144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 146 | // grad_out: (B, C, N) 147 | // weight: (B, N, 3) 148 | // output: 149 | // grad_points: (B, C, M) 150 | 151 | cudaError_t err; 152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 153 | dim3 threads(THREADS_PER_BLOCK); 154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | err = cudaGetLastError(); 157 | if (cudaSuccess != err) { 158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 159 | exit(-1); 160 | } 161 | } -------------------------------------------------------------------------------- /model/lib/pointnet2/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | 17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 19 | 20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 22 | 23 | 24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 26 | 27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 12 | 13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 15 | 16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 18 | 19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 20 | 21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 24 | } 25 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "sampling_gpu.h" 7 | 8 | 9 | 10 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 12 | const float *points = points_tensor.data(); 13 | const int *idx = idx_tensor.data(); 14 | float *out = out_tensor.data(); 15 | 16 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 17 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 18 | return 1; 19 | } 20 | 21 | 22 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 23 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 24 | 25 | const float *grad_out = grad_out_tensor.data(); 26 | const int *idx = idx_tensor.data(); 27 | float *grad_points = grad_points_tensor.data(); 28 | 29 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 30 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 31 | return 1; 32 | } 33 | 34 | 35 | int furthest_point_sampling_wrapper(int b, int n, int m, 36 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 37 | 38 | const float *points = points_tensor.data(); 39 | float *temp = temp_tensor.data(); 40 | int *idx = idx_tensor.data(); 41 | 42 | cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); 43 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 44 | return 1; 45 | } 46 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | 8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 10 | // points: (B, C, N) 11 | // idx: (B, M) 12 | // output: 13 | // out: (B, C, M) 14 | 15 | int bs_idx = blockIdx.z; 16 | int c_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 19 | 20 | out += bs_idx * c * m + c_idx * m + pt_idx; 21 | idx += bs_idx * m + pt_idx; 22 | points += bs_idx * c * n + c_idx * n; 23 | out[0] = points[idx[0]]; 24 | } 25 | 26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 27 | const float *points, const int *idx, float *out, cudaStream_t stream) { 28 | // points: (B, C, N) 29 | // idx: (B, npoints) 30 | // output: 31 | // out: (B, C, npoints) 32 | 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 47 | const int *__restrict__ idx, float *__restrict__ grad_points) { 48 | // grad_out: (B, C, M) 49 | // idx: (B, M) 50 | // output: 51 | // grad_points: (B, C, N) 52 | 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 57 | 58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 59 | idx += bs_idx * m + pt_idx; 60 | grad_points += bs_idx * c * n + c_idx * n; 61 | 62 | atomicAdd(grad_points + idx[0], grad_out[0]); 63 | } 64 | 65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 67 | // grad_out: (B, C, npoints) 68 | // idx: (B, npoints) 69 | // output: 70 | // grad_points: (B, C, N) 71 | 72 | cudaError_t err; 73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 77 | 78 | err = cudaGetLastError(); 79 | if (cudaSuccess != err) { 80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 81 | exit(-1); 82 | } 83 | } 84 | 85 | 86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ 87 | const float v1 = dists[idx1], v2 = dists[idx2]; 88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 89 | dists[idx1] = max(v1, v2); 90 | dists_i[idx1] = v2 > v1 ? i2 : i1; 91 | } 92 | 93 | template 94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 96 | // dataset: (B, N, 3) 97 | // tmp: (B, N) 98 | // output: 99 | // idx: (B, M) 100 | 101 | if (m <= 0) return; 102 | __shared__ float dists[block_size]; 103 | __shared__ int dists_i[block_size]; 104 | 105 | int batch_index = blockIdx.x; 106 | dataset += batch_index * n * 3; 107 | temp += batch_index * n; 108 | idxs += batch_index * m; 109 | 110 | int tid = threadIdx.x; 111 | const int stride = block_size; 112 | 113 | int old = 0; 114 | if (threadIdx.x == 0) 115 | idxs[0] = old; 116 | 117 | __syncthreads(); 118 | for (int j = 1; j < m; j++) { 119 | int besti = 0; 120 | float best = -1; 121 | float x1 = dataset[old * 3 + 0]; 122 | float y1 = dataset[old * 3 + 1]; 123 | float z1 = dataset[old * 3 + 2]; 124 | for (int k = tid; k < n; k += stride) { 125 | float x2, y2, z2; 126 | x2 = dataset[k * 3 + 0]; 127 | y2 = dataset[k * 3 + 1]; 128 | z2 = dataset[k * 3 + 2]; 129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 130 | // if (mag <= 1e-3) 131 | // continue; 132 | 133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 134 | float d2 = min(d, temp[k]); 135 | temp[k] = d2; 136 | besti = d2 > best ? k : besti; 137 | best = d2 > best ? d2 : best; 138 | } 139 | dists[tid] = best; 140 | dists_i[tid] = besti; 141 | __syncthreads(); 142 | 143 | if (block_size >= 1024) { 144 | if (tid < 512) { 145 | __update(dists, dists_i, tid, tid + 512); 146 | } 147 | __syncthreads(); 148 | } 149 | 150 | if (block_size >= 512) { 151 | if (tid < 256) { 152 | __update(dists, dists_i, tid, tid + 256); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 256) { 157 | if (tid < 128) { 158 | __update(dists, dists_i, tid, tid + 128); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 128) { 163 | if (tid < 64) { 164 | __update(dists, dists_i, tid, tid + 64); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 64) { 169 | if (tid < 32) { 170 | __update(dists, dists_i, tid, tid + 32); 171 | } 172 | __syncthreads(); 173 | } 174 | if (block_size >= 32) { 175 | if (tid < 16) { 176 | __update(dists, dists_i, tid, tid + 16); 177 | } 178 | __syncthreads(); 179 | } 180 | if (block_size >= 16) { 181 | if (tid < 8) { 182 | __update(dists, dists_i, tid, tid + 8); 183 | } 184 | __syncthreads(); 185 | } 186 | if (block_size >= 8) { 187 | if (tid < 4) { 188 | __update(dists, dists_i, tid, tid + 4); 189 | } 190 | __syncthreads(); 191 | } 192 | if (block_size >= 4) { 193 | if (tid < 2) { 194 | __update(dists, dists_i, tid, tid + 2); 195 | } 196 | __syncthreads(); 197 | } 198 | if (block_size >= 2) { 199 | if (tid < 1) { 200 | __update(dists, dists_i, tid, tid + 1); 201 | } 202 | __syncthreads(); 203 | } 204 | 205 | old = dists_i[0]; 206 | if (tid == 0) 207 | idxs[j] = old; 208 | } 209 | } 210 | 211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 213 | // dataset: (B, N, 3) 214 | // tmp: (B, N) 215 | // output: 216 | // idx: (B, M) 217 | 218 | cudaError_t err; 219 | unsigned int n_threads = opt_n_threads(n); 220 | 221 | switch (n_threads) { 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | 248 | err = cudaGetLastError(); 249 | if (cudaSuccess != err) { 250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 251 | exit(-1); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /model/lib/pointnet2/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /model/meta.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import torch.nn as nn 4 | import json 5 | import os 6 | from .tokenizer import Tokenizer 7 | from . import LLM 8 | 9 | from fairscale.nn.model_parallel import initialize as fs_init 10 | 11 | 12 | class MetaModel(nn.Module): 13 | 14 | def __init__(self, llama_type, llama_config, llama_ckpt_dir=None, tokenizer_path=None): 15 | super().__init__() 16 | 17 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 18 | 19 | ModelArgs = LLM.__dict__[llama_type].ModelArgs 20 | Transformer = LLM.__dict__[llama_type].Transformer 21 | 22 | with open(llama_config, "r") as f: 23 | params = json.loads(f.read()) 24 | model_args: ModelArgs = ModelArgs( 25 | max_seq_len=2048, max_batch_size=32, **params 26 | ) 27 | self.tokenizer = Tokenizer(model_path=tokenizer_path) 28 | model_args.vocab_size = self.tokenizer.n_words 29 | 30 | model = Transformer(model_args) 31 | mp_rank = fs_init.get_model_parallel_rank() 32 | if llama_ckpt_dir is not None: 33 | ckpt_path = os.path.join(llama_ckpt_dir, f"consolidated.{mp_rank:02d}.pth") 34 | if os.path.exists(ckpt_path): 35 | checkpoint = torch.load(ckpt_path, map_location="cpu") 36 | msg = model.load_state_dict(checkpoint, strict=False) 37 | print(msg) 38 | else: 39 | print(f'Checkpoint not found at {ckpt_path}') 40 | self.llma = model 41 | for name, param in self.named_parameters(): 42 | if param.requires_grad: 43 | print(f"Trainable param: {name}, {param.shape}, {param.dtype}") 44 | count = sum(p.numel() for p in self.parameters() if p.requires_grad) 45 | print(f"Parameter count : {count}") 46 | 47 | def forward(self, examples, labels, image=None, modal='image'): 48 | output = self.llma(examples, image=image, modal=modal) 49 | output = output[:, :-1, :] 50 | labels = labels[:, 1:] 51 | 52 | if labels.sum() == 0: 53 | c_loss = output.mean() * 0 54 | else: 55 | c_loss = self.criterion(output.reshape(-1, 32000), labels.flatten()) 56 | 57 | return c_loss 58 | 59 | def generate( 60 | self, 61 | prompts: List[str], 62 | images, 63 | max_gen_len: int, 64 | temperature: float = 0.8, 65 | top_p: float = 0.95, 66 | modal = ['image'], 67 | ) -> List[str]: 68 | bsz = len(prompts) 69 | params = self.llma.params 70 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 71 | 72 | prompt_tokens = [self.tokenizer.encode( 73 | x, bos=True, eos=False) for x in prompts] 74 | 75 | min_prompt_size = min([len(t) for t in prompt_tokens]) 76 | max_prompt_size = max([len(t) for t in prompt_tokens]) 77 | 78 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) 79 | 80 | tokens = torch.full( 81 | (bsz, total_len), self.tokenizer.pad_id).cuda().long() 82 | for k, t in enumerate(prompt_tokens): 83 | tokens[k, : len(t)] = torch.tensor(t).long() 84 | input_text_mask = tokens != self.tokenizer.pad_id 85 | start_pos = min_prompt_size 86 | prev_pos = 0 87 | for cur_pos in range(start_pos, total_len): 88 | logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal=modal) 89 | if temperature > 0: 90 | probs = torch.softmax(logits / temperature, dim=-1) 91 | next_token = self.sample_top_p(probs, top_p) 92 | else: 93 | next_token = torch.argmax(logits, dim=-1) 94 | next_token = next_token.reshape(-1) 95 | # only replace token if prompt has already been generated 96 | next_token = torch.where( 97 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token 98 | ) 99 | tokens[:, cur_pos] = next_token 100 | prev_pos = cur_pos 101 | 102 | decoded = [] 103 | for i, t in enumerate(tokens.tolist()): 104 | # cut to max gen len 105 | t = t[: len(prompt_tokens[i]) + max_gen_len] 106 | # cut to eos tok if any 107 | try: 108 | t = t[: t.index(self.tokenizer.eos_id)] 109 | except ValueError: 110 | pass 111 | decoded.append(self.tokenizer.decode(t)) 112 | return decoded 113 | 114 | @torch.inference_mode() 115 | def stream_generate( 116 | self, 117 | prompt: str, 118 | images, 119 | max_gen_len: int, 120 | temperature: float = 0.8, 121 | top_p: float = 0.95, 122 | modal = ['image'], 123 | ): 124 | params = self.llma.params 125 | 126 | prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) 127 | # truncate from the left. leave some space for generation. 128 | max_seq_len = params.max_seq_len 129 | if images is not None: 130 | max_seq_len -= self.llma.image_words 131 | 132 | max_prompt_size = max_seq_len - max_gen_len 133 | prompt_tokens = prompt_tokens[-max_prompt_size:] 134 | 135 | prompt_size = len(prompt_tokens) 136 | 137 | total_len = min(max_seq_len, max_gen_len + prompt_size) 138 | 139 | tokens = torch.full([total_len], 0).cuda().long() 140 | 141 | tokens[:len(prompt_tokens)] = torch.tensor(prompt_tokens).long() 142 | start_pos = prompt_size 143 | prev_pos = 0 144 | generate_until = start_pos 145 | for cur_pos in range(start_pos, total_len): 146 | logits = self.llma.forward_inference(tokens[None, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal = modal) 147 | if temperature > 0: 148 | probs = torch.softmax(logits / temperature, dim=-1) 149 | next_token = self.sample_top_p(probs, top_p) 150 | else: 151 | next_token = torch.argmax(logits, dim=-1) 152 | next_token = next_token.item() 153 | 154 | if next_token == self.tokenizer.eos_id: 155 | break 156 | 157 | tokens[cur_pos] = next_token 158 | prev_pos = cur_pos 159 | generate_until = cur_pos + 1 160 | yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": False} 161 | 162 | yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": True} 163 | 164 | def sample_top_p(self, probs, p): 165 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 166 | probs_sum = torch.cumsum(probs_sort, dim=-1) 167 | mask = probs_sum - probs_sort > p 168 | probs_sort[mask] = 0.0 169 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 170 | next_token = torch.multinomial(probs_sort, num_samples=1) 171 | next_token = torch.gather(probs_idx, -1, next_token) 172 | return next_token 173 | 174 | def get_image_words(self): 175 | return self.llma.image_words -------------------------------------------------------------------------------- /model/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from sentencepiece import SentencePieceProcessor 5 | from logging import getLogger 6 | from typing import List 7 | import os 8 | 9 | 10 | logger = getLogger() 11 | 12 | 13 | class Tokenizer: 14 | def __init__(self, model_path: str): 15 | # reload tokenizer 16 | assert os.path.isfile(model_path), model_path 17 | self.sp_model = SentencePieceProcessor(model_file=model_path) 18 | logger.info(f"Reloaded SentencePiece model from {model_path}") 19 | 20 | # BOS / EOS token IDs 21 | self.n_words: int = self.sp_model.vocab_size() 22 | self.bos_id: int = self.sp_model.bos_id() 23 | self.eos_id: int = self.sp_model.eos_id() 24 | self.pad_id: int = self.sp_model.pad_id() 25 | logger.info( 26 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 27 | ) 28 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 29 | 30 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 31 | assert type(s) is str 32 | t = self.sp_model.encode(s) 33 | if bos: 34 | t = [self.bos_id] + t 35 | if eos: 36 | t = t + [self.eos_id] 37 | return t 38 | 39 | def decode(self, t: List[int]) -> str: 40 | return self.sp_model.decode(t) 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==2.0.0+cu117 3 | packaging 4 | fairscale 5 | sentencepiece 6 | Pillow 7 | huggingface_hub 8 | open_clip_torch==2.23.0 9 | decord 10 | pytorchvideo==0.1.5 11 | torchaudio 12 | matplotlib 13 | flash-attn 14 | gradio 15 | pandas 16 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, it, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps 12 | lr = args.lr * it / args.warmup_iters 13 | elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate 14 | lr = args.min_lr 15 | else: # 3) in between, use cosine decay down to min learning rate 16 | decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters) 17 | assert 0 <= decay_ratio <= 1 18 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 19 | lr = args.min_lr + (args.lr - args.min_lr) * coeff 20 | 21 | for param_group in optimizer.param_groups: 22 | if "lr_scale" in param_group: 23 | param_group["lr"] = lr * param_group["lr_scale"] 24 | else: 25 | param_group["lr"] = lr 26 | return lr 27 | 28 | 29 | def adjust_learning_rate_epoch(optimizer, epoch, args): 30 | """Decay the learning rate with half-cycle cosine after warmup""" 31 | if epoch < args.warmup_epochs: 32 | lr = args.lr * epoch / args.warmup_epochs 33 | else: 34 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 35 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 36 | for param_group in optimizer.param_groups: 37 | if "lr_scale" in param_group: 38 | param_group["lr"] = lr * param_group["lr_scale"] 39 | else: 40 | param_group["lr"] = lr 41 | return lr 42 | 43 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | import torch 12 | 13 | # -------------------------------------------------------- 14 | # 2D sine-cosine position embedding 15 | # References: 16 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 17 | # MoCo v3: https://github.com/facebookresearch/moco-v3 18 | # -------------------------------------------------------- 19 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 20 | """ 21 | grid_size: int of the grid height and width 22 | return: 23 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 24 | """ 25 | grid_h = np.arange(grid_size, dtype=np.float32) 26 | grid_w = np.arange(grid_size, dtype=np.float32) 27 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 28 | grid = np.stack(grid, axis=0) 29 | 30 | grid = grid.reshape([2, 1, grid_size, grid_size]) 31 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 32 | if cls_token: 33 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 34 | return pos_embed 35 | 36 | 37 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 38 | assert embed_dim % 2 == 0 39 | 40 | # use half of dimensions to encode grid_h 41 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 42 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 43 | 44 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 45 | return emb 46 | 47 | 48 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 49 | """ 50 | embed_dim: output dimension for each position 51 | pos: a list of positions to be encoded: size (M,) 52 | out: (M, D) 53 | """ 54 | assert embed_dim % 2 == 0 55 | omega = np.arange(embed_dim // 2, dtype=np.float) 56 | omega /= embed_dim / 2. 57 | omega = 1. / 10000**omega # (D/2,) 58 | 59 | pos = pos.reshape(-1) # (M,) 60 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 61 | 62 | emb_sin = np.sin(out) # (M, D/2) 63 | emb_cos = np.cos(out) # (M, D/2) 64 | 65 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 66 | return emb 67 | 68 | 69 | # -------------------------------------------------------- 70 | # Interpolate position embeddings for high-resolution 71 | # References: 72 | # DeiT: https://github.com/facebookresearch/deit 73 | # -------------------------------------------------------- 74 | def interpolate_pos_embed(model, checkpoint_model): 75 | if 'pos_embed' in checkpoint_model: 76 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 77 | embedding_size = pos_embed_checkpoint.shape[-1] 78 | num_patches = model.patch_embed.num_patches 79 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 80 | # height (== width) for the checkpoint position embedding 81 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 82 | # height (== width) for the new position embedding 83 | new_size = int(num_patches ** 0.5) 84 | # class_token and dist_token are kept unchanged 85 | if orig_size != new_size: 86 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 87 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 88 | # only the position tokens are interpolated 89 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 90 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 91 | pos_tokens = torch.nn.functional.interpolate( 92 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 93 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 94 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 95 | checkpoint_model['pos_embed'] = new_pos_embed 96 | 97 | 98 | def interpolate_pos_embed_online( 99 | pos_embed, orig_size, new_size, num_extra_tokens: int 100 | ): 101 | # [257, 1024] 102 | extra_tokens = pos_embed[:num_extra_tokens] 103 | pos_tokens = pos_embed[num_extra_tokens:] 104 | embedding_size = pos_tokens.shape[1] 105 | pos_tokens = pos_tokens.reshape( 106 | -1, orig_size[0], orig_size[1], embedding_size 107 | ).permute(0, 3, 1, 2) 108 | pos_tokens = torch.nn.functional.interpolate( 109 | pos_tokens, size=new_size, mode="bicubic", align_corners=False, 110 | ) 111 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size) 112 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0) 113 | return new_pos_embed --------------------------------------------------------------------------------