├── Dataset.md ├── README.md ├── RUN_VideoGLaMM.md ├── Training.md ├── VideoGLaMM ├── .DS_Store ├── .gitignore ├── LICENSE ├── chat.py ├── eval_anet_entities_infer.py ├── eval_gcg_infer.py ├── eval_gcg_metrics.py ├── eval_grounding.py ├── eval_mevis.py ├── eval_referdavis_infer.py ├── eval_referdavis_metrics.py ├── gcg_data_gen │ ├── .DS_Store │ ├── anet_entities_gcg │ │ ├── 1_dev_anet_entities_for_gcg.py │ │ ├── 2_anet_entities_gcg_openai_refine.py │ │ └── 3_anet_entities_gcg_extract_masks.py │ ├── burst_ytvis_gcg │ │ ├── README.md │ │ ├── generate_annotations.py │ │ ├── generation.py │ │ ├── merge_b_y.py │ │ └── requirements.txt │ ├── dev_dataset_visualize.py │ ├── hcstvg_gcg │ │ ├── dev_hcstvg_2_gcg_captions.py │ │ └── dev_hcstvg_2_mask_gen.py │ ├── mevis_gcg │ │ └── dev_mevis_gcg.py │ ├── vidstg_gcg │ │ ├── dev_vidstg_gcg_captions.py │ │ └── dev_vidstg_gcg_mask_gen.py │ └── ytvos_gcg │ │ └── dev_ytvos_gcg.py ├── model │ ├── .DS_Store │ ├── VideoGLaMM.py │ ├── chatunivi │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── conversation.py │ │ ├── mm_utils.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── arch.py │ │ │ ├── builder.py │ │ │ ├── cluster.py │ │ │ ├── language_model │ │ │ │ └── llama.py │ │ │ └── multimodal_encoder │ │ │ │ ├── builder.py │ │ │ │ └── clip_encoder.py │ │ └── utils.py │ ├── llava │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── conversation.py │ │ ├── mm_utils.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── apply_delta.py │ │ │ ├── builder.py │ │ │ ├── consolidate.py │ │ │ ├── language_model │ │ │ │ ├── llava_llama.py │ │ │ │ ├── llava_mpt.py │ │ │ │ └── mpt │ │ │ │ │ ├── adapt_tokenizer.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── blocks.py │ │ │ │ │ ├── configuration_mpt.py │ │ │ │ │ ├── custom_embedding.py │ │ │ │ │ ├── flash_attn_triton.py │ │ │ │ │ ├── hf_prefixlm_converter.py │ │ │ │ │ ├── meta_init_context.py │ │ │ │ │ ├── modeling_mpt.py │ │ │ │ │ ├── norm.py │ │ │ │ │ └── param_init_fns.py │ │ │ ├── llava_arch.py │ │ │ ├── make_delta.py │ │ │ ├── multimodal_encoder │ │ │ │ ├── builder.py │ │ │ │ └── clip_encoder.py │ │ │ └── utils.py │ │ ├── train │ │ │ ├── llama_flash_attn_monkey_patch.py │ │ │ ├── llava_trainer.py │ │ │ ├── train.py │ │ │ └── train_mem.py │ │ └── utils.py │ ├── segment_anything │ │ ├── __init__.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── image_encoder.py │ │ │ ├── mask_decoder.py │ │ │ ├── prompt_encoder.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ ├── predictor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── amg.py │ │ │ ├── onnx.py │ │ │ └── transforms.py │ ├── segment_anything_2 │ │ ├── sam2 │ │ │ ├── __init__.py │ │ │ ├── automatic_mask_generator.py │ │ │ ├── build_sam.py │ │ │ ├── csrc │ │ │ │ └── connected_components.cu │ │ │ ├── modeling │ │ │ │ ├── __init__.py │ │ │ │ ├── backbones │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── hieradet.py │ │ │ │ │ ├── image_encoder.py │ │ │ │ │ └── utils.py │ │ │ │ ├── memory_attention.py │ │ │ │ ├── memory_encoder.py │ │ │ │ ├── position_encoding.py │ │ │ │ ├── sam │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── mask_decoder.py │ │ │ │ │ ├── prompt_encoder.py │ │ │ │ │ └── transformer.py │ │ │ │ ├── sam2_base.py │ │ │ │ └── sam2_utils.py │ │ │ ├── sam2_image_predictor.py │ │ │ ├── sam2_video_predictor.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── amg.py │ │ │ │ ├── misc.py │ │ │ │ └── transforms.py │ │ ├── sam2_configs │ │ │ ├── __init__.py │ │ │ ├── sam2_hiera_b+.yaml │ │ │ ├── sam2_hiera_l.yaml │ │ │ ├── sam2_hiera_s.yaml │ │ │ └── sam2_hiera_t.yaml │ │ └── setup.py │ └── videogpt_plus │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── conversation.py │ │ ├── mm_utils.py │ │ └── model │ │ ├── __init__.py │ │ ├── arch.py │ │ ├── builder.py │ │ ├── dataloader.py │ │ ├── internvideo │ │ ├── build_internvideo.py │ │ ├── config.py │ │ ├── easydict.py │ │ ├── flash_attention_class.py │ │ ├── internvideo2.py │ │ ├── internvideo2_stage2_config_vision.py │ │ ├── pos_embed.py │ │ └── utils.py │ │ ├── language_model │ │ ├── llama3_1.py │ │ └── phi3.py │ │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ └── processor.py │ │ └── multimodal_projector │ │ └── builder.py ├── requirements.txt ├── train_ds_with_videogptplus.py └── utils │ ├── .DS_Store │ ├── __init__.py │ ├── ade20k_classes.json │ ├── clair.py │ ├── cocostuff_classes.txt │ ├── conv_generator.py │ ├── conversation.py │ ├── data_processing.py │ ├── dataset.py │ ├── enc_preprocessors.py │ ├── grandf_dataset.py │ ├── grefcoco.py │ ├── grefer.py │ ├── grounded_video_qa.py │ ├── grounding_utils │ ├── __init__.py │ ├── box_ops.py │ ├── image_transforms.py │ └── misc.py │ ├── hcstvg_dataset.py │ ├── itm_transforms.py │ ├── mevis_dataset.py │ ├── mevis_gcg.py │ ├── misc.py │ ├── ordered_datasets │ ├── ordered_mevis.py │ └── ordered_rvos.py │ ├── preproc_hcstvgv2.py │ ├── preproc_vidstg.py │ ├── reason_seg_dataset.py │ ├── refer.py │ ├── refer_datasets │ ├── __init__.py │ ├── a2d.py │ ├── box_ops.py │ ├── davis.py │ ├── jhmdb.py │ ├── mevis.py │ ├── new │ │ ├── davis17.py │ │ └── ytvos.py │ ├── transforms.py │ └── ytvos.py │ ├── refer_seg_dataset.py │ ├── refer_vos_dataset.py │ ├── sam_transforms.py │ ├── sem_seg_dataset.py │ ├── temporal_grounding_datasets.py │ ├── trainer.py │ ├── utils.py │ ├── video_gcg_anet.py │ ├── video_gcg_dataset.py │ ├── video_vqa_dataset.py │ ├── vidstg_dataset.py │ ├── vidstg_hcstvg_gcg.py │ ├── vqa_dataset.py │ └── ytvos_gcg.py └── docs └── images ├── .DS_Store ├── figures ├── cvpr25-teaser.png ├── cvpr25_main_block_diagram-jpg.jpg ├── cvpr25_qualitative.png └── videoglamm_annotation_pipeline.png └── logos ├── IVAL_logo.png ├── MBZUAI_logo.png ├── Oryx_logo.png └── logo-videoglamm.png /Dataset.md: -------------------------------------------------------------------------------- 1 | ### **Datasets used for training VideoGLaMM** 2 | 3 | - **LISA datasets**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/Ed6NO_HzOtxHuLUtwU5llQoBNTKW-hWsat_ADhMPBhdrVA?e=eVFlLu) 4 | - **GranDf dataset**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EX4whuRa1NdEihJGAIL1j0MBvF8xvx22tX9D3g3lNhW_VQ?e=Y23PDA) 5 | - **Video datasets**: 6 | - **ActivityNet**: [Link](https://drive.google.com/file/d/1qW5bLQtOienpMjO7vkaghA3pCPXc1qj9/view?usp=sharing) 7 | - **ActivityNet Captions**: [Link](https://drive.google.com/file/d/13tlKsXTA7YLwN62h_OjxmNulWBjbweNF/view?usp=sharing) 8 | - **ActivityNet Entities**: [Link](https://drive.google.com/file/d/13uYeYjXNW9mvsLpuLcNCpp2Zle-HuaUS/view?usp=sharing) 9 | - **BURST**: [Link](https://drive.google.com/file/d/119syWknOhxX9HGedkerk9kQo6MHBBQVQ/view?usp=sharing) 10 | - **HC-STVG**: [Link](https://drive.google.com/file/d/1pzK3aP4bMfpUA1dzSXC9GCxypHgZA1aL/view?usp=sharing) 11 | - **MeViS**: [Link](https://drive.google.com/file/d/1uuE2IcD4UGpkFdD2MWrIlVVN48PVoXRT/view?usp=sharing) 12 | - **Processed**: [Link](https://drive.google.com/file/d/1Z16c1WgmoqsUa557ILIG2QBy0hit5Nhr/view?usp=sharing) 13 | - ActivityNet Entities 14 | - HC-STVG 15 | - Referring DAVIS 16 | - VideoInstruct100K 17 | - VidSTG 18 | - **Refer DAVIS**: [Link](https://drive.google.com/file/d/1B4uHyt3_KZIFs9bQobowkPg1y0tG0IuO/view?usp=sharing) 19 | - DAVIS 16 20 | - DAVIS 17 21 | - **Refer YouTube-VOS**: [Link](https://drive.google.com/file/d/1zApsra2fqGX8b3diSvhIfe7bBjZW9tJI/view?usp=sharing) 22 | - **VideoInstruct100K**: [Link](https://drive.google.com/file/d/1l6XKWbX40tGIG1K05iBW8q4QFfDA-2_1/view?usp=sharing) 23 | - **VidSTG**: [Link](https://drive.google.com/file/d/12INPWw_FAQcXkeIgGdm61vJ35tkJeGAF/view?usp=sharing) 24 | - **YTVIS**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EQIEnJmu1yhIhjkwRD8fVcYBRHaQFI9CmZDOoLRkmCXOBw?e=cbDr28) 25 | 26 | - **GCG Datasets**: 27 | - **ActivityNet Entities GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EaG0sNQ--y1CjVf7WeYdahEBy2l6LQOvo_shVZqY22YRHg?e=r9itQ5) 28 | - **Burst-YTVIS GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EThjSLl_aMhIka6S1KxiqEEBf9rUCKNbX9LVyg60rw6Urg?e=wyhMsC) 29 | - **Refer-YTVOS GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EYy0FJi1PCxBiyudE-z9M6UBu-Ceae-mpjQ8w7aQ7c6KAA?e=VKUcoP) 30 | - **VidSTG GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EQhWEdhvCX1OkSSYPcnX9KsBrlw1AeTSffUtiD8K7wsc8w?e=FIEqEA) 31 | - **HC-STVG GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EaVtsayKs9ZIg83K36F9YC8B-7HiPa-SW3AXDT3-28m_Zw?e=H3dWQK) 32 | - **MeViS GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EcMipuuIMx9AofwShTCzAB8BVtLiRDJoFjPDTNnY48gv8Q?e=6tTTBx) 33 | 34 | 35 | ### **File Structure** 36 | 37 | Extract **LISA dataset** and **GranDf dataset** under `./dataset`: 38 | 39 | dataset/ 40 | ├── ade20k 41 | ├── coco 42 | ├── cocostuff 43 | ├── grandf_dataset 44 | ├── llava_dataset 45 | ├── mapillary 46 | ├── other 47 | ├── reason_seg 48 | ├── refer_seg 49 | └── vlpart 50 | 51 | Extract **other datasets** under `./video_dataset`: 52 | 53 | video_dataset/ 54 | ├── activitynet 55 | ├── activitynet_captions 56 | ├── activitynet_entities 57 | ├── activitynet_entities_gcg 58 | ├── burst 59 | ├── hcstvg 60 | ├── hcstvg_gcg 61 | ├── mevis 62 | ├── mevis_gcg 63 | ├── processed 64 | ├── refer_davis 65 | ├── refer_youtube_vos 66 | ├── video_gcg 67 | ├── video_instruct_100k 68 | ├── vidstg 69 | ├── vidstg_gcg 70 | ├── ytvis 71 | └── ytvos_gcg 72 | 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VideoGLaMM: A Large Multimodal Model for Pixel-Level Visual Grounding in Videos [CVPR 2025🔥] 2 | ![](https://i.imgur.com/waxVImv.png) 3 | 4 | [Shehan Munasinghe](https://github.com/shehanmunasinghe) , [Hanan Gani](https://github.com/hananshafi) , [Wenqi Zhu](#) , [Jiale Cao](https://jialecao001.github.io/), [Eric Xing](https://www.cs.cmu.edu/~epxing/), [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en). [Salman Khan](https://salman-h-khan.github.io/), 5 | 6 | **Mohamed bin Zayed University of Artificial Intelligence, Tianjin University, 7 | Linköping University, Australian National University, Carnegie Mellon University** 8 | 9 | [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://mbzuai-oryx.github.io/VideoGLaMM/) 10 | [![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2411.04923) 11 | 12 | --- 13 | 14 | ## 📢 Latest Updates 15 | 16 | - **Feb-2025:** Video-GLaMM is accepted at CVPR 2025! 🎊🎊 17 | 18 | --- 19 | 20 | ## Overview 21 | 22 |

