├── README.md ├── finetune_llmseg.py ├── imgs ├── llmseg_exp.png ├── llmseg_overview.png ├── reasonseg_overview.png └── reasonseg_results_final_small.drawio.png ├── model ├── LISA.py ├── __init__.py ├── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ ├── model │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model │ │ │ ├── llava_llama.py │ │ │ ├── llava_mpt.py │ │ │ └── mpt │ │ │ │ ├── adapt_tokenizer.py │ │ │ │ ├── attention.py │ │ │ │ ├── blocks.py │ │ │ │ ├── configuration_mpt.py │ │ │ │ ├── custom_embedding.py │ │ │ │ ├── flash_attn_triton.py │ │ │ │ ├── hf_prefixlm_converter.py │ │ │ │ ├── meta_init_context.py │ │ │ │ ├── modeling_mpt.py │ │ │ │ ├── norm.py │ │ │ │ └── param_init_fns.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ └── utils.py │ ├── train │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ └── train_mem.py │ └── utils.py ├── loss.py ├── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py └── transformer.py ├── prepare_datasets ├── __init__.py ├── convert_h5_to_json.py ├── generate_index_reasonseg.py ├── prepare_ReasonSeg.py ├── prepare_ade20k.py ├── prepare_coco.py ├── prepare_egoobjects.py ├── prepare_mapillary.py ├── prepare_saiapr.py ├── prepare_voc2010.py └── split_coco.py ├── requirements.txt ├── scripts ├── finetune_llmseg.sh ├── train_10epoch.sh ├── train_20epoch.sh ├── train_zero_shot.sh ├── validate_llmseg40k.sh └── validate_visualize.sh ├── training.py ├── utils ├── ade20k_classes.json ├── cocostuff_classes.txt ├── conversation.py ├── data_processing.py ├── dataset.py ├── grefcoco.py ├── grefer.py ├── llm_seg_dataset.py ├── reason_seg_dataset.py ├── refer.py ├── refer_seg_dataset.py ├── sam_mask_reader.py ├── sem_seg_dataset.py ├── utils.py └── vqa_dataset.py ├── validate_llmseg.py └── validation.py /README.md: -------------------------------------------------------------------------------- 1 | # LLM-Seg: Bridging Image Segmentation and Large Language Model Reasoning 2 | 3 | This is the official repository for the paper [\[arxiv\]](https://arxiv.org/abs/2404.08767): "LLM-Seg: Bridging Image Segmentation and Large Language Model Reasoning" (CVPR Workshop 2024). 4 | 5 | Our project is based on [LISA](https://github.com/dvlab-research/LISA). We thank the authors for their great work. 6 | 7 | ## Overview 8 | LLM-Seg is a reasoning segmentation model that combines SAM and LLaVA. We also release our proposed LLM-Seg40K dataset, which is a new reasoning segmentation dataset that is generated by ChatGPT. 9 | 10 | ### Reasoning Segmentation 11 | ![image](imgs/reasonseg_overview.png) 12 | 13 | ### Model Architecture 14 | ![image](imgs/llmseg_overview.png) 15 | 16 | 17 | ## Experiment Results 18 | The table below shows the performance of LLM-Seg on ReasonSeg validation set. 19 | ![image](imgs/llmseg_exp.png) 20 | 21 | ![image](imgs/reasonseg_results_final_small.drawio.png) 22 | 23 | 24 | 25 | ## Prepare the environment 26 | We recommend using conda to create a virtual environment and install the dependencies. 27 | ```bash 28 | pip install -r requirements.txt 29 | pip install flash-attn --no-build-isolation 30 | ``` 31 | 32 | 33 | ## Preparing the dataset 34 | Please first refer to the [LISA](https://github.com/dvlab-research/LISA?tab=readme-ov-file#training-data-preparation) repository to download all the datasets. 35 | 36 | After downloading the dataset, you can use the python script from `prepare_datasets.py` to preprocess the different dataset. The script will extract SAM Everything masks and save them as h5 files. 37 | 38 | ```bash 39 | python prepare_datasets/prepare_.py 40 | ``` 41 | 42 | After preprocessing all datasets, run the following script to convert the h5 files to json files. 43 | 44 | ```bash 45 | python prepare_datasets/convert_h5_to_json.py 46 | ``` 47 | 48 | ## Prepare the pretrained models 49 | Please follow the [instruction](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) to merge the LLaVA delta weights. We use `LLaVA-lightning-7B-v1` checkpoint. 50 | 51 | For the SAM checkpoint, please use the following [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth). 52 | 53 | 54 | ## Training the model 55 | We provide some of the training scripts under the `scripts` directory. You can modify the scripts to train the model with different configurations. 56 | 57 | For example, to train the model for 10 epochs with 2 GPUs, you can run use `train_10epoch.sh` script. You should use your own paths for the dataset, pretrained models, and log directory. 58 | 59 | ```bash 60 | #! /bin/bash 61 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 62 | sam_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 63 | dataset_path="./lisa_dataset" 64 | sam_masks_path="./processed_data" 65 | log_path="./runs" 66 | 67 | deepspeed --include localhost:6,7 \ 68 | --master_port=24374 training.py \ 69 | --version="$llava_path" \ 70 | --dataset_dir="$dataset_path" \ 71 | --sam_masks_dir="$sam_masks_path" \ 72 | --vision_pretrained="$sam_path" \ 73 | --dataset="sem_seg||refer_seg||reason_seg" \ 74 | --sample_rates="9,3,1" \ 75 | --exp_name="10epoch" \ 76 | --log_base_dir="$log_path" \ 77 | --lr=0.0001 \ 78 | --epochs=10 \ 79 | --batch_size=1 \ 80 | 81 | ``` 82 | 83 | ## Evaluation on ReasonSeg 84 | To evaluate the trained model, please modify the `scripts/validate_visualize.sh` script with your own paths and run the script. The visualization results will be saved in the log directory. 85 | 86 | ```bash 87 | #! /bin/bash 88 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 89 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 90 | dataset_path="./lisa_dataset" 91 | sam_masks_path="./processed_data" 92 | log_path="./runs" 93 | 94 | deepspeed --include localhost:2,3 \ 95 | --master_port=24353 training_debug.py \ 96 | --version="$llava_path" \ 97 | --dataset_dir="$dataset_path" \ 98 | --sam_masks_dir="$sam_masks_path" \ 99 | --vision_pretrained="$vision_path" \ 100 | --dataset="reason_seg" \ 101 | --sample_rates="1" \ 102 | --exp_name="10epoch" \ 103 | --log_base_dir="$log_path" \ 104 | --batch_size=1 \ 105 | --eval_only \ 106 | --val_dataset="ReasonSeg|val" \ 107 | --visualize 108 | ``` 109 | 110 | We also provide the trained checkpoint for evaluation. You can download it from huggingface. Please note the checkpoint is in Deepspeed format, not huggingface format. 111 | 112 | Checkpoint for 20 epochs: [llmseg-20epoch](https://huggingface.co/JCdesu/LLM-Seg-deepspeed) 113 | 114 | Checkpoint for 10 epochs: [llmseg-10epoch](https://huggingface.co/JCdesu/LLM-Seg-deepspeed-10epoch) 115 | 116 | The deepspeed checkpoint has the same format as your own trained model. You can directly replace the checkpoint files in the log directory and run the evaluation script. 117 | 118 | If you do not train your own model, we suggest creating a new directory to mimic the log directory structure and store the checkpoint files. The directory structure should be like the following: 119 | 120 | ``` 121 | - resume_dir 122 | - ckpt_models 123 | - global_step5000 124 | -- mp_rank_00_model_states.pt 125 | -- bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt 126 | -- bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt 127 | -- latest 128 | ``` 129 | 130 | The `latest` file is just a text file and it should contain the folder name to the checkpoint. In the case above, the `latest` file should contain `global_step5000`. 131 | 132 | 133 | ## Finetuning and Evaluation on LLM-Seg40K dataset 134 | We also provide the our proposed LLM-Seg40K dataset. You can download the dataset from the following [link](https://drive.google.com/drive/folders/19MyzJN9hkvTSUr2lUFlMwmS2TSRRKmUX?usp=sharing). Besides the annotation files, you should also download the COCO2017 training images and EgoObjects images. You can put them together with other datasets. After downloading the dataset, you can use the `prepare_datasets/prepare_egoobjects.py` script to extract SAM masks for dataset. 135 | 136 | ### Finetuning the model 137 | 138 | You can use the `finetune_llmseg.py` file to finetune and evaluate the model on the LLM-Seg40K dataset. Please modify the the `init_validation_dataset` and `init_training_dataset` functions to correctly set the paths. 139 | 140 | For finetuning the model, you can use your own trained checkpointed or our provided checkpoint. Please modify the `script/finetune_llmseg.sh` script to set the correct paths and run the script. 141 | 142 | ```bash 143 | #! /bin/bash 144 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 145 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 146 | dataset_path="./lisa_dataset" 147 | sam_masks_path="./processed_data" 148 | log_path="./runs" 149 | resume_path="./runs/10epoch/ckpt_model" 150 | 151 | deepspeed --include localhost:2,3 \ 152 | --master_port=24374 finetune_llmseg.py \ 153 | --version="$llava_path" \ 154 | --dataset_dir="$dataset_path" \ 155 | --sam_masks_dir="$sam_masks_path" \ 156 | --vision_pretrained="$vision_path" \ 157 | --dataset="sem_seg||refer_seg||reason_seg" \ 158 | --sample_rates="9,3,1" \ 159 | --exp_name="finetune_llmseg" \ 160 | --log_base_dir="$log" \ 161 | --steps_per_epoch=500 \ 162 | --lr=1e-5 \ 163 | --epochs=5 \ 164 | --batch_size=1 \ 165 | --resume='$resume_path' \ 166 | 167 | ``` 168 | 169 | ### Evaluation on LLM-Seg40K dataset 170 | After get the finetuned model. you can evaluate the finetuned model on the LLM-Seg40K dataset, you can use the `scripts/validate_llmseg40k.sh` script. Please modify the script to set the correct paths and run the script. 171 | 172 | ```bash 173 | #! /bin/bash 174 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 175 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 176 | dataset_path="./lisa_dataset" 177 | sam_masks_path="./processed_data" 178 | log_path="./runs" 179 | 180 | deepspeed --include localhost:0,1 \ 181 | --master_port=24353 validate_llmseg.py \ 182 | --version="$llava_path" \ 183 | --dataset_dir="$dataset_path" \ 184 | --vision_pretrained="$vision_path" \ 185 | --dataset="reason_seg" \ 186 | --sample_rates="1" \ 187 | --exp_name="finetune_llmseg" \ 188 | --log_base_dir="$log_path" \ 189 | --batch_size=1 \ 190 | --eval_only \ 191 | --visualize \ 192 | ``` 193 | 194 | We also provide the our finetuned checkpoint for evaluation. You can download it from huggingface. The checkpoint is also in Deepspeed format. [Link](https://huggingface.co/JCdesu/llmseg_finetuned) 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | ## Acknowledgement 210 | Our project is based on the following repositories: 211 | - [LISA](https://github.com/dvlab-research/LISA) 212 | - [LLaVA](https://github.com/haotian-liu/LLaVA) 213 | - [SAM](https://github.com/facebookresearch/segment-anything) 214 | - [DINOv2](https://github.com/facebookresearch/dinov2) 215 | 216 | We thank the authors for their great work. Please refer to their repositories for more details. 217 | 218 | ## Citation 219 | ``` 220 | @article{wang2024llmseg, 221 | title={LLM-Seg: Bridging Image Segmentation and Large Language Model Reasoning}, 222 | author={Wang, Junchi and Ke, Lei}, 223 | journal={arXiv preprint arXiv:2404.08767}, 224 | year={2024} 225 | } 226 | ``` 227 | -------------------------------------------------------------------------------- /imgs/llmseg_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjunchi/LLMSeg/65147a98cc49c2dac9b7a3633d9d3d336f1ac023/imgs/llmseg_exp.png -------------------------------------------------------------------------------- /imgs/llmseg_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjunchi/LLMSeg/65147a98cc49c2dac9b7a3633d9d3d336f1ac023/imgs/llmseg_overview.png -------------------------------------------------------------------------------- /imgs/reasonseg_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjunchi/LLMSeg/65147a98cc49c2dac9b7a3633d9d3d336f1ac023/imgs/reasonseg_overview.png -------------------------------------------------------------------------------- /imgs/reasonseg_results_final_small.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjunchi/LLMSeg/65147a98cc49c2dac9b7a3633d9d3d336f1ac023/imgs/reasonseg_results_final_small.drawio.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjunchi/LLMSeg/65147a98cc49c2dac9b7a3633d9d3d336f1ac023/model/__init__.py -------------------------------------------------------------------------------- /model/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /model/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /model/llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import torch 5 | from PIL import Image 6 | from transformers import StoppingCriteria 7 | 8 | from .constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def process_images(images, image_processor, model_cfg): 16 | return image_processor(images, return_tensors="pt")["pixel_values"] 17 | 18 | 19 | def tokenizer_image_token( 20 | prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None 21 | ): 22 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] 23 | 24 | def insert_separator(X, sep): 25 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 26 | 27 | input_ids = [] 28 | offset = 0 29 | if ( 30 | len(prompt_chunks) > 0 31 | and len(prompt_chunks[0]) > 0 32 | and prompt_chunks[0][0] == tokenizer.bos_token_id 33 | ): 34 | offset = 1 35 | input_ids.append(prompt_chunks[0][0]) 36 | 37 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 38 | input_ids.extend(x[offset:]) 39 | 40 | if return_tensors is not None: 41 | if return_tensors == "pt": 42 | return torch.tensor(input_ids, dtype=torch.long) 43 | raise ValueError(f"Unsupported tensor type: {return_tensors}") 44 | return input_ids 45 | 46 | 47 | def get_model_name_from_path(model_path): 48 | model_path = model_path.strip("/") 49 | model_paths = model_path.split("/") 50 | if model_paths[-1].startswith("checkpoint-"): 51 | return model_paths[-2] + "_" + model_paths[-1] 52 | else: 53 | return model_paths[-1] 54 | 55 | 56 | class KeywordsStoppingCriteria(StoppingCriteria): 57 | def __init__(self, keywords, tokenizer, input_ids): 58 | self.keywords = keywords 59 | self.keyword_ids = [] 60 | for keyword in keywords: 61 | cur_keyword_ids = tokenizer(keyword).input_ids 62 | if ( 63 | len(cur_keyword_ids) > 1 64 | and cur_keyword_ids[0] == tokenizer.bos_token_id 65 | ): 66 | cur_keyword_ids = cur_keyword_ids[1:] 67 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 68 | self.tokenizer = tokenizer 69 | self.start_len = input_ids.shape[1] 70 | 71 | def __call__( 72 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 73 | ) -> bool: 74 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 75 | offset = min(output_ids.shape[1] - self.start_len, 3) 76 | self.keyword_ids = [ 77 | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids 78 | ] 79 | for keyword_id in self.keyword_ids: 80 | if output_ids[0, -keyword_id.shape[0] :] == keyword_id: 81 | return True 82 | outputs = self.tokenizer.batch_decode( 83 | output_ids[:, -offset:], skip_special_tokens=True 84 | )[0] 85 | for keyword in self.keywords: 86 | if keyword in outputs: 87 | return True 88 | return False 89 | -------------------------------------------------------------------------------- /model/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM 2 | from .language_model.llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM 3 | -------------------------------------------------------------------------------- /model/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava import LlavaLlamaForCausalLM 9 | from tqdm import tqdm 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 17 | ) 18 | 19 | print("Loading delta") 20 | delta = LlavaLlamaForCausalLM.from_pretrained( 21 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 22 | ) 23 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 24 | 25 | print("Applying delta") 26 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 27 | if name not in base.state_dict(): 28 | assert name in [ 29 | "model.mm_projector.weight", 30 | "model.mm_projector.bias", 31 | ], f"{name} not in base model" 32 | continue 33 | if param.data.shape == base.state_dict()[name].shape: 34 | param.data += base.state_dict()[name] 35 | else: 36 | assert name in [ 37 | "model.embed_tokens.weight", 38 | "lm_head.weight", 39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 40 | bparam = base.state_dict()[name] 41 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 42 | 43 | print("Saving target model") 44 | delta.save_pretrained(target_model_path) 45 | delta_tokenizer.save_pretrained(target_model_path) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--base-model-path", type=str, required=True) 51 | parser.add_argument("--target-model-path", type=str, required=True) 52 | parser.add_argument("--delta-path", type=str, required=True) 53 | 54 | args = parser.parse_args() 55 | 56 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 57 | -------------------------------------------------------------------------------- /model/llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import shutil 18 | 19 | import torch 20 | from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, 21 | DEFAULT_IMAGE_PATCH_TOKEN) 22 | from llava.model import * 23 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, 24 | BitsAndBytesConfig) 25 | 26 | 27 | def load_pretrained_model( 28 | model_path, 29 | model_base, 30 | model_name, 31 | load_8bit=False, 32 | load_4bit=False, 33 | device_map="auto", 34 | ): 35 | kwargs = {"device_map": device_map} 36 | 37 | if load_8bit: 38 | kwargs["load_in_8bit"] = True 39 | elif load_4bit: 40 | kwargs["load_in_4bit"] = True 41 | kwargs["quantization_config"] = BitsAndBytesConfig( 42 | load_in_4bit=True, 43 | bnb_4bit_compute_dtype=torch.float16, 44 | bnb_4bit_use_double_quant=True, 45 | bnb_4bit_quant_type="nf4", 46 | ) 47 | else: 48 | kwargs["torch_dtype"] = torch.float16 49 | 50 | if "llava" in model_name.lower(): 51 | # Load LLaVA model 52 | if "lora" in model_name.lower() and model_base is not None: 53 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 54 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 55 | print("Loading LLaVA from base model...") 56 | model = LlavaLlamaForCausalLM.from_pretrained( 57 | model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs 58 | ) 59 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 60 | if model.lm_head.weight.shape[0] != token_num: 61 | model.lm_head.weight = torch.nn.Parameter( 62 | torch.empty( 63 | token_num, tokem_dim, device=model.device, dtype=model.dtype 64 | ) 65 | ) 66 | model.model.embed_tokens.weight = torch.nn.Parameter( 67 | torch.empty( 68 | token_num, tokem_dim, device=model.device, dtype=model.dtype 69 | ) 70 | ) 71 | 72 | print("Loading additional LLaVA weights...") 73 | if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): 74 | non_lora_trainables = torch.load( 75 | os.path.join(model_path, "non_lora_trainables.bin"), 76 | map_location="cpu", 77 | ) 78 | else: 79 | # this is probably from HF Hub 80 | from huggingface_hub import hf_hub_download 81 | 82 | def load_from_hf(repo_id, filename, subfolder=None): 83 | cache_file = hf_hub_download( 84 | repo_id=repo_id, filename=filename, subfolder=subfolder 85 | ) 86 | return torch.load(cache_file, map_location="cpu") 87 | 88 | non_lora_trainables = load_from_hf( 89 | model_path, "non_lora_trainables.bin" 90 | ) 91 | non_lora_trainables = { 92 | (k[11:] if k.startswith("base_model.") else k): v 93 | for k, v in non_lora_trainables.items() 94 | } 95 | if any(k.startswith("model.model.") for k in non_lora_trainables): 96 | non_lora_trainables = { 97 | (k[6:] if k.startswith("model.") else k): v 98 | for k, v in non_lora_trainables.items() 99 | } 100 | model.load_state_dict(non_lora_trainables, strict=False) 101 | 102 | from peft import PeftModel 103 | 104 | print("Loading LoRA weights...") 105 | model = PeftModel.from_pretrained(model, model_path) 106 | print("Merging LoRA weights...") 107 | model = model.merge_and_unload() 108 | print("Model is loaded...") 109 | elif model_base is not None: 110 | # this may be mm projector only 111 | print("Loading LLaVA from base model...") 112 | if "mpt" in model_name.lower(): 113 | if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")): 114 | shutil.copyfile( 115 | os.path.join(model_base, "configuration_mpt.py"), 116 | os.path.join(model_path, "configuration_mpt.py"), 117 | ) 118 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 119 | cfg_pretrained = AutoConfig.from_pretrained( 120 | model_path, trust_remote_code=True 121 | ) 122 | model = LlavaMPTForCausalLM.from_pretrained( 123 | model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs 124 | ) 125 | else: 126 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 127 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 128 | model = LlavaLlamaForCausalLM.from_pretrained( 129 | model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs 130 | ) 131 | 132 | mm_projector_weights = torch.load( 133 | os.path.join(model_path, "mm_projector.bin"), map_location="cpu" 134 | ) 135 | mm_projector_weights = { 136 | k: v.to(torch.float16) for k, v in mm_projector_weights.items() 137 | } 138 | model.load_state_dict(mm_projector_weights, strict=False) 139 | else: 140 | if "mpt" in model_name.lower(): 141 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 142 | model = LlavaMPTForCausalLM.from_pretrained( 143 | model_path, low_cpu_mem_usage=True, **kwargs 144 | ) 145 | else: 146 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 147 | model = LlavaLlamaForCausalLM.from_pretrained( 148 | model_path, low_cpu_mem_usage=True, **kwargs 149 | ) 150 | else: 151 | # Load language model 152 | if model_base is not None: 153 | # PEFT model 154 | from peft import PeftModel 155 | 156 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 157 | model = AutoModelForCausalLM.from_pretrained( 158 | model_base, 159 | torch_dtype=torch.float16, 160 | low_cpu_mem_usage=True, 161 | device_map="auto", 162 | ) 163 | print(f"Loading LoRA weights from {model_path}") 164 | model = PeftModel.from_pretrained(model, model_path) 165 | print(f"Merging weights") 166 | model = model.merge_and_unload() 167 | print("Convert to FP16...") 168 | model.to(torch.float16) 169 | else: 170 | use_fast = False 171 | if "mpt" in model_name.lower(): 172 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 173 | model = AutoModelForCausalLM.from_pretrained( 174 | model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs 175 | ) 176 | else: 177 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 178 | model = AutoModelForCausalLM.from_pretrained( 179 | model_path, low_cpu_mem_usage=True, **kwargs 180 | ) 181 | 182 | image_processor = None 183 | 184 | if "llava" in model_name.lower(): 185 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 186 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 187 | if mm_use_im_patch_token: 188 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 189 | if mm_use_im_start_end: 190 | tokenizer.add_tokens( 191 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True 192 | ) 193 | model.resize_token_embeddings(len(tokenizer)) 194 | 195 | vision_tower = model.get_vision_tower() 196 | if not vision_tower.is_loaded: 197 | vision_tower.load_model() 198 | vision_tower.to(device="cuda", dtype=torch.float16) 199 | image_processor = vision_tower.image_processor 200 | 201 | if hasattr(model.config, "max_sequence_length"): 202 | context_len = model.config.max_sequence_length 203 | else: 204 | context_len = 2048 205 | 206 | return tokenizer, model, image_processor, context_len 207 | -------------------------------------------------------------------------------- /model/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava.model import * 9 | from llava.model.utils import auto_upgrade 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained( 17 | src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 20 | src_model.save_pretrained(dst_path) 21 | src_tokenizer.save_pretrained(dst_path) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--src", type=str, required=True) 27 | parser.add_argument("--dst", type=str, required=True) 28 | 29 | args = parser.parse_args() 30 | 31 | consolidate_ckpt(args.src, args.dst) 32 | -------------------------------------------------------------------------------- /model/llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig, 22 | LlamaForCausalLM, LlamaModel) 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel 26 | 27 | 28 | class LlavaConfig(LlamaConfig): 29 | model_type = "llava" 30 | 31 | 32 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 33 | config_class = LlavaConfig 34 | 35 | def __init__(self, config: LlamaConfig): 36 | super(LlavaLlamaModel, self).__init__(config) 37 | 38 | 39 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaConfig 41 | 42 | def __init__(self, config): 43 | super(LlamaForCausalLM, self).__init__(config) 44 | 45 | self.model = LlavaLlamaModel(config) 46 | 47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 48 | 49 | # Initialize weights and apply final processing 50 | self.post_init() 51 | 52 | def get_model(self): 53 | return self.model 54 | 55 | def forward( 56 | self, 57 | input_ids: torch.LongTensor = None, 58 | attention_mask: Optional[torch.Tensor] = None, 59 | past_key_values: Optional[List[torch.FloatTensor]] = None, 60 | inputs_embeds: Optional[torch.FloatTensor] = None, 61 | labels: Optional[torch.LongTensor] = None, 62 | use_cache: Optional[bool] = None, 63 | output_attentions: Optional[bool] = None, 64 | output_hidden_states: Optional[bool] = None, 65 | images: Optional[torch.FloatTensor] = None, 66 | return_dict: Optional[bool] = None, 67 | ) -> Union[Tuple, CausalLMOutputWithPast]: 68 | output_attentions = ( 69 | output_attentions 70 | if output_attentions is not None 71 | else self.config.output_attentions 72 | ) 73 | output_hidden_states = ( 74 | output_hidden_states 75 | if output_hidden_states is not None 76 | else self.config.output_hidden_states 77 | ) 78 | return_dict = ( 79 | return_dict if return_dict is not None else self.config.use_return_dict 80 | ) 81 | 82 | ( 83 | input_ids, 84 | attention_mask, 85 | past_key_values, 86 | inputs_embeds, 87 | labels, 88 | ) = self.prepare_inputs_labels_for_multimodal( 89 | input_ids, attention_mask, past_key_values, labels, images 90 | ) 91 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 92 | 93 | outputs = self.model( 94 | input_ids=input_ids, 95 | attention_mask=attention_mask, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict, 102 | ) 103 | 104 | hidden_states = outputs[0] 105 | logits = self.lm_head(hidden_states) 106 | 107 | loss = None 108 | if labels is not None: 109 | # Shift so that tokens < n predict n 110 | shift_logits = logits[..., :-1, :].contiguous() 111 | shift_labels = labels[..., 1:].contiguous() 112 | # Flatten the tokens 113 | loss_fct = CrossEntropyLoss() 114 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 115 | shift_labels = shift_labels.view(-1) 116 | # Enable model/pipeline parallelism 117 | shift_labels = shift_labels.to(shift_logits.device) 118 | loss = loss_fct(shift_logits, shift_labels) 119 | 120 | if not return_dict: 121 | output = (logits,) + outputs[1:] 122 | return (loss,) + output if loss is not None else output 123 | 124 | if self.training: 125 | output_hidden_states = outputs.hidden_states 126 | else: 127 | output_hidden_states = hidden_states 128 | 129 | return CausalLMOutputWithPast( 130 | loss=loss, 131 | logits=logits, 132 | past_key_values=outputs.past_key_values, 133 | hidden_states=output_hidden_states, # outputs.hidden_states, 134 | attentions=outputs.attentions, 135 | ) 136 | 137 | def prepare_inputs_for_generation( 138 | self, 139 | input_ids, 140 | past_key_values=None, 141 | attention_mask=None, 142 | inputs_embeds=None, 143 | images=None, 144 | **kwargs 145 | ): 146 | if past_key_values: 147 | input_ids = input_ids[:, -1:] 148 | 149 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 150 | if inputs_embeds is not None and past_key_values is None: 151 | model_inputs = {"inputs_embeds": inputs_embeds} 152 | else: 153 | model_inputs = {"input_ids": input_ids} 154 | 155 | model_inputs.update( 156 | { 157 | "past_key_values": past_key_values, 158 | "use_cache": kwargs.get("use_cache"), 159 | "attention_mask": attention_mask, 160 | "images": images, 161 | } 162 | ) 163 | return model_inputs 164 | 165 | 166 | AutoConfig.register("llava", LlavaConfig) 167 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 168 | -------------------------------------------------------------------------------- /model/llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | import warnings 18 | from typing import List, Optional, Tuple 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | from transformers import AutoConfig, AutoModelForCausalLM 23 | from transformers.modeling_outputs import CausalLMOutputWithPast 24 | 25 | from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | 28 | 29 | class LlavaMPTConfig(MPTConfig): 30 | model_type = "llava_mpt" 31 | 32 | 33 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 34 | config_class = LlavaMPTConfig 35 | 36 | def __init__(self, config: MPTConfig): 37 | config.hidden_size = config.d_model 38 | super(LlavaMPTModel, self).__init__(config) 39 | 40 | def embed_tokens(self, x): 41 | return self.wte(x) 42 | 43 | 44 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 45 | config_class = LlavaMPTConfig 46 | supports_gradient_checkpointing = True 47 | 48 | def __init__(self, config): 49 | super(MPTForCausalLM, self).__init__(config) 50 | 51 | if not config.tie_word_embeddings: 52 | raise ValueError("MPTForCausalLM only supports tied word embeddings") 53 | self.transformer = LlavaMPTModel(config) 54 | self.logit_scale = None 55 | if config.logit_scale is not None: 56 | logit_scale = config.logit_scale 57 | if isinstance(logit_scale, str): 58 | if logit_scale == "inv_sqrt_d_model": 59 | logit_scale = 1 / math.sqrt(config.d_model) 60 | else: 61 | raise ValueError( 62 | f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." 63 | ) 64 | self.logit_scale = logit_scale 65 | 66 | def get_model(self): 67 | return self.transformer 68 | 69 | def _set_gradient_checkpointing(self, module, value=False): 70 | if isinstance(module, LlavaMPTModel): 71 | module.gradient_checkpointing = value 72 | 73 | def forward( 74 | self, 75 | input_ids: torch.LongTensor, 76 | past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, 77 | attention_mask: Optional[torch.ByteTensor] = None, 78 | prefix_mask: Optional[torch.ByteTensor] = None, 79 | sequence_id: Optional[torch.LongTensor] = None, 80 | labels: Optional[torch.LongTensor] = None, 81 | return_dict: Optional[bool] = None, 82 | output_attentions: Optional[bool] = None, 83 | output_hidden_states: Optional[bool] = None, 84 | use_cache: Optional[bool] = None, 85 | images=None, 86 | ): 87 | return_dict = ( 88 | return_dict if return_dict is not None else self.config.return_dict 89 | ) 90 | use_cache = use_cache if use_cache is not None else self.config.use_cache 91 | 92 | ( 93 | input_ids, 94 | attention_mask, 95 | past_key_values, 96 | inputs_embeds, 97 | labels, 98 | ) = self.prepare_inputs_labels_for_multimodal( 99 | input_ids, attention_mask, past_key_values, labels, images 100 | ) 101 | outputs = self.transformer( 102 | input_ids=input_ids, 103 | inputs_embeds=inputs_embeds, 104 | past_key_values=past_key_values, 105 | attention_mask=attention_mask, 106 | prefix_mask=prefix_mask, 107 | sequence_id=sequence_id, 108 | return_dict=return_dict, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | use_cache=use_cache, 112 | ) 113 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 114 | logits = F.linear( 115 | outputs.last_hidden_state.to(self.transformer.wte.weight.device), 116 | self.transformer.wte.weight, 117 | ) 118 | if self.logit_scale is not None: 119 | if self.logit_scale == 0: 120 | warnings.warn( 121 | f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." 122 | ) 123 | logits *= self.logit_scale 124 | loss = None 125 | if labels is not None: 126 | labels = torch.roll(labels, shifts=-1) 127 | labels[:, -1] = -100 128 | loss = F.cross_entropy( 129 | logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) 130 | ) 131 | return CausalLMOutputWithPast( 132 | loss=loss, 133 | logits=logits, 134 | past_key_values=outputs.past_key_values, 135 | hidden_states=outputs.hidden_states, 136 | ) 137 | 138 | def prepare_inputs_for_generation( 139 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 140 | ): 141 | if inputs_embeds is not None: 142 | raise NotImplementedError("inputs_embeds is not implemented for MPT yet") 143 | attention_mask = kwargs["attention_mask"].bool() 144 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 145 | raise NotImplementedError( 146 | "MPT does not support generation with right padding." 147 | ) 148 | if self.transformer.attn_uses_sequence_id and self.training: 149 | sequence_id = torch.zeros_like(input_ids[:1]) 150 | else: 151 | sequence_id = None 152 | if past_key_values is not None: 153 | input_ids = input_ids[:, -1].unsqueeze(-1) 154 | if self.transformer.prefix_lm: 155 | prefix_mask = torch.ones_like(attention_mask) 156 | if kwargs.get("use_cache") == False: 157 | raise NotImplementedError( 158 | "MPT with prefix_lm=True does not support use_cache=False." 159 | ) 160 | else: 161 | prefix_mask = None 162 | return { 163 | "input_ids": input_ids, 164 | "attention_mask": attention_mask, 165 | "prefix_mask": prefix_mask, 166 | "sequence_id": sequence_id, 167 | "past_key_values": past_key_values, 168 | "use_cache": kwargs.get("use_cache", True), 169 | "images": kwargs.get("images", None), 170 | } 171 | 172 | 173 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 174 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 175 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from transformers import (AutoTokenizer, PreTrainedTokenizer, 4 | PreTrainedTokenizerFast) 5 | 6 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 7 | NUM_SENTINEL_TOKENS: int = 100 8 | 9 | 10 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 11 | """Adds sentinel tokens and padding token (if missing). 12 | 13 | Expands the tokenizer vocabulary to include sentinel tokens 14 | used in mixture-of-denoiser tasks as well as a padding token. 15 | 16 | All added tokens are added as special tokens. No tokens are 17 | added if sentinel tokens and padding token already exist. 18 | """ 19 | sentinels_to_add = [f"" for i in range(NUM_SENTINEL_TOKENS)] 20 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 21 | if tokenizer.pad_token is None: 22 | tokenizer.add_tokens("", special_tokens=True) 23 | tokenizer.pad_token = "" 24 | assert tokenizer.pad_token_id is not None 25 | sentinels = "".join([f"" for i in range(NUM_SENTINEL_TOKENS)]) 26 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 27 | tokenizer.sentinel_token_ids = _sentinel_token_ids 28 | 29 | 30 | class AutoTokenizerForMOD(AutoTokenizer): 31 | """AutoTokenizer + Adaptation for MOD. 32 | 33 | A simple wrapper around AutoTokenizer to make instantiating 34 | an MOD-adapted tokenizer a bit easier. 35 | 36 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 37 | a padding token, and a property to get the token ids of the 38 | sentinel tokens. 39 | """ 40 | 41 | @classmethod 42 | def from_pretrained(cls, *args, **kwargs): 43 | """See `AutoTokenizer.from_pretrained` docstring.""" 44 | tokenizer = super().from_pretrained(*args, **kwargs) 45 | adapt_tokenizer_for_denoising(tokenizer) 46 | return tokenizer 47 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .attention import ATTN_CLASS_REGISTRY 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | 11 | class MPTMLP(nn.Module): 12 | def __init__( 13 | self, d_model: int, expansion_ratio: int, device: Optional[str] = None 14 | ): 15 | super().__init__() 16 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 17 | self.act = nn.GELU(approximate="none") 18 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 19 | self.down_proj._is_residual = True 20 | 21 | def forward(self, x): 22 | return self.down_proj(self.act(self.up_proj(x))) 23 | 24 | 25 | class MPTBlock(nn.Module): 26 | def __init__( 27 | self, 28 | d_model: int, 29 | n_heads: int, 30 | expansion_ratio: int, 31 | attn_config: Dict = { 32 | "attn_type": "multihead_attention", 33 | "attn_pdrop": 0.0, 34 | "attn_impl": "triton", 35 | "qk_ln": False, 36 | "clip_qkv": None, 37 | "softmax_scale": None, 38 | "prefix_lm": False, 39 | "attn_uses_sequence_id": False, 40 | "alibi": False, 41 | "alibi_bias_max": 8, 42 | }, 43 | resid_pdrop: float = 0.0, 44 | norm_type: str = "low_precision_layernorm", 45 | verbose: int = 0, 46 | device: Optional[str] = None, 47 | **kwargs 48 | ): 49 | del kwargs 50 | super().__init__() 51 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 52 | attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] 53 | self.norm_1 = norm_class(d_model, device=device) 54 | self.attn = attn_class( 55 | attn_impl=attn_config["attn_impl"], 56 | clip_qkv=attn_config["clip_qkv"], 57 | qk_ln=attn_config["qk_ln"], 58 | softmax_scale=attn_config["softmax_scale"], 59 | attn_pdrop=attn_config["attn_pdrop"], 60 | d_model=d_model, 61 | n_heads=n_heads, 62 | verbose=verbose, 63 | device=device, 64 | ) 65 | self.norm_2 = norm_class(d_model, device=device) 66 | self.ffn = MPTMLP( 67 | d_model=d_model, expansion_ratio=expansion_ratio, device=device 68 | ) 69 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 70 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 71 | 72 | def forward( 73 | self, 74 | x: torch.Tensor, 75 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 76 | attn_bias: Optional[torch.Tensor] = None, 77 | attention_mask: Optional[torch.ByteTensor] = None, 78 | is_causal: bool = True, 79 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 80 | a = self.norm_1(x) 81 | (b, attn_weights, past_key_value) = self.attn( 82 | a, 83 | past_key_value=past_key_value, 84 | attn_bias=attn_bias, 85 | attention_mask=attention_mask, 86 | is_causal=is_causal, 87 | ) 88 | x = x + self.resid_attn_dropout(b) 89 | m = self.norm_2(x) 90 | n = self.ffn(m) 91 | x = x + self.resid_ffn_dropout(n) 92 | return (x, attn_weights, past_key_value) 93 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | 7 | class SharedEmbedding(nn.Embedding): 8 | def forward(self, input: Tensor, unembed: bool = False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) 12 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | @contextmanager 8 | def init_empty_weights(include_buffers: bool = False): 9 | """Meta initialization context manager. 10 | 11 | A context manager under which models are initialized with all parameters 12 | on the meta device, therefore creating an empty model. Useful when just 13 | initializing the model would blow the available RAM. 14 | 15 | Args: 16 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 17 | not to also put all buffers on the meta device while initializing. 18 | 19 | Example: 20 | ```python 21 | import torch.nn as nn 22 | 23 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 24 | with init_empty_weights(): 25 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 26 | ``` 27 | 28 | 29 | 30 | Any model created under this context manager has no weights. As such you can't do something like 31 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 32 | 33 | 34 | """ 35 | with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: 36 | yield f 37 | 38 | 39 | @contextmanager 40 | def init_on_device(device: torch.device, include_buffers: bool = False): 41 | """Device initialization context manager. 42 | 43 | A context manager under which models are initialized with all parameters 44 | on the specified device. 45 | 46 | Args: 47 | device (`torch.device`): Device to initialize all parameters on. 48 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 49 | not to also put all buffers on the meta device while initializing. 50 | 51 | Example: 52 | ```python 53 | import torch.nn as nn 54 | 55 | with init_on_device(device=torch.device("cuda")): 56 | tst = nn.Liner(100, 100) # on `cuda` device 57 | ``` 58 | """ 59 | old_register_parameter = nn.Module.register_parameter 60 | if include_buffers: 61 | old_register_buffer = nn.Module.register_buffer 62 | 63 | def register_empty_parameter(module, name, param): 64 | old_register_parameter(module, name, param) 65 | if param is not None: 66 | param_cls = type(module._parameters[name]) 67 | kwargs = module._parameters[name].__dict__ 68 | module._parameters[name] = param_cls( 69 | module._parameters[name].to(device), **kwargs 70 | ) 71 | 72 | def register_empty_buffer(module, name, buffer): 73 | old_register_buffer(module, name, buffer) 74 | if buffer is not None: 75 | module._buffers[name] = module._buffers[name].to(device) 76 | 77 | if include_buffers: 78 | tensor_constructors_to_patch = { 79 | torch_function_name: getattr(torch, torch_function_name) 80 | for torch_function_name in ["empty", "zeros", "ones", "full"] 81 | } 82 | else: 83 | tensor_constructors_to_patch = {} 84 | 85 | def patch_tensor_constructor(fn): 86 | def wrapper(*args, **kwargs): 87 | kwargs["device"] = device 88 | return fn(*args, **kwargs) 89 | 90 | return wrapper 91 | 92 | try: 93 | nn.Module.register_parameter = register_empty_parameter 94 | if include_buffers: 95 | nn.Module.register_buffer = register_empty_buffer 96 | for torch_function_name in tensor_constructors_to_patch.keys(): 97 | setattr( 98 | torch, 99 | torch_function_name, 100 | patch_tensor_constructor(getattr(torch, torch_function_name)), 101 | ) 102 | yield 103 | finally: 104 | nn.Module.register_parameter = old_register_parameter 105 | if include_buffers: 106 | nn.Module.register_buffer = old_register_buffer 107 | for ( 108 | torch_function_name, 109 | old_torch_function, 110 | ) in tensor_constructors_to_patch.items(): 111 | setattr(torch, torch_function_name, old_torch_function) 112 | -------------------------------------------------------------------------------- /model/llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _cast_if_autocast_enabled(tensor): 5 | if torch.is_autocast_enabled(): 6 | if tensor.device.type == "cuda": 7 | dtype = torch.get_autocast_gpu_dtype() 8 | elif tensor.device.type == "cpu": 9 | dtype = torch.get_autocast_cpu_dtype() 10 | else: 11 | raise NotImplementedError() 12 | return tensor.to(dtype=dtype) 13 | return tensor 14 | 15 | 16 | class LPLayerNorm(torch.nn.LayerNorm): 17 | def __init__( 18 | self, 19 | normalized_shape, 20 | eps=1e-05, 21 | elementwise_affine=True, 22 | device=None, 23 | dtype=None, 24 | ): 25 | super().__init__( 26 | normalized_shape=normalized_shape, 27 | eps=eps, 28 | elementwise_affine=elementwise_affine, 29 | device=device, 30 | dtype=dtype, 31 | ) 32 | 33 | def forward(self, x): 34 | module_device = x.device 35 | downcast_x = _cast_if_autocast_enabled(x) 36 | downcast_weight = ( 37 | _cast_if_autocast_enabled(self.weight) 38 | if self.weight is not None 39 | else self.weight 40 | ) 41 | downcast_bias = ( 42 | _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 43 | ) 44 | with torch.autocast(enabled=False, device_type=module_device.type): 45 | return torch.nn.functional.layer_norm( 46 | downcast_x, 47 | self.normalized_shape, 48 | downcast_weight, 49 | downcast_bias, 50 | self.eps, 51 | ) 52 | 53 | 54 | def rms_norm(x, weight=None, eps=1e-05): 55 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 56 | if weight is not None: 57 | return output * weight 58 | return output 59 | 60 | 61 | class RMSNorm(torch.nn.Module): 62 | def __init__( 63 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None 64 | ): 65 | super().__init__() 66 | self.eps = eps 67 | if weight: 68 | self.weight = torch.nn.Parameter( 69 | torch.ones(normalized_shape, dtype=dtype, device=device) 70 | ) 71 | else: 72 | self.register_parameter("weight", None) 73 | 74 | def forward(self, x): 75 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 76 | 77 | 78 | class LPRMSNorm(RMSNorm): 79 | def __init__( 80 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None 81 | ): 82 | super().__init__( 83 | normalized_shape=normalized_shape, 84 | eps=eps, 85 | weight=weight, 86 | dtype=dtype, 87 | device=device, 88 | ) 89 | 90 | def forward(self, x): 91 | downcast_x = _cast_if_autocast_enabled(x) 92 | downcast_weight = ( 93 | _cast_if_autocast_enabled(self.weight) 94 | if self.weight is not None 95 | else self.weight 96 | ) 97 | with torch.autocast(enabled=False, device_type=x.device.type): 98 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 99 | 100 | 101 | NORM_CLASS_REGISTRY = { 102 | "layernorm": torch.nn.LayerNorm, 103 | "low_precision_layernorm": LPLayerNorm, 104 | "rmsnorm": RMSNorm, 105 | "low_precision_rmsnorm": LPRMSNorm, 106 | } 107 | -------------------------------------------------------------------------------- /model/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava.model.utils import auto_upgrade 9 | from tqdm import tqdm 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 17 | ) 18 | 19 | print("Loading target model") 20 | auto_upgrade(target_model_path) 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | 25 | print("Calculating delta") 26 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 27 | if name not in base.state_dict(): 28 | assert name in [ 29 | "model.mm_projector.weight", 30 | "model.mm_projector.bias", 31 | ], f"{name} not in base model" 32 | continue 33 | if param.data.shape == base.state_dict()[name].shape: 34 | param.data -= base.state_dict()[name] 35 | else: 36 | assert name in [ 37 | "model.embed_tokens.weight", 38 | "lm_head.weight", 39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 40 | bparam = base.state_dict()[name] 41 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 42 | 43 | print("Saving delta") 44 | if hub_repo_id: 45 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 46 | else: 47 | kwargs = {} 48 | target.save_pretrained(delta_path, **kwargs) 49 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 50 | target_tokenizer.save_pretrained(delta_path, **kwargs) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--base-model-path", type=str, required=True) 56 | parser.add_argument("--target-model-path", type=str, required=True) 57 | parser.add_argument("--delta-path", type=str, required=True) 58 | parser.add_argument("--hub-repo-id", type=str, default=None) 59 | args = parser.parse_args() 60 | 61 | make_delta( 62 | args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id 63 | ) 64 | -------------------------------------------------------------------------------- /model/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr( 6 | vision_tower_cfg, 7 | "mm_vision_tower", 8 | getattr(vision_tower_cfg, "vision_tower", None), 9 | ) 10 | if ( 11 | vision_tower.startswith("openai") 12 | or vision_tower.startswith("laion") 13 | or "clip" in vision_tower 14 | ): 15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 16 | 17 | raise ValueError(f"Unknown vision tower: {vision_tower}") 18 | -------------------------------------------------------------------------------- /model/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel 4 | 5 | 6 | class CLIPVisionTower(nn.Module): 7 | def __init__(self, vision_tower, args, delay_load=False): 8 | super().__init__() 9 | 10 | self.is_loaded = False 11 | 12 | self.vision_tower_name = vision_tower 13 | self.select_layer = args.mm_vision_select_layer 14 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 15 | 16 | if not delay_load: 17 | self.load_model() 18 | else: 19 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 20 | 21 | def load_model(self): 22 | self.image_processor = CLIPImageProcessor.from_pretrained( 23 | self.vision_tower_name 24 | ) 25 | self.vision_tower = CLIPVisionModel.from_pretrained( 26 | self.vision_tower_name, low_cpu_mem_usage=True 27 | ) 28 | self.vision_tower.requires_grad_(False) 29 | self.is_loaded = True 30 | 31 | def feature_select(self, image_forward_outs): 32 | image_features = image_forward_outs.hidden_states[self.select_layer] 33 | if self.select_feature == "patch": 34 | image_features = image_features[:, 1:] 35 | elif self.select_feature == "cls_patch": 36 | image_features = image_features 37 | else: 38 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 39 | return image_features 40 | 41 | @torch.no_grad() 42 | def forward(self, images): 43 | if type(images) is list: 44 | image_features = [] 45 | for image in images: 46 | image_forward_out = self.vision_tower( 47 | image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 48 | output_hidden_states=True, 49 | ) 50 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 51 | image_features.append(image_feature) 52 | else: 53 | image_forward_outs = self.vision_tower( 54 | images.to(device=self.device, dtype=self.dtype), 55 | output_hidden_states=True, 56 | ) 57 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 58 | 59 | torch.cuda.empty_cache() 60 | return image_features 61 | 62 | @property 63 | def dummy_feature(self): 64 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 65 | 66 | @property 67 | def dtype(self): 68 | return self.vision_tower.dtype 69 | 70 | @property 71 | def device(self): 72 | return self.vision_tower.device 73 | 74 | @property 75 | def config(self): 76 | if self.is_loaded: 77 | return self.vision_tower.config 78 | else: 79 | return self.cfg_only 80 | 81 | @property 82 | def hidden_size(self): 83 | return self.config.hidden_size 84 | 85 | @property 86 | def num_patches(self): 87 | return (self.config.image_size // self.config.patch_size) ** 2 88 | -------------------------------------------------------------------------------- /model/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print( 9 | "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." 10 | ) 11 | print( 12 | "You must upgrade the checkpoint to the new code base (this can be done automatically)." 13 | ) 14 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 15 | if confirm.lower() in ["y", "yes"]: 16 | print("Upgrading checkpoint...") 17 | assert len(cfg.architectures) == 1 18 | setattr(cfg.__class__, "model_type", "llava") 19 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 20 | cfg.save_pretrained(config) 21 | print("Checkpoint upgraded.") 22 | else: 23 | print("Checkpoint upgrade aborted.") 24 | exit(1) 25 | -------------------------------------------------------------------------------- /model/llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | import transformers 6 | from einops import rearrange 7 | from torch import nn 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | try: 11 | from flash_attn.flash_attn_interface import \ 12 | flash_attn_unpadded_qkvpacked_func 13 | except ImportError: 14 | from flash_attn.flash_attn_interface import ( 15 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func, 16 | ) 17 | 18 | from flash_attn.bert_padding import pad_input, unpad_input 19 | 20 | 21 | def forward( 22 | self, 23 | hidden_states: torch.Tensor, 24 | attention_mask: Optional[torch.Tensor] = None, 25 | position_ids: Optional[torch.Tensor] = None, 26 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 27 | output_attentions: bool = False, 28 | use_cache: bool = False, 29 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 30 | """Input shape: Batch x Time x Channel 31 | 32 | attention_mask: [bsz, q_len] 33 | """ 34 | bsz, q_len, _ = hidden_states.size() 35 | 36 | query_states = ( 37 | self.q_proj(hidden_states) 38 | .view(bsz, q_len, self.num_heads, self.head_dim) 39 | .transpose(1, 2) 40 | ) 41 | key_states = ( 42 | self.k_proj(hidden_states) 43 | .view(bsz, q_len, self.num_heads, self.head_dim) 44 | .transpose(1, 2) 45 | ) 46 | value_states = ( 47 | self.v_proj(hidden_states) 48 | .view(bsz, q_len, self.num_heads, self.head_dim) 49 | .transpose(1, 2) 50 | ) 51 | # [bsz, q_len, nh, hd] 52 | # [bsz, nh, q_len, hd] 53 | 54 | kv_seq_len = key_states.shape[-2] 55 | assert past_key_value is None, "past_key_value is not supported" 56 | 57 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 58 | query_states, key_states = apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | assert not output_attentions, "output_attentions is not supported" 63 | assert not use_cache, "use_cache is not supported" 64 | 65 | # Flash attention codes from 66 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 67 | 68 | # transform the data into the format required by flash attention 69 | qkv = torch.stack( 70 | [query_states, key_states, value_states], dim=2 71 | ) # [bsz, nh, 3, q_len, hd] 72 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 73 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 74 | # the attention_mask should be the same as the key_padding_mask 75 | key_padding_mask = attention_mask 76 | 77 | if key_padding_mask is None: 78 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 79 | max_s = q_len 80 | cu_q_lens = torch.arange( 81 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 82 | ) 83 | output = flash_attn_unpadded_qkvpacked_func( 84 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 85 | ) 86 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 87 | else: 88 | nheads = qkv.shape[-2] 89 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 90 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 91 | x_unpad = rearrange( 92 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 93 | ) 94 | output_unpad = flash_attn_unpadded_qkvpacked_func( 95 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 96 | ) 97 | output = rearrange( 98 | pad_input( 99 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 100 | ), 101 | "b s (h d) -> b s h d", 102 | h=nheads, 103 | ) 104 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 105 | 106 | 107 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 108 | # requires the attention mask to be the same as the key_padding_mask 109 | def _prepare_decoder_attention_mask( 110 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 111 | ): 112 | # [bsz, seq_len] 113 | return attention_mask 114 | 115 | 116 | def replace_llama_attn_with_flash_attn(): 117 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 118 | if cuda_major < 8: 119 | logging.warning( 120 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 121 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 122 | ) 123 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 124 | _prepare_decoder_attention_mask 125 | ) 126 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 127 | -------------------------------------------------------------------------------- /model/llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import torch 5 | from transformers import Trainer 6 | 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | 12 | if hasattr(param, "ds_id"): 13 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 14 | if not ignore_status: 15 | print(name, "no ignore status") 16 | with zero.GatheredParameters([param]): 17 | param = param.data.detach().cpu().clone() 18 | else: 19 | param = param.detach().cpu().clone() 20 | return param 21 | 22 | 23 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 24 | to_return = { 25 | k: t 26 | for k, t in named_params 27 | if any(key_match in k for key_match in keys_to_match) 28 | } 29 | to_return = { 30 | k: maybe_zero_3(v, ignore_status=True, name=k).cpu() 31 | for k, v in to_return.items() 32 | } 33 | return to_return 34 | 35 | 36 | class LLaVATrainer(Trainer): 37 | def _save_checkpoint(self, model, trial, metrics=None): 38 | if getattr(self.args, "tune_mm_mlp_adapter", False): 39 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 40 | 41 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 42 | 43 | run_dir = self._get_output_dir(trial=trial) 44 | output_dir = os.path.join(run_dir, checkpoint_folder) 45 | 46 | # Only save Adapter 47 | keys_to_match = ["mm_projector"] 48 | if getattr(self.args, "use_im_start_end", False): 49 | keys_to_match.extend(["embed_tokens", "embed_in"]) 50 | 51 | weight_to_save = get_mm_adapter_state_maybe_zero_3( 52 | self.model.named_parameters(), keys_to_match 53 | ) 54 | 55 | if self.args.local_rank == 0 or self.args.local_rank == -1: 56 | self.model.config.save_pretrained(output_dir) 57 | torch.save( 58 | weight_to_save, os.path.join(output_dir, f"mm_projector.bin") 59 | ) 60 | else: 61 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 62 | 63 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 64 | if getattr(self.args, "tune_mm_mlp_adapter", False): 65 | pass 66 | else: 67 | super(LLaVATrainer, self)._save(output_dir, state_dict) 68 | -------------------------------------------------------------------------------- /model/llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import \ 7 | replace_llama_attn_with_flash_attn 8 | 9 | replace_llama_attn_with_flash_attn() 10 | 11 | from llava.train.train import train 12 | 13 | if __name__ == "__main__": 14 | train() 15 | -------------------------------------------------------------------------------- /model/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | from llava.constants import LOGDIR 9 | 10 | server_error_msg = ( 11 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | ) 13 | moderation_msg = ( 14 | "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 15 | ) 16 | 17 | handler = None 18 | 19 | 20 | def build_logger(logger_name, logger_filename): 21 | global handler 22 | 23 | formatter = logging.Formatter( 24 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 25 | datefmt="%Y-%m-%d %H:%M:%S", 26 | ) 27 | 28 | # Set the format of root handlers 29 | if not logging.getLogger().handlers: 30 | logging.basicConfig(level=logging.INFO) 31 | logging.getLogger().handlers[0].setFormatter(formatter) 32 | 33 | # Redirect stdout and stderr to loggers 34 | stdout_logger = logging.getLogger("stdout") 35 | stdout_logger.setLevel(logging.INFO) 36 | sl = StreamToLogger(stdout_logger, logging.INFO) 37 | sys.stdout = sl 38 | 39 | stderr_logger = logging.getLogger("stderr") 40 | stderr_logger.setLevel(logging.ERROR) 41 | sl = StreamToLogger(stderr_logger, logging.ERROR) 42 | sys.stderr = sl 43 | 44 | # Get logger 45 | logger = logging.getLogger(logger_name) 46 | logger.setLevel(logging.INFO) 47 | 48 | # Add a file handler for all loggers 49 | if handler is None: 50 | os.makedirs(LOGDIR, exist_ok=True) 51 | filename = os.path.join(LOGDIR, logger_filename) 52 | handler = logging.handlers.TimedRotatingFileHandler( 53 | filename, when="D", utc=True 54 | ) 55 | handler.setFormatter(formatter) 56 | 57 | for name, item in logging.root.manager.loggerDict.items(): 58 | if isinstance(item, logging.Logger): 59 | item.addHandler(handler) 60 | 61 | return logger 62 | 63 | 64 | class StreamToLogger(object): 65 | """ 66 | Fake file-like stream object that redirects writes to a logger instance. 67 | """ 68 | 69 | def __init__(self, logger, log_level=logging.INFO): 70 | self.terminal = sys.stdout 71 | self.logger = logger 72 | self.log_level = log_level 73 | self.linebuf = "" 74 | 75 | def __getattr__(self, attr): 76 | return getattr(self.terminal, attr) 77 | 78 | def write(self, buf): 79 | temp_linebuf = self.linebuf + buf 80 | self.linebuf = "" 81 | for line in temp_linebuf.splitlines(True): 82 | # From the io.TextIOWrapper docs: 83 | # On output, if newline is None, any '\n' characters written 84 | # are translated to the system default line separator. 85 | # By default sys.stdout.write() expects '\n' newlines and then 86 | # translates them so this is still cross platform. 87 | if line[-1] == "\n": 88 | self.logger.log(self.log_level, line.rstrip()) 89 | else: 90 | self.linebuf += line 91 | 92 | def flush(self): 93 | if self.linebuf != "": 94 | self.logger.log(self.log_level, self.linebuf.rstrip()) 95 | self.linebuf = "" 96 | 97 | 98 | def disable_torch_init(): 99 | """ 100 | Disable the redundant torch default initialization to accelerate model creation. 101 | """ 102 | import torch 103 | 104 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 105 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 106 | 107 | 108 | def violates_moderation(text): 109 | """ 110 | Check whether the text violates OpenAI moderation API. 111 | """ 112 | url = "https://api.openai.com/v1/moderations" 113 | headers = { 114 | "Content-Type": "application/json", 115 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], 116 | } 117 | text = text.replace("\n", "") 118 | data = "{" + '"input": ' + f'"{text}"' + "}" 119 | data = data.encode("utf-8") 120 | try: 121 | ret = requests.post(url, headers=headers, data=data, timeout=5) 122 | flagged = ret.json()["results"][0]["flagged"] 123 | except requests.exceptions.RequestException as e: 124 | flagged = False 125 | except KeyError as e: 126 | flagged = False 127 | 128 | return flagged 129 | 130 | 131 | def pretty_print_semaphore(semaphore): 132 | if semaphore is None: 133 | return "None" 134 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 135 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def dice_loss( 5 | inputs: torch.Tensor, 6 | targets: torch.Tensor, 7 | num_masks: float, 8 | scale=1000, # 100000.0, 9 | eps=1e-6, 10 | ): 11 | """ 12 | Compute the DICE loss, similar to generalized IOU for masks 13 | Args: 14 | inputs: A float tensor of arbitrary shape. 15 | The predictions for each example. 16 | targets: A float tensor with the same shape as inputs. Stores the binary 17 | classification label for each element in inputs 18 | (0 for the negative class and 1 for the positive class). 19 | """ 20 | inputs = inputs.sigmoid() 21 | inputs = inputs.flatten(1, 2) 22 | targets = targets.flatten(1, 2) 23 | numerator = 2 * (inputs / scale * targets).sum(-1) 24 | denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) 25 | loss = 1 - (numerator + eps) / (denominator + eps) 26 | loss = loss.sum() / (num_masks + 1e-8) 27 | return loss 28 | 29 | 30 | def sigmoid_ce_loss( 31 | inputs: torch.Tensor, 32 | targets: torch.Tensor, 33 | num_masks: float, 34 | ): 35 | """ 36 | Args: 37 | inputs: A float tensor of arbitrary shape. 38 | The predictions for each example. 39 | targets: A float tensor with the same shape as inputs. Stores the binary 40 | classification label for each element in inputs 41 | (0 for the negative class and 1 for the positive class). 42 | Returns: 43 | Loss tensor 44 | """ 45 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 46 | loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) 47 | return loss 48 | 49 | 50 | def softmax_align_loss(proposal_embeds: torch.tensor, target_embed: torch.tensor, gt_ious: torch.tensor,temperature: float = 0.05): 51 | """ 52 | Align the similarity with the ground truth iou, because iou is not integers, we 53 | actually have soft labels for the similarity. Instead of computing the cross entropy 54 | loss, we can compute the KL divergence between the similarity and the ground truth. 55 | 56 | The loss is based on RegionClip. 57 | https://github.com/microsoft/RegionCLIP/blob/4b8513b56e24827e3d6468e1f2105869f35c2d0b/detectron2/modeling/meta_arch/clip_rcnn.py#L587 58 | 59 | proposal_embeds: (K, D) 60 | target_embed: (1, D) 61 | gt_ious: (K, 1) 62 | """ 63 | 64 | # normalize the proposal_embeds and target_embed 65 | proposal_embeds = proposal_embeds / proposal_embeds.norm(dim=-1, keepdim=True) 66 | target_embed = target_embed / target_embed.norm(dim=-1, keepdim=True) 67 | 68 | # compute similarity scores 69 | sim_scores = proposal_embeds @ target_embed.t() # (K, 1) 70 | sim_scores_temp = sim_scores / temperature 71 | gt_iou_temp = gt_ious / temperature 72 | # normalize to distribution 73 | sim_dis = F.softmax(sim_scores_temp, dim=0) # (K, 1) 74 | gt_dis = F.softmax(gt_iou_temp, dim=0) # (K, 1) 75 | 76 | # for KL divergence, the input should be log-probability, the target is probability (when log_target==False) 77 | # Be Careful: loss domainted by negative samples, current use sum instead of mean 78 | loss = F.kl_div(sim_dis.log(), gt_dis, reduction="sum") 79 | 80 | return loss 81 | 82 | def iou_regression_loss(pred_ious: torch.tensor, gt_ious: torch.tensor, weighted: bool = True): 83 | """ 84 | pred_ious: (K, 1) 85 | gt_ious: (K, 1) 86 | """ 87 | if not weighted: 88 | loss = F.mse_loss(pred_ious, gt_ious, reduction="sum") 89 | else: 90 | loss = F.mse_loss(pred_ious, gt_ious, reduction="none") 91 | weight = torch.exp(gt_ious - 1.0) 92 | loss = loss * weight 93 | loss = loss.mean() * 50.0 # scale the loss as if all sample has 50 proposals 94 | return loss 95 | 96 | 97 | def sigmoid_align_loss(proposal_embeds: torch.tensor, target_embed: torch.tensor, gt_ious: torch.tensor,temperature: torch.tensor = 0.1, bias: torch.tensor = 0.0): 98 | """ 99 | Sigmoid loss for contrastive learning. 100 | From Paper: Sigmoid Loss for Language Image Pre-Training (https://arxiv.org/abs/2303.15343) 101 | 102 | proposal_embeds: (K, D) 103 | target_embed: (1, D) 104 | gt_ious: (K, 1) 105 | 106 | temperature and bias are learnable parameters 107 | """ 108 | 109 | sigmoid_layer = torch.nn.Sigmoid() 110 | 111 | t = torch.exp(temperature) 112 | b = bias 113 | 114 | # normalize the proposal_embeds and target_embed 115 | proposal_embeds = proposal_embeds / proposal_embeds.norm(dim=-1, keepdim=True) 116 | target_embed = target_embed / target_embed.norm(dim=-1, keepdim=True) 117 | 118 | logits = proposal_embeds @ target_embed.t() * t + b # (K, 1) 119 | # logits = sigmoid_layer(logits) 120 | 121 | # (K, 1) range from -1 to 1, we treat iou=0 as pure negative sample, iou=1 as pure positive sample, iou=0.5 as neutral sample 122 | # iou=0.5 as neutral may not be a good idea (maybe too high), but we can try 123 | labels = gt_ious * 2 - 1.0 124 | 125 | loss = -1.0 * torch.log(sigmoid_layer(logits * labels) + 1e-8 ).sum() 126 | 127 | # loss = F.binary_cross_entropy_with_logits(logits, gt_ious, reduction="sum") 128 | 129 | return loss 130 | 131 | 132 | def too_simple_to_believe_align_loss(proposal_embeds: torch.tensor, target_embed: torch.tensor, gt_ious: torch.tensor): 133 | """ 134 | proposal_embeds: (K, D) 135 | target_embed: (1, D) 136 | gt_ious: (K, 1) 137 | """ 138 | 139 | # scale to -1 ~ 1 140 | label = gt_ious * 2.0 - 1.0 141 | 142 | # normalize the proposal_embeds and target_embed 143 | proposal_embeds = proposal_embeds / proposal_embeds.norm(dim=-1, keepdim=True) 144 | target_embed = target_embed / target_embed.norm(dim=-1, keepdim=True) 145 | 146 | # the range of cosine similarity is [-1, 1] 147 | similarity = proposal_embeds @ target_embed.t() # (K, 1) 148 | 149 | loss = F.l1_loss(similarity, label, reduction="sum") 150 | 151 | return loss 152 | 153 | if __name__ == "__main__": 154 | # test contrastive_align_loss 155 | pass 156 | -------------------------------------------------------------------------------- /model/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .automatic_mask_generator import SamAutomaticMaskGenerator 8 | from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, 9 | build_sam_vit_l, sam_model_registry) 10 | from .predictor import SamPredictor 11 | -------------------------------------------------------------------------------- /model/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | 9 | import torch 10 | 11 | from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, 12 | TwoWayTransformer) 13 | 14 | 15 | def build_sam_vit_h(checkpoint=None): 16 | return _build_sam( 17 | encoder_embed_dim=1280, 18 | encoder_depth=32, 19 | encoder_num_heads=16, 20 | encoder_global_attn_indexes=[7, 15, 23, 31], 21 | checkpoint=checkpoint, 22 | ) 23 | 24 | 25 | build_sam = build_sam_vit_h 26 | 27 | 28 | def build_sam_vit_l(checkpoint=None): 29 | return _build_sam( 30 | encoder_embed_dim=1024, 31 | encoder_depth=24, 32 | encoder_num_heads=16, 33 | encoder_global_attn_indexes=[5, 11, 17, 23], 34 | checkpoint=checkpoint, 35 | ) 36 | 37 | 38 | def build_sam_vit_b(checkpoint=None): 39 | return _build_sam( 40 | encoder_embed_dim=768, 41 | encoder_depth=12, 42 | encoder_num_heads=12, 43 | encoder_global_attn_indexes=[2, 5, 8, 11], 44 | checkpoint=checkpoint, 45 | ) 46 | 47 | 48 | sam_model_registry = { 49 | "default": build_sam_vit_h, 50 | "vit_h": build_sam_vit_h, 51 | "vit_l": build_sam_vit_l, 52 | "vit_b": build_sam_vit_b, 53 | } 54 | 55 | 56 | def _build_sam( 57 | encoder_embed_dim, 58 | encoder_depth, 59 | encoder_num_heads, 60 | encoder_global_attn_indexes, 61 | checkpoint=None, 62 | ): 63 | prompt_embed_dim = 256 64 | image_size = 1024 65 | vit_patch_size = 16 66 | image_embedding_size = image_size // vit_patch_size 67 | sam = Sam( 68 | image_encoder=ImageEncoderViT( 69 | depth=encoder_depth, 70 | embed_dim=encoder_embed_dim, 71 | img_size=image_size, 72 | mlp_ratio=4, 73 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 74 | num_heads=encoder_num_heads, 75 | patch_size=vit_patch_size, 76 | qkv_bias=True, 77 | use_rel_pos=True, 78 | global_attn_indexes=encoder_global_attn_indexes, 79 | window_size=14, 80 | out_chans=prompt_embed_dim, 81 | ), 82 | prompt_encoder=PromptEncoder( 83 | embed_dim=prompt_embed_dim, 84 | image_embedding_size=(image_embedding_size, image_embedding_size), 85 | input_image_size=(image_size, image_size), 86 | mask_in_chans=16, 87 | ), 88 | mask_decoder=MaskDecoder( 89 | num_multimask_outputs=3, 90 | transformer=TwoWayTransformer( 91 | depth=2, 92 | embedding_dim=prompt_embed_dim, 93 | mlp_dim=2048, 94 | num_heads=8, 95 | ), 96 | transformer_dim=prompt_embed_dim, 97 | iou_head_depth=3, 98 | iou_head_hidden_dim=256, 99 | ), 100 | pixel_mean=[123.675, 116.28, 103.53], 101 | pixel_std=[58.395, 57.12, 57.375], 102 | ) 103 | sam.eval() 104 | if checkpoint is not None: 105 | with open(checkpoint, "rb") as f: 106 | state_dict = torch.load(f) 107 | sam.load_state_dict(state_dict, strict=False) 108 | return sam 109 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .image_encoder import ImageEncoderViT 8 | from .mask_decoder import MaskDecoder 9 | from .prompt_encoder import PromptEncoder 10 | from .sam import Sam 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d( 55 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 56 | ), 57 | LayerNorm2d(transformer_dim // 4), 58 | activation(), 59 | nn.ConvTranspose2d( 60 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 61 | ), 62 | activation(), 63 | ) 64 | self.output_hypernetworks_mlps = nn.ModuleList( 65 | [ 66 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 67 | for i in range(self.num_mask_tokens) 68 | ] 69 | ) 70 | 71 | self.iou_prediction_head = MLP( 72 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 73 | ) 74 | 75 | def forward( 76 | self, 77 | image_embeddings: torch.Tensor, 78 | image_pe: torch.Tensor, 79 | sparse_prompt_embeddings: torch.Tensor, 80 | dense_prompt_embeddings: torch.Tensor, 81 | multimask_output: bool, 82 | ) -> Tuple[torch.Tensor, torch.Tensor]: 83 | """ 84 | Predict masks given image and prompt embeddings. 85 | 86 | Arguments: 87 | image_embeddings (torch.Tensor): the embeddings from the image encoder 88 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 89 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 90 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 91 | multimask_output (bool): Whether to return multiple masks or a single 92 | mask. 93 | 94 | Returns: 95 | torch.Tensor: batched predicted masks 96 | torch.Tensor: batched predictions of mask quality 97 | """ 98 | masks, iou_pred = self.predict_masks( 99 | image_embeddings=image_embeddings, 100 | image_pe=image_pe, 101 | sparse_prompt_embeddings=sparse_prompt_embeddings, 102 | dense_prompt_embeddings=dense_prompt_embeddings, 103 | ) 104 | 105 | # Select the correct mask or masks for output 106 | if multimask_output: 107 | mask_slice = slice(1, None) 108 | else: 109 | mask_slice = slice(0, 1) 110 | masks = masks[:, mask_slice, :, :] 111 | iou_pred = iou_pred[:, mask_slice] 112 | 113 | # Prepare output 114 | return masks, iou_pred 115 | 116 | def predict_masks( 117 | self, 118 | image_embeddings: torch.Tensor, 119 | image_pe: torch.Tensor, 120 | sparse_prompt_embeddings: torch.Tensor, 121 | dense_prompt_embeddings: torch.Tensor, 122 | ) -> Tuple[torch.Tensor, torch.Tensor]: 123 | """Predicts masks. See 'forward' for more details.""" 124 | # Concatenate output tokens 125 | output_tokens = torch.cat( 126 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 127 | ) 128 | output_tokens = output_tokens.unsqueeze(0).expand( 129 | sparse_prompt_embeddings.size(0), -1, -1 130 | ) 131 | 132 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 133 | 134 | # image_embeddings: [1, C, H, W], tokens: [B, N, C] 135 | # dense_prompt_embeddings: [B, C, H, W] 136 | # Expand per-image data in batch direction to be per-mask 137 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 138 | src = src + dense_prompt_embeddings 139 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 140 | b, c, h, w = src.shape 141 | 142 | # Run the transformer 143 | hs, src = self.transformer(src, pos_src, tokens) 144 | iou_token_out = hs[:, 0, :] 145 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 146 | 147 | # Upscale mask embeddings and predict masks using the mask tokens 148 | src = src.transpose(1, 2).view(b, c, h, w) 149 | upscaled_embedding = self.output_upscaling(src) 150 | hyper_in_list: List[torch.Tensor] = [] 151 | for i in range(self.num_mask_tokens): 152 | hyper_in_list.append( 153 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 154 | ) 155 | hyper_in = torch.stack(hyper_in_list, dim=1) 156 | b, c, h, w = upscaled_embedding.shape 157 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view( 158 | b, self.num_mask_tokens, h, w 159 | ) 160 | 161 | # Generate mask quality predictions 162 | iou_pred = self.iou_prediction_head(iou_token_out) 163 | 164 | return masks, iou_pred 165 | 166 | 167 | # Lightly adapted from 168 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 169 | class MLP(nn.Module): 170 | def __init__( 171 | self, 172 | input_dim: int, 173 | hidden_dim: int, 174 | output_dim: int, 175 | num_layers: int, 176 | sigmoid_output: bool = False, 177 | ) -> None: 178 | super().__init__() 179 | self.num_layers = num_layers 180 | h = [hidden_dim] * (num_layers - 1) 181 | self.layers = nn.ModuleList( 182 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 183 | ) 184 | self.sigmoid_output = sigmoid_output 185 | 186 | def forward(self, x): 187 | for i, layer in enumerate(self.layers): 188 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 189 | if self.sigmoid_output: 190 | x = F.sigmoid(x) 191 | return x 192 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Optional, Tuple, Type 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [ 47 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 48 | ] 49 | self.point_embeddings = nn.ModuleList(point_embeddings) 50 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 51 | 52 | self.mask_input_size = ( 53 | 4 * image_embedding_size[0], 54 | 4 * image_embedding_size[1], 55 | ) 56 | self.mask_downscaling = nn.Sequential( 57 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 58 | LayerNorm2d(mask_in_chans // 4), 59 | activation(), 60 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 61 | LayerNorm2d(mask_in_chans), 62 | activation(), 63 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 64 | ) 65 | self.no_mask_embed = nn.Embedding(1, embed_dim) 66 | 67 | def get_dense_pe(self) -> torch.Tensor: 68 | """ 69 | Returns the positional encoding used to encode point prompts, 70 | applied to a dense set of points the shape of the image encoding. 71 | 72 | Returns: 73 | torch.Tensor: Positional encoding with shape 74 | 1x(embed_dim)x(embedding_h)x(embedding_w) 75 | """ 76 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 77 | 78 | def _embed_points( 79 | self, 80 | points: torch.Tensor, 81 | labels: torch.Tensor, 82 | pad: bool, 83 | ) -> torch.Tensor: 84 | """Embeds point prompts.""" 85 | points = points + 0.5 # Shift to center of pixel 86 | if pad: 87 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 88 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 89 | points = torch.cat([points, padding_point], dim=1) 90 | labels = torch.cat([labels, padding_label], dim=1) 91 | point_embedding = self.pe_layer.forward_with_coords( 92 | points, self.input_image_size 93 | ) 94 | point_embedding[labels == -1] = 0.0 95 | point_embedding[labels == -1] += self.not_a_point_embed.weight 96 | point_embedding[labels == 0] += self.point_embeddings[0].weight 97 | point_embedding[labels == 1] += self.point_embeddings[1].weight 98 | return point_embedding 99 | 100 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 101 | """Embeds box prompts.""" 102 | boxes = boxes + 0.5 # Shift to center of pixel 103 | coords = boxes.reshape(-1, 2, 2) 104 | corner_embedding = self.pe_layer.forward_with_coords( 105 | coords, self.input_image_size 106 | ) 107 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 108 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 109 | return corner_embedding 110 | 111 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 112 | """Embeds mask inputs.""" 113 | mask_embedding = self.mask_downscaling(masks) 114 | return mask_embedding 115 | 116 | def _get_batch_size( 117 | self, 118 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 119 | boxes: Optional[torch.Tensor], 120 | masks: Optional[torch.Tensor], 121 | text_embeds: Optional[torch.Tensor], 122 | ) -> int: 123 | """ 124 | Gets the batch size of the output given the batch size of the input prompts. 125 | """ 126 | if points is not None: 127 | return points[0].shape[0] 128 | elif boxes is not None: 129 | return boxes.shape[0] 130 | elif masks is not None: 131 | return masks.shape[0] 132 | elif text_embeds is not None: 133 | return text_embeds.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | text_embeds: Optional[torch.Tensor], 146 | ) -> Tuple[torch.Tensor, torch.Tensor]: 147 | """ 148 | Embeds different types of prompts, returning both sparse and dense 149 | embeddings. 150 | 151 | Arguments: 152 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 153 | and labels to embed. 154 | boxes (torch.Tensor or none): boxes to embed 155 | masks (torch.Tensor or none): masks to embed 156 | 157 | Returns: 158 | torch.Tensor: sparse embeddings for the points and boxes, with shape 159 | BxNx(embed_dim), where N is determined by the number of input points 160 | and boxes. 161 | torch.Tensor: dense embeddings for the masks, in the shape 162 | Bx(embed_dim)x(embed_H)x(embed_W) 163 | """ 164 | bs = self._get_batch_size(points, boxes, masks, text_embeds) 165 | sparse_embeddings = torch.empty( 166 | (bs, 0, self.embed_dim), device=self._get_device() 167 | ) 168 | if points is not None: 169 | coords, labels = points 170 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 171 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 172 | if boxes is not None: 173 | box_embeddings = self._embed_boxes(boxes) 174 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 175 | 176 | if text_embeds is not None: 177 | sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1) 178 | 179 | if masks is not None: 180 | dense_embeddings = self._embed_masks(masks) 181 | else: 182 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 183 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 184 | ) 185 | 186 | return sparse_embeddings, dense_embeddings 187 | 188 | 189 | class PositionEmbeddingRandom(nn.Module): 190 | """ 191 | Positional encoding using random spatial frequencies. 192 | """ 193 | 194 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 195 | super().__init__() 196 | if scale is None or scale <= 0.0: 197 | scale = 1.0 198 | self.register_buffer( 199 | "positional_encoding_gaussian_matrix", 200 | scale * torch.randn((2, num_pos_feats)), 201 | ) 202 | 203 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 204 | """Positionally encode points that are normalized to [0,1].""" 205 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 206 | coords = 2 * coords - 1 207 | 208 | if coords.dtype != self.positional_encoding_gaussian_matrix.dtype: 209 | coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) 210 | 211 | coords = coords @ self.positional_encoding_gaussian_matrix 212 | coords = 2 * np.pi * coords 213 | # outputs d_1 x ... x d_n x C shape 214 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 215 | 216 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 217 | """Generate positional encoding for a grid of the specified size.""" 218 | h, w = size 219 | device: Any = self.positional_encoding_gaussian_matrix.device 220 | grid = torch.ones( 221 | (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype 222 | ) 223 | y_embed = grid.cumsum(dim=0) - 0.5 224 | x_embed = grid.cumsum(dim=1) - 0.5 225 | y_embed = y_embed / h 226 | x_embed = x_embed / w 227 | 228 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 229 | return pe.permute(2, 0, 1) # C x H x W 230 | 231 | def forward_with_coords( 232 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 233 | ) -> torch.Tensor: 234 | """Positionally encode points that are not normalized to [0,1].""" 235 | coords = coords_input.clone() 236 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 237 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 238 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 239 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Tuple 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer( 47 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False 48 | ) 49 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 50 | 51 | @property 52 | def device(self) -> Any: 53 | return self.pixel_mean.device 54 | 55 | @torch.no_grad() 56 | def forward( 57 | self, 58 | batched_input: List[Dict[str, Any]], 59 | multimask_output: bool, 60 | ) -> List[Dict[str, torch.Tensor]]: 61 | """ 62 | Predicts masks end-to-end from provided images and prompts. 63 | If prompts are not known in advance, using SamPredictor is 64 | recommended over calling the model directly. 65 | 66 | Arguments: 67 | batched_input (list(dict)): A list over input images, each a 68 | dictionary with the following keys. A prompt key can be 69 | excluded if it is not present. 70 | 'image': The image as a torch tensor in 3xHxW format, 71 | already transformed for input to the model. 72 | 'original_size': (tuple(int, int)) The original size of 73 | the image before transformation, as (H, W). 74 | 'point_coords': (torch.Tensor) Batched point prompts for 75 | this image, with shape BxNx2. Already transformed to the 76 | input frame of the model. 77 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 78 | with shape BxN. 79 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 80 | Already transformed to the input frame of the model. 81 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 82 | in the form Bx1xHxW. 83 | multimask_output (bool): Whether the model should predict multiple 84 | disambiguating masks, or return a single mask. 85 | 86 | Returns: 87 | (list(dict)): A list over input images, where each element is 88 | as dictionary with the following keys. 89 | 'masks': (torch.Tensor) Batched binary mask predictions, 90 | with shape BxCxHxW, where B is the number of input prompts, 91 | C is determined by multimask_output, and (H, W) is the 92 | original size of the image. 93 | 'iou_predictions': (torch.Tensor) The model's predictions 94 | of mask quality, in shape BxC. 95 | 'low_res_logits': (torch.Tensor) Low resolution logits with 96 | shape BxCxHxW, where H=W=256. Can be passed as mask input 97 | to subsequent iterations of prediction. 98 | """ 99 | input_images = torch.stack( 100 | [self.preprocess(x["image"]) for x in batched_input], dim=0 101 | ) 102 | image_embeddings = self.image_encoder(input_images) 103 | 104 | outputs = [] 105 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 106 | if "point_coords" in image_record: 107 | points = (image_record["point_coords"], image_record["point_labels"]) 108 | else: 109 | points = None 110 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 111 | points=points, 112 | boxes=image_record.get("boxes", None), 113 | masks=image_record.get("mask_inputs", None), 114 | ) 115 | low_res_masks, iou_predictions = self.mask_decoder( 116 | image_embeddings=curr_embedding.unsqueeze(0), 117 | image_pe=self.prompt_encoder.get_dense_pe(), 118 | sparse_prompt_embeddings=sparse_embeddings, 119 | dense_prompt_embeddings=dense_embeddings, 120 | multimask_output=multimask_output, 121 | ) 122 | masks = self.postprocess_masks( 123 | low_res_masks, 124 | input_size=image_record["image"].shape[-2:], 125 | original_size=image_record["original_size"], 126 | ) 127 | masks = masks > self.mask_threshold 128 | outputs.append( 129 | { 130 | "masks": masks, 131 | "iou_predictions": iou_predictions, 132 | "low_res_logits": low_res_masks, 133 | } 134 | ) 135 | return outputs 136 | 137 | def postprocess_masks( 138 | self, 139 | masks: torch.Tensor, 140 | input_size: Tuple[int, ...], 141 | original_size: Tuple[int, ...], 142 | ) -> torch.Tensor: 143 | """ 144 | Remove padding and upscale masks to the original image size. 145 | 146 | Arguments: 147 | masks (torch.Tensor): Batched masks from the mask_decoder, 148 | in BxCxHxW format. 149 | input_size (tuple(int, int)): The size of the image input to the 150 | model, in (H, W) format. Used to remove padding. 151 | original_size (tuple(int, int)): The original size of the image 152 | before resizing for input to the model, in (H, W) format. 153 | 154 | Returns: 155 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 156 | is given by original_size. 157 | """ 158 | 159 | dtype = masks.dtype 160 | 161 | masks = F.interpolate( 162 | masks.float(), 163 | (self.image_encoder.img_size, self.image_encoder.img_size), 164 | mode="bilinear", 165 | align_corners=False, 166 | ) 167 | # masks = masks.to(dtype) 168 | masks = masks[..., : input_size[0], : input_size[1]] 169 | masks = F.interpolate( 170 | masks, original_size, mode="bilinear", align_corners=False 171 | ) 172 | return masks 173 | 174 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 175 | """Normalize pixel values and pad to a square input.""" 176 | # Normalize colors 177 | x = (x - self.pixel_mean) / self.pixel_std 178 | 179 | # Pad 180 | h, w = x.shape[-2:] 181 | padh = self.image_encoder.img_size - h 182 | padw = self.image_encoder.img_size - w 183 | x = F.pad(x, (0, padw, 0, padh)) 184 | return x 185 | -------------------------------------------------------------------------------- /model/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple, Type 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert ( 202 | self.internal_dim % num_heads == 0 203 | ), "num_heads must divide embedding_dim." 204 | 205 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 207 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 208 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 209 | 210 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 211 | b, n, c = x.shape 212 | x = x.reshape(b, n, num_heads, c // num_heads) 213 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 214 | 215 | def _recombine_heads(self, x: Tensor) -> Tensor: 216 | b, n_heads, n_tokens, c_per_head = x.shape 217 | x = x.transpose(1, 2) 218 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 219 | 220 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 221 | # Input projections 222 | q = self.q_proj(q) 223 | k = self.k_proj(k) 224 | v = self.v_proj(v) 225 | 226 | # Separate into heads 227 | q = self._separate_heads(q, self.num_heads) 228 | k = self._separate_heads(k, self.num_heads) 229 | v = self._separate_heads(v, self.num_heads) 230 | 231 | # Attention 232 | _, _, _, c_per_head = q.shape 233 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 234 | attn = attn / math.sqrt(c_per_head) 235 | attn = torch.softmax(attn, dim=-1) 236 | 237 | # Get output 238 | out = attn @ v 239 | out = self._recombine_heads(out) 240 | out = self.out_proj(out) 241 | 242 | return out 243 | -------------------------------------------------------------------------------- /model/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /model/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points( 52 | self, point_coords: torch.Tensor, point_labels: torch.Tensor 53 | ) -> torch.Tensor: 54 | point_coords = point_coords + 0.5 55 | point_coords = point_coords / self.img_size 56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 58 | 59 | point_embedding = point_embedding * (point_labels != -1) 60 | point_embedding = ( 61 | point_embedding 62 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) 63 | ) 64 | 65 | for i in range(self.model.prompt_encoder.num_point_embeddings): 66 | point_embedding = ( 67 | point_embedding 68 | + self.model.prompt_encoder.point_embeddings[i].weight 69 | * (point_labels == i) 70 | ) 71 | 72 | return point_embedding 73 | 74 | def _embed_masks( 75 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor 76 | ) -> torch.Tensor: 77 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( 78 | input_mask 79 | ) 80 | mask_embedding = mask_embedding + ( 81 | 1 - has_mask_input 82 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 83 | return mask_embedding 84 | 85 | def mask_postprocessing( 86 | self, masks: torch.Tensor, orig_im_size: torch.Tensor 87 | ) -> torch.Tensor: 88 | masks = F.interpolate( 89 | masks, 90 | size=(self.img_size, self.img_size), 91 | mode="bilinear", 92 | align_corners=False, 93 | ) 94 | 95 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( 96 | torch.int64 97 | ) 98 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 99 | 100 | orig_im_size = orig_im_size.to(torch.int64) 101 | h, w = orig_im_size[0], orig_im_size[1] 102 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 103 | return masks 104 | 105 | def select_masks( 106 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 107 | ) -> Tuple[torch.Tensor, torch.Tensor]: 108 | # Determine if we should return the multiclick mask or not from the number of points. 109 | # The reweighting is used to avoid control flow. 110 | score_reweight = torch.tensor( 111 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 112 | ).to(iou_preds.device) 113 | score = iou_preds + (num_points - 2.5) * score_reweight 114 | best_idx = torch.argmax(score, dim=1) 115 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 116 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 117 | 118 | return masks, iou_preds 119 | 120 | @torch.no_grad() 121 | def forward( 122 | self, 123 | image_embeddings: torch.Tensor, 124 | point_coords: torch.Tensor, 125 | point_labels: torch.Tensor, 126 | mask_input: torch.Tensor, 127 | has_mask_input: torch.Tensor, 128 | orig_im_size: torch.Tensor, 129 | ): 130 | sparse_embedding = self._embed_points(point_coords, point_labels) 131 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 132 | 133 | masks, scores = self.model.mask_decoder.predict_masks( 134 | image_embeddings=image_embeddings, 135 | image_pe=self.model.prompt_encoder.get_dense_pe(), 136 | sparse_prompt_embeddings=sparse_embedding, 137 | dense_prompt_embeddings=dense_embedding, 138 | ) 139 | 140 | if self.use_stability_score: 141 | scores = calculate_stability_score( 142 | masks, self.model.mask_threshold, self.stability_score_offset 143 | ) 144 | 145 | if self.return_single_mask: 146 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 147 | 148 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 149 | 150 | if self.return_extra_metrics: 151 | stability_scores = calculate_stability_score( 152 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 153 | ) 154 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 155 | return upscaled_masks, scores, stability_scores, areas, masks 156 | 157 | return upscaled_masks, scores, masks 158 | -------------------------------------------------------------------------------- /model/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from copy import deepcopy 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import functional as F 13 | from torchvision.transforms.functional import resize # type: ignore 14 | from torchvision.transforms.functional import to_pil_image 15 | 16 | 17 | class ResizeLongestSide: 18 | """ 19 | Resizes images to the longest side 'target_length', as well as provides 20 | methods for resizing coordinates and boxes. Provides methods for 21 | transforming both numpy array and batched torch tensors. 22 | """ 23 | 24 | def __init__(self, target_length: int) -> None: 25 | self.target_length = target_length 26 | 27 | def apply_image(self, image: np.ndarray) -> np.ndarray: 28 | """ 29 | Expects a numpy array with shape HxWxC in uint8 format. 30 | """ 31 | target_size = self.get_preprocess_shape( 32 | image.shape[0], image.shape[1], self.target_length 33 | ) 34 | return np.array(resize(to_pil_image(image), target_size)) 35 | 36 | def apply_coords( 37 | self, coords: np.ndarray, original_size: Tuple[int, ...] 38 | ) -> np.ndarray: 39 | """ 40 | Expects a numpy array of length 2 in the final dimension. Requires the 41 | original image size in (H, W) format. 42 | """ 43 | old_h, old_w = original_size 44 | new_h, new_w = self.get_preprocess_shape( 45 | original_size[0], original_size[1], self.target_length 46 | ) 47 | coords = deepcopy(coords).astype(float) 48 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 49 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 50 | return coords 51 | 52 | def apply_boxes( 53 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 54 | ) -> np.ndarray: 55 | """ 56 | Expects a numpy array shape Bx4. Requires the original image size 57 | in (H, W) format. 58 | """ 59 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 60 | return boxes.reshape(-1, 4) 61 | 62 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Expects batched images with shape BxCxHxW and float format. This 65 | transformation may not exactly match apply_image. apply_image is 66 | the transformation expected by the model. 67 | """ 68 | # Expects an image in BCHW format. May not exactly match apply_image. 69 | target_size = self.get_preprocess_shape( 70 | image.shape[0], image.shape[1], self.target_length 71 | ) 72 | return F.interpolate( 73 | image, target_size, mode="bilinear", align_corners=False, antialias=True 74 | ) 75 | 76 | def apply_coords_torch( 77 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 78 | ) -> torch.Tensor: 79 | """ 80 | Expects a torch tensor with length 2 in the last dimension. Requires the 81 | original image size in (H, W) format. 82 | """ 83 | old_h, old_w = original_size 84 | new_h, new_w = self.get_preprocess_shape( 85 | original_size[0], original_size[1], self.target_length 86 | ) 87 | coords = deepcopy(coords).to(torch.float) 88 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 89 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 90 | return coords 91 | 92 | def apply_boxes_torch( 93 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 94 | ) -> torch.Tensor: 95 | """ 96 | Expects a torch tensor with shape Bx4. Requires the original image 97 | size in (H, W) format. 98 | """ 99 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 100 | return boxes.reshape(-1, 4) 101 | 102 | @staticmethod 103 | def get_preprocess_shape( 104 | oldh: int, oldw: int, long_side_length: int 105 | ) -> Tuple[int, int]: 106 | """ 107 | Compute the output size given input size and target long side length. 108 | """ 109 | scale = long_side_length * 1.0 / max(oldh, oldw) 110 | newh, neww = oldh * scale, oldw * scale 111 | neww = int(neww + 0.5) 112 | newh = int(newh + 0.5) 113 | return (newh, neww) 114 | -------------------------------------------------------------------------------- /prepare_datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjunchi/LLMSeg/65147a98cc49c2dac9b7a3633d9d3d336f1ac023/prepare_datasets/__init__.py -------------------------------------------------------------------------------- /prepare_datasets/convert_h5_to_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import h5py 4 | 5 | 6 | def convert_h5_to_json_standard(h5_path, json_save_path): 7 | # read h5 file 8 | h5_file = h5py.File(h5_path, "r") 9 | # get keys 10 | keys = list(h5_file.keys()) 11 | assert 'masks' in keys, "masks not in keys" 12 | 13 | data = h5_file['masks'] 14 | 15 | print("parsing h5 file...") 16 | dataset = [eval(data[i].decode('utf-8')) for i in range(len(data))] 17 | 18 | # convert dataset to json compatitible 19 | for sample in dataset: 20 | masks = sample['masks'] 21 | for mask in masks: 22 | seg = mask['segmentation'] 23 | seg['counts'] = seg['counts'].decode() 24 | 25 | # save to json 26 | print("saving to json...") 27 | with open(json_save_path, 'w') as f: 28 | json.dump(dataset, f) 29 | 30 | print("saved to {}".format(json_save_path)) 31 | 32 | def convert_h5_to_json_reasonSeg(base_dir, json_save_path): 33 | splits = ['train', 'val'] 34 | for split in splits: 35 | h5_path = os.path.join(base_dir, "ReasonSeg", split, "masks.h5") 36 | json_save_path = os.path.join(base_dir, "ReasonSeg", split, "masks.json") 37 | convert_h5_to_json_standard(h5_path, json_save_path) 38 | 39 | def convert_h5_to_json_coco(h5_dir, json_save_path): 40 | # coco have 8 splits 41 | dataset = [] 42 | for i in range(8): 43 | # if i == 2: 44 | # break 45 | h5_path = os.path.join(h5_dir, "coco_split{}.h5".format(i)) 46 | # read h5 file 47 | h5_file = h5py.File(h5_path, "r") 48 | # get keys 49 | keys = list(h5_file.keys()) 50 | assert 'masks' in keys, "masks not in keys" 51 | data = h5_file['masks'] 52 | 53 | print("Parsing split {}...".format(i)) 54 | dataset_split = [eval(data[i].decode('utf-8')) for i in range(len(data))] 55 | dataset.extend(dataset_split) 56 | 57 | 58 | # convert dataset to json compatitible 59 | for sample in dataset: 60 | masks = sample['masks'] 61 | for mask in masks: 62 | seg = mask['segmentation'] 63 | seg['counts'] = seg['counts'].decode() 64 | 65 | # save to json 66 | print("saving to json...") 67 | with open(json_save_path, 'w') as f: 68 | json.dump(dataset, f) 69 | 70 | 71 | 72 | 73 | def main(): 74 | base_dir = "/cluster/home/leikel/junchi/processed_data/" 75 | 76 | dataset_names = ["ade20k", "coco2014", "coco2017", "reason_seg", "saiapr", "voc2010"] 77 | # dataset_names = ["reason_seg"] 78 | # dataset_names = ["ade20k"] 79 | # dataset_names = ["coco2014", "coco2017"] 80 | 81 | for dataset_name in dataset_names: 82 | print("processing {}".format(dataset_name)) 83 | dataset_dir = os.path.join(base_dir, dataset_name) 84 | h5_path = os.path.join(dataset_dir, "masks.h5") 85 | json_save_path = os.path.join(dataset_dir, "masks.json") 86 | if 'coco' in dataset_name: 87 | convert_h5_to_json_coco(dataset_dir, json_save_path) 88 | elif dataset_name == "reason_seg": 89 | convert_h5_to_json_reasonSeg(dataset_dir, json_save_path) 90 | else: 91 | convert_h5_to_json_standard(h5_path, json_save_path) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | 97 | -------------------------------------------------------------------------------- /prepare_datasets/generate_index_reasonseg.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import pickle 4 | 5 | from typing import List, Dict 6 | 7 | 8 | def read_mask_list(path: str): 9 | h5file = h5py.File(path, "r") 10 | dataset = [h5file["masks"][i].decode("utf-8") for i in range(len(h5file["masks"]))] 11 | dataset_restored = [eval(data) for data in dataset] 12 | 13 | return dataset_restored 14 | 15 | def build_sam_mask_dict(mask_list: List[Dict]): 16 | sam_mask_dict = {} 17 | for i, sample in enumerate(mask_list): 18 | sample_name = sample["image"] 19 | sam_mask_dict[sample_name] = i 20 | 21 | return sam_mask_dict 22 | 23 | def main(): 24 | mask_dir = "/cluster/home/leikel/junchi/processed_data/reason_seg/ReasonSeg/" 25 | split = ["train", "val"] 26 | 27 | for s in split: 28 | mask_file = os.path.join(mask_dir, s, "masks.h5") 29 | mask_list = read_mask_list(mask_file) 30 | sam_mask_dict = build_sam_mask_dict(mask_list) 31 | 32 | print("len(sam_mask_dict): ", len(sam_mask_dict)) 33 | 34 | # save sam_mask_dict as pickle file 35 | with open(os.path.join(mask_dir, s, "sam_mask_index_dict.pkl"), "wb") as f: 36 | pickle.dump(sam_mask_dict, f) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_ReasonSeg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from typing import List 5 | import json 6 | import h5py 7 | from pycocotools import mask as mask_utils 8 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 9 | 10 | 11 | dataset_root = "/cluster/scratch/leikel/junchi/lisa_dataset/reason_seg/ReasonSeg/" 12 | output_root = "/cluster/home/leikel/junchi/processed_data/reason_seg/ReasonSeg/" 13 | sam_checkpoint = "/cluster/home/leikel/junchi/segment-anything/checkpoints/sam_vit_h_4b8939.pth" 14 | 15 | 16 | # test set not available for ReasonSeg 17 | available_dataset_type = ["train", "val"] 18 | 19 | def get_all_samples(dataset_path) -> List[str]: 20 | 21 | all_samples = [] 22 | files = os.listdir(dataset_path) 23 | 24 | for file in files: 25 | if file.endswith(".jpg"): 26 | name = file.split('.')[0] 27 | all_samples.append(name) 28 | 29 | return all_samples 30 | 31 | 32 | def preprocess_images(image: np.ndarray) -> np.ndarray: 33 | # scale the large side to 1024 34 | H, W, _ = image.shape 35 | if max(W, H) > 1024: 36 | # scale 37 | scale_factor = 1024.0 / max(W, H) 38 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 39 | 40 | return image 41 | 42 | 43 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 44 | # check if cuda is available 45 | device = "cuda" 46 | 47 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 48 | sam.to(device=device) 49 | 50 | mask_generator = SamAutomaticMaskGenerator(sam) 51 | 52 | return mask_generator 53 | 54 | 55 | def process_dataset(dataset_type: str = "train") -> None: 56 | assert dataset_type in available_dataset_type, "dataset_type must be one of {}".format(available_dataset_type) 57 | print("Processing dataset {}".format(dataset_type)) 58 | 59 | input_path = os.path.join(dataset_root, dataset_type) 60 | output_path = os.path.join(output_root, dataset_type) 61 | 62 | if not os.path.exists(output_path): 63 | os.makedirs(output_path, exist_ok=True) 64 | 65 | sample_list = get_all_samples(os.path.join(dataset_root, dataset_type)) 66 | 67 | mask_generator = init_SAM_everything(model_type="vit_h", sam_checkpoint=sam_checkpoint) 68 | 69 | # create a json file to store the masks 70 | masks_all_samples = [] 71 | 72 | for idx, sample in enumerate(sample_list): 73 | img_file = sample + ".jpg" 74 | 75 | print("Processing sample {} / {}".format(idx, len(sample_list))) 76 | 77 | image_path = os.path.join(input_path, img_file) 78 | assert os.path.exists(image_path), "Image file {} does not exist".format(image_path) 79 | 80 | image = cv2.imread(os.path.join(input_path, img_file)) 81 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 82 | 83 | image = preprocess_images(image) 84 | 85 | masks = mask_generator.generate(image) 86 | 87 | sample_dict = {} 88 | 89 | sample_dict["image"] = img_file 90 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 91 | 92 | # convert masks to coco format 93 | masks_coco = [] 94 | for mask in masks: 95 | binary_mask = mask['segmentation'] 96 | mask_rle = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 97 | mask['segmentation'] = mask_rle 98 | masks_coco.append(mask) 99 | 100 | sample_dict["masks"] = masks_coco 101 | 102 | masks_all_samples.append(sample_dict) 103 | 104 | # for large dataset, it is more reasonable to save it as the h5 file 105 | # save masks as the h5 file 106 | h5_save_path = os.path.join(output_path, dataset_type + "_masks.h5") 107 | 108 | # convert dict to string 109 | # https://stackoverflow.com/questions/16494669/how-to-store-dictionary-in-hdf5-dataset 110 | masks_all_samples_str = [] 111 | for sample in masks_all_samples: 112 | masks_all_samples_str.append(str(sample)) 113 | 114 | with h5py.File(h5_save_path, 'w') as f: 115 | f.create_dataset('masks', data=masks_all_samples_str) 116 | 117 | 118 | 119 | def main(): 120 | for dataset_type in available_dataset_type: 121 | process_dataset(dataset_type=dataset_type) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import List 4 | import cv2 5 | import h5py 6 | import argparse 7 | from pycocotools import mask as mask_utils 8 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 9 | 10 | 11 | # accept split number as argument 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input_path", type=str, default="/cluster/scratch/leikel/junchi/lisa_dataset/ade20k/images/training", 14 | help="path to coco dataset") 15 | parser.add_argument("--output_path", type=str, default="/cluster/home/leikel/junchi/processed_data/ade20k", 16 | help="path to split directory") 17 | parser.add_argument("--sam_checkpoint", type=str, default="/cluster/home/leikel/junchi/segment-anything/checkpoints/sam_vit_h_4b8939.pth", 18 | help="path to sam checkpoint") 19 | 20 | def get_all_samples(dataset_path) -> List[str]: 21 | 22 | all_samples = [] 23 | files = os.listdir(dataset_path) 24 | 25 | for file in files: 26 | if file.endswith(".jpg"): 27 | name = file.split('.')[0] 28 | all_samples.append(name) 29 | 30 | return all_samples 31 | 32 | 33 | def preprocess_images(image: np.ndarray) -> np.ndarray: 34 | # scale the large side to 1024 35 | H, W, _ = image.shape 36 | if max(W, H) > 1024: 37 | # scale 38 | scale_factor = 1024.0 / max(W, H) 39 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 40 | 41 | return image 42 | 43 | 44 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 45 | # check if cuda is available 46 | device = "cuda" 47 | 48 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 49 | sam.to(device=device) 50 | 51 | mask_generator = SamAutomaticMaskGenerator(sam) 52 | 53 | return mask_generator 54 | 55 | 56 | def read_split_file(split_file: str) -> List[str]: 57 | with open(split_file, "r") as f: 58 | lines = f.readlines() 59 | images = [line.strip() for line in lines] 60 | # check if it is a valid image fil 61 | for image in images: 62 | assert image.endswith(".jpg"), "Invalid image file: {}".format(image) 63 | 64 | return images 65 | 66 | 67 | def process_dataset(input_path: str, output_path: str, sam_checkpoint: str) -> None: 68 | 69 | if not os.path.exists(output_path): 70 | os.makedirs(output_path, exist_ok=True) 71 | 72 | sample_list = get_all_samples(input_path) 73 | 74 | mask_generator = init_SAM_everything(model_type="vit_h", sam_checkpoint=sam_checkpoint) 75 | 76 | # create a json file to store the masks 77 | masks_all_samples = [] 78 | 79 | for idx, sample in enumerate(sample_list): 80 | img_file = sample + ".jpg" 81 | 82 | if idx % 10 == 0: 83 | print("Processing sample {} / {}".format(idx, len(sample_list))) 84 | 85 | image_path = os.path.join(input_path, img_file) 86 | assert os.path.exists(image_path), "Image file {} does not exist".format(image_path) 87 | 88 | image = cv2.imread(os.path.join(input_path, img_file)) 89 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 90 | 91 | image = preprocess_images(image) 92 | 93 | masks = mask_generator.generate(image) 94 | 95 | sample_dict = {} 96 | 97 | sample_dict["image"] = img_file 98 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 99 | 100 | # convert masks to coco format 101 | masks_coco = [] 102 | for mask in masks: 103 | binary_mask = mask['segmentation'] 104 | mask_rle = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 105 | mask['segmentation'] = mask_rle 106 | masks_coco.append(mask) 107 | 108 | sample_dict["masks"] = masks_coco 109 | 110 | masks_all_samples.append(sample_dict) 111 | 112 | # # debug 113 | # if idx == 10: 114 | # break 115 | 116 | # for large dataset, it is more reasonable to save it as the h5 file 117 | # save masks as the h5 file 118 | h5_save_path = os.path.join(output_path, "masks.h5") 119 | 120 | # convert dict to string 121 | # https://stackoverflow.com/questions/16494669/how-to-store-dictionary-in-hdf5-dataset 122 | masks_all_samples_str = [] 123 | for sample in masks_all_samples: 124 | masks_all_samples_str.append(str(sample)) 125 | 126 | with h5py.File(h5_save_path, 'w') as f: 127 | f.create_dataset('masks', data=masks_all_samples_str) 128 | 129 | 130 | 131 | def main(): 132 | args = parser.parse_args() 133 | input_path = args.input_path 134 | output_path = args.output_path 135 | sam_checkpoint = args.sam_checkpoint 136 | 137 | process_dataset(input_path=input_path, output_path=output_path, sam_checkpoint=sam_checkpoint) 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_coco.py: -------------------------------------------------------------------------------- 1 | ''' 2 | COCO training set: 118,287 images, if we naively use 1 GPU to process 1 image, it will take almost a week to process the whole dataset. 3 | 4 | A more efficient and simple way is to split the dataset into 8 parts, and we can submit 8 jobs to the cluster to process the dataset in parallel. 5 | Ideally, it will take 1 day to process the whole dataset. 6 | 7 | ''' 8 | 9 | import os 10 | import numpy as np 11 | from typing import List 12 | import cv2 13 | import h5py 14 | import argparse 15 | from pycocotools import mask as mask_utils 16 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 17 | 18 | 19 | # accept split number as argument 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--split", type=int, required=True, help="split number") 22 | parser.add_argument("--dataset_path", type=str, required=True, help="path to coco dataset") 23 | parser.add_argument("--split_dir", type=str, required=True, help="path to split directory") 24 | parser.add_argument("--sam_checkpoint", type=str, default="/cluster/home/leikel/junchi/segment-anything/checkpoints/sam_vit_h_4b8939.pth", 25 | help="path to sam checkpoint") 26 | 27 | def preprocess_images(image: np.ndarray) -> np.ndarray: 28 | # scale the large side to 1024 29 | H, W, _ = image.shape 30 | if max(W, H) > 1024: 31 | # scale 32 | scale_factor = 1024.0 / max(W, H) 33 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 34 | 35 | return image 36 | 37 | 38 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 39 | # check if cuda is available 40 | device = "cuda" 41 | 42 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 43 | sam.to(device=device) 44 | 45 | mask_generator = SamAutomaticMaskGenerator(sam) 46 | 47 | return mask_generator 48 | 49 | 50 | def read_split_file(split_file: str) -> List[str]: 51 | with open(split_file, "r") as f: 52 | lines = f.readlines() 53 | images = [line.strip() for line in lines] 54 | # check if it is a valid image fil 55 | for image in images: 56 | assert image.endswith(".jpg"), "Invalid image file: {}".format(image) 57 | 58 | return images 59 | 60 | 61 | def process_split(split_file: str, dataset_path: str, sam_checkpoint: str, args) -> None: 62 | # read split file 63 | images = read_split_file(split_file) 64 | 65 | split_num = int(split_file.split("_")[-1].split(".")[0]) 66 | print("Processing split {} with {} images".format(split_num, len(images))) 67 | 68 | # init SAM 69 | mask_generator = init_SAM_everything("vit_h", sam_checkpoint) 70 | 71 | masks_all_samples = [] 72 | 73 | # process images 74 | for idx, image_file in enumerate(images): 75 | # read image 76 | image_path = os.path.join(dataset_path, image_file) 77 | image = cv2.imread(image_path) 78 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 79 | image = preprocess_images(image) 80 | 81 | # generate mask 82 | masks = mask_generator.generate(image) 83 | 84 | sample_dict = {} 85 | 86 | sample_dict["image"] = image_file 87 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 88 | 89 | # convert masks to coco format 90 | masks_coco = [] 91 | for mask in masks: 92 | binary_mask = mask['segmentation'] 93 | mask_rel = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 94 | mask['segmentation'] = mask_rel 95 | masks_coco.append(mask) 96 | 97 | sample_dict["masks"] = masks_coco 98 | 99 | masks_all_samples.append(sample_dict) 100 | 101 | 102 | if idx % 10 == 0: 103 | print("Processing image {}/{}".format(idx, len(images))) 104 | 105 | # # debug 106 | # if idx == 10: 107 | # break 108 | 109 | h5_save_path = os.path.join(args.split_dir, "coco_split{}.h5".format(split_num)) 110 | 111 | # convert dict to string 112 | # https://stackoverflow.com/questions/16494669/how-to-store-dictionary-in-hdf5-dataset 113 | masks_all_samples_str = [] 114 | for sample in masks_all_samples: 115 | masks_all_samples_str.append(str(sample)) 116 | 117 | with h5py.File(h5_save_path, 'w') as f: 118 | f.create_dataset('masks', data=masks_all_samples_str) 119 | 120 | 121 | def main(): 122 | args = parser.parse_args() 123 | split_num = args.split 124 | dataset_path = args.dataset_path 125 | split_dir = args.split_dir 126 | sam_checkpoint = args.sam_checkpoint 127 | 128 | assert split_num >= 0 and split_num < 8, "Invalid split number: {}".format(split_num) 129 | 130 | split_file = os.path.join(split_dir, "part_{}.txt".format(split_num)) 131 | process_split(split_file, dataset_path, sam_checkpoint, args) 132 | 133 | print("Done.") 134 | 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_egoobjects.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from typing import List 5 | import json 6 | import h5py 7 | from pycocotools import mask as mask_utils 8 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 9 | from tqdm import tqdm 10 | 11 | dataset_root = "/home/leikel/junchi/lisa_dataset/ego_objects" 12 | output_root = "/home/leikel/junchi/lisa_dataset/ego_objects" 13 | sam_checkpoint = "/home/leikel/junchi/pretrained_weights/SAM/sam_vit_h_4b8939.pth" 14 | 15 | 16 | # test set not available for ReasonSeg 17 | available_dataset_type = ["train", "validation", "test"] 18 | 19 | def get_all_samples() -> List[str]: 20 | 21 | json_path = "/home/leikel/junchi/ReasonCOCO/post_processing/split" 22 | 23 | all_samples = [] 24 | 25 | for dataset_type in available_dataset_type: 26 | with open(os.path.join(json_path, dataset_type + ".json"), "r") as f: 27 | data = json.load(f) 28 | images = data.keys() 29 | for image in images: 30 | sample = data[image] 31 | if sample['from_dataset'] == "ego_objects": 32 | all_samples.append(image) 33 | 34 | return all_samples 35 | 36 | 37 | def preprocess_images(image: np.ndarray) -> np.ndarray: 38 | # scale the large side to 1024 39 | H, W, _ = image.shape 40 | if max(W, H) > 1024: 41 | # scale 42 | scale_factor = 1024.0 / max(W, H) 43 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 44 | 45 | return image 46 | 47 | 48 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 49 | # check if cuda is available 50 | device = "cuda" 51 | 52 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 53 | sam.to(device=device) 54 | 55 | mask_generator = SamAutomaticMaskGenerator(sam) 56 | 57 | return mask_generator 58 | 59 | 60 | def process_dataset() -> None: 61 | 62 | 63 | input_path = os.path.join(dataset_root, "images") 64 | output_path = output_root 65 | 66 | if not os.path.exists(output_path): 67 | os.makedirs(output_path, exist_ok=True) 68 | 69 | sample_list = get_all_samples() 70 | 71 | print("we have {} samples".format(len(sample_list))) 72 | 73 | mask_generator = init_SAM_everything(model_type="vit_h", sam_checkpoint=sam_checkpoint) 74 | 75 | # create a json file to store the masks 76 | masks_all_samples = [] 77 | 78 | for sample in tqdm(sample_list[:]): 79 | img_file = sample 80 | 81 | # print("Processing sample {} / {}".format(idx, len(sample_list))) 82 | 83 | image_path = os.path.join(input_path, img_file) 84 | assert os.path.exists(image_path), "Image file {} does not exist".format(image_path) 85 | 86 | image = cv2.imread(os.path.join(input_path, img_file)) 87 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 88 | 89 | image = preprocess_images(image) 90 | 91 | masks = mask_generator.generate(image) 92 | 93 | sample_dict = {} 94 | 95 | sample_dict["image"] = img_file 96 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 97 | 98 | # convert masks to coco format 99 | masks_coco = [] 100 | for mask in masks: 101 | binary_mask = mask['segmentation'] 102 | mask_rle = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 103 | mask_rle['counts'] = mask_rle['counts'].decode() 104 | mask['segmentation'] = mask_rle 105 | masks_coco.append(mask) 106 | 107 | sample_dict["masks"] = masks_coco 108 | 109 | masks_all_samples.append(sample_dict) 110 | 111 | # save to json 112 | json_save_path = os.path.join(output_path, "masks.json") 113 | 114 | with open(json_save_path, 'w') as f: 115 | json.dump(masks_all_samples, f) 116 | 117 | 118 | def main(): 119 | process_dataset() 120 | 121 | 122 | if __name__ == "__main__": 123 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_mapillary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import List 4 | import cv2 5 | import h5py 6 | import json 7 | import argparse 8 | from pycocotools import mask as mask_utils 9 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 10 | 11 | 12 | # accept split number as argument 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--input_path", type=str, default="/cluster/scratch/leikel/junchi/lisa_dataset/mapillary/training/images", 15 | help="path to coco dataset") 16 | parser.add_argument("--output_path", type=str, default="/cluster/home/leikel/junchi/processed_data/mapillary", 17 | help="path to split directory") 18 | parser.add_argument("--sam_checkpoint", type=str, default="/cluster/home/leikel/junchi/segment-anything/checkpoints/sam_vit_h_4b8939.pth", 19 | help="path to sam checkpoint") 20 | 21 | def get_all_samples(dataset_path) -> List[str]: 22 | 23 | all_samples = [] 24 | files = os.listdir(dataset_path) 25 | 26 | for file in files: 27 | if file.endswith(".jpg"): 28 | name = file.split('.')[0] 29 | all_samples.append(name) 30 | 31 | return all_samples 32 | 33 | 34 | def preprocess_images(image: np.ndarray) -> np.ndarray: 35 | # scale the large side to 1024 36 | H, W, _ = image.shape 37 | if max(W, H) > 1024: 38 | # scale 39 | scale_factor = 1024.0 / max(W, H) 40 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 41 | 42 | return image 43 | 44 | 45 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 46 | # check if cuda is available 47 | device = "cuda" 48 | 49 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 50 | sam.to(device=device) 51 | 52 | mask_generator = SamAutomaticMaskGenerator(sam) 53 | 54 | return mask_generator 55 | 56 | 57 | def read_split_file(split_file: str) -> List[str]: 58 | with open(split_file, "r") as f: 59 | lines = f.readlines() 60 | images = [line.strip() for line in lines] 61 | # check if it is a valid image fil 62 | for image in images: 63 | assert image.endswith(".jpg"), "Invalid image file: {}".format(image) 64 | 65 | return images 66 | 67 | 68 | def process_dataset(input_path: str, output_path: str, sam_checkpoint: str) -> None: 69 | 70 | if not os.path.exists(output_path): 71 | os.makedirs(output_path, exist_ok=True) 72 | 73 | sample_list = get_all_samples(input_path) 74 | 75 | mask_generator = init_SAM_everything(model_type="vit_h", sam_checkpoint=sam_checkpoint) 76 | 77 | # create a json file to store the masks 78 | masks_all_samples = [] 79 | 80 | for idx, sample in enumerate(sample_list): 81 | img_file = sample + ".jpg" 82 | 83 | if idx % 10 == 0: 84 | print("Processing sample {} / {}".format(idx, len(sample_list))) 85 | 86 | image_path = os.path.join(input_path, img_file) 87 | assert os.path.exists(image_path), "Image file {} does not exist".format(image_path) 88 | 89 | image = cv2.imread(os.path.join(input_path, img_file)) 90 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 91 | 92 | image = preprocess_images(image) 93 | 94 | masks = mask_generator.generate(image) 95 | 96 | sample_dict = {} 97 | 98 | sample_dict["image"] = img_file 99 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 100 | 101 | # convert masks to coco format 102 | masks_coco = [] 103 | for mask in masks: 104 | binary_mask = mask['segmentation'] 105 | mask_rle = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 106 | mask_rle['counts'] = mask_rle['counts'].decode() 107 | mask['segmentation'] = mask_rle 108 | masks_coco.append(mask) 109 | 110 | sample_dict["masks"] = masks_coco 111 | 112 | masks_all_samples.append(sample_dict) 113 | 114 | # # debug 115 | # if idx == 10: 116 | # break 117 | 118 | # # for large dataset, it is more reasonable to save it as the h5 file 119 | # # save masks as the h5 file 120 | # h5_save_path = os.path.join(output_path, "masks.h5") 121 | 122 | # save as json 123 | json_save_path = os.path.join(output_path, "masks.json") 124 | 125 | # # convert dict to string 126 | # # https://stackoverflow.com/questions/16494669/how-to-store-dictionary-in-hdf5-dataset 127 | # masks_all_samples_str = [] 128 | # for sample in masks_all_samples: 129 | # masks_all_samples_str.append(str(sample)) 130 | 131 | # with h5py.File(h5_save_path, 'w') as f: 132 | # f.create_dataset('masks', data=masks_all_samples_str) 133 | 134 | with open(json_save_path, 'w') as f: 135 | json.dump(masks_all_samples, f) 136 | 137 | 138 | 139 | def main(): 140 | args = parser.parse_args() 141 | input_path = args.input_path 142 | output_path = args.output_path 143 | sam_checkpoint = args.sam_checkpoint 144 | 145 | process_dataset(input_path=input_path, output_path=output_path, sam_checkpoint=sam_checkpoint) 146 | 147 | 148 | 149 | if __name__ == "__main__": 150 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_saiapr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import List 4 | import cv2 5 | import h5py 6 | import argparse 7 | from pycocotools import mask as mask_utils 8 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 9 | 10 | 11 | # accept split number as argument 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input_path", type=str, default="/cluster/scratch/leikel/junchi/lisa_dataset/refer_seg/images/saiapr_tc-12", 14 | help="path to coco dataset") 15 | parser.add_argument("--output_path", type=str, default="/cluster/home/leikel/junchi/processed_data/saiapr", 16 | help="path to split directory") 17 | parser.add_argument("--sam_checkpoint", type=str, default="/cluster/home/leikel/junchi/segment-anything/checkpoints/sam_vit_h_4b8939.pth", 18 | help="path to sam checkpoint") 19 | 20 | def get_all_samples_saiapr(dataset_path) -> List[str]: 21 | 22 | all_samples = [] 23 | folder_list = os.listdir(dataset_path) 24 | 25 | # folder name should be 00 to 40 26 | assert len(folder_list) == 41, "The number of folders is not 41" 27 | 28 | for folder in folder_list: 29 | image_foler = os.path.join(dataset_path, folder, "images") 30 | # skip if image folder does not exist 31 | if not os.path.exists(image_foler): 32 | continue 33 | files = os.listdir(image_foler) 34 | for file in files: 35 | if file.endswith(".jpg"): 36 | # add folder name to the image name 37 | name = folder + "/images/" + file 38 | all_samples.append(name) 39 | 40 | return all_samples 41 | 42 | 43 | def preprocess_images(image: np.ndarray) -> np.ndarray: 44 | # scale the large side to 1024 45 | H, W, _ = image.shape 46 | if max(W, H) > 1024: 47 | # scale 48 | scale_factor = 1024.0 / max(W, H) 49 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 50 | 51 | return image 52 | 53 | 54 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 55 | # check if cuda is available 56 | device = "cuda" 57 | 58 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 59 | sam.to(device=device) 60 | 61 | mask_generator = SamAutomaticMaskGenerator(sam) 62 | 63 | return mask_generator 64 | 65 | 66 | def process_dataset(input_path: str, output_path: str, sam_checkpoint: str) -> None: 67 | 68 | if not os.path.exists(output_path): 69 | os.makedirs(output_path, exist_ok=True) 70 | 71 | sample_list = get_all_samples_saiapr(input_path) 72 | 73 | mask_generator = init_SAM_everything(model_type="vit_h", sam_checkpoint=sam_checkpoint) 74 | 75 | # create a json file to store the masks 76 | masks_all_samples = [] 77 | 78 | for idx, sample in enumerate(sample_list): 79 | img_file = sample 80 | 81 | if idx % 10 == 0: 82 | print("Processing sample {} / {}".format(idx, len(sample_list))) 83 | 84 | image_path = os.path.join(input_path, img_file) 85 | 86 | image = cv2.imread(os.path.join(input_path, img_file)) 87 | 88 | # some image are corrupted, skip them 89 | if image is None: 90 | print("Image file {} cannot be loaded".format(image_path)) 91 | continue 92 | 93 | # check if image is loaded 94 | assert image is not None, "Image file {} cannot be loaded".format(image_path) 95 | 96 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 97 | 98 | image = preprocess_images(image) 99 | 100 | masks = mask_generator.generate(image) 101 | 102 | sample_dict = {} 103 | 104 | sample_dict["image"] = img_file 105 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 106 | 107 | # convert masks to coco format 108 | masks_coco = [] 109 | for mask in masks: 110 | binary_mask = mask['segmentation'] 111 | mask_rle = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 112 | mask['segmentation'] = mask_rle 113 | masks_coco.append(mask) 114 | 115 | sample_dict["masks"] = masks_coco 116 | 117 | masks_all_samples.append(sample_dict) 118 | 119 | # # debug 120 | # if idx == 10: 121 | # break 122 | 123 | # for large dataset, it is more reasonable to save it as the h5 file 124 | # save masks as the h5 file 125 | h5_save_path = os.path.join(output_path, "masks.h5") 126 | 127 | # convert dict to string 128 | # https://stackoverflow.com/questions/16494669/how-to-store-dictionary-in-hdf5-dataset 129 | masks_all_samples_str = [] 130 | for sample in masks_all_samples: 131 | masks_all_samples_str.append(str(sample)) 132 | 133 | with h5py.File(h5_save_path, 'w') as f: 134 | f.create_dataset('masks', data=masks_all_samples_str) 135 | 136 | 137 | 138 | def main(): 139 | args = parser.parse_args() 140 | input_path = args.input_path 141 | output_path = args.output_path 142 | sam_checkpoint = args.sam_checkpoint 143 | 144 | process_dataset(input_path=input_path, output_path=output_path, sam_checkpoint=sam_checkpoint) 145 | 146 | 147 | 148 | if __name__ == "__main__": 149 | main() -------------------------------------------------------------------------------- /prepare_datasets/prepare_voc2010.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import List 4 | import cv2 5 | import h5py 6 | import argparse 7 | from pycocotools import mask as mask_utils 8 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 9 | 10 | 11 | # accept split number as argument 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input_path", type=str, default="/cluster/scratch/leikel/junchi/lisa_dataset/vlpart/pascal_part/VOCdevkit/VOC2010/JPEGImages", 14 | help="path to coco dataset") 15 | parser.add_argument("--output_path", type=str, default="/cluster/home/leikel/junchi/processed_data/voc2010", 16 | help="path to split directory") 17 | parser.add_argument("--sam_checkpoint", type=str, default="/cluster/home/leikel/junchi/segment-anything/checkpoints/sam_vit_h_4b8939.pth", 18 | help="path to sam checkpoint") 19 | 20 | def get_all_samples(dataset_path) -> List[str]: 21 | 22 | all_samples = [] 23 | files = os.listdir(dataset_path) 24 | 25 | for file in files: 26 | if file.endswith(".jpg"): 27 | name = file.split('.')[0] 28 | all_samples.append(name) 29 | 30 | return all_samples 31 | 32 | 33 | def preprocess_images(image: np.ndarray) -> np.ndarray: 34 | # scale the large side to 1024 35 | H, W, _ = image.shape 36 | if max(W, H) > 1024: 37 | # scale 38 | scale_factor = 1024.0 / max(W, H) 39 | image = cv2.resize(image, (int(W*scale_factor), int(H*scale_factor)), interpolation = cv2.INTER_AREA) 40 | 41 | return image 42 | 43 | 44 | def init_SAM_everything(model_type: str, sam_checkpoint: str) -> SamAutomaticMaskGenerator: 45 | # check if cuda is available 46 | device = "cuda" 47 | 48 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 49 | sam.to(device=device) 50 | 51 | mask_generator = SamAutomaticMaskGenerator(sam) 52 | 53 | return mask_generator 54 | 55 | 56 | def read_split_file(split_file: str) -> List[str]: 57 | with open(split_file, "r") as f: 58 | lines = f.readlines() 59 | images = [line.strip() for line in lines] 60 | # check if it is a valid image fil 61 | for image in images: 62 | assert image.endswith(".jpg"), "Invalid image file: {}".format(image) 63 | 64 | return images 65 | 66 | 67 | def process_dataset(input_path: str, output_path: str, sam_checkpoint: str) -> None: 68 | 69 | if not os.path.exists(output_path): 70 | os.makedirs(output_path, exist_ok=True) 71 | 72 | sample_list = get_all_samples(input_path) 73 | 74 | mask_generator = init_SAM_everything(model_type="vit_h", sam_checkpoint=sam_checkpoint) 75 | 76 | # create a json file to store the masks 77 | masks_all_samples = [] 78 | 79 | for idx, sample in enumerate(sample_list): 80 | img_file = sample + ".jpg" 81 | 82 | if idx % 10 == 0: 83 | print("Processing sample {} / {}".format(idx, len(sample_list))) 84 | 85 | image_path = os.path.join(input_path, img_file) 86 | assert os.path.exists(image_path), "Image file {} does not exist".format(image_path) 87 | 88 | image = cv2.imread(os.path.join(input_path, img_file)) 89 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 90 | 91 | image = preprocess_images(image) 92 | 93 | masks = mask_generator.generate(image) 94 | 95 | sample_dict = {} 96 | 97 | sample_dict["image"] = img_file 98 | sample_dict["target_size"] = [image.shape[0], image.shape[1]] 99 | 100 | # convert masks to coco format 101 | masks_coco = [] 102 | for mask in masks: 103 | binary_mask = mask['segmentation'] 104 | mask_rle = mask_utils.encode(np.asfortranarray(binary_mask.astype(np.uint8))) 105 | mask['segmentation'] = mask_rle 106 | masks_coco.append(mask) 107 | 108 | sample_dict["masks"] = masks_coco 109 | 110 | masks_all_samples.append(sample_dict) 111 | 112 | # # debug 113 | # if idx == 10: 114 | # break 115 | 116 | # for large dataset, it is more reasonable to save it as the h5 file 117 | # save masks as the h5 file 118 | h5_save_path = os.path.join(output_path, "masks.h5") 119 | 120 | # convert dict to string 121 | # https://stackoverflow.com/questions/16494669/how-to-store-dictionary-in-hdf5-dataset 122 | masks_all_samples_str = [] 123 | for sample in masks_all_samples: 124 | masks_all_samples_str.append(str(sample)) 125 | 126 | with h5py.File(h5_save_path, 'w') as f: 127 | f.create_dataset('masks', data=masks_all_samples_str) 128 | 129 | 130 | 131 | def main(): 132 | args = parser.parse_args() 133 | input_path = args.input_path 134 | output_path = args.output_path 135 | sam_checkpoint = args.sam_checkpoint 136 | 137 | process_dataset(input_path=input_path, output_path=output_path, sam_checkpoint=sam_checkpoint) 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | main() -------------------------------------------------------------------------------- /prepare_datasets/split_coco.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Read the file list of coco training set, and split it into 8 parts. 3 | ''' 4 | 5 | import os 6 | 7 | # dataset_path = "/cluster/scratch/leikel/junchi/lisa_dataset/coco/train2017" 8 | dataset_path = "/cluster/scratch/leikel/junchi/lisa_dataset/refer_seg/images/mscoco/images/train2014" 9 | output_path = "/cluster/home/leikel/junchi/processed_data/coco2014" 10 | 11 | if not os.path.exists(output_path): 12 | os.makedirs(output_path, exist_ok=True) 13 | 14 | files = os.listdir(dataset_path) 15 | # filter out non-image files 16 | files = [file for file in files if file.endswith(".jpg")] 17 | 18 | print("Total number of files: {}".format(len(files))) 19 | 20 | # split into 8 parts 21 | num_parts = 8 22 | 23 | part_cnt = [0 for i in range(num_parts)] 24 | 25 | for i in range(num_parts): 26 | part_files = files[i::num_parts] 27 | part_files.sort() 28 | part_cnt[i] = len(part_files) 29 | with open(os.path.join(output_path, "part_{}.txt".format(i)), "w") as f: 30 | for file in part_files: 31 | f.write(file + "\n") 32 | 33 | assert sum(part_cnt) == len(files), "Number of files in parts does not match total number of files." 34 | 35 | print("Done.") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | absl-py=2.1.0=pypi_0 7 | accelerate=0.21.0=pypi_0 8 | addict=2.4.0=pypi_0 9 | aiofiles=23.2.1=pypi_0 10 | aiohttp=3.9.3=pypi_0 11 | aiosignal=1.3.1=pypi_0 12 | altair=5.0.1=pypi_0 13 | antlr4-python3-runtime=4.9.3=pypi_0 14 | anyio=4.2.0=pypi_0 15 | appdirs=1.4.4=pypi_0 16 | attrs=23.2.0=pypi_0 17 | beautifulsoup4=4.12.3=pypi_0 18 | bitsandbytes=0.42.0=pypi_0 19 | black=24.1.1=pypi_0 20 | blas=1.0=mkl 21 | blis=0.7.11=pypi_0 22 | brotli-python=1.0.9=py311h6a678d5_7 23 | bzip2=1.0.8=h7b6447c_0 24 | ca-certificates=2023.12.12=h06a4308_0 25 | cachetools=5.3.2=pypi_0 26 | catalogue=2.0.10=pypi_0 27 | certifi=2023.11.17=py311h06a4308_0 28 | cffi=1.16.0=py311h5eee18b_0 29 | charset-normalizer=2.0.4=pyhd3eb1b0_0 30 | click=8.1.7=pypi_0 31 | cloudpathlib=0.16.0=pypi_0 32 | cloudpickle=3.0.0=pypi_0 33 | confection=0.1.4=pypi_0 34 | contourpy=1.2.0=pypi_0 35 | cryptography=41.0.7=py311hdda0065_0 36 | cuda=11.8.0=0 37 | cuda-cccl=11.8.89=0 38 | cuda-command-line-tools=11.8.0=0 39 | cuda-compiler=11.8.0=0 40 | cuda-cudart=11.8.89=0 41 | cuda-cudart-dev=11.8.89=0 42 | cuda-cuobjdump=11.8.86=0 43 | cuda-cupti=11.8.87=0 44 | cuda-cuxxfilt=11.8.86=0 45 | cuda-demo-suite=11.8.86=0 46 | cuda-documentation=11.8.86=0 47 | cuda-driver-dev=11.8.89=0 48 | cuda-gdb=11.8.86=0 49 | cuda-libraries=11.8.0=0 50 | cuda-libraries-dev=11.8.0=0 51 | cuda-memcheck=11.8.86=0 52 | cuda-nsight=11.8.86=0 53 | cuda-nsight-compute=11.8.0=0 54 | cuda-nvcc=11.8.89=0 55 | cuda-nvdisasm=11.8.86=0 56 | cuda-nvml-dev=11.8.86=0 57 | cuda-nvprof=11.8.87=0 58 | cuda-nvprune=11.8.86=0 59 | cuda-nvrtc=11.8.89=0 60 | cuda-nvrtc-dev=11.8.89=0 61 | cuda-nvtx=11.8.86=0 62 | cuda-nvvp=11.8.87=0 63 | cuda-profiler-api=11.8.86=0 64 | cuda-runtime=11.8.0=0 65 | cuda-sanitizer-api=11.8.86=0 66 | cuda-toolkit=11.8.0=0 67 | cuda-tools=11.8.0=0 68 | cuda-visual-tools=11.8.0=0 69 | cycler=0.12.1=pypi_0 70 | cymem=2.0.8=pypi_0 71 | cython=3.0.8=pypi_0 72 | deepspeed=0.10.0=pypi_0 73 | defusedxml=0.7.1=pypi_0 74 | detectron2=0.6=dev_0 75 | docker-pycreds=0.4.0=pypi_0 76 | einops=0.6.1=pypi_0 77 | einops-exts=0.0.4=pypi_0 78 | fastapi=0.100.1=pypi_0 79 | ffmpeg=4.3=hf484d3e_0 80 | ffmpy=0.3.1=pypi_0 81 | filelock=3.13.1=py311h06a4308_0 82 | flash-attn=2.5.2=pypi_0 83 | fonttools=4.47.2=pypi_0 84 | freetype=2.12.1=h4a9f257_0 85 | frozenlist=1.4.1=pypi_0 86 | fsspec=2023.12.2=pypi_0 87 | fvcore=0.1.5.post20221221=pypi_0 88 | gdown=5.1.0=pypi_0 89 | gds-tools=1.4.0.31=0 90 | giflib=5.2.1=h5eee18b_3 91 | gitdb=4.0.11=pypi_0 92 | gitpython=3.1.41=pypi_0 93 | gmp=6.2.1=h295c915_3 94 | gmpy2=2.1.2=py311hc9b5ff0_0 95 | gnutls=3.6.15=he1e5248_0 96 | google-auth=2.27.0=pypi_0 97 | google-auth-oauthlib=1.2.0=pypi_0 98 | gradio=3.35.2=pypi_0 99 | gradio-client=0.2.9=pypi_0 100 | groundingdino=0.1.0=dev_0 101 | grpcio=1.60.1=pypi_0 102 | h11=0.14.0=pypi_0 103 | h5py=3.10.0=pypi_0 104 | hjson=3.1.0=pypi_0 105 | httpcore=0.17.3=pypi_0 106 | httpx=0.24.0=pypi_0 107 | huggingface-hub=0.20.3=pypi_0 108 | hydra-core=1.3.2=pypi_0 109 | idna=3.4=py311h06a4308_0 110 | imageio=2.33.1=pypi_0 111 | importlib-metadata=7.0.2=pypi_0 112 | intel-openmp=2023.1.0=hdb19cb5_46306 113 | iopath=0.1.9=pypi_0 114 | jinja2=3.1.3=py311h06a4308_0 115 | joblib=1.3.2=pypi_0 116 | jpeg=9e=h5eee18b_1 117 | jsonschema=4.21.1=pypi_0 118 | jsonschema-specifications=2023.12.1=pypi_0 119 | kiwisolver=1.4.5=pypi_0 120 | lame=3.100=h7b6447c_0 121 | langcodes=3.3.0=pypi_0 122 | lazy-loader=0.3=pypi_0 123 | lcms2=2.12=h3be6417_0 124 | ld_impl_linux-64=2.38=h1181459_1 125 | lerc=3.0=h295c915_0 126 | libcublas=11.11.3.6=0 127 | libcublas-dev=11.11.3.6=0 128 | libcufft=10.9.0.58=0 129 | libcufft-dev=10.9.0.58=0 130 | libcufile=1.4.0.31=0 131 | libcufile-dev=1.4.0.31=0 132 | libcurand=10.3.0.86=0 133 | libcurand-dev=10.3.0.86=0 134 | libcusolver=11.4.1.48=0 135 | libcusolver-dev=11.4.1.48=0 136 | libcusparse=11.7.5.86=0 137 | libcusparse-dev=11.7.5.86=0 138 | libdeflate=1.17=h5eee18b_1 139 | libffi=3.4.4=h6a678d5_0 140 | libgcc-ng=11.2.0=h1234567_1 141 | libgomp=11.2.0=h1234567_1 142 | libiconv=1.16=h7f8727e_2 143 | libidn2=2.3.4=h5eee18b_0 144 | libjpeg-turbo=2.0.0=h9bf148f_0 145 | libnpp=11.8.0.86=0 146 | libnpp-dev=11.8.0.86=0 147 | libnvjpeg=11.9.0.86=0 148 | libnvjpeg-dev=11.9.0.86=0 149 | libpng=1.6.39=h5eee18b_0 150 | libstdcxx-ng=11.2.0=h1234567_1 151 | libtasn1=4.19.0=h5eee18b_0 152 | libtiff=4.5.1=h6a678d5_0 153 | libunistring=0.9.10=h27cfd23_0 154 | libuuid=1.41.5=h5eee18b_0 155 | libwebp=1.3.2=h11a3e52_0 156 | libwebp-base=1.3.2=h5eee18b_0 157 | linkify-it-py=2.0.2=pypi_0 158 | llava=1.1.1=pypi_0 159 | llvm-openmp=14.0.6=h9e868ea_0 160 | lz4-c=1.9.4=h6a678d5_0 161 | markdown=3.5.2=pypi_0 162 | markdown-it-py=2.2.0=pypi_0 163 | markdown2=2.4.10=pypi_0 164 | markupsafe=2.1.3=py311h5eee18b_0 165 | matplotlib=3.8.2=pypi_0 166 | mdit-py-plugins=0.3.3=pypi_0 167 | mdurl=0.1.2=pypi_0 168 | mkl=2023.1.0=h213fc3f_46344 169 | mkl-service=2.4.0=py311h5eee18b_1 170 | mkl_fft=1.3.8=py311h5eee18b_0 171 | mkl_random=1.2.4=py311hdb19cb5_0 172 | mpc=1.1.0=h10f8cd9_1 173 | mpfr=4.0.2=hb69a4c5_1 174 | mpmath=1.3.0=py311h06a4308_0 175 | msgpack=1.0.7=pypi_0 176 | multidict=6.0.5=pypi_0 177 | multiscaledeformableattention=1.0=pypi_0 178 | murmurhash=1.0.10=pypi_0 179 | mypy-extensions=1.0.0=pypi_0 180 | ncurses=6.4=h6a678d5_0 181 | nettle=3.7.3=hbbd107a_1 182 | networkx=3.1=py311h06a4308_0 183 | ninja=1.11.1=pypi_0 184 | nsight-compute=2022.3.0.22=0 185 | numpy=1.24.2=pypi_0 186 | oauthlib=3.2.2=pypi_0 187 | omegaconf=2.3.0=pypi_0 188 | openai=0.27.8=pypi_0 189 | opencv-python=4.8.0.74=pypi_0 190 | opencv-python-headless=4.9.0.80=pypi_0 191 | openh264=2.1.1=h4ff587b_0 192 | openjpeg=2.4.0=h3ad879b_0 193 | openssl=3.0.12=h7f8727e_0 194 | orjson=3.9.5=pypi_0 195 | packaging=23.2=pypi_0 196 | pandas=2.2.0=pypi_0 197 | pathspec=0.12.1=pypi_0 198 | peft=0.4.0=pypi_0 199 | pillow=9.4.0=pypi_0 200 | pip=23.3.1=py311h06a4308_0 201 | platformdirs=4.2.0=pypi_0 202 | portalocker=2.8.2=pypi_0 203 | preshed=3.0.9=pypi_0 204 | protobuf=4.23.4=pypi_0 205 | psutil=5.9.8=pypi_0 206 | py-cpuinfo=9.0.0=pypi_0 207 | pyasn1=0.5.1=pypi_0 208 | pyasn1-modules=0.3.0=pypi_0 209 | pycocotools=2.0.7=pypi_0 210 | pycparser=2.21=pyhd3eb1b0_0 211 | pydantic=1.10.14=pypi_0 212 | pydub=0.25.1=pypi_0 213 | pygments=2.17.2=pypi_0 214 | pyopenssl=23.2.0=py311h06a4308_0 215 | pyparsing=3.1.1=pypi_0 216 | pysocks=1.7.1=py311h06a4308_0 217 | python=3.11.7=h955ad1f_0 218 | python-dateutil=2.8.2=pypi_0 219 | python-multipart=0.0.7=pypi_0 220 | pytorch=2.2.0=py3.11_cuda11.8_cudnn8.7.0_0 221 | pytorch-cuda=11.8=h7e8668a_5 222 | pytorch-mutex=1.0=cuda 223 | pytz=2024.1=pypi_0 224 | pyyaml=6.0.1=py311h5eee18b_0 225 | ray=2.6.1=pypi_0 226 | readline=8.2=h5eee18b_0 227 | referencing=0.33.0=pypi_0 228 | regex=2023.12.25=pypi_0 229 | requests=2.31.0=py311h06a4308_0 230 | requests-oauthlib=1.3.1=pypi_0 231 | rpds-py=0.17.1=pypi_0 232 | rsa=4.9=pypi_0 233 | safetensors=0.3.2=pypi_0 234 | scikit-image=0.22.0=pypi_0 235 | scikit-learn=1.4.1.post1=pypi_0 236 | scipy=1.12.0=pypi_0 237 | segment-anything=1.0=dev_0 238 | semantic-version=2.10.0=pypi_0 239 | sentencepiece=0.1.99=pypi_0 240 | sentry-sdk=1.40.0=pypi_0 241 | setproctitle=1.3.3=pypi_0 242 | setuptools=68.2.2=py311h06a4308_0 243 | shapely=2.0.2=pypi_0 244 | shortuuid=1.0.12=pypi_0 245 | six=1.16.0=pypi_0 246 | smart-open=6.4.0=pypi_0 247 | smmap=5.0.1=pypi_0 248 | sniffio=1.3.0=pypi_0 249 | soupsieve=2.5=pypi_0 250 | spacy=3.7.4=pypi_0 251 | spacy-legacy=3.0.12=pypi_0 252 | spacy-loggers=1.0.5=pypi_0 253 | sqlite=3.41.2=h5eee18b_0 254 | srsly=2.4.8=pypi_0 255 | starlette=0.27.0=pypi_0 256 | supervision=0.18.0=pypi_0 257 | svgwrite=1.4.3=pypi_0 258 | sympy=1.12=py311h06a4308_0 259 | tabulate=0.9.0=pypi_0 260 | tbb=2021.8.0=hdb19cb5_0 261 | tensorboard=2.15.1=pypi_0 262 | tensorboard-data-server=0.7.2=pypi_0 263 | termcolor=2.4.0=pypi_0 264 | thinc=8.2.3=pypi_0 265 | threadpoolctl=3.3.0=pypi_0 266 | tifffile=2024.1.30=pypi_0 267 | timm=0.6.13=pypi_0 268 | tk=8.6.12=h1ccaba5_0 269 | tokenizers=0.13.3=pypi_0 270 | tomli=2.0.1=pypi_0 271 | toolz=0.12.1=pypi_0 272 | torchaudio=2.2.0=py311_cu118 273 | torchtriton=2.2.0=py311 274 | torchvision=0.17.0=py311_cu118 275 | tqdm=4.64.1=pypi_0 276 | transformers=4.29.0=pypi_0 277 | typer=0.9.0=pypi_0 278 | typing_extensions=4.9.0=py311h06a4308_1 279 | tzdata=2023.4=pypi_0 280 | uc-micro-py=1.0.2=pypi_0 281 | urllib3=1.26.18=py311h06a4308_0 282 | uvicorn=0.23.2=pypi_0 283 | wandb=0.16.2=pypi_0 284 | wasabi=1.1.2=pypi_0 285 | wavedrom=2.0.3.post3=pypi_0 286 | weasel=0.3.4=pypi_0 287 | websockets=11.0.3=pypi_0 288 | werkzeug=3.0.1=pypi_0 289 | wheel=0.41.2=py311h06a4308_0 290 | wordcloud=1.9.3=pypi_0 291 | xz=5.4.5=h5eee18b_0 292 | yacs=0.1.8=pypi_0 293 | yaml=0.2.5=h7b6447c_0 294 | yapf=0.40.2=pypi_0 295 | yarl=1.9.4=pypi_0 296 | zipp=3.17.0=pypi_0 297 | zlib=1.2.13=h5eee18b_0 298 | zstd=1.5.5=hc292b87_0 299 | -------------------------------------------------------------------------------- /scripts/finetune_llmseg.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 5 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 6 | dataset_path="./lisa_dataset" 7 | sam_masks_path="./processed_data" 8 | log_path="./runs" 9 | resume_path="./runs/10epoch/ckpt_model" 10 | 11 | deepspeed --include localhost:2,3 \ 12 | --master_port=24374 finetune_llmseg.py \ 13 | --version="$llava_path" \ 14 | --dataset_dir="$dataset_path" \ 15 | --sam_masks_dir="$sam_masks_path" \ 16 | --vision_pretrained="$vision_path" \ 17 | --dataset="sem_seg||refer_seg||reason_seg" \ 18 | --sample_rates="9,3,1" \ 19 | --exp_name="finetune_llmseg" \ 20 | --log_base_dir="$log" \ 21 | --steps_per_epoch=500 \ 22 | --lr=1e-5 \ 23 | --epochs=5 \ 24 | --batch_size=1 \ 25 | --resume='$resume_path' \ 26 | -------------------------------------------------------------------------------- /scripts/train_10epoch.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 5 | sam_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 6 | dataset_path="./lisa_dataset" 7 | sam_masks_path="./processed_data" 8 | log_path="./runs" 9 | 10 | deepspeed --include localhost:6,7 \ 11 | --master_port=24374 training.py \ 12 | --version="$llava_path" \ 13 | --dataset_dir="$dataset_path" \ 14 | --sam_masks_dir="$sam_masks_path" \ 15 | --vision_pretrained="$sam_path" \ 16 | --dataset="sem_seg||refer_seg||reason_seg" \ 17 | --sample_rates="9,3,1" \ 18 | --exp_name="10epoch" \ 19 | --log_base_dir="$log_path" \ 20 | --lr=0.0001 \ 21 | --epochs=10 \ 22 | --batch_size=1 \ 23 | -------------------------------------------------------------------------------- /scripts/train_20epoch.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 5 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 6 | dataset_path="./lisa_dataset" 7 | sam_masks_path="./processed_data" 8 | log_path="./lisa_dataset/new_runs" 9 | 10 | deepspeed --include localhost:6,7 \ 11 | --master_port=24374 training_debug.py \ 12 | --version="$llava_path" \ 13 | --dataset_dir="$dataset_path" \ 14 | --sam_masks_dir="$sam_masks_path" \ 15 | --vision_pretrained="$vision_path" \ 16 | --dataset="sem_seg||refer_seg||reason_seg" \ 17 | --sample_rates="9,3,1" \ 18 | --exp_name="20epoch" \ 19 | --log_base_dir="$log_path" \ 20 | --lr=0.0001 \ 21 | --epochs=20 \ 22 | --batch_size=1 \ 23 | -------------------------------------------------------------------------------- /scripts/train_zero_shot.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 5 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 6 | dataset_path="./lisa_dataset" 7 | sam_masks_path="./processed_data" 8 | log_path="./runs" 9 | 10 | deepspeed --include localhost:0,1 \ 11 | --master_port=24371 training_debug.py \ 12 | --version="$llava_path" \ 13 | --dataset_dir="$dataset_path" \ 14 | --sam_masks_dir="$sam_masks_path" \ 15 | --vision_pretrained="$vision_path" \ 16 | --dataset="sem_seg||refer_seg" \ 17 | --sample_rates="9,3" \ 18 | --exp_name="zeroshot" \ 19 | --log_base_dir="$log_path" \ 20 | --lr=0.0001 \ 21 | --epochs=10 \ 22 | --batch_size=1 \ 23 | -------------------------------------------------------------------------------- /scripts/validate_llmseg40k.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 5 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 6 | dataset_path="./lisa_dataset" 7 | sam_masks_path="./processed_data" 8 | log_path="./runs" 9 | 10 | deepspeed --include localhost:0,1 \ 11 | --master_port=24353 validate_llmseg.py \ 12 | --version="$llava_path" \ 13 | --dataset_dir="$dataset_path" \ 14 | --vision_pretrained="$vision_path" \ 15 | --dataset="reason_seg" \ 16 | --sample_rates="1" \ 17 | --exp_name="finetune_llmseg" \ 18 | --log_base_dir="$log_path" \ 19 | --batch_size=1 \ 20 | --eval_only \ 21 | --visualize \ -------------------------------------------------------------------------------- /scripts/validate_visualize.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | llava_path="./pretrained_weights/LLaVA-lightning-7B-v1/" 5 | vision_path="./pretrained_weights/SAM/sam_vit_h_4b8939.pth" 6 | dataset_path="./lisa_dataset" 7 | sam_masks_path="./processed_data" 8 | log_path="./runs" 9 | 10 | 11 | deepspeed --include localhost:2,3 \ 12 | --master_port=24353 training_debug.py \ 13 | --version="$llava_path" \ 14 | --dataset_dir="$dataset_path" \ 15 | --sam_masks_dir="$sam_masks_path" \ 16 | --vision_pretrained="$vision_path" \ 17 | --dataset="reason_seg" \ 18 | --sample_rates="1" \ 19 | --exp_name="10epoch" \ 20 | --log_base_dir="$log_path" \ 21 | --batch_size=1 \ 22 | --eval_only \ 23 | --val_dataset="ReasonSeg|val" \ 24 | --visualize -------------------------------------------------------------------------------- /utils/ade20k_classes.json: -------------------------------------------------------------------------------- 1 | [ 2 | "wall", "building", "sky", "floor", "tree", "ceiling", "road", 3 | "bed", "windowpane", "grass", "cabinet", "sidewalk", 4 | "person", "earth", "door", "table", "mountain", "plant", 5 | "curtain", "chair", "car", "water", "painting", "sofa", 6 | "shelf", "house", "sea", "mirror", "rug", "field", "armchair", 7 | "seat", "fence", "desk", "rock", "wardrobe", "lamp", 8 | "bathtub", "railing", "cushion", "base", "box", "column", 9 | "signboard", "chest of drawers", "counter", "sand", "sink", 10 | "skyscraper", "fireplace", "refrigerator", "grandstand", 11 | "path", "stairs", "runway", "case", "pool table", "pillow", 12 | "screen door", "stairway", "river", "bridge", "bookcase", 13 | "blind", "coffee table", "toilet", "flower", "book", "hill", 14 | "bench", "countertop", "stove", "palm", "kitchen island", 15 | "computer", "swivel chair", "boat", "bar", "arcade machine", 16 | "hovel", "bus", "towel", "light", "truck", "tower", 17 | "chandelier", "awning", "streetlight", "booth", 18 | "television receiver", "airplane", "dirt track", "apparel", 19 | "pole", "land", "bannister", "escalator", "ottoman", "bottle", 20 | "buffet", "poster", "stage", "van", "ship", "fountain", 21 | "conveyer belt", "canopy", "washer", "plaything", 22 | "swimming pool", "stool", "barrel", "basket", "waterfall", 23 | "tent", "bag", "minibike", "cradle", "oven", "ball", "food", 24 | "step", "tank", "trade name", "microwave", "pot", "animal", 25 | "bicycle", "lake", "dishwasher", "screen", "blanket", 26 | "sculpture", "hood", "sconce", "vase", "traffic light", 27 | "tray", "ashcan", "fan", "pier", "crt screen", "plate", 28 | "monitor", "bulletin board", "shower", "radiator", "glass", 29 | "clock", "flag" 30 | ] -------------------------------------------------------------------------------- /utils/cocostuff_classes.txt: -------------------------------------------------------------------------------- 1 | 0: unlabeled 2 | 1: person 3 | 2: bicycle 4 | 3: car 5 | 4: motorcycle 6 | 5: airplane 7 | 6: bus 8 | 7: train 9 | 8: truck 10 | 9: boat 11 | 10: traffic light 12 | 11: fire hydrant 13 | 12: street sign 14 | 13: stop sign 15 | 14: parking meter 16 | 15: bench 17 | 16: bird 18 | 17: cat 19 | 18: dog 20 | 19: horse 21 | 20: sheep 22 | 21: cow 23 | 22: elephant 24 | 23: bear 25 | 24: zebra 26 | 25: giraffe 27 | 26: hat 28 | 27: backpack 29 | 28: umbrella 30 | 29: shoe 31 | 30: eye glasses 32 | 31: handbag 33 | 32: tie 34 | 33: suitcase 35 | 34: frisbee 36 | 35: skis 37 | 36: snowboard 38 | 37: sports ball 39 | 38: kite 40 | 39: baseball bat 41 | 40: baseball glove 42 | 41: skateboard 43 | 42: surfboard 44 | 43: tennis racket 45 | 44: bottle 46 | 45: plate 47 | 46: wine glass 48 | 47: cup 49 | 48: fork 50 | 49: knife 51 | 50: spoon 52 | 51: bowl 53 | 52: banana 54 | 53: apple 55 | 54: sandwich 56 | 55: orange 57 | 56: broccoli 58 | 57: carrot 59 | 58: hot dog 60 | 59: pizza 61 | 60: donut 62 | 61: cake 63 | 62: chair 64 | 63: couch 65 | 64: potted plant 66 | 65: bed 67 | 66: mirror 68 | 67: dining table 69 | 68: window 70 | 69: desk 71 | 70: toilet 72 | 71: door 73 | 72: tv 74 | 73: laptop 75 | 74: mouse 76 | 75: remote 77 | 76: keyboard 78 | 77: cell phone 79 | 78: microwave 80 | 79: oven 81 | 80: toaster 82 | 81: sink 83 | 82: refrigerator 84 | 83: blender 85 | 84: book 86 | 85: clock 87 | 86: vase 88 | 87: scissors 89 | 88: teddy bear 90 | 89: hair drier 91 | 90: toothbrush 92 | 91: hair brush 93 | 92: banner 94 | 93: blanket 95 | 94: branch 96 | 95: bridge 97 | 96: building-other 98 | 97: bush 99 | 98: cabinet 100 | 99: cage 101 | 100: cardboard 102 | 101: carpet 103 | 102: ceiling-other 104 | 103: ceiling-tile 105 | 104: cloth 106 | 105: clothes 107 | 106: clouds 108 | 107: counter 109 | 108: cupboard 110 | 109: curtain 111 | 110: desk-stuff 112 | 111: dirt 113 | 112: door-stuff 114 | 113: fence 115 | 114: floor-marble 116 | 115: floor-other 117 | 116: floor-stone 118 | 117: floor-tile 119 | 118: floor-wood 120 | 119: flower 121 | 120: fog 122 | 121: food-other 123 | 122: fruit 124 | 123: furniture-other 125 | 124: grass 126 | 125: gravel 127 | 126: ground-other 128 | 127: hill 129 | 128: house 130 | 129: leaves 131 | 130: light 132 | 131: mat 133 | 132: metal 134 | 133: mirror-stuff 135 | 134: moss 136 | 135: mountain 137 | 136: mud 138 | 137: napkin 139 | 138: net 140 | 139: paper 141 | 140: pavement 142 | 141: pillow 143 | 142: plant-other 144 | 143: plastic 145 | 144: platform 146 | 145: playingfield 147 | 146: railing 148 | 147: railroad 149 | 148: river 150 | 149: road 151 | 150: rock 152 | 151: roof 153 | 152: rug 154 | 153: salad 155 | 154: sand 156 | 155: sea 157 | 156: shelf 158 | 157: sky 159 | 158: skyscraper 160 | 159: snow 161 | 160: solid-other 162 | 161: stairs 163 | 162: stone 164 | 163: straw 165 | 164: structural-other 166 | 165: table 167 | 166: tent 168 | 167: textile-other 169 | 168: towel 170 | 169: tree 171 | 170: vegetable 172 | 171: wall-brick 173 | 172: wall-concrete 174 | 173: wall-other 175 | 174: wall-panel 176 | 175: wall-stone 177 | 176: wall-tile 178 | 177: wall-wood 179 | 178: water-other 180 | 179: waterdrops 181 | 180: window-blind 182 | 181: window-other 183 | 182: wood 184 | -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | def get_mask_from_json(json_path, img): 10 | try: 11 | with open(json_path, "r") as r: 12 | anno = json.loads(r.read()) 13 | except: 14 | with open(json_path, "r", encoding="cp1252") as r: 15 | anno = json.loads(r.read()) 16 | 17 | inform = anno["shapes"] 18 | comments = anno["text"] 19 | is_sentence = anno["is_sentence"] 20 | 21 | height, width = img.shape[:2] 22 | 23 | ### sort polies by area 24 | area_list = [] 25 | valid_poly_list = [] 26 | for i in inform: 27 | label_id = i["label"] 28 | points = i["points"] 29 | if "flag" == label_id.lower(): ## meaningless deprecated annotations 30 | continue 31 | 32 | tmp_mask = np.zeros((height, width), dtype=np.uint8) 33 | cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1) 34 | cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1) 35 | tmp_area = tmp_mask.sum() 36 | 37 | area_list.append(tmp_area) 38 | valid_poly_list.append(i) 39 | 40 | ### ground-truth mask 41 | sort_index = np.argsort(area_list)[::-1].astype(np.int32) 42 | sort_index = list(sort_index) 43 | sort_inform = [] 44 | for s_idx in sort_index: 45 | sort_inform.append(valid_poly_list[s_idx]) 46 | 47 | mask = np.zeros((height, width), dtype=np.uint8) 48 | for i in sort_inform: 49 | label_id = i["label"] 50 | points = i["points"] 51 | 52 | if "ignore" in label_id.lower(): 53 | label_value = 255 # ignored during evaluation 54 | else: 55 | label_value = 1 # target 56 | 57 | cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1) 58 | cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value) 59 | 60 | return mask, comments, is_sentence 61 | 62 | 63 | if __name__ == "__main__": 64 | data_dir = "./train" 65 | vis_dir = "./vis" 66 | 67 | if not os.path.exists(vis_dir): 68 | os.makedirs(vis_dir) 69 | 70 | json_path_list = sorted(glob.glob(data_dir + "/*.json")) 71 | for json_path in json_path_list: 72 | img_path = json_path.replace(".json", ".jpg") 73 | img = cv2.imread(img_path)[:, :, ::-1] 74 | 75 | # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton. 76 | mask, comments, is_sentence = get_mask_from_json(json_path, img) 77 | 78 | ## visualization. Green for target, and red for ignore. 79 | valid_mask = (mask == 1).astype(np.float32)[:, :, None] 80 | ignore_mask = (mask == 255).astype(np.float32)[:, :, None] 81 | vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + ( 82 | (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask 83 | + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask 84 | ) 85 | vis_img = np.concatenate([img, vis_img], 1) 86 | vis_path = os.path.join( 87 | vis_dir, json_path.split("/")[-1].replace(".json", ".jpg") 88 | ) 89 | cv2.imwrite(vis_path, vis_img[:, :, ::-1]) 90 | print("Visualization has been saved to: ", vis_path) 91 | -------------------------------------------------------------------------------- /utils/grefcoco.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import io 4 | import logging 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import pycocotools.mask as mask_util 10 | from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes 11 | from detectron2.utils.file_io import PathManager 12 | from fvcore.common.timer import Timer 13 | from PIL import Image 14 | 15 | """ 16 | This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format". 17 | """ 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | __all__ = ["load_refcoco_json"] 23 | 24 | 25 | def load_grefcoco_json( 26 | refer_root, 27 | dataset_name, 28 | splitby, 29 | split, 30 | image_root, 31 | extra_annotation_keys=None, 32 | extra_refer_keys=None, 33 | ): 34 | if dataset_name == "refcocop": 35 | dataset_name = "refcoco+" 36 | if dataset_name == "refcoco" or dataset_name == "refcoco+": 37 | splitby == "unc" 38 | if dataset_name == "refcocog": 39 | assert splitby == "umd" or splitby == "google" 40 | 41 | dataset_id = "_".join([dataset_name, splitby, split]) 42 | 43 | from .grefer import G_REFER 44 | 45 | logger.info("Loading dataset {} ({}-{}) ...".format(dataset_name, splitby, split)) 46 | logger.info("Refcoco root: {}".format(refer_root)) 47 | timer = Timer() 48 | refer_root = PathManager.get_local_path(refer_root) 49 | with contextlib.redirect_stdout(io.StringIO()): 50 | refer_api = G_REFER(data_root=refer_root, dataset=dataset_name, splitBy=splitby) 51 | if timer.seconds() > 1: 52 | logger.info( 53 | "Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds()) 54 | ) 55 | 56 | ref_ids = refer_api.getRefIds(split=split) 57 | img_ids = refer_api.getImgIds(ref_ids) 58 | refs = refer_api.loadRefs(ref_ids) 59 | imgs = [refer_api.loadImgs(ref["image_id"])[0] for ref in refs] 60 | anns = [refer_api.loadAnns(ref["ann_id"]) for ref in refs] 61 | imgs_refs_anns = list(zip(imgs, refs, anns)) 62 | 63 | logger.info( 64 | "Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format( 65 | len(img_ids), len(ref_ids), dataset_id 66 | ) 67 | ) 68 | 69 | dataset_dicts = [] 70 | 71 | ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or []) 72 | ref_keys = ["raw", "sent_id"] + (extra_refer_keys or []) 73 | 74 | ann_lib = {} 75 | 76 | NT_count = 0 77 | MT_count = 0 78 | 79 | for img_dict, ref_dict, anno_dicts in imgs_refs_anns: 80 | record = {} 81 | record["source"] = "grefcoco" 82 | record["file_name"] = os.path.join(image_root, img_dict["file_name"]) 83 | record["height"] = img_dict["height"] 84 | record["width"] = img_dict["width"] 85 | image_id = record["image_id"] = img_dict["id"] 86 | 87 | # Check that information of image, ann and ref match each other 88 | # This fails only when the data parsing logic or the annotation file is buggy. 89 | assert ref_dict["image_id"] == image_id 90 | assert ref_dict["split"] == split 91 | if not isinstance(ref_dict["ann_id"], list): 92 | ref_dict["ann_id"] = [ref_dict["ann_id"]] 93 | 94 | # No target samples 95 | if None in anno_dicts: 96 | assert anno_dicts == [None] 97 | assert ref_dict["ann_id"] == [-1] 98 | record["empty"] = True 99 | obj = {key: None for key in ann_keys if key in ann_keys} 100 | obj["bbox_mode"] = BoxMode.XYWH_ABS 101 | obj["empty"] = True 102 | obj = [obj] 103 | 104 | # Multi target samples 105 | else: 106 | record["empty"] = False 107 | obj = [] 108 | for anno_dict in anno_dicts: 109 | ann_id = anno_dict["id"] 110 | if anno_dict["iscrowd"]: 111 | continue 112 | assert anno_dict["image_id"] == image_id 113 | assert ann_id in ref_dict["ann_id"] 114 | 115 | if ann_id in ann_lib: 116 | ann = ann_lib[ann_id] 117 | else: 118 | ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict} 119 | ann["bbox_mode"] = BoxMode.XYWH_ABS 120 | ann["empty"] = False 121 | 122 | segm = anno_dict.get("segmentation", None) 123 | assert segm # either list[list[float]] or dict(RLE) 124 | if isinstance(segm, dict): 125 | if isinstance(segm["counts"], list): 126 | # convert to compressed RLE 127 | segm = mask_util.frPyObjects(segm, *segm["size"]) 128 | else: 129 | # filter out invalid polygons (< 3 points) 130 | segm = [ 131 | poly 132 | for poly in segm 133 | if len(poly) % 2 == 0 and len(poly) >= 6 134 | ] 135 | if len(segm) == 0: 136 | num_instances_without_valid_segmentation += 1 137 | continue # ignore this instance 138 | ann["segmentation"] = segm 139 | ann_lib[ann_id] = ann 140 | 141 | obj.append(ann) 142 | 143 | record["annotations"] = obj 144 | 145 | # Process referring expressions 146 | sents = ref_dict["sentences"] 147 | for sent in sents: 148 | ref_record = record.copy() 149 | ref = {key: sent[key] for key in ref_keys if key in sent} 150 | ref["ref_id"] = ref_dict["ref_id"] 151 | ref_record["sentence"] = ref 152 | dataset_dicts.append(ref_record) 153 | # if ref_record['empty']: 154 | # NT_count += 1 155 | # else: 156 | # MT_count += 1 157 | 158 | # logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count) 159 | 160 | # Debug mode 161 | # return dataset_dicts[:100] 162 | 163 | return dataset_dicts 164 | 165 | 166 | if __name__ == "__main__": 167 | """ 168 | Test the COCO json dataset loader. 169 | 170 | Usage: 171 | python -m detectron2.data.datasets.coco \ 172 | path/to/json path/to/image_root dataset_name 173 | 174 | "dataset_name" can be "coco_2014_minival_100", or other 175 | pre-registered ones 176 | """ 177 | import sys 178 | 179 | import detectron2.data.datasets # noqa # add pre-defined metadata 180 | from detectron2.utils.logger import setup_logger 181 | from detectron2.utils.visualizer import Visualizer 182 | 183 | REFCOCO_PATH = "/mnt/lustre/hhding/code/ReLA/datasets" 184 | COCO_TRAIN_2014_IMAGE_ROOT = "/mnt/lustre/hhding/code/ReLA/datasets/images" 185 | REFCOCO_DATASET = "grefcoco" 186 | REFCOCO_SPLITBY = "unc" 187 | REFCOCO_SPLIT = "train" 188 | 189 | logger = setup_logger(name=__name__) 190 | 191 | dicts = load_grefcoco_json( 192 | REFCOCO_PATH, 193 | REFCOCO_DATASET, 194 | REFCOCO_SPLITBY, 195 | REFCOCO_SPLIT, 196 | COCO_TRAIN_2014_IMAGE_ROOT, 197 | ) 198 | logger.info("Done loading {} samples.".format(len(dicts))) 199 | -------------------------------------------------------------------------------- /utils/sam_mask_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict 4 | 5 | import numpy as np 6 | import torch 7 | import pycocotools.mask as mask_util 8 | from skimage.transform import resize 9 | import cv2 10 | import time 11 | class SAM_Mask_Reader: 12 | 13 | def __init__(self, json_dir) -> None: 14 | self.json_dir = json_dir 15 | 16 | print("reading sam mask json: ", json_dir) 17 | start_time = time.time() 18 | self.mask_list = self.read_mask_json(json_dir) 19 | end_time = time.time() 20 | print(f"read sam mask json takes {end_time - start_time} seconds") 21 | 22 | print("building sam mask index") 23 | start_time = time.time() 24 | self.sam_mask_index = self.build_sam_mask_index() 25 | end_time = time.time() 26 | print(f"build sam mask index takes {end_time - start_time} seconds") 27 | 28 | print("sam_mask_list: ", len(self.mask_list)) 29 | print("sam_mask_index: ", len(self.sam_mask_index)) 30 | 31 | def read_mask_json(self, path: str): 32 | with open(path, "r") as f: 33 | mask_list = json.load(f) 34 | return mask_list 35 | 36 | def build_sam_mask_index(self): 37 | sam_mask_index = {} 38 | for i, sample in enumerate(self.mask_list): 39 | sample_name = sample["image"] 40 | sam_mask_index[sample_name] = i 41 | 42 | return sam_mask_index 43 | 44 | def get_sam_mask_index(self, image_name: str): 45 | if image_name not in self.sam_mask_index: 46 | raise ValueError(f"image_name: {image_name} not in sam_mask_index") 47 | return self.sam_mask_index[image_name] 48 | 49 | def preprocess_mask(self, masks: np.ndarray): 50 | # masks: (H, W, K) 51 | 52 | # convert mask to float 53 | masks = masks.astype(np.float64) 54 | # padding to square 55 | h, w, _ = masks.shape 56 | padh = max(h, w) - h 57 | padw = max(h, w) - w 58 | masks = np.pad(masks, ((0, padh), (0, padw), (0, 0)), mode="constant", constant_values=0) 59 | 60 | assert masks.shape[0] == masks.shape[1] 61 | assert masks.shape[0] == max(h, w) 62 | 63 | # # resize to 64x64 64 | # mask = resize(mask, (64, 64), anti_aliasing=True) 65 | 66 | return masks 67 | 68 | 69 | def extract_sam_segs(self, image_name: str): 70 | 71 | index = self.get_sam_mask_index(image_name) 72 | sam_masks = self.mask_list[index] 73 | 74 | seg_list = [] 75 | seg_np_list_large = [] 76 | # extract binary seg 77 | 78 | # sort sam_masks by area 79 | masks = sam_masks['masks'] 80 | masks_sorted = sorted(masks, key=lambda x: x['area'], reverse=True) 81 | 82 | if len(masks_sorted) > 50: 83 | masks_sorted = masks_sorted[:50] 84 | # print("Warning: too many sam masks, only use the top 50, image_name: ", image_name) 85 | 86 | rle_segs = [mask['segmentation'] for mask in masks_sorted] 87 | segs_origin = mask_util.decode(rle_segs) # (H, W, K) 88 | 89 | segs_square = self.preprocess_mask(segs_origin) 90 | 91 | bbox = [mask['bbox'] for mask in masks_sorted] 92 | 93 | # for mask in sam_masks['masks']: 94 | # seg_rle = mask['segmentation'] 95 | # # decode coco rle format 96 | # seg = mask_util.decode(seg_rle) 97 | # seg_np_list_large.append(seg) 98 | # # padding to square and resize to 64x64 99 | # seg_small = self.preprocess_mask(seg) 100 | # # to tensor 101 | # seg_small = torch.from_numpy(seg_small).float() 102 | 103 | # seg_list.append(seg_small.unsqueeze(0)) 104 | 105 | # # concat seg_list 106 | # segs_tensor_small = torch.cat(seg_list, dim=0) 107 | # return segs_square, segs 108 | 109 | return { 110 | "segs_square": segs_square, 111 | "segs_origin": segs_origin, 112 | "bbox": bbox 113 | } 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | import cv2 7 | from typing import List 8 | from skimage.transform import resize 9 | 10 | IGNORE_INDEX = -100 11 | IMAGE_TOKEN_INDEX = -200 12 | DEFAULT_IMAGE_TOKEN = "" 13 | DEFAULT_IMAGE_PATCH_TOKEN = "" 14 | DEFAULT_IM_START_TOKEN = "" 15 | DEFAULT_IM_END_TOKEN = "" 16 | 17 | SHORT_QUESTION_LIST = [ 18 | DEFAULT_IMAGE_TOKEN + "\n" + "Can you segment the {class_name} in this image?", 19 | DEFAULT_IMAGE_TOKEN + "\n" + "Please segment the {class_name} in this image.", 20 | DEFAULT_IMAGE_TOKEN 21 | + "\n" 22 | + "What is {class_name} in this image? Please respond with segmentation mask.", 23 | DEFAULT_IMAGE_TOKEN 24 | + "\n" 25 | + "What is {class_name} in this image? Please output segmentation mask.", 26 | ] 27 | 28 | LONG_QUESTION_LIST = [ 29 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please respond with segmentation mask.", 30 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please output segmentation mask.", 31 | ] 32 | 33 | EXPLANATORY_QUESTION_LIST = [ 34 | "Please output segmentation mask and explain why.", 35 | "Please output segmentation mask and explain the reason.", 36 | "Please output segmentation mask and give some explaination.", 37 | ] 38 | 39 | ANSWER_LIST = [ 40 | "It is [SEG].", 41 | "Sure, [SEG].", 42 | "Sure, it is [SEG].", 43 | "Sure, the segmentation result is [SEG].", 44 | "[SEG].", 45 | ] 46 | 47 | 48 | class Summary(Enum): 49 | NONE = 0 50 | AVERAGE = 1 51 | SUM = 2 52 | COUNT = 3 53 | 54 | 55 | class AverageMeter(object): 56 | """Computes and stores the average and current value""" 57 | 58 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 59 | self.name = name 60 | self.fmt = fmt 61 | self.summary_type = summary_type 62 | self.reset() 63 | 64 | def reset(self): 65 | self.val = 0 66 | self.avg = 0 67 | self.sum = 0 68 | self.count = 0 69 | 70 | def update(self, val, n=1): 71 | self.val = val 72 | self.sum += val * n 73 | self.count += n 74 | self.avg = self.sum / self.count 75 | 76 | def all_reduce(self): 77 | device = "cuda" if torch.cuda.is_available() else "cpu" 78 | if isinstance(self.sum, np.ndarray): 79 | total = torch.tensor( 80 | self.sum.tolist() 81 | + [ 82 | self.count, 83 | ], 84 | dtype=torch.float32, 85 | device=device, 86 | ) 87 | else: 88 | total = torch.tensor( 89 | [self.sum, self.count], dtype=torch.float32, device=device 90 | ) 91 | 92 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 93 | if total.shape[0] > 2: 94 | self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() 95 | else: 96 | self.sum, self.count = total.tolist() 97 | self.avg = self.sum / (self.count + 1e-5) 98 | 99 | def __str__(self): 100 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 101 | return fmtstr.format(**self.__dict__) 102 | 103 | def summary(self): 104 | fmtstr = "" 105 | if self.summary_type is Summary.NONE: 106 | fmtstr = "" 107 | elif self.summary_type is Summary.AVERAGE: 108 | fmtstr = "{name} {avg:.3f}" 109 | elif self.summary_type is Summary.SUM: 110 | fmtstr = "{name} {sum:.3f}" 111 | elif self.summary_type is Summary.COUNT: 112 | fmtstr = "{name} {count:.3f}" 113 | else: 114 | raise ValueError("invalid summary type %r" % self.summary_type) 115 | 116 | return fmtstr.format(**self.__dict__) 117 | 118 | 119 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 120 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 121 | 122 | assert output.dim() in [1, 2, 3] 123 | assert output.shape == target.shape 124 | output = output.view(-1) 125 | target = target.view(-1) 126 | output[target == ignore_index] = ignore_index 127 | intersection = output[output == target] 128 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) 129 | area_output = torch.histc(output, bins=K, min=0, max=K - 1) 130 | area_target = torch.histc(target, bins=K, min=0, max=K - 1) 131 | area_union = area_output + area_target - area_intersection 132 | return area_intersection, area_union, area_target 133 | 134 | 135 | class ProgressMeter(object): 136 | def __init__(self, num_batches, meters, prefix=""): 137 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 138 | self.meters = meters 139 | self.prefix = prefix 140 | 141 | def display(self, batch): 142 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 143 | entries += [str(meter) for meter in self.meters] 144 | print("\t".join(entries)) 145 | 146 | def display_summary(self): 147 | entries = [" *"] 148 | entries += [meter.summary() for meter in self.meters] 149 | print(" ".join(entries)) 150 | 151 | def _get_batch_fmtstr(self, num_batches): 152 | num_digits = len(str(num_batches // 1)) 153 | fmt = "{:" + str(num_digits) + "d}" 154 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 155 | 156 | 157 | def dict_to_cuda(input_dict, torch_dtype=torch.bfloat16): 158 | for k, v in input_dict.items(): 159 | if isinstance(input_dict[k], torch.Tensor): 160 | input_dict[k] = v.cuda(non_blocking=True) 161 | if k == "images" or k == "images_clip": 162 | input_dict[k] = input_dict[k].to(dtype=torch_dtype) 163 | elif ( 164 | isinstance(input_dict[k], list) 165 | and len(input_dict[k]) > 0 166 | and isinstance(input_dict[k][0], torch.Tensor) 167 | ): 168 | input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] 169 | if k == "sam_segs_list": 170 | input_dict[k] = [ele.to(dtype=torch_dtype) for ele in input_dict[k]] 171 | return input_dict 172 | 173 | 174 | def compute_iop(seg: np.ndarray, gt: np.ndarray): 175 | # seg and gt are both binary masks 176 | assert seg.shape == gt.shape 177 | 178 | if seg.max() > 1 or gt.max() > 1: 179 | raise ValueError("seg and gt should be binary masks") 180 | 181 | # comput iou with all connected components in gt 182 | intersection = np.logical_and(seg, gt) 183 | union = np.logical_or(seg, gt) 184 | # iou = np.sum(intersection) / np.sum(union) 185 | iop = np.sum(intersection) / np.sum(seg) 186 | 187 | max_iop = iop 188 | 189 | return max_iop 190 | 191 | 192 | def compute_iou(seg: np.ndarray, gt: np.ndarray): 193 | # seg and gt are both binary masks 194 | assert seg.shape == gt.shape 195 | 196 | if seg.max() > 1 or gt.max() > 1: 197 | raise ValueError("seg and gt should be binary masks") 198 | 199 | # comput iou with all connected components in gt 200 | intersection = np.logical_and(seg, gt) 201 | union = np.logical_or(seg, gt) 202 | iou = np.sum(intersection) / np.sum(union) 203 | 204 | max_iou = iou 205 | 206 | # # Find connected components in the ground truth mask 207 | # num_labels, labels = cv2.connectedComponents(gt) 208 | 209 | 210 | # # Compute IoU for each component 211 | # for i in range(1, num_labels): # Start from 1 to ignore background 212 | # component_mask = (labels == i).astype(np.uint8) 213 | # # compute the iou between seg and component_mask 214 | # intersection = np.logical_and(seg, component_mask) 215 | # union = np.logical_or(seg, component_mask) 216 | # iou = np.sum(intersection) / np.sum(union) 217 | 218 | # max_iou = max(max_iou, iou) 219 | 220 | return max_iou 221 | 222 | # def compute_iou(seg: np.ndarray, gt: np.ndarray): 223 | # # seg and gt are both binary masks 224 | # assert seg.shape == gt.shape 225 | # if seg.max() > 1 or gt.max() > 1: 226 | # raise ValueError("seg and gt should be binary masks") 227 | 228 | # intersection = np.logical_and(seg, gt) 229 | # union = np.logical_or(seg, gt) 230 | # iou = np.sum(intersection) / np.sum(union) 231 | 232 | # return iou 233 | 234 | def compute_all_iou(segs: List[np.ndarray], gt: np.ndarray): 235 | # conptue the iou between segs and gt 236 | # segs: list of (H, W) : may be resized to 1024 if the original size is too large 237 | # gt: (H', W') original size of the image 238 | H, W, K = segs.shape 239 | 240 | gt = resize(gt, (H, W), anti_aliasing=False, preserve_range=True, order=0) 241 | 242 | ious = [] 243 | for i in range(K): 244 | seg_i = segs[:, :, i] 245 | 246 | assert seg_i.shape == gt.shape 247 | 248 | iou = compute_iou(seg_i, gt) 249 | 250 | ious.append(iou) 251 | 252 | return np.array(ious) 253 | 254 | 255 | def compute_all_iop(segs: List[np.ndarray], gt: np.ndarray): 256 | # conptue the iou between segs and gt 257 | # segs: list of (H, W) : may be resized to 1024 if the original size is too large 258 | # gt: (H', W') original size of the image 259 | H, W, K = segs.shape 260 | 261 | gt = resize(gt, (H, W), anti_aliasing=False, preserve_range=True, order=0) 262 | 263 | iops = [] 264 | for i in range(K): 265 | seg_i = segs[:, :, i] 266 | 267 | assert seg_i.shape == gt.shape 268 | 269 | iop = compute_iop(seg_i, gt) 270 | 271 | iops.append(iop) 272 | 273 | return np.array(iops) -------------------------------------------------------------------------------- /utils/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import time 5 | 6 | import cv2 7 | import torch 8 | import torch.nn.functional as F 9 | from transformers import CLIPImageProcessor 10 | 11 | from model.llava import conversation as conversation_lib 12 | from model.segment_anything.utils.transforms import ResizeLongestSide 13 | 14 | from .utils import DEFAULT_IMAGE_TOKEN 15 | from .sam_mask_reader import SAM_Mask_Reader 16 | 17 | def preprocess_multimodal(source, mm_use_im_start_end): 18 | for sentence in source: 19 | if DEFAULT_IMAGE_TOKEN in sentence["value"]: 20 | sentence["value"] = ( 21 | sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() 22 | ) 23 | sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] 24 | sentence["value"] = sentence["value"].strip() 25 | if "mmtag" in conversation_lib.default_conversation.version: 26 | sentence["value"] = sentence["value"].replace( 27 | DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "" 28 | ) 29 | return source 30 | 31 | 32 | class VQADataset(torch.utils.data.Dataset): 33 | pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) 34 | pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) 35 | img_size = 896 36 | ignore_label = 255 37 | 38 | def __init__( 39 | self, 40 | base_image_dir, 41 | tokenizer, 42 | vision_tower, 43 | samples_per_epoch=500 * 8 * 2 * 10, 44 | precision: str = "fp32", 45 | image_size: int = 224, 46 | num_classes_per_sample: int = 3, 47 | exclude_val=False, 48 | vqa_data="llava_instruct_150k", 49 | coco2017_sam_mask_helper=None 50 | ): 51 | self.exclude_val = exclude_val 52 | self.samples_per_epoch = samples_per_epoch 53 | self.num_classes_per_sample = num_classes_per_sample 54 | 55 | self.base_image_dir = base_image_dir 56 | self.image_size = image_size 57 | self.tokenizer = tokenizer 58 | self.precision = precision 59 | self.transform = ResizeLongestSide(image_size) 60 | self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) 61 | 62 | DATA_DIR = os.path.join(base_image_dir, "llava_dataset") 63 | self.vqa_image_root = os.path.join(base_image_dir, "coco/train2017") 64 | with open(os.path.join(DATA_DIR, "{}.json".format(vqa_data))) as f: 65 | vqa_data = json.load(f) 66 | self.vqa_data = vqa_data 67 | 68 | print("vqa_data: ", len(self.vqa_data)) 69 | 70 | # sam mask reader 71 | self.sam_mask_helper = coco2017_sam_mask_helper 72 | assert self.sam_mask_helper is not None 73 | 74 | 75 | def __len__(self): 76 | return self.samples_per_epoch 77 | 78 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 79 | """Normalize pixel values and pad to a square input.""" 80 | # Normalize colors 81 | x = (x - self.pixel_mean) / self.pixel_std 82 | 83 | # Pad 84 | h, w = x.shape[-2:] 85 | padh = self.img_size - h 86 | padw = self.img_size - w 87 | x = F.pad(x, (0, padw, 0, padh)) 88 | return x 89 | 90 | def __getitem__(self, idx): 91 | idx = random.randint(0, len(self.vqa_data) - 1) 92 | item = self.vqa_data[idx] 93 | image_path = os.path.join(self.vqa_image_root, item["image"]) 94 | image = cv2.imread(image_path) 95 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 96 | ori_size = image.shape[:2] 97 | image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[ 98 | "pixel_values" 99 | ][ 100 | 0 101 | ] # preprocess image for clip 102 | 103 | image = self.transform.apply_image(image) # preprocess image for sam 104 | resize = image.shape[:2] 105 | 106 | # read sam seg 107 | segs_dict = self.sam_mask_helper.extract_sam_segs(item["image"]) 108 | 109 | segs_square = segs_dict["segs_square"] 110 | segs_origin = segs_dict["segs_origin"] 111 | 112 | conv = conversation_lib.default_conversation.copy() 113 | source = item["conversations"] 114 | source = preprocess_multimodal( 115 | source, 116 | mm_use_im_start_end=conv.sep_style == conversation_lib.SeparatorStyle.TWO, 117 | ) 118 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 119 | conversations = [] 120 | if roles[source[0]["from"]] != conv.roles[0]: 121 | # Skip the first one if it is not from human 122 | source = source[1:] 123 | conv.messages = [] 124 | for j, sentence in enumerate(source): 125 | role = roles[sentence["from"]] 126 | assert role == conv.roles[j % 2], f"{i}" 127 | conv.append_message(role, sentence["value"]) 128 | conversations.append(conv.get_prompt()) 129 | 130 | questions = conversations 131 | sampled_classes = conversations 132 | 133 | image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) 134 | 135 | masks = torch.rand(0, *ori_size) 136 | label = torch.ones(ori_size) * self.ignore_label 137 | ious = torch.rand((0, segs_origin.shape[-1])) 138 | iops = torch.rand((0, segs_origin.shape[-1])) 139 | 140 | # convert segs_square to tensor and resize to 64x64 141 | segs_square = torch.from_numpy(segs_square).permute(2, 0, 1).contiguous() # (K, H, W) 142 | segs = F.interpolate(segs_square.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False) # (1, K, 64, 64) 143 | segs = segs.squeeze(0) # (K, 64, 64) 144 | 145 | 146 | # return ( 147 | # image_path, 148 | # image, 149 | # image_clip, 150 | # conversations, 151 | # masks, 152 | # label, 153 | # resize, 154 | # questions, 155 | # sampled_classes, 156 | # segs, 157 | # ious, 158 | # ) 159 | 160 | return { 161 | 'image_path': image_path, 162 | 'images': image, 163 | 'images_clip': image_clip, 164 | 'conversations': conversations, 165 | 'masks': masks, 166 | 'label': label, 167 | 'resize': resize, 168 | 'questions': questions, 169 | 'sampled_classes': sampled_classes, 170 | 'segs': segs, 171 | 'ious': ious, 172 | 'iops': iops, 173 | 'segs_origin': None, 174 | 'bbox': None, 175 | 'inference': False, 176 | } 177 | --------------------------------------------------------------------------------