23 | VideoGLaMM Architectural Overview 24 |

25 | 26 | VideoGLaMM is a large video multimodal video model capable of pixel-level visual grounding. The model responds to natural language queries from the user and intertwines spatio-temporal object masks in its generated textual responses to provide a detailed understanding of video content. VideoGLaMM seamlessly connects three key components: a Large Language Model (LLM); dual vision encoders; and a spatio-temporal pixel decoder. The dual vision encoders extract spatial and temporal features separately, which are jointly passed to the LLM to output responses rich in both spatial and temporal cues. This is facilitated by end-to-end training on our proposed benchmark Grounded conversation Generation (GCG) dataset featuring 38k Video-QA triplets with 87k objects and 671k fine-grained masks. 27 | 28 | --- 29 | ## 🏆 Highlights 30 | 1. We introduce Video Grounded Large Multi-modal Model (VideoGLaMM), a video large multimodal model, capable of pixel-level visual grounding, featuring an end-to-end alignment mechanism. 31 | 32 | 2. To achieve fine-grained spatio-temporal alignment, we introduce a benchmark grounded conversation generation (GCG) dataset consisting of 38k grounded video-QA triplet pairs and 83k objects and roughly 671k fine-grained spatio-temporal masks. 33 | 34 | 3. We assess the performance of VideoGLaMM across diverse tasks spanning grounded conversation generation, visual grounding, and referring video segmentation, where it achieves state-of-the-art performance 35 | 36 | --- 37 | 38 | ## Architecture 39 | 40 |

41 | VideoGLaMM Architecture 42 |

43 | 44 | VideoGLaMM consists of following key components: (i) Spatio-Temporal Dual Encoder, (ii) Dual Alignment V-L Adapters for image and video features, (iii) Large Language Model (LLM) iv) L-V Adapter and (iv) Promptable Pixel Decoder. 45 | 46 | --- 47 | ## Benchmark and Annotation Pipeline 48 | 49 |

50 | Annotation Pipeline 51 |

52 | 53 | We propose a semi-automatic annotation pipeline for creating a grounded conversation generation (GCG) dataset for videos. 54 | 55 | 56 | --- 57 | ## Examples 🔍 58 | 59 | Given user queries, the VideoGLaMM generates textual responses and grounds objects and phrases using pixel-level masks, showing its detailed understanding of the video. 60 | 61 |

62 | VideoGLaMM Architecture 63 |

64 | 65 | --- 66 | 67 | ## Running VideoGLaMM 🔧 68 | 69 | ### Environment setup 70 | 71 | conda create --name=videoglamm python=3.11 72 | 73 | conda activate videoglamm 74 | 75 | pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121 76 | pip install transformers==4.41.0 77 | DS_BUILD_FUSED_ADAM=1 pip install deepspeed==0.14.0 78 | 79 | pip install -r VideoGLaMM/requirements_sam2_cluster.txt 80 | 81 | cd VideoGLaMM/model/segment_anything_2 82 | python setup.py build_ext --inplace 83 | cd ../../.. 84 | 85 | ### Training and Evaluation 86 | 87 | Please refer [here](RUN_VideoGLaMM.md) for instructions 88 | 89 | 90 | ## Citation 📜 91 | 92 | ```bibtex 93 | @article{munasinghe2024videoglamm, 94 | title={VideoGLaMM: A Large Multimodal Model for Pixel-Level Visual Grounding in Videos}, 95 | author={Shehan Munasinghe and Hanan Gani and Wenqi Zhu and Jiale Cao and Eric Xing and Fahad Khan and Salman Khan}, 96 | journal={ArXiv}, 97 | year={2024}, 98 | url={https://arxiv.org/abs/2411.04923} 99 | } 100 | ``` 101 | 102 | --- 103 | 104 | [](https://www.ival-mbzuai.com) 105 | [](https://github.com/mbzuai-oryx) 106 | [](https://mbzuai.ac.ae) -------------------------------------------------------------------------------- /RUN_VideoGLaMM.md: -------------------------------------------------------------------------------- 1 | # Checkpoints 2 | 3 | * Download SAM2 checkpoints from [here](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) 4 | 5 | * Download InternVideo2 checkpoints from [here](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4) 6 | 7 | * Download VideoGLaMM checkpoints from [here](https://mbzuaiac-my.sharepoint.com/:f:/g/personal/shehan_munasinghe_mbzuai_ac_ae/Etucj3LuqdRDocrle_8eJbcB8C11u-020AX7fwIYWJh-dg?e=uPanYM) 8 | 9 | # Command Line Demo 10 | 11 | python chat.py \ 12 | --llava_version_or_path="" \ 13 | --use_sam2_video_branch \ 14 | --base_model_type="vgpt|phi3" 15 | 16 | # Evaluation 17 | 18 | 19 | ## GCG Task 20 | 21 | python eval_gcg_infer.py \ 22 | --llava_version_or_path="" \ 23 | --use_sam2_video_branch \ 24 | --base_model_type="vgpt|phi3" \ 25 | --dataset_name='video_gcg'\ 26 | --vis_save_path="./vis_output_path" 27 | 28 | export OPENAI_API_KEY='' 29 | 30 | python eval_gcg_metrics.py \ 31 | --vis_save_path="" \ 32 | --eval_miou --eval_recall --eval_caption --use_clair 33 | 34 | ## MeViS 35 | 36 | python eval_mevis.py \ 37 | --llava_version_or_path="" \ 38 | --use_sam2_video_branch \ 39 | --base_model_type="vgpt|phi3" \ 40 | --dataset_name="MEVIS|valid"\ 41 | --vis_save_path="./vis_output_path" 42 | 43 | You can use following command to prepare .zip submission file 44 | 45 | cd [vis_output_path] 46 | zip -r ../mevis_out.zip * 47 | 48 | ## VidSTG 49 | 50 | python eval_grounding.py \ 51 | --llava_version_or_path="" \ 52 | --use_sam2_video_branch \ 53 | --base_model_type="vgpt|phi3" \ 54 | --dataset_name="vidstg"\ 55 | --vis_save_path="./vis_output_path" 56 | 57 | 58 | ## HCSTVG 59 | 60 | 61 | python eval_grounding.py \ 62 | --llava_version_or_path="" \ 63 | --use_sam2_video_branch \ 64 | --base_model_type="vgpt|phi3" \ 65 | --dataset_name="hcstvg"\ 66 | --vis_save_path="./vis_output_path" 67 | 68 | ## ReferYTVOS 69 | 70 | python eval_mevis.py \ 71 | --llava_version_or_path="" \ 72 | --use_sam2_video_branch \ 73 | --base_model_type="vgpt|phi3" \ 74 | --dataset_name="ReferYouTubeVOS|valid" \ 75 | --vis_save_path="./vis_output_path" 76 | 77 | 78 | ## ReferDAVIS17 79 | 80 | 81 | python eval_referdavis_infer.py \ 82 | --llava_version_or_path="" \ 83 | --use_sam2_video_branch \ 84 | --base_model_type="vgpt|phi3" \ 85 | --dataset_name="ReferDAVIS|valid" \ 86 | --vis_save_path="./vis_output_path" 87 | 88 | python eval_referdavis_metrics.py --output_dir \ 89 | "./vis_output_path" -------------------------------------------------------------------------------- /Training.md: -------------------------------------------------------------------------------- 1 | # Training Instructions 2 | 3 | 4 | * Initial training 5 | 6 | 7 | deepspeed --master_port=29504 --num_gpus=4 train_ds_with_videogptplus.py \ 8 | --videogptplus_path="./checkpoints_hf/MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench" \ 9 | --vision_tower="./OpenGVLab/InternVideo2-Stage2_1B-224p-f4/InternVideo2-stage2_1b-224p-f4.pt" \ 10 | --image_vision_tower="openai/clip-vit-large-patch14-336" \ 11 | --dataset_dir='./dataset' \ 12 | --video_dataset_dir='./video_dataset' \ 13 | --sam_pretrained_path="./checkpoints/sam2/sam2_hiera_large.pt" \ 14 | --exp_name="sam2_videogptplusphi3" \ 15 | --logs_base_dir "./runs/logs" \ 16 | --ckpt_base_dir "./runs/ckpts" \ 17 | --dataset="sem_seg||refer_seg||vqa||reason_seg||grandf||refer_vos||mevis||vidstg||video_vqa" \ 18 | --sample_rates_for_datasets="9,3,3,1,1,10,10,10,10,10" \ 19 | --train_mask_decoder=False \ 20 | --tune_mm_mlp_adapter=True \ 21 | --use_sam_version='v2' \ 22 | --precision='fp16' \ 23 | --num_frames_for_sam=8 \ 24 | --batch_size=1 \ 25 | --grad_accumulation_steps=10 \ 26 | --epochs=20 \ 27 | --auto_resume 28 | 29 | * Finetuning with video-GCG data 30 | 31 | deepspeed --master_port=29504 --num_gpus=4 train_ds_with_videogptplus.py \ 32 | --videogptplus_path="./checkpoints_hf/MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench" \ 33 | --vision_tower="./OpenGVLab/InternVideo2-Stage2_1B-224p-f4/InternVideo2-stage2_1b-224p-f4.pt" \ 34 | --image_vision_tower="openai/clip-vit-large-patch14-336" \ 35 | --dataset_dir='./dataset' \ 36 | --video_dataset_dir='./video_dataset' \ 37 | --sam_pretrained_path="./checkpoints/sam2/sam2_hiera_large.pt" \ 38 | --exp_name="sam2_videogptplusphi3" \ 39 | --logs_base_dir "./runs/logs" \ 40 | --ckpt_base_dir "./runs/ckpts" \ 41 | --dataset="sem_seg||refer_seg||vqa||reason_seg||grandf||refer_vos||mevis||vidstg||video_vqa||anet_gcg||video_gcg||mevis_gcg||vidstg_gcg||hcstvg_gcg" \ 42 | --sample_rates_for_datasets="1,1,1,1,20,1,1,1,1,1,20,5,20,20,10" \ 43 | --train_mask_decoder=False \ 44 | --tune_mm_mlp_adapter=True \ 45 | --use_sam_version='v2' \ 46 | --precision='fp16' \ 47 | --num_frames_for_sam=8 \ 48 | --batch_size=1 \ 49 | --grad_accumulation_steps=10 \ 50 | --epochs=30 \ 51 | --auto_resume 52 | 53 | -------------------------------------------------------------------------------- /VideoGLaMM/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGLaMM/f8351c5f1fcda715c6afff66fa4a777e75d08e1b/VideoGLaMM/.DS_Store -------------------------------------------------------------------------------- /VideoGLaMM/.gitignore: -------------------------------------------------------------------------------- 1 | !aws_backup/ 2 | 3 | **/__pycache__ 4 | archive/ 5 | .runs/ 6 | runs 7 | .vscode/ 8 | 9 | dataset 10 | video_dataset 11 | checkpoints 12 | checkpoints/ 13 | checkpoints_hf 14 | checkpoints_hf/ 15 | 16 | .ipynb_checkpoints 17 | */.ipynb_checkpoints/* 18 | 19 | *.ipynb 20 | 21 | vis_output/ 22 | slurm_outputs/ 23 | 24 | scripts/ 25 | 26 | 27 | model/segment_anything_2/build/lib.linux-x86_64-cpython-310/sam2/_C.so 28 | model/segment_anything_2/build/lib.linux-x86_64-cpython-311/sam2/_C.so 29 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/.ninja_deps 30 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/.ninja_log 31 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/build.ninja 32 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/sam2/csrc/connected_components.o 33 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/.ninja_deps 34 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/.ninja_log 35 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/build.ninja 36 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/sam2/csrc/connected_components.o 37 | model/segment_anything_2/sam2/_C.so 38 | *.so 39 | *.ninja_log 40 | *.ninja_deps 41 | *.ninja 42 | *.o 43 | 44 | 45 | .gradio 46 | 47 | 48 | tools/stanford-corenlp-full-2018-02-27 -------------------------------------------------------------------------------- /VideoGLaMM/gcg_data_gen/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGLaMM/f8351c5f1fcda715c6afff66fa4a777e75d08e1b/VideoGLaMM/gcg_data_gen/.DS_Store -------------------------------------------------------------------------------- /VideoGLaMM/gcg_data_gen/burst_ytvis_gcg/README.md: -------------------------------------------------------------------------------- 1 | 2 | **Run the demo**
3 | Take YouTubeVIS2019 for example 4 | ```bash 5 | # Generation step1 --rough description of each object 6 | python generation.py --video_path video_data/youtube2019/train/JPEGImages --question These are frames from a video that I want to upload. What does the look like and what is the doing? --ann_path video_data/youtube2019/train.json --output_file generated_step1.txt --step step1 7 | 8 | 9 | # Generation step2 --corrected description of the object 10 | python generation.py --video_path video_data/youtube2019/train/JPEGImages --question These are frames from a video that I want to upload. Please modify this caption: The instance in the video is surrounded by a rectangular box with color number . The output caption must include what the looks like and what the is doing. Please do not mention any information about the bbox in the output. --ann_path video_data/youtube2019/train.json --output_file generated_step2.txt --step step2 --caption_file output/generated_step1.json 11 | 12 | # Generation step3 --comprehensive description of the video 13 | python generation.py --video_path video_data/youtube2019/train/JPEGImages --question These are frames from a video that I want to upload. In the video, the ID number of the box is on the top left of the box. There are some instance captions: '' Generate a dense caption that describes the video in detail based on the video and instance captions, including all of the instances mentioned in the instance captions and other instances in the video. Ensure that each instance mentioned in the instance caption appears exactly once in the dense caption, followed by the format {obj_} to indicate which instance caption the mentioned instance corresponds to. The {obj_} must directly follow the noun representing the instance.Please do not mention any information about the bbox in the output. --ann_path video_data/youtube2019/train.json --output_file generated_step3.txt --step step3 --caption_file output/generated_step2.json 14 | 15 | # Manually review the {obj_id} in the generated video captions based on the video content 16 | 17 | # Generate annotation file with caption 18 | python generate_annotations.py --ann_file video_data/youtube2019/train.json --obj_cap output/generated_step2.json --dense_cap output/manual_generated_step3.json --out_ann_file generated_annotation.json 19 | 20 | 21 | # Merge BURST and YouTubeVIS2019 annotation files 22 | python merge_b_y.py --burst_train video_data/burst/train/b2y_train_add_cap_del_filtered_ann.json --burst_val video_data/burst/val/b2y_val_add_cap_del_filtered_ann.json --yt19_train video_data/ytvis_2019/train_add_cap_filtered_ann.json --hq_ann_file video_data/ytvis_2019/ --out_ann_path output 23 | ``` 24 | -------------------------------------------------------------------------------- /VideoGLaMM/gcg_data_gen/burst_ytvis_gcg/generate_annotations.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | 5 | 6 | def get_arguments(): 7 | parser = argparse.ArgumentParser(description="Inference parameters") 8 | 9 | parser.add_argument("--ann_file", type=str, help="path to annotations") 10 | parser.add_argument("--obj_cap", type=str, help="the generated object-level caption") 11 | parser.add_argument("--dense_cap", type=str, help="the generated video-level caption") 12 | parser.add_argument("--out_ann_file", type=str, help="path to the final annotations file") 13 | 14 | return parser.parse_args() 15 | 16 | if __name__ == "__main__": 17 | args=get_arguments() 18 | ann_file= json.load(open(args.ann_file)) 19 | obj_cap=json.load(open(args.obj_cap)) 20 | dense_cap=json.load(open(args.dense_cap)) 21 | out_ann_file= args.out_ann_file 22 | 23 | 24 | for ann in ann_file['annotations']: 25 | if str(ann['id']) in obj_cap.keys(): 26 | ann['cap']=obj_cap[str(ann['id'])] 27 | else: 28 | ann['cap']=None 29 | 30 | video_cap={} 31 | for ann in ann_file['annotations']: 32 | if ann['video_id'] not in video_cap.keys(): 33 | video_cap[ann['video_id']]=[] 34 | if str(ann['id']) in obj_cap.keys(): 35 | video_cap[ann['video_id']].append(dict(cls_id=ann['category_id'],seg=ann['segmentations'],bboxes=ann['bboxes'],cap=obj_cap[str(ann['id'])],obj_id=ann['id'],ann_id=len(video_cap[ann['video_id']]))) 36 | for vid,video in enumerate(ann_file['videos']): 37 | if str(vid) in dense_cap.keys(): 38 | if len(video_cap[video['id']])==0: 39 | video['dense_cap']={} 40 | video['dense_cap']['v_id2o_id']=None 41 | video['dense_cap']['token_pos']=None 42 | video['dense_cap']['mask_id']=None 43 | video['dense_cap']['caption']=None 44 | else: 45 | video['dense_cap']={} 46 | video['dense_cap']['v_id2o_id']={} 47 | video['dense_cap']['token_pos']=[] 48 | video['dense_cap']['mask_id']=[] 49 | for an in video_cap[video['id']]: 50 | video['dense_cap']['v_id2o_id'][an['ann_id']]=an['obj_id'] 51 | spl_dense_cap=dense_cap[str(vid)].split(' ') 52 | me_cap=[] 53 | for wid,word in enumerate(spl_dense_cap): 54 | if '{obj_' in word : 55 | video['dense_cap']['token_pos'].append(len(me_cap)-1) 56 | m_id=int(re.findall(r'\d+',word)[0]) 57 | if m_id')] 20 | 21 | def insert_separator(X, sep): 22 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 23 | 24 | input_ids = [] 25 | offset = 0 26 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 27 | offset = 1 28 | input_ids.append(prompt_chunks[0][0]) 29 | 30 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 31 | input_ids.extend(x[offset:]) 32 | 33 | if return_tensors is not None: 34 | if return_tensors == 'pt': 35 | return torch.tensor(input_ids, dtype=torch.long) 36 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 37 | 38 | return input_ids 39 | 40 | 41 | def get_model_name_from_path(model_path): 42 | model_path = model_path.strip("/") 43 | model_paths = model_path.split("/") 44 | if model_paths[-1].startswith('checkpoint-'): 45 | return model_paths[-2] + "_" + model_paths[-1] 46 | else: 47 | return model_paths[-1] 48 | 49 | 50 | class KeywordsStoppingCriteria(StoppingCriteria): 51 | def __init__(self, keywords, tokenizer, input_ids): 52 | self.keywords = keywords 53 | self.keyword_ids = [] 54 | for keyword in keywords: 55 | cur_keyword_ids = tokenizer(keyword).input_ids 56 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 57 | cur_keyword_ids = cur_keyword_ids[1:] 58 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 59 | self.tokenizer = tokenizer 60 | self.start_len = input_ids.shape[1] 61 | 62 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 63 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 64 | offset = min(output_ids.shape[1] - self.start_len, 3) 65 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 66 | for keyword_id in self.keyword_ids: 67 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 68 | return True 69 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 70 | for keyword in self.keywords: 71 | if keyword in outputs: 72 | return True 73 | return False 74 | -------------------------------------------------------------------------------- /VideoGLaMM/model/chatunivi/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGLaMM/f8351c5f1fcda715c6afff66fa4a777e75d08e1b/VideoGLaMM/model/chatunivi/model/__init__.py -------------------------------------------------------------------------------- /VideoGLaMM/model/chatunivi/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | # from .eva_encoder import EVAVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | if vision_tower.startswith("openai") or vision_tower.startswith("laion"): 8 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 9 | 10 | # elif vision_tower.startswith("eva_vit_g"): 11 | # return EVAVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 12 | 13 | raise ValueError(f'Unknown vision tower: {vision_tower}') 14 | -------------------------------------------------------------------------------- /VideoGLaMM/model/chatunivi/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args=None, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | if args is None: 15 | self.select_layer = -2 16 | self.select_feature = 'patch' 17 | else: 18 | self.select_layer = args.mm_vision_select_layer 19 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 20 | 21 | if not delay_load: 22 | self.load_model() 23 | else: 24 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 25 | 26 | def load_model(self): 27 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 28 | self.image_eval_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 29 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 30 | self.vision_tower.requires_grad_(False) 31 | 32 | self.is_loaded = True 33 | 34 | def feature_select(self, image_forward_outs, select_feature='patch'): 35 | image_features = image_forward_outs.hidden_states[self.select_layer] 36 | if select_feature == 'patch': 37 | image_features = image_features[:, 1:] 38 | elif select_feature == 'cls_patch': 39 | image_features = image_features 40 | else: 41 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 42 | return image_features 43 | 44 | @torch.no_grad() 45 | def forward(self, images, select_feature='patch'): 46 | if type(images) is list: 47 | image_features = [] 48 | for image in images: 49 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 50 | image_feature = self.feature_select(image_forward_out, select_feature).to(image.dtype) 51 | image_features.append(image_feature) 52 | else: 53 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 54 | image_features = self.feature_select(image_forward_outs, select_feature).to(images.dtype) 55 | 56 | return image_features 57 | 58 | @property 59 | def dummy_feature(self): 60 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 61 | 62 | @property 63 | def dtype(self): 64 | return self.vision_tower.dtype 65 | 66 | @property 67 | def device(self): 68 | return self.vision_tower.device 69 | 70 | @property 71 | def config(self): 72 | if self.is_loaded: 73 | return self.vision_tower.config 74 | else: 75 | return self.cfg_only 76 | 77 | @property 78 | def hidden_size(self): 79 | return self.config.hidden_size 80 | 81 | @property 82 | def num_patches(self): 83 | return (self.config.image_size // self.config.patch_size) ** 2 84 | -------------------------------------------------------------------------------- /VideoGLaMM/model/chatunivi/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from ChatUniVi.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import torch 5 | from PIL import Image 6 | from transformers import StoppingCriteria 7 | 8 | from .constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def process_images(images, image_processor, model_cfg): 16 | return image_processor(images, return_tensors="pt")["pixel_values"] 17 | 18 | 19 | def tokenizer_image_token( 20 | prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None 21 | ): 22 | ''' 23 | This function is used to tokenize a prompt string containing the token. 24 | 25 | - The input string is tokenized into a list of token IDs. 26 | - The image_token_index is inserted between the chunks of the prompt where was present. 27 | - The resulting result is either a list of token IDs or a PyTorch tensor, depending on the specified return_tensors parameter. 28 | 29 | ''' 30 | # splits the prompt string into chunks based on the token and tokenizes each chunk separately using the provided tokenizer. 31 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] 32 | 33 | 34 | # helper function that takes a list X and a separator sep, and it interleaves the elements of X with the separator. 35 | # It is used to insert the image_token_index between the tokenized chunks of the prompt. 36 | def insert_separator(X, sep): 37 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 38 | 39 | input_ids = [] 40 | offset = 0 41 | # If the first chunk of the prompt starts with the tokenizer's beginning-of-sequence (BOS) token, 42 | # it increments the offset and appends the BOS token to the input_ids list. 43 | if ( 44 | len(prompt_chunks) > 0 45 | and len(prompt_chunks[0]) > 0 46 | and prompt_chunks[0][0] == tokenizer.bos_token_id 47 | ): 48 | offset = 1 49 | input_ids.append(prompt_chunks[0][0]) 50 | 51 | # uses the insert_separator function to insert the image_token_index between the tokenized chunks of the prompt. 52 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 53 | input_ids.extend(x[offset:]) 54 | 55 | # The resulting input_ids list contains the tokenized prompt with the image token index appropriately inserted. 56 | 57 | if return_tensors is not None: 58 | if return_tensors == "pt": 59 | return torch.tensor(input_ids, dtype=torch.long) 60 | raise ValueError(f"Unsupported tensor type: {return_tensors}") 61 | return input_ids 62 | 63 | 64 | def get_model_name_from_path(model_path): 65 | model_path = model_path.strip("/") 66 | model_paths = model_path.split("/") 67 | if model_paths[-1].startswith("checkpoint-"): 68 | return model_paths[-2] + "_" + model_paths[-1] 69 | else: 70 | return model_paths[-1] 71 | 72 | 73 | class KeywordsStoppingCriteria(StoppingCriteria): 74 | def __init__(self, keywords, tokenizer, input_ids): 75 | self.keywords = keywords 76 | self.keyword_ids = [] 77 | for keyword in keywords: 78 | cur_keyword_ids = tokenizer(keyword).input_ids 79 | if ( 80 | len(cur_keyword_ids) > 1 81 | and cur_keyword_ids[0] == tokenizer.bos_token_id 82 | ): 83 | cur_keyword_ids = cur_keyword_ids[1:] 84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 85 | self.tokenizer = tokenizer 86 | self.start_len = input_ids.shape[1] 87 | 88 | def __call__( 89 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 90 | ) -> bool: 91 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 92 | offset = min(output_ids.shape[1] - self.start_len, 3) 93 | self.keyword_ids = [ 94 | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids 95 | ] 96 | for keyword_id in self.keyword_ids: 97 | if output_ids[0, -keyword_id.shape[0] :] == keyword_id: 98 | return True 99 | outputs = self.tokenizer.batch_decode( 100 | output_ids[:, -offset:], skip_special_tokens=True 101 | )[0] 102 | for keyword in self.keywords: 103 | if keyword in outputs: 104 | return True 105 | return False 106 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM 2 | from .language_model.llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM 3 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava import LlavaLlamaForCausalLM 9 | from tqdm import tqdm 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 17 | ) 18 | 19 | print("Loading delta") 20 | delta = LlavaLlamaForCausalLM.from_pretrained( 21 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 22 | ) 23 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 24 | 25 | print("Applying delta") 26 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 27 | if name not in base.state_dict(): 28 | assert name in [ 29 | "model.mm_projector.weight", 30 | "model.mm_projector.bias", 31 | ], f"{name} not in base model" 32 | continue 33 | if param.data.shape == base.state_dict()[name].shape: 34 | param.data += base.state_dict()[name] 35 | else: 36 | assert name in [ 37 | "model.embed_tokens.weight", 38 | "lm_head.weight", 39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 40 | bparam = base.state_dict()[name] 41 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 42 | 43 | print("Saving target model") 44 | delta.save_pretrained(target_model_path) 45 | delta_tokenizer.save_pretrained(target_model_path) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--base-model-path", type=str, required=True) 51 | parser.add_argument("--target-model-path", type=str, required=True) 52 | parser.add_argument("--delta-path", type=str, required=True) 53 | 54 | args = parser.parse_args() 55 | 56 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 57 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava.model import * 9 | from llava.model.utils import auto_upgrade 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained( 17 | src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 18 | ) 19 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 20 | src_model.save_pretrained(dst_path) 21 | src_tokenizer.save_pretrained(dst_path) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--src", type=str, required=True) 27 | parser.add_argument("--dst", type=str, required=True) 28 | 29 | args = parser.parse_args() 30 | 31 | consolidate_ckpt(args.src, args.dst) 32 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from transformers import (AutoTokenizer, PreTrainedTokenizer, 4 | PreTrainedTokenizerFast) 5 | 6 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 7 | NUM_SENTINEL_TOKENS: int = 100 8 | 9 | 10 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 11 | """Adds sentinel tokens and padding token (if missing). 12 | 13 | Expands the tokenizer vocabulary to include sentinel tokens 14 | used in mixture-of-denoiser tasks as well as a padding token. 15 | 16 | All added tokens are added as special tokens. No tokens are 17 | added if sentinel tokens and padding token already exist. 18 | """ 19 | sentinels_to_add = [f"" for i in range(NUM_SENTINEL_TOKENS)] 20 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 21 | if tokenizer.pad_token is None: 22 | tokenizer.add_tokens("", special_tokens=True) 23 | tokenizer.pad_token = "" 24 | assert tokenizer.pad_token_id is not None 25 | sentinels = "".join([f"" for i in range(NUM_SENTINEL_TOKENS)]) 26 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 27 | tokenizer.sentinel_token_ids = _sentinel_token_ids 28 | 29 | 30 | class AutoTokenizerForMOD(AutoTokenizer): 31 | """AutoTokenizer + Adaptation for MOD. 32 | 33 | A simple wrapper around AutoTokenizer to make instantiating 34 | an MOD-adapted tokenizer a bit easier. 35 | 36 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 37 | a padding token, and a property to get the token ids of the 38 | sentinel tokens. 39 | """ 40 | 41 | @classmethod 42 | def from_pretrained(cls, *args, **kwargs): 43 | """See `AutoTokenizer.from_pretrained` docstring.""" 44 | tokenizer = super().from_pretrained(*args, **kwargs) 45 | adapt_tokenizer_for_denoising(tokenizer) 46 | return tokenizer 47 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .attention import ATTN_CLASS_REGISTRY 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | 11 | class MPTMLP(nn.Module): 12 | def __init__( 13 | self, d_model: int, expansion_ratio: int, device: Optional[str] = None 14 | ): 15 | super().__init__() 16 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 17 | self.act = nn.GELU(approximate="none") 18 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 19 | self.down_proj._is_residual = True 20 | 21 | def forward(self, x): 22 | return self.down_proj(self.act(self.up_proj(x))) 23 | 24 | 25 | class MPTBlock(nn.Module): 26 | def __init__( 27 | self, 28 | d_model: int, 29 | n_heads: int, 30 | expansion_ratio: int, 31 | attn_config: Dict = { 32 | "attn_type": "multihead_attention", 33 | "attn_pdrop": 0.0, 34 | "attn_impl": "triton", 35 | "qk_ln": False, 36 | "clip_qkv": None, 37 | "softmax_scale": None, 38 | "prefix_lm": False, 39 | "attn_uses_sequence_id": False, 40 | "alibi": False, 41 | "alibi_bias_max": 8, 42 | }, 43 | resid_pdrop: float = 0.0, 44 | norm_type: str = "low_precision_layernorm", 45 | verbose: int = 0, 46 | device: Optional[str] = None, 47 | **kwargs 48 | ): 49 | del kwargs 50 | super().__init__() 51 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 52 | attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] 53 | self.norm_1 = norm_class(d_model, device=device) 54 | self.attn = attn_class( 55 | attn_impl=attn_config["attn_impl"], 56 | clip_qkv=attn_config["clip_qkv"], 57 | qk_ln=attn_config["qk_ln"], 58 | softmax_scale=attn_config["softmax_scale"], 59 | attn_pdrop=attn_config["attn_pdrop"], 60 | d_model=d_model, 61 | n_heads=n_heads, 62 | verbose=verbose, 63 | device=device, 64 | ) 65 | self.norm_2 = norm_class(d_model, device=device) 66 | self.ffn = MPTMLP( 67 | d_model=d_model, expansion_ratio=expansion_ratio, device=device 68 | ) 69 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 70 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 71 | 72 | def forward( 73 | self, 74 | x: torch.Tensor, 75 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 76 | attn_bias: Optional[torch.Tensor] = None, 77 | attention_mask: Optional[torch.ByteTensor] = None, 78 | is_causal: bool = True, 79 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 80 | a = self.norm_1(x) 81 | (b, attn_weights, past_key_value) = self.attn( 82 | a, 83 | past_key_value=past_key_value, 84 | attn_bias=attn_bias, 85 | attention_mask=attention_mask, 86 | is_causal=is_causal, 87 | ) 88 | x = x + self.resid_attn_dropout(b) 89 | m = self.norm_2(x) 90 | n = self.ffn(m) 91 | x = x + self.resid_ffn_dropout(n) 92 | return (x, attn_weights, past_key_value) 93 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | 7 | class SharedEmbedding(nn.Embedding): 8 | def forward(self, input: Tensor, unembed: bool = False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) 12 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | @contextmanager 8 | def init_empty_weights(include_buffers: bool = False): 9 | """Meta initialization context manager. 10 | 11 | A context manager under which models are initialized with all parameters 12 | on the meta device, therefore creating an empty model. Useful when just 13 | initializing the model would blow the available RAM. 14 | 15 | Args: 16 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 17 | not to also put all buffers on the meta device while initializing. 18 | 19 | Example: 20 | ```python 21 | import torch.nn as nn 22 | 23 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 24 | with init_empty_weights(): 25 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 26 | ``` 27 | 28 | 29 | 30 | Any model created under this context manager has no weights. As such you can't do something like 31 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 32 | 33 | 34 | """ 35 | with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: 36 | yield f 37 | 38 | 39 | @contextmanager 40 | def init_on_device(device: torch.device, include_buffers: bool = False): 41 | """Device initialization context manager. 42 | 43 | A context manager under which models are initialized with all parameters 44 | on the specified device. 45 | 46 | Args: 47 | device (`torch.device`): Device to initialize all parameters on. 48 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 49 | not to also put all buffers on the meta device while initializing. 50 | 51 | Example: 52 | ```python 53 | import torch.nn as nn 54 | 55 | with init_on_device(device=torch.device("cuda")): 56 | tst = nn.Liner(100, 100) # on `cuda` device 57 | ``` 58 | """ 59 | old_register_parameter = nn.Module.register_parameter 60 | if include_buffers: 61 | old_register_buffer = nn.Module.register_buffer 62 | 63 | def register_empty_parameter(module, name, param): 64 | old_register_parameter(module, name, param) 65 | if param is not None: 66 | param_cls = type(module._parameters[name]) 67 | kwargs = module._parameters[name].__dict__ 68 | module._parameters[name] = param_cls( 69 | module._parameters[name].to(device), **kwargs 70 | ) 71 | 72 | def register_empty_buffer(module, name, buffer): 73 | old_register_buffer(module, name, buffer) 74 | if buffer is not None: 75 | module._buffers[name] = module._buffers[name].to(device) 76 | 77 | if include_buffers: 78 | tensor_constructors_to_patch = { 79 | torch_function_name: getattr(torch, torch_function_name) 80 | for torch_function_name in ["empty", "zeros", "ones", "full"] 81 | } 82 | else: 83 | tensor_constructors_to_patch = {} 84 | 85 | def patch_tensor_constructor(fn): 86 | def wrapper(*args, **kwargs): 87 | kwargs["device"] = device 88 | return fn(*args, **kwargs) 89 | 90 | return wrapper 91 | 92 | try: 93 | nn.Module.register_parameter = register_empty_parameter 94 | if include_buffers: 95 | nn.Module.register_buffer = register_empty_buffer 96 | for torch_function_name in tensor_constructors_to_patch.keys(): 97 | setattr( 98 | torch, 99 | torch_function_name, 100 | patch_tensor_constructor(getattr(torch, torch_function_name)), 101 | ) 102 | yield 103 | finally: 104 | nn.Module.register_parameter = old_register_parameter 105 | if include_buffers: 106 | nn.Module.register_buffer = old_register_buffer 107 | for ( 108 | torch_function_name, 109 | old_torch_function, 110 | ) in tensor_constructors_to_patch.items(): 111 | setattr(torch, torch_function_name, old_torch_function) 112 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _cast_if_autocast_enabled(tensor): 5 | if torch.is_autocast_enabled(): 6 | if tensor.device.type == "cuda": 7 | dtype = torch.get_autocast_gpu_dtype() 8 | elif tensor.device.type == "cpu": 9 | dtype = torch.get_autocast_cpu_dtype() 10 | else: 11 | raise NotImplementedError() 12 | return tensor.to(dtype=dtype) 13 | return tensor 14 | 15 | 16 | class LPLayerNorm(torch.nn.LayerNorm): 17 | def __init__( 18 | self, 19 | normalized_shape, 20 | eps=1e-05, 21 | elementwise_affine=True, 22 | device=None, 23 | dtype=None, 24 | ): 25 | super().__init__( 26 | normalized_shape=normalized_shape, 27 | eps=eps, 28 | elementwise_affine=elementwise_affine, 29 | device=device, 30 | dtype=dtype, 31 | ) 32 | 33 | def forward(self, x): 34 | module_device = x.device 35 | downcast_x = _cast_if_autocast_enabled(x) 36 | downcast_weight = ( 37 | _cast_if_autocast_enabled(self.weight) 38 | if self.weight is not None 39 | else self.weight 40 | ) 41 | downcast_bias = ( 42 | _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 43 | ) 44 | with torch.autocast(enabled=False, device_type=module_device.type): 45 | return torch.nn.functional.layer_norm( 46 | downcast_x, 47 | self.normalized_shape, 48 | downcast_weight, 49 | downcast_bias, 50 | self.eps, 51 | ) 52 | 53 | 54 | def rms_norm(x, weight=None, eps=1e-05): 55 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 56 | if weight is not None: 57 | return output * weight 58 | return output 59 | 60 | 61 | class RMSNorm(torch.nn.Module): 62 | def __init__( 63 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None 64 | ): 65 | super().__init__() 66 | self.eps = eps 67 | if weight: 68 | self.weight = torch.nn.Parameter( 69 | torch.ones(normalized_shape, dtype=dtype, device=device) 70 | ) 71 | else: 72 | self.register_parameter("weight", None) 73 | 74 | def forward(self, x): 75 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 76 | 77 | 78 | class LPRMSNorm(RMSNorm): 79 | def __init__( 80 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None 81 | ): 82 | super().__init__( 83 | normalized_shape=normalized_shape, 84 | eps=eps, 85 | weight=weight, 86 | dtype=dtype, 87 | device=device, 88 | ) 89 | 90 | def forward(self, x): 91 | downcast_x = _cast_if_autocast_enabled(x) 92 | downcast_weight = ( 93 | _cast_if_autocast_enabled(self.weight) 94 | if self.weight is not None 95 | else self.weight 96 | ) 97 | with torch.autocast(enabled=False, device_type=x.device.type): 98 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 99 | 100 | 101 | NORM_CLASS_REGISTRY = { 102 | "layernorm": torch.nn.LayerNorm, 103 | "low_precision_layernorm": LPLayerNorm, 104 | "rmsnorm": RMSNorm, 105 | "low_precision_rmsnorm": LPRMSNorm, 106 | } 107 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from llava.model.utils import auto_upgrade 9 | from tqdm import tqdm 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 17 | ) 18 | 19 | print("Loading target model") 20 | auto_upgrade(target_model_path) 21 | target = AutoModelForCausalLM.from_pretrained( 22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 23 | ) 24 | 25 | print("Calculating delta") 26 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 27 | if name not in base.state_dict(): 28 | assert name in [ 29 | "model.mm_projector.weight", 30 | "model.mm_projector.bias", 31 | ], f"{name} not in base model" 32 | continue 33 | if param.data.shape == base.state_dict()[name].shape: 34 | param.data -= base.state_dict()[name] 35 | else: 36 | assert name in [ 37 | "model.embed_tokens.weight", 38 | "lm_head.weight", 39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 40 | bparam = base.state_dict()[name] 41 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 42 | 43 | print("Saving delta") 44 | if hub_repo_id: 45 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 46 | else: 47 | kwargs = {} 48 | target.save_pretrained(delta_path, **kwargs) 49 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 50 | target_tokenizer.save_pretrained(delta_path, **kwargs) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--base-model-path", type=str, required=True) 56 | parser.add_argument("--target-model-path", type=str, required=True) 57 | parser.add_argument("--delta-path", type=str, required=True) 58 | parser.add_argument("--hub-repo-id", type=str, default=None) 59 | args = parser.parse_args() 60 | 61 | make_delta( 62 | args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id 63 | ) 64 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr( 6 | vision_tower_cfg, 7 | "mm_vision_tower", 8 | getattr(vision_tower_cfg, "vision_tower", None), 9 | ) 10 | if ( 11 | vision_tower.startswith("openai") 12 | or vision_tower.startswith("laion") 13 | or "clip" in vision_tower 14 | ): 15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 16 | 17 | raise ValueError(f"Unknown vision tower: {vision_tower}") 18 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel 4 | 5 | 6 | class CLIPVisionTower(nn.Module): 7 | def __init__(self, vision_tower, args, delay_load=False): 8 | super().__init__() 9 | 10 | self.is_loaded = False 11 | 12 | self.vision_tower_name = vision_tower 13 | self.select_layer = args.mm_vision_select_layer 14 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 15 | 16 | if not delay_load: 17 | self.load_model() 18 | else: 19 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 20 | 21 | def load_model(self): 22 | self.image_processor = CLIPImageProcessor.from_pretrained( 23 | self.vision_tower_name 24 | ) 25 | self.vision_tower = CLIPVisionModel.from_pretrained( 26 | self.vision_tower_name, low_cpu_mem_usage=True 27 | ) 28 | self.vision_tower.requires_grad_(False) 29 | self.is_loaded = True 30 | 31 | def feature_select(self, image_forward_outs): 32 | image_features = image_forward_outs.hidden_states[self.select_layer] 33 | if self.select_feature == "patch": 34 | image_features = image_features[:, 1:] 35 | elif self.select_feature == "cls_patch": 36 | image_features = image_features 37 | else: 38 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 39 | return image_features 40 | 41 | @torch.no_grad() 42 | def forward(self, images): 43 | if type(images) is list: 44 | image_features = [] 45 | for image in images: 46 | image_forward_out = self.vision_tower( 47 | image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 48 | output_hidden_states=True, 49 | ) 50 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 51 | image_features.append(image_feature) 52 | else: 53 | image_forward_outs = self.vision_tower( 54 | images.to(device=self.device, dtype=self.dtype), 55 | output_hidden_states=True, 56 | ) 57 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 58 | 59 | torch.cuda.empty_cache() 60 | return image_features 61 | 62 | @property 63 | def dummy_feature(self): 64 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 65 | 66 | @property 67 | def dtype(self): 68 | return self.vision_tower.dtype 69 | 70 | @property 71 | def device(self): 72 | return self.vision_tower.device 73 | 74 | @property 75 | def config(self): 76 | if self.is_loaded: 77 | return self.vision_tower.config 78 | else: 79 | return self.cfg_only 80 | 81 | @property 82 | def hidden_size(self): 83 | return self.config.hidden_size 84 | 85 | @property 86 | def num_patches(self): 87 | return (self.config.image_size // self.config.patch_size) ** 2 88 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print( 9 | "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." 10 | ) 11 | print( 12 | "You must upgrade the checkpoint to the new code base (this can be done automatically)." 13 | ) 14 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 15 | if confirm.lower() in ["y", "yes"]: 16 | print("Upgrading checkpoint...") 17 | assert len(cfg.architectures) == 1 18 | setattr(cfg.__class__, "model_type", "llava") 19 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 20 | cfg.save_pretrained(config) 21 | print("Checkpoint upgraded.") 22 | else: 23 | print("Checkpoint upgrade aborted.") 24 | exit(1) 25 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | import transformers 6 | from einops import rearrange 7 | from torch import nn 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | try: 11 | from flash_attn.flash_attn_interface import \ 12 | flash_attn_unpadded_qkvpacked_func 13 | except ImportError: 14 | from flash_attn.flash_attn_interface import ( 15 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func, 16 | ) 17 | 18 | from flash_attn.bert_padding import pad_input, unpad_input 19 | 20 | 21 | def forward( 22 | self, 23 | hidden_states: torch.Tensor, 24 | attention_mask: Optional[torch.Tensor] = None, 25 | position_ids: Optional[torch.Tensor] = None, 26 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 27 | output_attentions: bool = False, 28 | use_cache: bool = False, 29 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 30 | """Input shape: Batch x Time x Channel 31 | 32 | attention_mask: [bsz, q_len] 33 | """ 34 | bsz, q_len, _ = hidden_states.size() 35 | 36 | query_states = ( 37 | self.q_proj(hidden_states) 38 | .view(bsz, q_len, self.num_heads, self.head_dim) 39 | .transpose(1, 2) 40 | ) 41 | key_states = ( 42 | self.k_proj(hidden_states) 43 | .view(bsz, q_len, self.num_heads, self.head_dim) 44 | .transpose(1, 2) 45 | ) 46 | value_states = ( 47 | self.v_proj(hidden_states) 48 | .view(bsz, q_len, self.num_heads, self.head_dim) 49 | .transpose(1, 2) 50 | ) 51 | # [bsz, q_len, nh, hd] 52 | # [bsz, nh, q_len, hd] 53 | 54 | kv_seq_len = key_states.shape[-2] 55 | assert past_key_value is None, "past_key_value is not supported" 56 | 57 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 58 | query_states, key_states = apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | assert not output_attentions, "output_attentions is not supported" 63 | assert not use_cache, "use_cache is not supported" 64 | 65 | # Flash attention codes from 66 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 67 | 68 | # transform the data into the format required by flash attention 69 | qkv = torch.stack( 70 | [query_states, key_states, value_states], dim=2 71 | ) # [bsz, nh, 3, q_len, hd] 72 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 73 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 74 | # the attention_mask should be the same as the key_padding_mask 75 | key_padding_mask = attention_mask 76 | 77 | if key_padding_mask is None: 78 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 79 | max_s = q_len 80 | cu_q_lens = torch.arange( 81 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 82 | ) 83 | output = flash_attn_unpadded_qkvpacked_func( 84 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 85 | ) 86 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 87 | else: 88 | nheads = qkv.shape[-2] 89 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 90 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 91 | x_unpad = rearrange( 92 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 93 | ) 94 | output_unpad = flash_attn_unpadded_qkvpacked_func( 95 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 96 | ) 97 | output = rearrange( 98 | pad_input( 99 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 100 | ), 101 | "b s (h d) -> b s h d", 102 | h=nheads, 103 | ) 104 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 105 | 106 | 107 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 108 | # requires the attention mask to be the same as the key_padding_mask 109 | def _prepare_decoder_attention_mask( 110 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 111 | ): 112 | # [bsz, seq_len] 113 | return attention_mask 114 | 115 | 116 | def replace_llama_attn_with_flash_attn(): 117 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 118 | if cuda_major < 8: 119 | logging.warning( 120 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 121 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 122 | ) 123 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 124 | _prepare_decoder_attention_mask 125 | ) 126 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 127 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import torch 5 | from transformers import Trainer 6 | 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | 12 | if hasattr(param, "ds_id"): 13 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 14 | if not ignore_status: 15 | print(name, "no ignore status") 16 | with zero.GatheredParameters([param]): 17 | param = param.data.detach().cpu().clone() 18 | else: 19 | param = param.detach().cpu().clone() 20 | return param 21 | 22 | 23 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 24 | to_return = { 25 | k: t 26 | for k, t in named_params 27 | if any(key_match in k for key_match in keys_to_match) 28 | } 29 | to_return = { 30 | k: maybe_zero_3(v, ignore_status=True, name=k).cpu() 31 | for k, v in to_return.items() 32 | } 33 | return to_return 34 | 35 | 36 | class LLaVATrainer(Trainer): 37 | def _save_checkpoint(self, model, trial, metrics=None): 38 | if getattr(self.args, "tune_mm_mlp_adapter", False): 39 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 40 | 41 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 42 | 43 | run_dir = self._get_output_dir(trial=trial) 44 | output_dir = os.path.join(run_dir, checkpoint_folder) 45 | 46 | # Only save Adapter 47 | keys_to_match = ["mm_projector"] 48 | if getattr(self.args, "use_im_start_end", False): 49 | keys_to_match.extend(["embed_tokens", "embed_in"]) 50 | 51 | weight_to_save = get_mm_adapter_state_maybe_zero_3( 52 | self.model.named_parameters(), keys_to_match 53 | ) 54 | 55 | if self.args.local_rank == 0 or self.args.local_rank == -1: 56 | self.model.config.save_pretrained(output_dir) 57 | torch.save( 58 | weight_to_save, os.path.join(output_dir, f"mm_projector.bin") 59 | ) 60 | else: 61 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 62 | 63 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 64 | if getattr(self.args, "tune_mm_mlp_adapter", False): 65 | pass 66 | else: 67 | super(LLaVATrainer, self)._save(output_dir, state_dict) 68 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import \ 7 | replace_llama_attn_with_flash_attn 8 | 9 | replace_llama_attn_with_flash_attn() 10 | 11 | from llava.train.train import train 12 | 13 | if __name__ == "__main__": 14 | train() 15 | -------------------------------------------------------------------------------- /VideoGLaMM/model/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | from llava.constants import LOGDIR 9 | 10 | server_error_msg = ( 11 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | ) 13 | moderation_msg = ( 14 | "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 15 | ) 16 | 17 | handler = None 18 | 19 | 20 | def build_logger(logger_name, logger_filename): 21 | global handler 22 | 23 | formatter = logging.Formatter( 24 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 25 | datefmt="%Y-%m-%d %H:%M:%S", 26 | ) 27 | 28 | # Set the format of root handlers 29 | if not logging.getLogger().handlers: 30 | logging.basicConfig(level=logging.INFO) 31 | logging.getLogger().handlers[0].setFormatter(formatter) 32 | 33 | # Redirect stdout and stderr to loggers 34 | stdout_logger = logging.getLogger("stdout") 35 | stdout_logger.setLevel(logging.INFO) 36 | sl = StreamToLogger(stdout_logger, logging.INFO) 37 | sys.stdout = sl 38 | 39 | stderr_logger = logging.getLogger("stderr") 40 | stderr_logger.setLevel(logging.ERROR) 41 | sl = StreamToLogger(stderr_logger, logging.ERROR) 42 | sys.stderr = sl 43 | 44 | # Get logger 45 | logger = logging.getLogger(logger_name) 46 | logger.setLevel(logging.INFO) 47 | 48 | # Add a file handler for all loggers 49 | if handler is None: 50 | os.makedirs(LOGDIR, exist_ok=True) 51 | filename = os.path.join(LOGDIR, logger_filename) 52 | handler = logging.handlers.TimedRotatingFileHandler( 53 | filename, when="D", utc=True 54 | ) 55 | handler.setFormatter(formatter) 56 | 57 | for name, item in logging.root.manager.loggerDict.items(): 58 | if isinstance(item, logging.Logger): 59 | item.addHandler(handler) 60 | 61 | return logger 62 | 63 | 64 | class StreamToLogger(object): 65 | """ 66 | Fake file-like stream object that redirects writes to a logger instance. 67 | """ 68 | 69 | def __init__(self, logger, log_level=logging.INFO): 70 | self.terminal = sys.stdout 71 | self.logger = logger 72 | self.log_level = log_level 73 | self.linebuf = "" 74 | 75 | def __getattr__(self, attr): 76 | return getattr(self.terminal, attr) 77 | 78 | def write(self, buf): 79 | temp_linebuf = self.linebuf + buf 80 | self.linebuf = "" 81 | for line in temp_linebuf.splitlines(True): 82 | # From the io.TextIOWrapper docs: 83 | # On output, if newline is None, any '\n' characters written 84 | # are translated to the system default line separator. 85 | # By default sys.stdout.write() expects '\n' newlines and then 86 | # translates them so this is still cross platform. 87 | if line[-1] == "\n": 88 | self.logger.log(self.log_level, line.rstrip()) 89 | else: 90 | self.linebuf += line 91 | 92 | def flush(self): 93 | if self.linebuf != "": 94 | self.logger.log(self.log_level, self.linebuf.rstrip()) 95 | self.linebuf = "" 96 | 97 | 98 | def disable_torch_init(): 99 | """ 100 | Disable the redundant torch default initialization to accelerate model creation. 101 | """ 102 | import torch 103 | 104 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 105 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 106 | 107 | 108 | def violates_moderation(text): 109 | """ 110 | Check whether the text violates OpenAI moderation API. 111 | """ 112 | url = "https://api.openai.com/v1/moderations" 113 | headers = { 114 | "Content-Type": "application/json", 115 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], 116 | } 117 | text = text.replace("\n", "") 118 | data = "{" + '"input": ' + f'"{text}"' + "}" 119 | data = data.encode("utf-8") 120 | try: 121 | ret = requests.post(url, headers=headers, data=data, timeout=5) 122 | flagged = ret.json()["results"][0]["flagged"] 123 | except requests.exceptions.RequestException as e: 124 | flagged = False 125 | except KeyError as e: 126 | flagged = False 127 | 128 | return flagged 129 | 130 | 131 | def pretty_print_semaphore(semaphore): 132 | if semaphore is None: 133 | return "None" 134 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 135 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .automatic_mask_generator import SamAutomaticMaskGenerator 8 | from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, 9 | build_sam_vit_l, sam_model_registry) 10 | from .predictor import SamPredictor 11 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | 9 | import torch 10 | 11 | from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, 12 | TwoWayTransformer, 13 | CustomMaskDecoder) 14 | 15 | 16 | def build_sam_vit_h(checkpoint=None, with_itm=False): 17 | if with_itm: 18 | return _build_sam( 19 | encoder_embed_dim=1280, 20 | encoder_depth=32, 21 | encoder_num_heads=16, 22 | encoder_global_attn_indexes=[7, 15, 23, 31], 23 | checkpoint=checkpoint, 24 | with_itm=True 25 | ) 26 | return _build_sam( 27 | encoder_embed_dim=1280, 28 | encoder_depth=32, 29 | encoder_num_heads=16, 30 | encoder_global_attn_indexes=[7, 15, 23, 31], 31 | checkpoint=checkpoint, 32 | ) 33 | 34 | 35 | build_sam = build_sam_vit_h 36 | 37 | 38 | def build_sam_vit_l(checkpoint=None): 39 | return _build_sam( 40 | encoder_embed_dim=1024, 41 | encoder_depth=24, 42 | encoder_num_heads=16, 43 | encoder_global_attn_indexes=[5, 11, 17, 23], 44 | checkpoint=checkpoint, 45 | ) 46 | 47 | 48 | def build_sam_vit_b(checkpoint=None): 49 | return _build_sam( 50 | encoder_embed_dim=768, 51 | encoder_depth=12, 52 | encoder_num_heads=12, 53 | encoder_global_attn_indexes=[2, 5, 8, 11], 54 | checkpoint=checkpoint, 55 | ) 56 | 57 | 58 | sam_model_registry = { 59 | "default": build_sam_vit_h, 60 | "vit_h": build_sam_vit_h, 61 | "vit_l": build_sam_vit_l, 62 | "vit_b": build_sam_vit_b, 63 | } 64 | 65 | 66 | def _build_sam( 67 | encoder_embed_dim, 68 | encoder_depth, 69 | encoder_num_heads, 70 | encoder_global_attn_indexes, 71 | checkpoint=None, 72 | with_itm=False, 73 | ): 74 | prompt_embed_dim = 256 75 | image_size = 1024 76 | vit_patch_size = 16 77 | image_embedding_size = image_size // vit_patch_size 78 | sam = Sam( 79 | image_encoder=ImageEncoderViT( 80 | depth=encoder_depth, 81 | embed_dim=encoder_embed_dim, 82 | img_size=image_size, 83 | mlp_ratio=4, 84 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 85 | num_heads=encoder_num_heads, 86 | patch_size=vit_patch_size, 87 | qkv_bias=True, 88 | use_rel_pos=True, 89 | global_attn_indexes=encoder_global_attn_indexes, 90 | window_size=14, 91 | out_chans=prompt_embed_dim, 92 | ), 93 | prompt_encoder=PromptEncoder( 94 | embed_dim=prompt_embed_dim, 95 | image_embedding_size=(image_embedding_size, image_embedding_size), 96 | input_image_size=(image_size, image_size), 97 | mask_in_chans=16, 98 | ), 99 | 100 | mask_decoder= CustomMaskDecoder( #NOTE: replace with CustomMaskDecoder 101 | num_multimask_outputs=3, 102 | transformer=TwoWayTransformer( 103 | depth=2, 104 | embedding_dim=prompt_embed_dim, 105 | mlp_dim=2048, 106 | num_heads=8, 107 | ), 108 | transformer_dim=prompt_embed_dim, 109 | iou_head_depth=3, 110 | iou_head_hidden_dim=256, 111 | ) if with_itm else 112 | MaskDecoder( 113 | num_multimask_outputs=3, 114 | transformer=TwoWayTransformer( 115 | depth=2, 116 | embedding_dim=prompt_embed_dim, 117 | mlp_dim=2048, 118 | num_heads=8, 119 | ), 120 | transformer_dim=prompt_embed_dim, 121 | iou_head_depth=3, 122 | iou_head_hidden_dim=256, 123 | ), 124 | 125 | pixel_mean=[123.675, 116.28, 103.53], 126 | pixel_std=[58.395, 57.12, 57.375], 127 | ) 128 | sam.eval() 129 | if checkpoint is not None: 130 | with open(checkpoint, "rb") as f: 131 | state_dict = torch.load(f) 132 | sam.load_state_dict(state_dict, strict=False) 133 | return sam 134 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .image_encoder import ImageEncoderViT 8 | from .mask_decoder import MaskDecoder, CustomMaskDecoder 9 | from .prompt_encoder import PromptEncoder 10 | from .sam import Sam 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from copy import deepcopy 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import functional as F 13 | from torchvision.transforms.functional import resize # type: ignore 14 | from torchvision.transforms.functional import to_pil_image 15 | 16 | 17 | class ResizeLongestSide: 18 | """ 19 | Resizes images to the longest side 'target_length', as well as provides 20 | methods for resizing coordinates and boxes. Provides methods for 21 | transforming both numpy array and batched torch tensors. 22 | """ 23 | 24 | def __init__(self, target_length: int) -> None: 25 | self.target_length = target_length 26 | 27 | def apply_image(self, image: np.ndarray) -> np.ndarray: 28 | """ 29 | Expects a numpy array with shape HxWxC in uint8 format. 30 | """ 31 | target_size = self.get_preprocess_shape( 32 | image.shape[0], image.shape[1], self.target_length 33 | ) 34 | return np.array(resize(to_pil_image(image), target_size)) 35 | 36 | def apply_coords( 37 | self, coords: np.ndarray, original_size: Tuple[int, ...] 38 | ) -> np.ndarray: 39 | """ 40 | Expects a numpy array of length 2 in the final dimension. Requires the 41 | original image size in (H, W) format. 42 | """ 43 | old_h, old_w = original_size 44 | new_h, new_w = self.get_preprocess_shape( 45 | original_size[0], original_size[1], self.target_length 46 | ) 47 | coords = deepcopy(coords).astype(float) 48 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 49 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 50 | return coords 51 | 52 | def apply_boxes( 53 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 54 | ) -> np.ndarray: 55 | """ 56 | Expects a numpy array shape Bx4. Requires the original image size 57 | in (H, W) format. 58 | """ 59 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 60 | return boxes.reshape(-1, 4) 61 | 62 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Expects batched images with shape BxCxHxW and float format. This 65 | transformation may not exactly match apply_image. apply_image is 66 | the transformation expected by the model. 67 | """ 68 | # Expects an image in BCHW format. May not exactly match apply_image. 69 | target_size = self.get_preprocess_shape( 70 | image.shape[0], image.shape[1], self.target_length 71 | ) 72 | return F.interpolate( 73 | image, target_size, mode="bilinear", align_corners=False, antialias=True 74 | ) 75 | 76 | def apply_coords_torch( 77 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 78 | ) -> torch.Tensor: 79 | """ 80 | Expects a torch tensor with length 2 in the last dimension. Requires the 81 | original image size in (H, W) format. 82 | """ 83 | old_h, old_w = original_size 84 | new_h, new_w = self.get_preprocess_shape( 85 | original_size[0], original_size[1], self.target_length 86 | ) 87 | coords = deepcopy(coords).to(torch.float) 88 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 89 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 90 | return coords 91 | 92 | def apply_boxes_torch( 93 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 94 | ) -> torch.Tensor: 95 | """ 96 | Expects a torch tensor with shape Bx4. Requires the original image 97 | size in (H, W) format. 98 | """ 99 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 100 | return boxes.reshape(-1, 4) 101 | 102 | @staticmethod 103 | def get_preprocess_shape( 104 | oldh: int, oldw: int, long_side_length: int 105 | ) -> Tuple[int, int]: 106 | """ 107 | Compute the output size given input size and target long side length. 108 | """ 109 | scale = long_side_length * 1.0 / max(oldh, oldw) 110 | newh, neww = oldh * scale, oldw * scale 111 | neww = int(neww + 0.5) 112 | newh = int(newh + 0.5) 113 | return (newh, neww) 114 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from hydra import initialize_config_module 9 | 10 | initialize_config_module("model/segment_anything_2/sam2_configs", version_base="1.2") 11 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | def build_sam2( 15 | config_file, 16 | ckpt_path=None, 17 | device="cuda", 18 | mode="eval", 19 | hydra_overrides_extra=[], 20 | apply_postprocessing=True, 21 | ): 22 | 23 | if apply_postprocessing: 24 | hydra_overrides_extra = hydra_overrides_extra.copy() 25 | hydra_overrides_extra += [ 26 | # dynamically fall back to multi-mask if the single mask is not stable 27 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 28 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 30 | ] 31 | # Read config and init model 32 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 33 | OmegaConf.resolve(cfg) 34 | model = instantiate(cfg.model, _recursive_=True) 35 | _load_checkpoint(model, ckpt_path) 36 | if device: 37 | model = model.to(device) 38 | if mode == "eval": 39 | model.eval() 40 | return model 41 | 42 | 43 | def build_sam2_video_predictor( 44 | config_file, 45 | ckpt_path=None, 46 | device="cuda", 47 | mode="eval", 48 | hydra_overrides_extra=[], 49 | apply_postprocessing=True, 50 | ): 51 | hydra_overrides = [ 52 | "++model._target_=model.segment_anything_2.sam2.sam2_video_predictor.SAM2VideoPredictor", 53 | ] 54 | if apply_postprocessing: 55 | hydra_overrides_extra = hydra_overrides_extra.copy() 56 | hydra_overrides_extra += [ 57 | # dynamically fall back to multi-mask if the single mask is not stable 58 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 59 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 61 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 62 | "++model.binarize_mask_from_pts_for_mem_enc=true", 63 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 64 | "++model.fill_hole_area=8", 65 | ] 66 | hydra_overrides.extend(hydra_overrides_extra) 67 | 68 | # Read config and init model 69 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 70 | OmegaConf.resolve(cfg) 71 | model = instantiate(cfg.model, _recursive_=True) 72 | _load_checkpoint(model, ckpt_path) 73 | if device: 74 | model = model.to(device) 75 | if mode == "eval": 76 | model.eval() 77 | return model 78 | 79 | 80 | # def _load_checkpoint(model, ckpt_path): 81 | # if ckpt_path is not None: 82 | # sd = torch.load(ckpt_path, map_location="cpu")["model"] 83 | # missing_keys, unexpected_keys = model.load_state_dict(sd) 84 | # if missing_keys: 85 | # logging.error(missing_keys) 86 | # raise RuntimeError() 87 | # if unexpected_keys: 88 | # logging.error(unexpected_keys) 89 | # raise RuntimeError() 90 | # logging.info("Loaded checkpoint sucessfully") 91 | 92 | def _load_checkpoint(model, ckpt_path): 93 | ''' 94 | load checkpoint from ckpt_path to model, while renaming 'gamma' to 'weight' in the state dict 95 | ''' 96 | if ckpt_path is not None: 97 | sd = torch.load(ckpt_path, map_location="cpu")["model"] 98 | 99 | # Rename 'gamma' to 'weight' in the state dict 100 | sd = {key.replace('.gamma', '.weight'): value for key, value in sd.items()} 101 | 102 | missing_keys, unexpected_keys = model.load_state_dict(sd) 103 | 104 | if missing_keys: 105 | logging.error(f"Missing keys: {missing_keys}") 106 | raise RuntimeError("Missing keys found in the state dict.") 107 | 108 | if unexpected_keys: 109 | logging.error(f"Unexpected keys: {unexpected_keys}") 110 | raise RuntimeError("Unexpected keys found in the state dict.") 111 | 112 | logging.info("Loaded checkpoint successfully.") 113 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | def forward(self, xs: List[torch.Tensor]): 102 | 103 | out = [None] * len(self.convs) 104 | pos = [None] * len(self.convs) 105 | assert len(xs) == len(self.convs) 106 | # fpn forward pass 107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 108 | prev_features = None 109 | # forward in top-down order (from low to high resolution) 110 | n = len(self.convs) - 1 111 | for i in range(n, -1, -1): 112 | x = xs[i] 113 | lateral_features = self.convs[n - i](x) 114 | if i in self.fpn_top_down_levels and prev_features is not None: 115 | top_down_features = F.interpolate( 116 | prev_features.to(dtype=torch.float32), 117 | scale_factor=2.0, 118 | mode=self.fpn_interp_model, 119 | align_corners=( 120 | None if self.fpn_interp_model == "nearest" else False 121 | ), 122 | antialias=False, 123 | ) 124 | prev_features = lateral_features + top_down_features 125 | if self.fuse_type == "avg": 126 | prev_features /= 2 127 | else: 128 | prev_features = lateral_features 129 | x_out = prev_features 130 | out[i] = x_out 131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 132 | 133 | return out, pos 134 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Normalize, Resize, ToTensor 11 | 12 | 13 | class SAM2Transforms(nn.Module): 14 | def __init__( 15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 16 | ): 17 | """ 18 | Transforms for SAM2. 19 | """ 20 | super().__init__() 21 | self.resolution = resolution 22 | self.mask_threshold = mask_threshold 23 | self.max_hole_area = max_hole_area 24 | self.max_sprinkle_area = max_sprinkle_area 25 | self.mean = [0.485, 0.456, 0.406] 26 | self.std = [0.229, 0.224, 0.225] 27 | self.to_tensor = ToTensor() 28 | self.transforms = torch.jit.script( 29 | nn.Sequential( 30 | Resize((self.resolution, self.resolution)), 31 | Normalize(self.mean, self.std), 32 | ) 33 | ) 34 | 35 | def __call__(self, x): 36 | x = self.to_tensor(x) 37 | return self.transforms(x) 38 | 39 | def forward_batch(self, img_list): 40 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 41 | img_batch = torch.stack(img_batch, dim=0) 42 | return img_batch 43 | 44 | def transform_coords( 45 | self, coords: torch.Tensor, normalize=False, orig_hw=None 46 | ) -> torch.Tensor: 47 | """ 48 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 49 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 50 | 51 | Returns 52 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 53 | """ 54 | if normalize: 55 | assert orig_hw is not None 56 | h, w = orig_hw 57 | coords = coords.clone() 58 | coords[..., 0] = coords[..., 0] / w 59 | coords[..., 1] = coords[..., 1] / h 60 | 61 | coords = coords * self.resolution # unnormalize coords 62 | return coords 63 | 64 | def transform_boxes( 65 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 66 | ) -> torch.Tensor: 67 | """ 68 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 69 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 70 | """ 71 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 72 | return boxes 73 | 74 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 75 | """ 76 | Perform PostProcessing on output masks. 77 | """ 78 | from model.segment_anything_2.sam2.utils.misc import get_connected_components 79 | 80 | masks = masks.float() 81 | if self.max_hole_area > 0: 82 | # Holes are those connected components in background with area <= self.fill_hole_area 83 | # (background regions are those with mask scores <= self.mask_threshold) 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold) 86 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 87 | is_hole = is_hole.reshape_as(masks) 88 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 89 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 90 | 91 | if self.max_sprinkle_area > 0: 92 | labels, areas = get_connected_components(mask_flat > self.mask_threshold) 93 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 94 | is_hole = is_hole.reshape_as(masks) 95 | # We fill holes with negative mask score (-10.0) to change them to background. 96 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 97 | 98 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 99 | return masks 100 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /VideoGLaMM/model/segment_anything_2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | def get_extensions(): 11 | srcs = ["sam2/csrc/connected_components.cu"] 12 | compile_args = { 13 | "cxx": [], 14 | "nvcc": [ 15 | "-DCUDA_HAS_FP16=1", 16 | "-D__CUDA_NO_HALF_OPERATORS__", 17 | "-D__CUDA_NO_HALF_CONVERSIONS__", 18 | "-D__CUDA_NO_HALF2_OPERATORS__", 19 | ], 20 | } 21 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 22 | return ext_modules 23 | 24 | 25 | # Setup configuration 26 | setup( 27 | ext_modules=get_extensions(), 28 | cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, 29 | ) 30 | -------------------------------------------------------------------------------- /VideoGLaMM/model/videogpt_plus/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import VideoGPTPlusPhi3ForCausalLM 2 | from .model import VideoGPTPlusLlamaForCausalLM 3 | -------------------------------------------------------------------------------- /VideoGLaMM/model/videogpt_plus/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from distutils.util import strtobool 3 | 4 | # Configuration Constants 5 | # TODO: Change the chunk size if you use any other video encoder accordingly 6 | CHUNK_SIZE = 4 # Video chunk size for InternVideo2-Stage2_1B-224p-f4 which is trained using 4 frames per video 7 | NUM_FRAMES = int(os.environ.get("NUM_FRAMES", 16)) # Number of video frames (if using video) 8 | NUM_CONTEXT_IMAGES = int(os.environ.get("NUM_CONTEXT_IMAGES", 16)) # Number of context images for video 9 | 10 | # Model Constants 11 | IGNORE_INDEX = -100 12 | IMAGE_TOKEN_INDEX = -200 13 | DEFAULT_IMAGE_TOKEN = "" 14 | DEFAULT_VIDEO_TOKEN = "