├── Dataset.md
├── README.md
├── RUN_VideoGLaMM.md
├── Training.md
├── VideoGLaMM
├── .DS_Store
├── .gitignore
├── LICENSE
├── chat.py
├── eval_anet_entities_infer.py
├── eval_gcg_infer.py
├── eval_gcg_metrics.py
├── eval_grounding.py
├── eval_mevis.py
├── eval_referdavis_infer.py
├── eval_referdavis_metrics.py
├── gcg_data_gen
│ ├── .DS_Store
│ ├── anet_entities_gcg
│ │ ├── 1_dev_anet_entities_for_gcg.py
│ │ ├── 2_anet_entities_gcg_openai_refine.py
│ │ └── 3_anet_entities_gcg_extract_masks.py
│ ├── burst_ytvis_gcg
│ │ ├── README.md
│ │ ├── generate_annotations.py
│ │ ├── generation.py
│ │ ├── merge_b_y.py
│ │ └── requirements.txt
│ ├── dev_dataset_visualize.py
│ ├── hcstvg_gcg
│ │ ├── dev_hcstvg_2_gcg_captions.py
│ │ └── dev_hcstvg_2_mask_gen.py
│ ├── mevis_gcg
│ │ └── dev_mevis_gcg.py
│ ├── vidstg_gcg
│ │ ├── dev_vidstg_gcg_captions.py
│ │ └── dev_vidstg_gcg_mask_gen.py
│ └── ytvos_gcg
│ │ └── dev_ytvos_gcg.py
├── model
│ ├── .DS_Store
│ ├── VideoGLaMM.py
│ ├── chatunivi
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ ├── mm_utils.py
│ │ ├── model
│ │ │ ├── __init__.py
│ │ │ ├── arch.py
│ │ │ ├── builder.py
│ │ │ ├── cluster.py
│ │ │ ├── language_model
│ │ │ │ └── llama.py
│ │ │ └── multimodal_encoder
│ │ │ │ ├── builder.py
│ │ │ │ └── clip_encoder.py
│ │ └── utils.py
│ ├── llava
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ ├── mm_utils.py
│ │ ├── model
│ │ │ ├── __init__.py
│ │ │ ├── apply_delta.py
│ │ │ ├── builder.py
│ │ │ ├── consolidate.py
│ │ │ ├── language_model
│ │ │ │ ├── llava_llama.py
│ │ │ │ ├── llava_mpt.py
│ │ │ │ └── mpt
│ │ │ │ │ ├── adapt_tokenizer.py
│ │ │ │ │ ├── attention.py
│ │ │ │ │ ├── blocks.py
│ │ │ │ │ ├── configuration_mpt.py
│ │ │ │ │ ├── custom_embedding.py
│ │ │ │ │ ├── flash_attn_triton.py
│ │ │ │ │ ├── hf_prefixlm_converter.py
│ │ │ │ │ ├── meta_init_context.py
│ │ │ │ │ ├── modeling_mpt.py
│ │ │ │ │ ├── norm.py
│ │ │ │ │ └── param_init_fns.py
│ │ │ ├── llava_arch.py
│ │ │ ├── make_delta.py
│ │ │ ├── multimodal_encoder
│ │ │ │ ├── builder.py
│ │ │ │ └── clip_encoder.py
│ │ │ └── utils.py
│ │ ├── train
│ │ │ ├── llama_flash_attn_monkey_patch.py
│ │ │ ├── llava_trainer.py
│ │ │ ├── train.py
│ │ │ └── train_mem.py
│ │ └── utils.py
│ ├── segment_anything
│ │ ├── __init__.py
│ │ ├── automatic_mask_generator.py
│ │ ├── build_sam.py
│ │ ├── modeling
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── image_encoder.py
│ │ │ ├── mask_decoder.py
│ │ │ ├── prompt_encoder.py
│ │ │ ├── sam.py
│ │ │ └── transformer.py
│ │ ├── predictor.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── amg.py
│ │ │ ├── onnx.py
│ │ │ └── transforms.py
│ ├── segment_anything_2
│ │ ├── sam2
│ │ │ ├── __init__.py
│ │ │ ├── automatic_mask_generator.py
│ │ │ ├── build_sam.py
│ │ │ ├── csrc
│ │ │ │ └── connected_components.cu
│ │ │ ├── modeling
│ │ │ │ ├── __init__.py
│ │ │ │ ├── backbones
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── hieradet.py
│ │ │ │ │ ├── image_encoder.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── memory_attention.py
│ │ │ │ ├── memory_encoder.py
│ │ │ │ ├── position_encoding.py
│ │ │ │ ├── sam
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── mask_decoder.py
│ │ │ │ │ ├── prompt_encoder.py
│ │ │ │ │ └── transformer.py
│ │ │ │ ├── sam2_base.py
│ │ │ │ └── sam2_utils.py
│ │ │ ├── sam2_image_predictor.py
│ │ │ ├── sam2_video_predictor.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── amg.py
│ │ │ │ ├── misc.py
│ │ │ │ └── transforms.py
│ │ ├── sam2_configs
│ │ │ ├── __init__.py
│ │ │ ├── sam2_hiera_b+.yaml
│ │ │ ├── sam2_hiera_l.yaml
│ │ │ ├── sam2_hiera_s.yaml
│ │ │ └── sam2_hiera_t.yaml
│ │ └── setup.py
│ └── videogpt_plus
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ ├── mm_utils.py
│ │ └── model
│ │ ├── __init__.py
│ │ ├── arch.py
│ │ ├── builder.py
│ │ ├── dataloader.py
│ │ ├── internvideo
│ │ ├── build_internvideo.py
│ │ ├── config.py
│ │ ├── easydict.py
│ │ ├── flash_attention_class.py
│ │ ├── internvideo2.py
│ │ ├── internvideo2_stage2_config_vision.py
│ │ ├── pos_embed.py
│ │ └── utils.py
│ │ ├── language_model
│ │ ├── llama3_1.py
│ │ └── phi3.py
│ │ ├── multimodal_encoder
│ │ ├── builder.py
│ │ ├── clip_encoder.py
│ │ └── processor.py
│ │ └── multimodal_projector
│ │ └── builder.py
├── requirements.txt
├── train_ds_with_videogptplus.py
└── utils
│ ├── .DS_Store
│ ├── __init__.py
│ ├── ade20k_classes.json
│ ├── clair.py
│ ├── cocostuff_classes.txt
│ ├── conv_generator.py
│ ├── conversation.py
│ ├── data_processing.py
│ ├── dataset.py
│ ├── enc_preprocessors.py
│ ├── grandf_dataset.py
│ ├── grefcoco.py
│ ├── grefer.py
│ ├── grounded_video_qa.py
│ ├── grounding_utils
│ ├── __init__.py
│ ├── box_ops.py
│ ├── image_transforms.py
│ └── misc.py
│ ├── hcstvg_dataset.py
│ ├── itm_transforms.py
│ ├── mevis_dataset.py
│ ├── mevis_gcg.py
│ ├── misc.py
│ ├── ordered_datasets
│ ├── ordered_mevis.py
│ └── ordered_rvos.py
│ ├── preproc_hcstvgv2.py
│ ├── preproc_vidstg.py
│ ├── reason_seg_dataset.py
│ ├── refer.py
│ ├── refer_datasets
│ ├── __init__.py
│ ├── a2d.py
│ ├── box_ops.py
│ ├── davis.py
│ ├── jhmdb.py
│ ├── mevis.py
│ ├── new
│ │ ├── davis17.py
│ │ └── ytvos.py
│ ├── transforms.py
│ └── ytvos.py
│ ├── refer_seg_dataset.py
│ ├── refer_vos_dataset.py
│ ├── sam_transforms.py
│ ├── sem_seg_dataset.py
│ ├── temporal_grounding_datasets.py
│ ├── trainer.py
│ ├── utils.py
│ ├── video_gcg_anet.py
│ ├── video_gcg_dataset.py
│ ├── video_vqa_dataset.py
│ ├── vidstg_dataset.py
│ ├── vidstg_hcstvg_gcg.py
│ ├── vqa_dataset.py
│ └── ytvos_gcg.py
└── docs
└── images
├── .DS_Store
├── figures
├── cvpr25-teaser.png
├── cvpr25_main_block_diagram-jpg.jpg
├── cvpr25_qualitative.png
└── videoglamm_annotation_pipeline.png
└── logos
├── IVAL_logo.png
├── MBZUAI_logo.png
├── Oryx_logo.png
└── logo-videoglamm.png
/Dataset.md:
--------------------------------------------------------------------------------
1 | ### **Datasets used for training VideoGLaMM**
2 |
3 | - **LISA datasets**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/Ed6NO_HzOtxHuLUtwU5llQoBNTKW-hWsat_ADhMPBhdrVA?e=eVFlLu)
4 | - **GranDf dataset**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EX4whuRa1NdEihJGAIL1j0MBvF8xvx22tX9D3g3lNhW_VQ?e=Y23PDA)
5 | - **Video datasets**:
6 | - **ActivityNet**: [Link](https://drive.google.com/file/d/1qW5bLQtOienpMjO7vkaghA3pCPXc1qj9/view?usp=sharing)
7 | - **ActivityNet Captions**: [Link](https://drive.google.com/file/d/13tlKsXTA7YLwN62h_OjxmNulWBjbweNF/view?usp=sharing)
8 | - **ActivityNet Entities**: [Link](https://drive.google.com/file/d/13uYeYjXNW9mvsLpuLcNCpp2Zle-HuaUS/view?usp=sharing)
9 | - **BURST**: [Link](https://drive.google.com/file/d/119syWknOhxX9HGedkerk9kQo6MHBBQVQ/view?usp=sharing)
10 | - **HC-STVG**: [Link](https://drive.google.com/file/d/1pzK3aP4bMfpUA1dzSXC9GCxypHgZA1aL/view?usp=sharing)
11 | - **MeViS**: [Link](https://drive.google.com/file/d/1uuE2IcD4UGpkFdD2MWrIlVVN48PVoXRT/view?usp=sharing)
12 | - **Processed**: [Link](https://drive.google.com/file/d/1Z16c1WgmoqsUa557ILIG2QBy0hit5Nhr/view?usp=sharing)
13 | - ActivityNet Entities
14 | - HC-STVG
15 | - Referring DAVIS
16 | - VideoInstruct100K
17 | - VidSTG
18 | - **Refer DAVIS**: [Link](https://drive.google.com/file/d/1B4uHyt3_KZIFs9bQobowkPg1y0tG0IuO/view?usp=sharing)
19 | - DAVIS 16
20 | - DAVIS 17
21 | - **Refer YouTube-VOS**: [Link](https://drive.google.com/file/d/1zApsra2fqGX8b3diSvhIfe7bBjZW9tJI/view?usp=sharing)
22 | - **VideoInstruct100K**: [Link](https://drive.google.com/file/d/1l6XKWbX40tGIG1K05iBW8q4QFfDA-2_1/view?usp=sharing)
23 | - **VidSTG**: [Link](https://drive.google.com/file/d/12INPWw_FAQcXkeIgGdm61vJ35tkJeGAF/view?usp=sharing)
24 | - **YTVIS**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EQIEnJmu1yhIhjkwRD8fVcYBRHaQFI9CmZDOoLRkmCXOBw?e=cbDr28)
25 |
26 | - **GCG Datasets**:
27 | - **ActivityNet Entities GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EaG0sNQ--y1CjVf7WeYdahEBy2l6LQOvo_shVZqY22YRHg?e=r9itQ5)
28 | - **Burst-YTVIS GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EThjSLl_aMhIka6S1KxiqEEBf9rUCKNbX9LVyg60rw6Urg?e=wyhMsC)
29 | - **Refer-YTVOS GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EYy0FJi1PCxBiyudE-z9M6UBu-Ceae-mpjQ8w7aQ7c6KAA?e=VKUcoP)
30 | - **VidSTG GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EQhWEdhvCX1OkSSYPcnX9KsBrlw1AeTSffUtiD8K7wsc8w?e=FIEqEA)
31 | - **HC-STVG GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EaVtsayKs9ZIg83K36F9YC8B-7HiPa-SW3AXDT3-28m_Zw?e=H3dWQK)
32 | - **MeViS GCG**: [Link](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/shehan_munasinghe_mbzuai_ac_ae/EcMipuuIMx9AofwShTCzAB8BVtLiRDJoFjPDTNnY48gv8Q?e=6tTTBx)
33 |
34 |
35 | ### **File Structure**
36 |
37 | Extract **LISA dataset** and **GranDf dataset** under `./dataset`:
38 |
39 | dataset/
40 | ├── ade20k
41 | ├── coco
42 | ├── cocostuff
43 | ├── grandf_dataset
44 | ├── llava_dataset
45 | ├── mapillary
46 | ├── other
47 | ├── reason_seg
48 | ├── refer_seg
49 | └── vlpart
50 |
51 | Extract **other datasets** under `./video_dataset`:
52 |
53 | video_dataset/
54 | ├── activitynet
55 | ├── activitynet_captions
56 | ├── activitynet_entities
57 | ├── activitynet_entities_gcg
58 | ├── burst
59 | ├── hcstvg
60 | ├── hcstvg_gcg
61 | ├── mevis
62 | ├── mevis_gcg
63 | ├── processed
64 | ├── refer_davis
65 | ├── refer_youtube_vos
66 | ├── video_gcg
67 | ├── video_instruct_100k
68 | ├── vidstg
69 | ├── vidstg_gcg
70 | ├── ytvis
71 | └── ytvos_gcg
72 |
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
VideoGLaMM: A Large Multimodal Model for Pixel-Level Visual Grounding in Videos [CVPR 2025🔥]
2 | 
3 |
4 | [Shehan Munasinghe](https://github.com/shehanmunasinghe) , [Hanan Gani](https://github.com/hananshafi) , [Wenqi Zhu](#) , [Jiale Cao](https://jialecao001.github.io/), [Eric Xing](https://www.cs.cmu.edu/~epxing/), [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en). [Salman Khan](https://salman-h-khan.github.io/),
5 |
6 | **Mohamed bin Zayed University of Artificial Intelligence, Tianjin University,
7 | Linköping University, Australian National University, Carnegie Mellon University**
8 |
9 | [](https://mbzuai-oryx.github.io/VideoGLaMM/)
10 | [](https://arxiv.org/abs/2411.04923)
11 |
12 | ---
13 |
14 | ## 📢 Latest Updates
15 |
16 | - **Feb-2025:** Video-GLaMM is accepted at CVPR 2025! 🎊🎊
17 |
18 | ---
19 |
20 | ##
Overview
21 |
22 |
23 |
24 |
25 |
26 | VideoGLaMM is a large video multimodal video model capable of pixel-level visual grounding. The model responds to natural language queries from the user and intertwines spatio-temporal object masks in its generated textual responses to provide a detailed understanding of video content. VideoGLaMM seamlessly connects three key components: a Large Language Model (LLM); dual vision encoders; and a spatio-temporal pixel decoder. The dual vision encoders extract spatial and temporal features separately, which are jointly passed to the LLM to output responses rich in both spatial and temporal cues. This is facilitated by end-to-end training on our proposed benchmark Grounded conversation Generation (GCG) dataset featuring 38k Video-QA triplets with 87k objects and 671k fine-grained masks.
27 |
28 | ---
29 | ## 🏆 Highlights
30 | 1. We introduce Video Grounded Large Multi-modal Model (VideoGLaMM), a video large multimodal model, capable of pixel-level visual grounding, featuring an end-to-end alignment mechanism.
31 |
32 | 2. To achieve fine-grained spatio-temporal alignment, we introduce a benchmark grounded conversation generation (GCG) dataset consisting of 38k grounded video-QA triplet pairs and 83k objects and roughly 671k fine-grained spatio-temporal masks.
33 |
34 | 3. We assess the performance of VideoGLaMM across diverse tasks spanning grounded conversation generation, visual grounding, and referring video segmentation, where it achieves state-of-the-art performance
35 |
36 | ---
37 |
38 | ##
Architecture
39 |
40 |
41 |
42 |
43 |
44 | VideoGLaMM consists of following key components: (i) Spatio-Temporal Dual Encoder, (ii) Dual Alignment V-L Adapters for image and video features, (iii) Large Language Model (LLM) iv) L-V Adapter and (iv) Promptable Pixel Decoder.
45 |
46 | ---
47 | ##
Benchmark and Annotation Pipeline
48 |
49 |
50 |
51 |
52 |
53 | We propose a semi-automatic annotation pipeline for creating a grounded conversation generation (GCG) dataset for videos.
54 |
55 |
56 | ---
57 | ## Examples 🔍
58 |
59 | Given user queries, the VideoGLaMM generates textual responses and grounds objects and phrases using pixel-level masks, showing its detailed understanding of the video.
60 |
61 |
62 |
63 |
64 |
65 | ---
66 |
67 | ## Running VideoGLaMM 🔧
68 |
69 | ### Environment setup
70 |
71 | conda create --name=videoglamm python=3.11
72 |
73 | conda activate videoglamm
74 |
75 | pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121
76 | pip install transformers==4.41.0
77 | DS_BUILD_FUSED_ADAM=1 pip install deepspeed==0.14.0
78 |
79 | pip install -r VideoGLaMM/requirements_sam2_cluster.txt
80 |
81 | cd VideoGLaMM/model/segment_anything_2
82 | python setup.py build_ext --inplace
83 | cd ../../..
84 |
85 | ### Training and Evaluation
86 |
87 | Please refer [here](RUN_VideoGLaMM.md) for instructions
88 |
89 |
90 | ## Citation 📜
91 |
92 | ```bibtex
93 | @article{munasinghe2024videoglamm,
94 | title={VideoGLaMM: A Large Multimodal Model for Pixel-Level Visual Grounding in Videos},
95 | author={Shehan Munasinghe and Hanan Gani and Wenqi Zhu and Jiale Cao and Eric Xing and Fahad Khan and Salman Khan},
96 | journal={ArXiv},
97 | year={2024},
98 | url={https://arxiv.org/abs/2411.04923}
99 | }
100 | ```
101 |
102 | ---
103 |
104 | [
](https://www.ival-mbzuai.com)
105 | [
](https://github.com/mbzuai-oryx)
106 | [
](https://mbzuai.ac.ae)
--------------------------------------------------------------------------------
/RUN_VideoGLaMM.md:
--------------------------------------------------------------------------------
1 | # Checkpoints
2 |
3 | * Download SAM2 checkpoints from [here](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)
4 |
5 | * Download InternVideo2 checkpoints from [here](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4)
6 |
7 | * Download VideoGLaMM checkpoints from [here](https://mbzuaiac-my.sharepoint.com/:f:/g/personal/shehan_munasinghe_mbzuai_ac_ae/Etucj3LuqdRDocrle_8eJbcB8C11u-020AX7fwIYWJh-dg?e=uPanYM)
8 |
9 | # Command Line Demo
10 |
11 | python chat.py \
12 | --llava_version_or_path="" \
13 | --use_sam2_video_branch \
14 | --base_model_type="vgpt|phi3"
15 |
16 | # Evaluation
17 |
18 |
19 | ## GCG Task
20 |
21 | python eval_gcg_infer.py \
22 | --llava_version_or_path="" \
23 | --use_sam2_video_branch \
24 | --base_model_type="vgpt|phi3" \
25 | --dataset_name='video_gcg'\
26 | --vis_save_path="./vis_output_path"
27 |
28 | export OPENAI_API_KEY=''
29 |
30 | python eval_gcg_metrics.py \
31 | --vis_save_path="" \
32 | --eval_miou --eval_recall --eval_caption --use_clair
33 |
34 | ## MeViS
35 |
36 | python eval_mevis.py \
37 | --llava_version_or_path="" \
38 | --use_sam2_video_branch \
39 | --base_model_type="vgpt|phi3" \
40 | --dataset_name="MEVIS|valid"\
41 | --vis_save_path="./vis_output_path"
42 |
43 | You can use following command to prepare .zip submission file
44 |
45 | cd [vis_output_path]
46 | zip -r ../mevis_out.zip *
47 |
48 | ## VidSTG
49 |
50 | python eval_grounding.py \
51 | --llava_version_or_path="" \
52 | --use_sam2_video_branch \
53 | --base_model_type="vgpt|phi3" \
54 | --dataset_name="vidstg"\
55 | --vis_save_path="./vis_output_path"
56 |
57 |
58 | ## HCSTVG
59 |
60 |
61 | python eval_grounding.py \
62 | --llava_version_or_path="" \
63 | --use_sam2_video_branch \
64 | --base_model_type="vgpt|phi3" \
65 | --dataset_name="hcstvg"\
66 | --vis_save_path="./vis_output_path"
67 |
68 | ## ReferYTVOS
69 |
70 | python eval_mevis.py \
71 | --llava_version_or_path="" \
72 | --use_sam2_video_branch \
73 | --base_model_type="vgpt|phi3" \
74 | --dataset_name="ReferYouTubeVOS|valid" \
75 | --vis_save_path="./vis_output_path"
76 |
77 |
78 | ## ReferDAVIS17
79 |
80 |
81 | python eval_referdavis_infer.py \
82 | --llava_version_or_path="" \
83 | --use_sam2_video_branch \
84 | --base_model_type="vgpt|phi3" \
85 | --dataset_name="ReferDAVIS|valid" \
86 | --vis_save_path="./vis_output_path"
87 |
88 | python eval_referdavis_metrics.py --output_dir \
89 | "./vis_output_path"
--------------------------------------------------------------------------------
/Training.md:
--------------------------------------------------------------------------------
1 | # Training Instructions
2 |
3 |
4 | * Initial training
5 |
6 |
7 | deepspeed --master_port=29504 --num_gpus=4 train_ds_with_videogptplus.py \
8 | --videogptplus_path="./checkpoints_hf/MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench" \
9 | --vision_tower="./OpenGVLab/InternVideo2-Stage2_1B-224p-f4/InternVideo2-stage2_1b-224p-f4.pt" \
10 | --image_vision_tower="openai/clip-vit-large-patch14-336" \
11 | --dataset_dir='./dataset' \
12 | --video_dataset_dir='./video_dataset' \
13 | --sam_pretrained_path="./checkpoints/sam2/sam2_hiera_large.pt" \
14 | --exp_name="sam2_videogptplusphi3" \
15 | --logs_base_dir "./runs/logs" \
16 | --ckpt_base_dir "./runs/ckpts" \
17 | --dataset="sem_seg||refer_seg||vqa||reason_seg||grandf||refer_vos||mevis||vidstg||video_vqa" \
18 | --sample_rates_for_datasets="9,3,3,1,1,10,10,10,10,10" \
19 | --train_mask_decoder=False \
20 | --tune_mm_mlp_adapter=True \
21 | --use_sam_version='v2' \
22 | --precision='fp16' \
23 | --num_frames_for_sam=8 \
24 | --batch_size=1 \
25 | --grad_accumulation_steps=10 \
26 | --epochs=20 \
27 | --auto_resume
28 |
29 | * Finetuning with video-GCG data
30 |
31 | deepspeed --master_port=29504 --num_gpus=4 train_ds_with_videogptplus.py \
32 | --videogptplus_path="./checkpoints_hf/MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench" \
33 | --vision_tower="./OpenGVLab/InternVideo2-Stage2_1B-224p-f4/InternVideo2-stage2_1b-224p-f4.pt" \
34 | --image_vision_tower="openai/clip-vit-large-patch14-336" \
35 | --dataset_dir='./dataset' \
36 | --video_dataset_dir='./video_dataset' \
37 | --sam_pretrained_path="./checkpoints/sam2/sam2_hiera_large.pt" \
38 | --exp_name="sam2_videogptplusphi3" \
39 | --logs_base_dir "./runs/logs" \
40 | --ckpt_base_dir "./runs/ckpts" \
41 | --dataset="sem_seg||refer_seg||vqa||reason_seg||grandf||refer_vos||mevis||vidstg||video_vqa||anet_gcg||video_gcg||mevis_gcg||vidstg_gcg||hcstvg_gcg" \
42 | --sample_rates_for_datasets="1,1,1,1,20,1,1,1,1,1,20,5,20,20,10" \
43 | --train_mask_decoder=False \
44 | --tune_mm_mlp_adapter=True \
45 | --use_sam_version='v2' \
46 | --precision='fp16' \
47 | --num_frames_for_sam=8 \
48 | --batch_size=1 \
49 | --grad_accumulation_steps=10 \
50 | --epochs=30 \
51 | --auto_resume
52 |
53 |
--------------------------------------------------------------------------------
/VideoGLaMM/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbzuai-oryx/VideoGLaMM/f8351c5f1fcda715c6afff66fa4a777e75d08e1b/VideoGLaMM/.DS_Store
--------------------------------------------------------------------------------
/VideoGLaMM/.gitignore:
--------------------------------------------------------------------------------
1 | !aws_backup/
2 |
3 | **/__pycache__
4 | archive/
5 | .runs/
6 | runs
7 | .vscode/
8 |
9 | dataset
10 | video_dataset
11 | checkpoints
12 | checkpoints/
13 | checkpoints_hf
14 | checkpoints_hf/
15 |
16 | .ipynb_checkpoints
17 | */.ipynb_checkpoints/*
18 |
19 | *.ipynb
20 |
21 | vis_output/
22 | slurm_outputs/
23 |
24 | scripts/
25 |
26 |
27 | model/segment_anything_2/build/lib.linux-x86_64-cpython-310/sam2/_C.so
28 | model/segment_anything_2/build/lib.linux-x86_64-cpython-311/sam2/_C.so
29 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/.ninja_deps
30 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/.ninja_log
31 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/build.ninja
32 | model/segment_anything_2/build/temp.linux-x86_64-cpython-310/sam2/csrc/connected_components.o
33 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/.ninja_deps
34 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/.ninja_log
35 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/build.ninja
36 | model/segment_anything_2/build/temp.linux-x86_64-cpython-311/sam2/csrc/connected_components.o
37 | model/segment_anything_2/sam2/_C.so
38 | *.so
39 | *.ninja_log
40 | *.ninja_deps
41 | *.ninja
42 | *.o
43 |
44 |
45 | .gradio
46 |
47 |
48 | tools/stanford-corenlp-full-2018-02-27
--------------------------------------------------------------------------------
/VideoGLaMM/gcg_data_gen/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbzuai-oryx/VideoGLaMM/f8351c5f1fcda715c6afff66fa4a777e75d08e1b/VideoGLaMM/gcg_data_gen/.DS_Store
--------------------------------------------------------------------------------
/VideoGLaMM/gcg_data_gen/burst_ytvis_gcg/README.md:
--------------------------------------------------------------------------------
1 |
2 | **Run the demo**
3 | Take YouTubeVIS2019 for example
4 | ```bash
5 | # Generation step1 --rough description of each object
6 | python generation.py --video_path video_data/youtube2019/train/JPEGImages --question These are frames from a video that I want to upload. What does the look like and what is the doing? --ann_path video_data/youtube2019/train.json --output_file generated_step1.txt --step step1
7 |
8 |
9 | # Generation step2 --corrected description of the object
10 | python generation.py --video_path video_data/youtube2019/train/JPEGImages --question These are frames from a video that I want to upload. Please modify this caption: The instance in the video is surrounded by a rectangular box with color number . The output caption must include what the looks like and what the is doing. Please do not mention any information about the bbox in the output. --ann_path video_data/youtube2019/train.json --output_file generated_step2.txt --step step2 --caption_file output/generated_step1.json
11 |
12 | # Generation step3 --comprehensive description of the video
13 | python generation.py --video_path video_data/youtube2019/train/JPEGImages --question These are frames from a video that I want to upload. In the video, the ID number of the box is on the top left of the box. There are some instance captions: '' Generate a dense caption that describes the video in detail based on the video and instance captions, including all of the instances mentioned in the instance captions and other instances in the video. Ensure that each instance mentioned in the instance caption appears exactly once in the dense caption, followed by the format {obj_} to indicate which instance caption the mentioned instance corresponds to. The {obj_} must directly follow the noun representing the instance.Please do not mention any information about the bbox in the output. --ann_path video_data/youtube2019/train.json --output_file generated_step3.txt --step step3 --caption_file output/generated_step2.json
14 |
15 | # Manually review the {obj_id} in the generated video captions based on the video content
16 |
17 | # Generate annotation file with caption
18 | python generate_annotations.py --ann_file video_data/youtube2019/train.json --obj_cap output/generated_step2.json --dense_cap output/manual_generated_step3.json --out_ann_file generated_annotation.json
19 |
20 |
21 | # Merge BURST and YouTubeVIS2019 annotation files
22 | python merge_b_y.py --burst_train video_data/burst/train/b2y_train_add_cap_del_filtered_ann.json --burst_val video_data/burst/val/b2y_val_add_cap_del_filtered_ann.json --yt19_train video_data/ytvis_2019/train_add_cap_filtered_ann.json --hq_ann_file video_data/ytvis_2019/ --out_ann_path output
23 | ```
24 |
--------------------------------------------------------------------------------
/VideoGLaMM/gcg_data_gen/burst_ytvis_gcg/generate_annotations.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import argparse
4 |
5 |
6 | def get_arguments():
7 | parser = argparse.ArgumentParser(description="Inference parameters")
8 |
9 | parser.add_argument("--ann_file", type=str, help="path to annotations")
10 | parser.add_argument("--obj_cap", type=str, help="the generated object-level caption")
11 | parser.add_argument("--dense_cap", type=str, help="the generated video-level caption")
12 | parser.add_argument("--out_ann_file", type=str, help="path to the final annotations file")
13 |
14 | return parser.parse_args()
15 |
16 | if __name__ == "__main__":
17 | args=get_arguments()
18 | ann_file= json.load(open(args.ann_file))
19 | obj_cap=json.load(open(args.obj_cap))
20 | dense_cap=json.load(open(args.dense_cap))
21 | out_ann_file= args.out_ann_file
22 |
23 |
24 | for ann in ann_file['annotations']:
25 | if str(ann['id']) in obj_cap.keys():
26 | ann['cap']=obj_cap[str(ann['id'])]
27 | else:
28 | ann['cap']=None
29 |
30 | video_cap={}
31 | for ann in ann_file['annotations']:
32 | if ann['video_id'] not in video_cap.keys():
33 | video_cap[ann['video_id']]=[]
34 | if str(ann['id']) in obj_cap.keys():
35 | video_cap[ann['video_id']].append(dict(cls_id=ann['category_id'],seg=ann['segmentations'],bboxes=ann['bboxes'],cap=obj_cap[str(ann['id'])],obj_id=ann['id'],ann_id=len(video_cap[ann['video_id']])))
36 | for vid,video in enumerate(ann_file['videos']):
37 | if str(vid) in dense_cap.keys():
38 | if len(video_cap[video['id']])==0:
39 | video['dense_cap']={}
40 | video['dense_cap']['v_id2o_id']=None
41 | video['dense_cap']['token_pos']=None
42 | video['dense_cap']['mask_id']=None
43 | video['dense_cap']['caption']=None
44 | else:
45 | video['dense_cap']={}
46 | video['dense_cap']['v_id2o_id']={}
47 | video['dense_cap']['token_pos']=[]
48 | video['dense_cap']['mask_id']=[]
49 | for an in video_cap[video['id']]:
50 | video['dense_cap']['v_id2o_id'][an['ann_id']]=an['obj_id']
51 | spl_dense_cap=dense_cap[str(vid)].split(' ')
52 | me_cap=[]
53 | for wid,word in enumerate(spl_dense_cap):
54 | if '{obj_' in word :
55 | video['dense_cap']['token_pos'].append(len(me_cap)-1)
56 | m_id=int(re.findall(r'\d+',word)[0])
57 | if m_id')]
20 |
21 | def insert_separator(X, sep):
22 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
23 |
24 | input_ids = []
25 | offset = 0
26 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
27 | offset = 1
28 | input_ids.append(prompt_chunks[0][0])
29 |
30 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
31 | input_ids.extend(x[offset:])
32 |
33 | if return_tensors is not None:
34 | if return_tensors == 'pt':
35 | return torch.tensor(input_ids, dtype=torch.long)
36 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
37 |
38 | return input_ids
39 |
40 |
41 | def get_model_name_from_path(model_path):
42 | model_path = model_path.strip("/")
43 | model_paths = model_path.split("/")
44 | if model_paths[-1].startswith('checkpoint-'):
45 | return model_paths[-2] + "_" + model_paths[-1]
46 | else:
47 | return model_paths[-1]
48 |
49 |
50 | class KeywordsStoppingCriteria(StoppingCriteria):
51 | def __init__(self, keywords, tokenizer, input_ids):
52 | self.keywords = keywords
53 | self.keyword_ids = []
54 | for keyword in keywords:
55 | cur_keyword_ids = tokenizer(keyword).input_ids
56 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
57 | cur_keyword_ids = cur_keyword_ids[1:]
58 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
59 | self.tokenizer = tokenizer
60 | self.start_len = input_ids.shape[1]
61 |
62 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
63 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
64 | offset = min(output_ids.shape[1] - self.start_len, 3)
65 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
66 | for keyword_id in self.keyword_ids:
67 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
68 | return True
69 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
70 | for keyword in self.keywords:
71 | if keyword in outputs:
72 | return True
73 | return False
74 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/chatunivi/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbzuai-oryx/VideoGLaMM/f8351c5f1fcda715c6afff66fa4a777e75d08e1b/VideoGLaMM/model/chatunivi/model/__init__.py
--------------------------------------------------------------------------------
/VideoGLaMM/model/chatunivi/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | from .clip_encoder import CLIPVisionTower
2 | # from .eva_encoder import EVAVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | if vision_tower.startswith("openai") or vision_tower.startswith("laion"):
8 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
9 |
10 | # elif vision_tower.startswith("eva_vit_g"):
11 | # return EVAVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
12 |
13 | raise ValueError(f'Unknown vision tower: {vision_tower}')
14 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/chatunivi/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args=None, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | if args is None:
15 | self.select_layer = -2
16 | self.select_feature = 'patch'
17 | else:
18 | self.select_layer = args.mm_vision_select_layer
19 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
20 |
21 | if not delay_load:
22 | self.load_model()
23 | else:
24 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
25 |
26 | def load_model(self):
27 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
28 | self.image_eval_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
29 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
30 | self.vision_tower.requires_grad_(False)
31 |
32 | self.is_loaded = True
33 |
34 | def feature_select(self, image_forward_outs, select_feature='patch'):
35 | image_features = image_forward_outs.hidden_states[self.select_layer]
36 | if select_feature == 'patch':
37 | image_features = image_features[:, 1:]
38 | elif select_feature == 'cls_patch':
39 | image_features = image_features
40 | else:
41 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
42 | return image_features
43 |
44 | @torch.no_grad()
45 | def forward(self, images, select_feature='patch'):
46 | if type(images) is list:
47 | image_features = []
48 | for image in images:
49 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
50 | image_feature = self.feature_select(image_forward_out, select_feature).to(image.dtype)
51 | image_features.append(image_feature)
52 | else:
53 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
54 | image_features = self.feature_select(image_forward_outs, select_feature).to(images.dtype)
55 |
56 | return image_features
57 |
58 | @property
59 | def dummy_feature(self):
60 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61 |
62 | @property
63 | def dtype(self):
64 | return self.vision_tower.dtype
65 |
66 | @property
67 | def device(self):
68 | return self.vision_tower.device
69 |
70 | @property
71 | def config(self):
72 | if self.is_loaded:
73 | return self.vision_tower.config
74 | else:
75 | return self.cfg_only
76 |
77 | @property
78 | def hidden_size(self):
79 | return self.config.hidden_size
80 |
81 | @property
82 | def num_patches(self):
83 | return (self.config.image_size // self.config.patch_size) ** 2
84 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/chatunivi/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from ChatUniVi.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from io import BytesIO
3 |
4 | import torch
5 | from PIL import Image
6 | from transformers import StoppingCriteria
7 |
8 | from .constants import IMAGE_TOKEN_INDEX
9 |
10 |
11 | def load_image_from_base64(image):
12 | return Image.open(BytesIO(base64.b64decode(image)))
13 |
14 |
15 | def process_images(images, image_processor, model_cfg):
16 | return image_processor(images, return_tensors="pt")["pixel_values"]
17 |
18 |
19 | def tokenizer_image_token(
20 | prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
21 | ):
22 | '''
23 | This function is used to tokenize a prompt string containing the token.
24 |
25 | - The input string is tokenized into a list of token IDs.
26 | - The image_token_index is inserted between the chunks of the prompt where was present.
27 | - The resulting result is either a list of token IDs or a PyTorch tensor, depending on the specified return_tensors parameter.
28 |
29 | '''
30 | # splits the prompt string into chunks based on the token and tokenizes each chunk separately using the provided tokenizer.
31 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")]
32 |
33 |
34 | # helper function that takes a list X and a separator sep, and it interleaves the elements of X with the separator.
35 | # It is used to insert the image_token_index between the tokenized chunks of the prompt.
36 | def insert_separator(X, sep):
37 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
38 |
39 | input_ids = []
40 | offset = 0
41 | # If the first chunk of the prompt starts with the tokenizer's beginning-of-sequence (BOS) token,
42 | # it increments the offset and appends the BOS token to the input_ids list.
43 | if (
44 | len(prompt_chunks) > 0
45 | and len(prompt_chunks[0]) > 0
46 | and prompt_chunks[0][0] == tokenizer.bos_token_id
47 | ):
48 | offset = 1
49 | input_ids.append(prompt_chunks[0][0])
50 |
51 | # uses the insert_separator function to insert the image_token_index between the tokenized chunks of the prompt.
52 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
53 | input_ids.extend(x[offset:])
54 |
55 | # The resulting input_ids list contains the tokenized prompt with the image token index appropriately inserted.
56 |
57 | if return_tensors is not None:
58 | if return_tensors == "pt":
59 | return torch.tensor(input_ids, dtype=torch.long)
60 | raise ValueError(f"Unsupported tensor type: {return_tensors}")
61 | return input_ids
62 |
63 |
64 | def get_model_name_from_path(model_path):
65 | model_path = model_path.strip("/")
66 | model_paths = model_path.split("/")
67 | if model_paths[-1].startswith("checkpoint-"):
68 | return model_paths[-2] + "_" + model_paths[-1]
69 | else:
70 | return model_paths[-1]
71 |
72 |
73 | class KeywordsStoppingCriteria(StoppingCriteria):
74 | def __init__(self, keywords, tokenizer, input_ids):
75 | self.keywords = keywords
76 | self.keyword_ids = []
77 | for keyword in keywords:
78 | cur_keyword_ids = tokenizer(keyword).input_ids
79 | if (
80 | len(cur_keyword_ids) > 1
81 | and cur_keyword_ids[0] == tokenizer.bos_token_id
82 | ):
83 | cur_keyword_ids = cur_keyword_ids[1:]
84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
85 | self.tokenizer = tokenizer
86 | self.start_len = input_ids.shape[1]
87 |
88 | def __call__(
89 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
90 | ) -> bool:
91 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
92 | offset = min(output_ids.shape[1] - self.start_len, 3)
93 | self.keyword_ids = [
94 | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
95 | ]
96 | for keyword_id in self.keyword_ids:
97 | if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
98 | return True
99 | outputs = self.tokenizer.batch_decode(
100 | output_ids[:, -offset:], skip_special_tokens=True
101 | )[0]
102 | for keyword in self.keywords:
103 | if keyword in outputs:
104 | return True
105 | return False
106 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM
2 | from .language_model.llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM
3 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from llava import LlavaLlamaForCausalLM
9 | from tqdm import tqdm
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
17 | )
18 |
19 | print("Loading delta")
20 | delta = LlavaLlamaForCausalLM.from_pretrained(
21 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
22 | )
23 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
24 |
25 | print("Applying delta")
26 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
27 | if name not in base.state_dict():
28 | assert name in [
29 | "model.mm_projector.weight",
30 | "model.mm_projector.bias",
31 | ], f"{name} not in base model"
32 | continue
33 | if param.data.shape == base.state_dict()[name].shape:
34 | param.data += base.state_dict()[name]
35 | else:
36 | assert name in [
37 | "model.embed_tokens.weight",
38 | "lm_head.weight",
39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
40 | bparam = base.state_dict()[name]
41 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
42 |
43 | print("Saving target model")
44 | delta.save_pretrained(target_model_path)
45 | delta_tokenizer.save_pretrained(target_model_path)
46 |
47 |
48 | if __name__ == "__main__":
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument("--base-model-path", type=str, required=True)
51 | parser.add_argument("--target-model-path", type=str, required=True)
52 | parser.add_argument("--delta-path", type=str, required=True)
53 |
54 | args = parser.parse_args()
55 |
56 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
57 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from llava.model import *
9 | from llava.model.utils import auto_upgrade
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(
17 | src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
18 | )
19 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
20 | src_model.save_pretrained(dst_path)
21 | src_tokenizer.save_pretrained(dst_path)
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--src", type=str, required=True)
27 | parser.add_argument("--dst", type=str, required=True)
28 |
29 | args = parser.parse_args()
30 |
31 | consolidate_ckpt(args.src, args.dst)
32 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/language_model/mpt/adapt_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from transformers import (AutoTokenizer, PreTrainedTokenizer,
4 | PreTrainedTokenizerFast)
5 |
6 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
7 | NUM_SENTINEL_TOKENS: int = 100
8 |
9 |
10 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
11 | """Adds sentinel tokens and padding token (if missing).
12 |
13 | Expands the tokenizer vocabulary to include sentinel tokens
14 | used in mixture-of-denoiser tasks as well as a padding token.
15 |
16 | All added tokens are added as special tokens. No tokens are
17 | added if sentinel tokens and padding token already exist.
18 | """
19 | sentinels_to_add = [f"" for i in range(NUM_SENTINEL_TOKENS)]
20 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
21 | if tokenizer.pad_token is None:
22 | tokenizer.add_tokens("", special_tokens=True)
23 | tokenizer.pad_token = ""
24 | assert tokenizer.pad_token_id is not None
25 | sentinels = "".join([f"" for i in range(NUM_SENTINEL_TOKENS)])
26 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
27 | tokenizer.sentinel_token_ids = _sentinel_token_ids
28 |
29 |
30 | class AutoTokenizerForMOD(AutoTokenizer):
31 | """AutoTokenizer + Adaptation for MOD.
32 |
33 | A simple wrapper around AutoTokenizer to make instantiating
34 | an MOD-adapted tokenizer a bit easier.
35 |
36 | MOD-adapted tokenizers have sentinel tokens (e.g., ),
37 | a padding token, and a property to get the token ids of the
38 | sentinel tokens.
39 | """
40 |
41 | @classmethod
42 | def from_pretrained(cls, *args, **kwargs):
43 | """See `AutoTokenizer.from_pretrained` docstring."""
44 | tokenizer = super().from_pretrained(*args, **kwargs)
45 | adapt_tokenizer_for_denoising(tokenizer)
46 | return tokenizer
47 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/language_model/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .attention import ATTN_CLASS_REGISTRY
8 | from .norm import NORM_CLASS_REGISTRY
9 |
10 |
11 | class MPTMLP(nn.Module):
12 | def __init__(
13 | self, d_model: int, expansion_ratio: int, device: Optional[str] = None
14 | ):
15 | super().__init__()
16 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
17 | self.act = nn.GELU(approximate="none")
18 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
19 | self.down_proj._is_residual = True
20 |
21 | def forward(self, x):
22 | return self.down_proj(self.act(self.up_proj(x)))
23 |
24 |
25 | class MPTBlock(nn.Module):
26 | def __init__(
27 | self,
28 | d_model: int,
29 | n_heads: int,
30 | expansion_ratio: int,
31 | attn_config: Dict = {
32 | "attn_type": "multihead_attention",
33 | "attn_pdrop": 0.0,
34 | "attn_impl": "triton",
35 | "qk_ln": False,
36 | "clip_qkv": None,
37 | "softmax_scale": None,
38 | "prefix_lm": False,
39 | "attn_uses_sequence_id": False,
40 | "alibi": False,
41 | "alibi_bias_max": 8,
42 | },
43 | resid_pdrop: float = 0.0,
44 | norm_type: str = "low_precision_layernorm",
45 | verbose: int = 0,
46 | device: Optional[str] = None,
47 | **kwargs
48 | ):
49 | del kwargs
50 | super().__init__()
51 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
52 | attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
53 | self.norm_1 = norm_class(d_model, device=device)
54 | self.attn = attn_class(
55 | attn_impl=attn_config["attn_impl"],
56 | clip_qkv=attn_config["clip_qkv"],
57 | qk_ln=attn_config["qk_ln"],
58 | softmax_scale=attn_config["softmax_scale"],
59 | attn_pdrop=attn_config["attn_pdrop"],
60 | d_model=d_model,
61 | n_heads=n_heads,
62 | verbose=verbose,
63 | device=device,
64 | )
65 | self.norm_2 = norm_class(d_model, device=device)
66 | self.ffn = MPTMLP(
67 | d_model=d_model, expansion_ratio=expansion_ratio, device=device
68 | )
69 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
71 |
72 | def forward(
73 | self,
74 | x: torch.Tensor,
75 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
76 | attn_bias: Optional[torch.Tensor] = None,
77 | attention_mask: Optional[torch.ByteTensor] = None,
78 | is_causal: bool = True,
79 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80 | a = self.norm_1(x)
81 | (b, attn_weights, past_key_value) = self.attn(
82 | a,
83 | past_key_value=past_key_value,
84 | attn_bias=attn_bias,
85 | attention_mask=attention_mask,
86 | is_causal=is_causal,
87 | )
88 | x = x + self.resid_attn_dropout(b)
89 | m = self.norm_2(x)
90 | n = self.ffn(m)
91 | x = x + self.resid_ffn_dropout(n)
92 | return (x, attn_weights, past_key_value)
93 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/language_model/mpt/custom_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 |
7 | class SharedEmbedding(nn.Embedding):
8 | def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
9 | if unembed:
10 | return F.linear(input, self.weight)
11 | return super().forward(input)
12 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/language_model/mpt/meta_init_context.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | @contextmanager
8 | def init_empty_weights(include_buffers: bool = False):
9 | """Meta initialization context manager.
10 |
11 | A context manager under which models are initialized with all parameters
12 | on the meta device, therefore creating an empty model. Useful when just
13 | initializing the model would blow the available RAM.
14 |
15 | Args:
16 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
17 | not to also put all buffers on the meta device while initializing.
18 |
19 | Example:
20 | ```python
21 | import torch.nn as nn
22 |
23 | # Initialize a model with 100 billions parameters in no time and without using any RAM.
24 | with init_empty_weights():
25 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
26 | ```
27 |
28 |
29 |
30 | Any model created under this context manager has no weights. As such you can't do something like
31 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
32 |
33 |
34 | """
35 | with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
36 | yield f
37 |
38 |
39 | @contextmanager
40 | def init_on_device(device: torch.device, include_buffers: bool = False):
41 | """Device initialization context manager.
42 |
43 | A context manager under which models are initialized with all parameters
44 | on the specified device.
45 |
46 | Args:
47 | device (`torch.device`): Device to initialize all parameters on.
48 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
49 | not to also put all buffers on the meta device while initializing.
50 |
51 | Example:
52 | ```python
53 | import torch.nn as nn
54 |
55 | with init_on_device(device=torch.device("cuda")):
56 | tst = nn.Liner(100, 100) # on `cuda` device
57 | ```
58 | """
59 | old_register_parameter = nn.Module.register_parameter
60 | if include_buffers:
61 | old_register_buffer = nn.Module.register_buffer
62 |
63 | def register_empty_parameter(module, name, param):
64 | old_register_parameter(module, name, param)
65 | if param is not None:
66 | param_cls = type(module._parameters[name])
67 | kwargs = module._parameters[name].__dict__
68 | module._parameters[name] = param_cls(
69 | module._parameters[name].to(device), **kwargs
70 | )
71 |
72 | def register_empty_buffer(module, name, buffer):
73 | old_register_buffer(module, name, buffer)
74 | if buffer is not None:
75 | module._buffers[name] = module._buffers[name].to(device)
76 |
77 | if include_buffers:
78 | tensor_constructors_to_patch = {
79 | torch_function_name: getattr(torch, torch_function_name)
80 | for torch_function_name in ["empty", "zeros", "ones", "full"]
81 | }
82 | else:
83 | tensor_constructors_to_patch = {}
84 |
85 | def patch_tensor_constructor(fn):
86 | def wrapper(*args, **kwargs):
87 | kwargs["device"] = device
88 | return fn(*args, **kwargs)
89 |
90 | return wrapper
91 |
92 | try:
93 | nn.Module.register_parameter = register_empty_parameter
94 | if include_buffers:
95 | nn.Module.register_buffer = register_empty_buffer
96 | for torch_function_name in tensor_constructors_to_patch.keys():
97 | setattr(
98 | torch,
99 | torch_function_name,
100 | patch_tensor_constructor(getattr(torch, torch_function_name)),
101 | )
102 | yield
103 | finally:
104 | nn.Module.register_parameter = old_register_parameter
105 | if include_buffers:
106 | nn.Module.register_buffer = old_register_buffer
107 | for (
108 | torch_function_name,
109 | old_torch_function,
110 | ) in tensor_constructors_to_patch.items():
111 | setattr(torch, torch_function_name, old_torch_function)
112 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/language_model/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _cast_if_autocast_enabled(tensor):
5 | if torch.is_autocast_enabled():
6 | if tensor.device.type == "cuda":
7 | dtype = torch.get_autocast_gpu_dtype()
8 | elif tensor.device.type == "cpu":
9 | dtype = torch.get_autocast_cpu_dtype()
10 | else:
11 | raise NotImplementedError()
12 | return tensor.to(dtype=dtype)
13 | return tensor
14 |
15 |
16 | class LPLayerNorm(torch.nn.LayerNorm):
17 | def __init__(
18 | self,
19 | normalized_shape,
20 | eps=1e-05,
21 | elementwise_affine=True,
22 | device=None,
23 | dtype=None,
24 | ):
25 | super().__init__(
26 | normalized_shape=normalized_shape,
27 | eps=eps,
28 | elementwise_affine=elementwise_affine,
29 | device=device,
30 | dtype=dtype,
31 | )
32 |
33 | def forward(self, x):
34 | module_device = x.device
35 | downcast_x = _cast_if_autocast_enabled(x)
36 | downcast_weight = (
37 | _cast_if_autocast_enabled(self.weight)
38 | if self.weight is not None
39 | else self.weight
40 | )
41 | downcast_bias = (
42 | _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
43 | )
44 | with torch.autocast(enabled=False, device_type=module_device.type):
45 | return torch.nn.functional.layer_norm(
46 | downcast_x,
47 | self.normalized_shape,
48 | downcast_weight,
49 | downcast_bias,
50 | self.eps,
51 | )
52 |
53 |
54 | def rms_norm(x, weight=None, eps=1e-05):
55 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
56 | if weight is not None:
57 | return output * weight
58 | return output
59 |
60 |
61 | class RMSNorm(torch.nn.Module):
62 | def __init__(
63 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
64 | ):
65 | super().__init__()
66 | self.eps = eps
67 | if weight:
68 | self.weight = torch.nn.Parameter(
69 | torch.ones(normalized_shape, dtype=dtype, device=device)
70 | )
71 | else:
72 | self.register_parameter("weight", None)
73 |
74 | def forward(self, x):
75 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
76 |
77 |
78 | class LPRMSNorm(RMSNorm):
79 | def __init__(
80 | self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
81 | ):
82 | super().__init__(
83 | normalized_shape=normalized_shape,
84 | eps=eps,
85 | weight=weight,
86 | dtype=dtype,
87 | device=device,
88 | )
89 |
90 | def forward(self, x):
91 | downcast_x = _cast_if_autocast_enabled(x)
92 | downcast_weight = (
93 | _cast_if_autocast_enabled(self.weight)
94 | if self.weight is not None
95 | else self.weight
96 | )
97 | with torch.autocast(enabled=False, device_type=x.device.type):
98 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
99 |
100 |
101 | NORM_CLASS_REGISTRY = {
102 | "layernorm": torch.nn.LayerNorm,
103 | "low_precision_layernorm": LPLayerNorm,
104 | "rmsnorm": RMSNorm,
105 | "low_precision_rmsnorm": LPRMSNorm,
106 | }
107 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from llava.model.utils import auto_upgrade
9 | from tqdm import tqdm
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
17 | )
18 |
19 | print("Loading target model")
20 | auto_upgrade(target_model_path)
21 | target = AutoModelForCausalLM.from_pretrained(
22 | target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
23 | )
24 |
25 | print("Calculating delta")
26 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
27 | if name not in base.state_dict():
28 | assert name in [
29 | "model.mm_projector.weight",
30 | "model.mm_projector.bias",
31 | ], f"{name} not in base model"
32 | continue
33 | if param.data.shape == base.state_dict()[name].shape:
34 | param.data -= base.state_dict()[name]
35 | else:
36 | assert name in [
37 | "model.embed_tokens.weight",
38 | "lm_head.weight",
39 | ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
40 | bparam = base.state_dict()[name]
41 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
42 |
43 | print("Saving delta")
44 | if hub_repo_id:
45 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
46 | else:
47 | kwargs = {}
48 | target.save_pretrained(delta_path, **kwargs)
49 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
50 | target_tokenizer.save_pretrained(delta_path, **kwargs)
51 |
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--base-model-path", type=str, required=True)
56 | parser.add_argument("--target-model-path", type=str, required=True)
57 | parser.add_argument("--delta-path", type=str, required=True)
58 | parser.add_argument("--hub-repo-id", type=str, default=None)
59 | args = parser.parse_args()
60 |
61 | make_delta(
62 | args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
63 | )
64 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | from .clip_encoder import CLIPVisionTower
2 |
3 |
4 | def build_vision_tower(vision_tower_cfg, **kwargs):
5 | vision_tower = getattr(
6 | vision_tower_cfg,
7 | "mm_vision_tower",
8 | getattr(vision_tower_cfg, "vision_tower", None),
9 | )
10 | if (
11 | vision_tower.startswith("openai")
12 | or vision_tower.startswith("laion")
13 | or "clip" in vision_tower
14 | ):
15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16 |
17 | raise ValueError(f"Unknown vision tower: {vision_tower}")
18 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
4 |
5 |
6 | class CLIPVisionTower(nn.Module):
7 | def __init__(self, vision_tower, args, delay_load=False):
8 | super().__init__()
9 |
10 | self.is_loaded = False
11 |
12 | self.vision_tower_name = vision_tower
13 | self.select_layer = args.mm_vision_select_layer
14 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
15 |
16 | if not delay_load:
17 | self.load_model()
18 | else:
19 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
20 |
21 | def load_model(self):
22 | self.image_processor = CLIPImageProcessor.from_pretrained(
23 | self.vision_tower_name
24 | )
25 | self.vision_tower = CLIPVisionModel.from_pretrained(
26 | self.vision_tower_name, low_cpu_mem_usage=True
27 | )
28 | self.vision_tower.requires_grad_(False)
29 | self.is_loaded = True
30 |
31 | def feature_select(self, image_forward_outs):
32 | image_features = image_forward_outs.hidden_states[self.select_layer]
33 | if self.select_feature == "patch":
34 | image_features = image_features[:, 1:]
35 | elif self.select_feature == "cls_patch":
36 | image_features = image_features
37 | else:
38 | raise ValueError(f"Unexpected select feature: {self.select_feature}")
39 | return image_features
40 |
41 | @torch.no_grad()
42 | def forward(self, images):
43 | if type(images) is list:
44 | image_features = []
45 | for image in images:
46 | image_forward_out = self.vision_tower(
47 | image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
48 | output_hidden_states=True,
49 | )
50 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
51 | image_features.append(image_feature)
52 | else:
53 | image_forward_outs = self.vision_tower(
54 | images.to(device=self.device, dtype=self.dtype),
55 | output_hidden_states=True,
56 | )
57 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
58 |
59 | torch.cuda.empty_cache()
60 | return image_features
61 |
62 | @property
63 | def dummy_feature(self):
64 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
65 |
66 | @property
67 | def dtype(self):
68 | return self.vision_tower.dtype
69 |
70 | @property
71 | def device(self):
72 | return self.vision_tower.device
73 |
74 | @property
75 | def config(self):
76 | if self.is_loaded:
77 | return self.vision_tower.config
78 | else:
79 | return self.cfg_only
80 |
81 | @property
82 | def hidden_size(self):
83 | return self.config.hidden_size
84 |
85 | @property
86 | def num_patches(self):
87 | return (self.config.image_size // self.config.patch_size) ** 2
88 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if "llava" in config and "llava" not in cfg.model_type:
7 | assert cfg.model_type == "llama"
8 | print(
9 | "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
10 | )
11 | print(
12 | "You must upgrade the checkpoint to the new code base (this can be done automatically)."
13 | )
14 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
15 | if confirm.lower() in ["y", "yes"]:
16 | print("Upgrading checkpoint...")
17 | assert len(cfg.architectures) == 1
18 | setattr(cfg.__class__, "model_type", "llava")
19 | cfg.architectures[0] = "LlavaLlamaForCausalLM"
20 | cfg.save_pretrained(config)
21 | print("Checkpoint upgraded.")
22 | else:
23 | print("Checkpoint upgrade aborted.")
24 | exit(1)
25 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | import transformers
6 | from einops import rearrange
7 | from torch import nn
8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9 |
10 | try:
11 | from flash_attn.flash_attn_interface import \
12 | flash_attn_unpadded_qkvpacked_func
13 | except ImportError:
14 | from flash_attn.flash_attn_interface import (
15 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
16 | )
17 |
18 | from flash_attn.bert_padding import pad_input, unpad_input
19 |
20 |
21 | def forward(
22 | self,
23 | hidden_states: torch.Tensor,
24 | attention_mask: Optional[torch.Tensor] = None,
25 | position_ids: Optional[torch.Tensor] = None,
26 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
27 | output_attentions: bool = False,
28 | use_cache: bool = False,
29 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
30 | """Input shape: Batch x Time x Channel
31 |
32 | attention_mask: [bsz, q_len]
33 | """
34 | bsz, q_len, _ = hidden_states.size()
35 |
36 | query_states = (
37 | self.q_proj(hidden_states)
38 | .view(bsz, q_len, self.num_heads, self.head_dim)
39 | .transpose(1, 2)
40 | )
41 | key_states = (
42 | self.k_proj(hidden_states)
43 | .view(bsz, q_len, self.num_heads, self.head_dim)
44 | .transpose(1, 2)
45 | )
46 | value_states = (
47 | self.v_proj(hidden_states)
48 | .view(bsz, q_len, self.num_heads, self.head_dim)
49 | .transpose(1, 2)
50 | )
51 | # [bsz, q_len, nh, hd]
52 | # [bsz, nh, q_len, hd]
53 |
54 | kv_seq_len = key_states.shape[-2]
55 | assert past_key_value is None, "past_key_value is not supported"
56 |
57 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
58 | query_states, key_states = apply_rotary_pos_emb(
59 | query_states, key_states, cos, sin, position_ids
60 | )
61 | # [bsz, nh, t, hd]
62 | assert not output_attentions, "output_attentions is not supported"
63 | assert not use_cache, "use_cache is not supported"
64 |
65 | # Flash attention codes from
66 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
67 |
68 | # transform the data into the format required by flash attention
69 | qkv = torch.stack(
70 | [query_states, key_states, value_states], dim=2
71 | ) # [bsz, nh, 3, q_len, hd]
72 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
73 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
74 | # the attention_mask should be the same as the key_padding_mask
75 | key_padding_mask = attention_mask
76 |
77 | if key_padding_mask is None:
78 | qkv = rearrange(qkv, "b s ... -> (b s) ...")
79 | max_s = q_len
80 | cu_q_lens = torch.arange(
81 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
82 | )
83 | output = flash_attn_unpadded_qkvpacked_func(
84 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
85 | )
86 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
87 | else:
88 | nheads = qkv.shape[-2]
89 | x = rearrange(qkv, "b s three h d -> b s (three h d)")
90 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
91 | x_unpad = rearrange(
92 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
93 | )
94 | output_unpad = flash_attn_unpadded_qkvpacked_func(
95 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
96 | )
97 | output = rearrange(
98 | pad_input(
99 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
100 | ),
101 | "b s (h d) -> b s h d",
102 | h=nheads,
103 | )
104 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
105 |
106 |
107 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
108 | # requires the attention mask to be the same as the key_padding_mask
109 | def _prepare_decoder_attention_mask(
110 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
111 | ):
112 | # [bsz, seq_len]
113 | return attention_mask
114 |
115 |
116 | def replace_llama_attn_with_flash_attn():
117 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
118 | if cuda_major < 8:
119 | logging.warning(
120 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
121 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
122 | )
123 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
124 | _prepare_decoder_attention_mask
125 | )
126 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
127 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/train/llava_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | import torch
5 | from transformers import Trainer
6 |
7 |
8 | def maybe_zero_3(param, ignore_status=False, name=None):
9 | from deepspeed import zero
10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
11 |
12 | if hasattr(param, "ds_id"):
13 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
14 | if not ignore_status:
15 | print(name, "no ignore status")
16 | with zero.GatheredParameters([param]):
17 | param = param.data.detach().cpu().clone()
18 | else:
19 | param = param.detach().cpu().clone()
20 | return param
21 |
22 |
23 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
24 | to_return = {
25 | k: t
26 | for k, t in named_params
27 | if any(key_match in k for key_match in keys_to_match)
28 | }
29 | to_return = {
30 | k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
31 | for k, v in to_return.items()
32 | }
33 | return to_return
34 |
35 |
36 | class LLaVATrainer(Trainer):
37 | def _save_checkpoint(self, model, trial, metrics=None):
38 | if getattr(self.args, "tune_mm_mlp_adapter", False):
39 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
40 |
41 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
42 |
43 | run_dir = self._get_output_dir(trial=trial)
44 | output_dir = os.path.join(run_dir, checkpoint_folder)
45 |
46 | # Only save Adapter
47 | keys_to_match = ["mm_projector"]
48 | if getattr(self.args, "use_im_start_end", False):
49 | keys_to_match.extend(["embed_tokens", "embed_in"])
50 |
51 | weight_to_save = get_mm_adapter_state_maybe_zero_3(
52 | self.model.named_parameters(), keys_to_match
53 | )
54 |
55 | if self.args.local_rank == 0 or self.args.local_rank == -1:
56 | self.model.config.save_pretrained(output_dir)
57 | torch.save(
58 | weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
59 | )
60 | else:
61 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
62 |
63 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
64 | if getattr(self.args, "tune_mm_mlp_adapter", False):
65 | pass
66 | else:
67 | super(LLaVATrainer, self)._save(output_dir, state_dict)
68 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from llava.train.llama_flash_attn_monkey_patch import \
7 | replace_llama_attn_with_flash_attn
8 |
9 | replace_llama_attn_with_flash_attn()
10 |
11 | from llava.train.train import train
12 |
13 | if __name__ == "__main__":
14 | train()
15 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 | from llava.constants import LOGDIR
9 |
10 | server_error_msg = (
11 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | )
13 | moderation_msg = (
14 | "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15 | )
16 |
17 | handler = None
18 |
19 |
20 | def build_logger(logger_name, logger_filename):
21 | global handler
22 |
23 | formatter = logging.Formatter(
24 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
25 | datefmt="%Y-%m-%d %H:%M:%S",
26 | )
27 |
28 | # Set the format of root handlers
29 | if not logging.getLogger().handlers:
30 | logging.basicConfig(level=logging.INFO)
31 | logging.getLogger().handlers[0].setFormatter(formatter)
32 |
33 | # Redirect stdout and stderr to loggers
34 | stdout_logger = logging.getLogger("stdout")
35 | stdout_logger.setLevel(logging.INFO)
36 | sl = StreamToLogger(stdout_logger, logging.INFO)
37 | sys.stdout = sl
38 |
39 | stderr_logger = logging.getLogger("stderr")
40 | stderr_logger.setLevel(logging.ERROR)
41 | sl = StreamToLogger(stderr_logger, logging.ERROR)
42 | sys.stderr = sl
43 |
44 | # Get logger
45 | logger = logging.getLogger(logger_name)
46 | logger.setLevel(logging.INFO)
47 |
48 | # Add a file handler for all loggers
49 | if handler is None:
50 | os.makedirs(LOGDIR, exist_ok=True)
51 | filename = os.path.join(LOGDIR, logger_filename)
52 | handler = logging.handlers.TimedRotatingFileHandler(
53 | filename, when="D", utc=True
54 | )
55 | handler.setFormatter(formatter)
56 |
57 | for name, item in logging.root.manager.loggerDict.items():
58 | if isinstance(item, logging.Logger):
59 | item.addHandler(handler)
60 |
61 | return logger
62 |
63 |
64 | class StreamToLogger(object):
65 | """
66 | Fake file-like stream object that redirects writes to a logger instance.
67 | """
68 |
69 | def __init__(self, logger, log_level=logging.INFO):
70 | self.terminal = sys.stdout
71 | self.logger = logger
72 | self.log_level = log_level
73 | self.linebuf = ""
74 |
75 | def __getattr__(self, attr):
76 | return getattr(self.terminal, attr)
77 |
78 | def write(self, buf):
79 | temp_linebuf = self.linebuf + buf
80 | self.linebuf = ""
81 | for line in temp_linebuf.splitlines(True):
82 | # From the io.TextIOWrapper docs:
83 | # On output, if newline is None, any '\n' characters written
84 | # are translated to the system default line separator.
85 | # By default sys.stdout.write() expects '\n' newlines and then
86 | # translates them so this is still cross platform.
87 | if line[-1] == "\n":
88 | self.logger.log(self.log_level, line.rstrip())
89 | else:
90 | self.linebuf += line
91 |
92 | def flush(self):
93 | if self.linebuf != "":
94 | self.logger.log(self.log_level, self.linebuf.rstrip())
95 | self.linebuf = ""
96 |
97 |
98 | def disable_torch_init():
99 | """
100 | Disable the redundant torch default initialization to accelerate model creation.
101 | """
102 | import torch
103 |
104 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
105 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
106 |
107 |
108 | def violates_moderation(text):
109 | """
110 | Check whether the text violates OpenAI moderation API.
111 | """
112 | url = "https://api.openai.com/v1/moderations"
113 | headers = {
114 | "Content-Type": "application/json",
115 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
116 | }
117 | text = text.replace("\n", "")
118 | data = "{" + '"input": ' + f'"{text}"' + "}"
119 | data = data.encode("utf-8")
120 | try:
121 | ret = requests.post(url, headers=headers, data=data, timeout=5)
122 | flagged = ret.json()["results"][0]["flagged"]
123 | except requests.exceptions.RequestException as e:
124 | flagged = False
125 | except KeyError as e:
126 | flagged = False
127 |
128 | return flagged
129 |
130 |
131 | def pretty_print_semaphore(semaphore):
132 | if semaphore is None:
133 | return "None"
134 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
135 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .automatic_mask_generator import SamAutomaticMaskGenerator
8 | from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
9 | build_sam_vit_l, sam_model_registry)
10 | from .predictor import SamPredictor
11 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from functools import partial
8 |
9 | import torch
10 |
11 | from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam,
12 | TwoWayTransformer,
13 | CustomMaskDecoder)
14 |
15 |
16 | def build_sam_vit_h(checkpoint=None, with_itm=False):
17 | if with_itm:
18 | return _build_sam(
19 | encoder_embed_dim=1280,
20 | encoder_depth=32,
21 | encoder_num_heads=16,
22 | encoder_global_attn_indexes=[7, 15, 23, 31],
23 | checkpoint=checkpoint,
24 | with_itm=True
25 | )
26 | return _build_sam(
27 | encoder_embed_dim=1280,
28 | encoder_depth=32,
29 | encoder_num_heads=16,
30 | encoder_global_attn_indexes=[7, 15, 23, 31],
31 | checkpoint=checkpoint,
32 | )
33 |
34 |
35 | build_sam = build_sam_vit_h
36 |
37 |
38 | def build_sam_vit_l(checkpoint=None):
39 | return _build_sam(
40 | encoder_embed_dim=1024,
41 | encoder_depth=24,
42 | encoder_num_heads=16,
43 | encoder_global_attn_indexes=[5, 11, 17, 23],
44 | checkpoint=checkpoint,
45 | )
46 |
47 |
48 | def build_sam_vit_b(checkpoint=None):
49 | return _build_sam(
50 | encoder_embed_dim=768,
51 | encoder_depth=12,
52 | encoder_num_heads=12,
53 | encoder_global_attn_indexes=[2, 5, 8, 11],
54 | checkpoint=checkpoint,
55 | )
56 |
57 |
58 | sam_model_registry = {
59 | "default": build_sam_vit_h,
60 | "vit_h": build_sam_vit_h,
61 | "vit_l": build_sam_vit_l,
62 | "vit_b": build_sam_vit_b,
63 | }
64 |
65 |
66 | def _build_sam(
67 | encoder_embed_dim,
68 | encoder_depth,
69 | encoder_num_heads,
70 | encoder_global_attn_indexes,
71 | checkpoint=None,
72 | with_itm=False,
73 | ):
74 | prompt_embed_dim = 256
75 | image_size = 1024
76 | vit_patch_size = 16
77 | image_embedding_size = image_size // vit_patch_size
78 | sam = Sam(
79 | image_encoder=ImageEncoderViT(
80 | depth=encoder_depth,
81 | embed_dim=encoder_embed_dim,
82 | img_size=image_size,
83 | mlp_ratio=4,
84 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
85 | num_heads=encoder_num_heads,
86 | patch_size=vit_patch_size,
87 | qkv_bias=True,
88 | use_rel_pos=True,
89 | global_attn_indexes=encoder_global_attn_indexes,
90 | window_size=14,
91 | out_chans=prompt_embed_dim,
92 | ),
93 | prompt_encoder=PromptEncoder(
94 | embed_dim=prompt_embed_dim,
95 | image_embedding_size=(image_embedding_size, image_embedding_size),
96 | input_image_size=(image_size, image_size),
97 | mask_in_chans=16,
98 | ),
99 |
100 | mask_decoder= CustomMaskDecoder( #NOTE: replace with CustomMaskDecoder
101 | num_multimask_outputs=3,
102 | transformer=TwoWayTransformer(
103 | depth=2,
104 | embedding_dim=prompt_embed_dim,
105 | mlp_dim=2048,
106 | num_heads=8,
107 | ),
108 | transformer_dim=prompt_embed_dim,
109 | iou_head_depth=3,
110 | iou_head_hidden_dim=256,
111 | ) if with_itm else
112 | MaskDecoder(
113 | num_multimask_outputs=3,
114 | transformer=TwoWayTransformer(
115 | depth=2,
116 | embedding_dim=prompt_embed_dim,
117 | mlp_dim=2048,
118 | num_heads=8,
119 | ),
120 | transformer_dim=prompt_embed_dim,
121 | iou_head_depth=3,
122 | iou_head_hidden_dim=256,
123 | ),
124 |
125 | pixel_mean=[123.675, 116.28, 103.53],
126 | pixel_std=[58.395, 57.12, 57.375],
127 | )
128 | sam.eval()
129 | if checkpoint is not None:
130 | with open(checkpoint, "rb") as f:
131 | state_dict = torch.load(f)
132 | sam.load_state_dict(state_dict, strict=False)
133 | return sam
134 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .image_encoder import ImageEncoderViT
8 | from .mask_decoder import MaskDecoder, CustomMaskDecoder
9 | from .prompt_encoder import PromptEncoder
10 | from .sam import Sam
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Type
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from copy import deepcopy
8 | from typing import Tuple
9 |
10 | import numpy as np
11 | import torch
12 | from torch.nn import functional as F
13 | from torchvision.transforms.functional import resize # type: ignore
14 | from torchvision.transforms.functional import to_pil_image
15 |
16 |
17 | class ResizeLongestSide:
18 | """
19 | Resizes images to the longest side 'target_length', as well as provides
20 | methods for resizing coordinates and boxes. Provides methods for
21 | transforming both numpy array and batched torch tensors.
22 | """
23 |
24 | def __init__(self, target_length: int) -> None:
25 | self.target_length = target_length
26 |
27 | def apply_image(self, image: np.ndarray) -> np.ndarray:
28 | """
29 | Expects a numpy array with shape HxWxC in uint8 format.
30 | """
31 | target_size = self.get_preprocess_shape(
32 | image.shape[0], image.shape[1], self.target_length
33 | )
34 | return np.array(resize(to_pil_image(image), target_size))
35 |
36 | def apply_coords(
37 | self, coords: np.ndarray, original_size: Tuple[int, ...]
38 | ) -> np.ndarray:
39 | """
40 | Expects a numpy array of length 2 in the final dimension. Requires the
41 | original image size in (H, W) format.
42 | """
43 | old_h, old_w = original_size
44 | new_h, new_w = self.get_preprocess_shape(
45 | original_size[0], original_size[1], self.target_length
46 | )
47 | coords = deepcopy(coords).astype(float)
48 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
49 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
50 | return coords
51 |
52 | def apply_boxes(
53 | self, boxes: np.ndarray, original_size: Tuple[int, ...]
54 | ) -> np.ndarray:
55 | """
56 | Expects a numpy array shape Bx4. Requires the original image size
57 | in (H, W) format.
58 | """
59 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
60 | return boxes.reshape(-1, 4)
61 |
62 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
63 | """
64 | Expects batched images with shape BxCxHxW and float format. This
65 | transformation may not exactly match apply_image. apply_image is
66 | the transformation expected by the model.
67 | """
68 | # Expects an image in BCHW format. May not exactly match apply_image.
69 | target_size = self.get_preprocess_shape(
70 | image.shape[0], image.shape[1], self.target_length
71 | )
72 | return F.interpolate(
73 | image, target_size, mode="bilinear", align_corners=False, antialias=True
74 | )
75 |
76 | def apply_coords_torch(
77 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
78 | ) -> torch.Tensor:
79 | """
80 | Expects a torch tensor with length 2 in the last dimension. Requires the
81 | original image size in (H, W) format.
82 | """
83 | old_h, old_w = original_size
84 | new_h, new_w = self.get_preprocess_shape(
85 | original_size[0], original_size[1], self.target_length
86 | )
87 | coords = deepcopy(coords).to(torch.float)
88 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
89 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
90 | return coords
91 |
92 | def apply_boxes_torch(
93 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
94 | ) -> torch.Tensor:
95 | """
96 | Expects a torch tensor with shape Bx4. Requires the original image
97 | size in (H, W) format.
98 | """
99 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
100 | return boxes.reshape(-1, 4)
101 |
102 | @staticmethod
103 | def get_preprocess_shape(
104 | oldh: int, oldw: int, long_side_length: int
105 | ) -> Tuple[int, int]:
106 | """
107 | Compute the output size given input size and target long side length.
108 | """
109 | scale = long_side_length * 1.0 / max(oldh, oldw)
110 | newh, neww = oldh * scale, oldw * scale
111 | neww = int(neww + 0.5)
112 | newh = int(newh + 0.5)
113 | return (newh, neww)
114 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from hydra import initialize_config_module
9 |
10 | initialize_config_module("model/segment_anything_2/sam2_configs", version_base="1.2")
11 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 |
9 | import torch
10 | from hydra import compose
11 | from hydra.utils import instantiate
12 | from omegaconf import OmegaConf
13 |
14 | def build_sam2(
15 | config_file,
16 | ckpt_path=None,
17 | device="cuda",
18 | mode="eval",
19 | hydra_overrides_extra=[],
20 | apply_postprocessing=True,
21 | ):
22 |
23 | if apply_postprocessing:
24 | hydra_overrides_extra = hydra_overrides_extra.copy()
25 | hydra_overrides_extra += [
26 | # dynamically fall back to multi-mask if the single mask is not stable
27 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
28 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
30 | ]
31 | # Read config and init model
32 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
33 | OmegaConf.resolve(cfg)
34 | model = instantiate(cfg.model, _recursive_=True)
35 | _load_checkpoint(model, ckpt_path)
36 | if device:
37 | model = model.to(device)
38 | if mode == "eval":
39 | model.eval()
40 | return model
41 |
42 |
43 | def build_sam2_video_predictor(
44 | config_file,
45 | ckpt_path=None,
46 | device="cuda",
47 | mode="eval",
48 | hydra_overrides_extra=[],
49 | apply_postprocessing=True,
50 | ):
51 | hydra_overrides = [
52 | "++model._target_=model.segment_anything_2.sam2.sam2_video_predictor.SAM2VideoPredictor",
53 | ]
54 | if apply_postprocessing:
55 | hydra_overrides_extra = hydra_overrides_extra.copy()
56 | hydra_overrides_extra += [
57 | # dynamically fall back to multi-mask if the single mask is not stable
58 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
59 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
61 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
62 | "++model.binarize_mask_from_pts_for_mem_enc=true",
63 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
64 | "++model.fill_hole_area=8",
65 | ]
66 | hydra_overrides.extend(hydra_overrides_extra)
67 |
68 | # Read config and init model
69 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
70 | OmegaConf.resolve(cfg)
71 | model = instantiate(cfg.model, _recursive_=True)
72 | _load_checkpoint(model, ckpt_path)
73 | if device:
74 | model = model.to(device)
75 | if mode == "eval":
76 | model.eval()
77 | return model
78 |
79 |
80 | # def _load_checkpoint(model, ckpt_path):
81 | # if ckpt_path is not None:
82 | # sd = torch.load(ckpt_path, map_location="cpu")["model"]
83 | # missing_keys, unexpected_keys = model.load_state_dict(sd)
84 | # if missing_keys:
85 | # logging.error(missing_keys)
86 | # raise RuntimeError()
87 | # if unexpected_keys:
88 | # logging.error(unexpected_keys)
89 | # raise RuntimeError()
90 | # logging.info("Loaded checkpoint sucessfully")
91 |
92 | def _load_checkpoint(model, ckpt_path):
93 | '''
94 | load checkpoint from ckpt_path to model, while renaming 'gamma' to 'weight' in the state dict
95 | '''
96 | if ckpt_path is not None:
97 | sd = torch.load(ckpt_path, map_location="cpu")["model"]
98 |
99 | # Rename 'gamma' to 'weight' in the state dict
100 | sd = {key.replace('.gamma', '.weight'): value for key, value in sd.items()}
101 |
102 | missing_keys, unexpected_keys = model.load_state_dict(sd)
103 |
104 | if missing_keys:
105 | logging.error(f"Missing keys: {missing_keys}")
106 | raise RuntimeError("Missing keys found in the state dict.")
107 |
108 | if unexpected_keys:
109 | logging.error(f"Unexpected keys: {unexpected_keys}")
110 | raise RuntimeError("Unexpected keys found in the state dict.")
111 |
112 | logging.info("Loaded checkpoint successfully.")
113 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/modeling/backbones/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class ImageEncoder(nn.Module):
15 | def __init__(
16 | self,
17 | trunk: nn.Module,
18 | neck: nn.Module,
19 | scalp: int = 0,
20 | ):
21 | super().__init__()
22 | self.trunk = trunk
23 | self.neck = neck
24 | self.scalp = scalp
25 | assert (
26 | self.trunk.channel_list == self.neck.backbone_channel_list
27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28 |
29 | def forward(self, sample: torch.Tensor):
30 | # Forward through backbone
31 | features, pos = self.neck(self.trunk(sample))
32 | if self.scalp > 0:
33 | # Discard the lowest resolution features
34 | features, pos = features[: -self.scalp], pos[: -self.scalp]
35 |
36 | src = features[-1]
37 | output = {
38 | "vision_features": src,
39 | "vision_pos_enc": pos,
40 | "backbone_fpn": features,
41 | }
42 | return output
43 |
44 |
45 | class FpnNeck(nn.Module):
46 | """
47 | A modified variant of Feature Pyramid Network (FPN) neck
48 | (we remove output conv and also do bicubic interpolation similar to ViT
49 | pos embed interpolation)
50 | """
51 |
52 | def __init__(
53 | self,
54 | position_encoding: nn.Module,
55 | d_model: int,
56 | backbone_channel_list: List[int],
57 | kernel_size: int = 1,
58 | stride: int = 1,
59 | padding: int = 0,
60 | fpn_interp_model: str = "bilinear",
61 | fuse_type: str = "sum",
62 | fpn_top_down_levels: Optional[List[int]] = None,
63 | ):
64 | """Initialize the neck
65 | :param trunk: the backbone
66 | :param position_encoding: the positional encoding to use
67 | :param d_model: the dimension of the model
68 | :param neck_norm: the normalization to use
69 | """
70 | super().__init__()
71 | self.position_encoding = position_encoding
72 | self.convs = nn.ModuleList()
73 | self.backbone_channel_list = backbone_channel_list
74 | for dim in backbone_channel_list:
75 | current = nn.Sequential()
76 | current.add_module(
77 | "conv",
78 | nn.Conv2d(
79 | in_channels=dim,
80 | out_channels=d_model,
81 | kernel_size=kernel_size,
82 | stride=stride,
83 | padding=padding,
84 | ),
85 | )
86 |
87 | self.convs.append(current)
88 | self.fpn_interp_model = fpn_interp_model
89 | assert fuse_type in ["sum", "avg"]
90 | self.fuse_type = fuse_type
91 |
92 | # levels to have top-down features in its outputs
93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
94 | # have top-down propagation, while outputs of level 0 and level 1 have only
95 | # lateral features from the same backbone level.
96 | if fpn_top_down_levels is None:
97 | # default is to have top-down features on all levels
98 | fpn_top_down_levels = range(len(self.convs))
99 | self.fpn_top_down_levels = list(fpn_top_down_levels)
100 |
101 | def forward(self, xs: List[torch.Tensor]):
102 |
103 | out = [None] * len(self.convs)
104 | pos = [None] * len(self.convs)
105 | assert len(xs) == len(self.convs)
106 | # fpn forward pass
107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
108 | prev_features = None
109 | # forward in top-down order (from low to high resolution)
110 | n = len(self.convs) - 1
111 | for i in range(n, -1, -1):
112 | x = xs[i]
113 | lateral_features = self.convs[n - i](x)
114 | if i in self.fpn_top_down_levels and prev_features is not None:
115 | top_down_features = F.interpolate(
116 | prev_features.to(dtype=torch.float32),
117 | scale_factor=2.0,
118 | mode=self.fpn_interp_model,
119 | align_corners=(
120 | None if self.fpn_interp_model == "nearest" else False
121 | ),
122 | antialias=False,
123 | )
124 | prev_features = lateral_features + top_down_features
125 | if self.fuse_type == "avg":
126 | prev_features /= 2
127 | else:
128 | prev_features = lateral_features
129 | x_out = prev_features
130 | out[i] = x_out
131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
132 |
133 | return out, pos
134 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/modeling/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some utilities for backbones, in particular for windowing"""
8 |
9 | from typing import Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def window_partition(x, window_size):
17 | """
18 | Partition into non-overlapping windows with padding if needed.
19 | Args:
20 | x (tensor): input tokens with [B, H, W, C].
21 | window_size (int): window size.
22 | Returns:
23 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
24 | (Hp, Wp): padded height and width before partition
25 | """
26 | B, H, W, C = x.shape
27 |
28 | pad_h = (window_size - H % window_size) % window_size
29 | pad_w = (window_size - W % window_size) % window_size
30 | if pad_h > 0 or pad_w > 0:
31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32 | Hp, Wp = H + pad_h, W + pad_w
33 |
34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35 | windows = (
36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | )
38 | return windows, (Hp, Wp)
39 |
40 |
41 | def window_unpartition(windows, window_size, pad_hw, hw):
42 | """
43 | Window unpartition into original sequences and removing padding.
44 | Args:
45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46 | window_size (int): window size.
47 | pad_hw (Tuple): padded height and width (Hp, Wp).
48 | hw (Tuple): original height and width (H, W) before padding.
49 | Returns:
50 | x: unpartitioned sequences with [B, H, W, C].
51 | """
52 | Hp, Wp = pad_hw
53 | H, W = hw
54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55 | x = windows.view(
56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57 | )
58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59 |
60 | if Hp > H or Wp > W:
61 | x = x[:, :H, :W, :].contiguous()
62 | return x
63 |
64 |
65 | class PatchEmbed(nn.Module):
66 | """
67 | Image to Patch Embedding.
68 | """
69 |
70 | def __init__(
71 | self,
72 | kernel_size: Tuple[int, ...] = (7, 7),
73 | stride: Tuple[int, ...] = (4, 4),
74 | padding: Tuple[int, ...] = (3, 3),
75 | in_chans: int = 3,
76 | embed_dim: int = 768,
77 | ):
78 | """
79 | Args:
80 | kernel_size (Tuple): kernel size of the projection layer.
81 | stride (Tuple): stride of the projection layer.
82 | padding (Tuple): padding size of the projection layer.
83 | in_chans (int): Number of input image channels.
84 | embed_dim (int): embed_dim (int): Patch embedding dimension.
85 | """
86 | super().__init__()
87 | self.proj = nn.Conv2d(
88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89 | )
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | x = self.proj(x)
93 | # B C H W -> B H W C
94 | x = x.permute(0, 2, 3, 1)
95 | return x
96 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/modeling/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torchvision.transforms import Normalize, Resize, ToTensor
11 |
12 |
13 | class SAM2Transforms(nn.Module):
14 | def __init__(
15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
16 | ):
17 | """
18 | Transforms for SAM2.
19 | """
20 | super().__init__()
21 | self.resolution = resolution
22 | self.mask_threshold = mask_threshold
23 | self.max_hole_area = max_hole_area
24 | self.max_sprinkle_area = max_sprinkle_area
25 | self.mean = [0.485, 0.456, 0.406]
26 | self.std = [0.229, 0.224, 0.225]
27 | self.to_tensor = ToTensor()
28 | self.transforms = torch.jit.script(
29 | nn.Sequential(
30 | Resize((self.resolution, self.resolution)),
31 | Normalize(self.mean, self.std),
32 | )
33 | )
34 |
35 | def __call__(self, x):
36 | x = self.to_tensor(x)
37 | return self.transforms(x)
38 |
39 | def forward_batch(self, img_list):
40 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
41 | img_batch = torch.stack(img_batch, dim=0)
42 | return img_batch
43 |
44 | def transform_coords(
45 | self, coords: torch.Tensor, normalize=False, orig_hw=None
46 | ) -> torch.Tensor:
47 | """
48 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
49 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
50 |
51 | Returns
52 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
53 | """
54 | if normalize:
55 | assert orig_hw is not None
56 | h, w = orig_hw
57 | coords = coords.clone()
58 | coords[..., 0] = coords[..., 0] / w
59 | coords[..., 1] = coords[..., 1] / h
60 |
61 | coords = coords * self.resolution # unnormalize coords
62 | return coords
63 |
64 | def transform_boxes(
65 | self, boxes: torch.Tensor, normalize=False, orig_hw=None
66 | ) -> torch.Tensor:
67 | """
68 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
69 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
70 | """
71 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
72 | return boxes
73 |
74 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
75 | """
76 | Perform PostProcessing on output masks.
77 | """
78 | from model.segment_anything_2.sam2.utils.misc import get_connected_components
79 |
80 | masks = masks.float()
81 | if self.max_hole_area > 0:
82 | # Holes are those connected components in background with area <= self.fill_hole_area
83 | # (background regions are those with mask scores <= self.mask_threshold)
84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
85 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
86 | is_hole = (labels > 0) & (areas <= self.max_hole_area)
87 | is_hole = is_hole.reshape_as(masks)
88 | # We fill holes with a small positive mask score (10.0) to change them to foreground.
89 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
90 |
91 | if self.max_sprinkle_area > 0:
92 | labels, areas = get_connected_components(mask_flat > self.mask_threshold)
93 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
94 | is_hole = is_hole.reshape_as(masks)
95 | # We fill holes with negative mask score (-10.0) to change them to background.
96 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
97 |
98 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
99 | return masks
100 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2_configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | # use high-resolution feature map in the SAM mask decoder
93 | use_high_res_features_in_sam: true
94 | # output 3 masks on the first click on initial conditioning frames
95 | multimask_output_in_sam: true
96 | # SAM heads
97 | iou_prediction_use_sigmoid: True
98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99 | use_obj_ptrs_in_encoder: true
100 | add_tpos_enc_to_obj_ptrs: false
101 | only_obj_ptrs_in_the_past_for_eval: true
102 | # object occlusion prediction
103 | pred_obj_scores: true
104 | pred_obj_scores_mlp: true
105 | fixed_no_obj_ptr: true
106 | # multimask tracking settings
107 | multimask_output_for_tracking: true
108 | use_multimask_token_for_obj_ptr: true
109 | multimask_min_pt_num: 0
110 | multimask_max_pt_num: 1
111 | use_mlp_for_obj_ptr_proj: true
112 | # Compilation flag
113 | compile_image_encoder: False
114 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
118 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | # use high-resolution feature map in the SAM mask decoder
96 | use_high_res_features_in_sam: true
97 | # output 3 masks on the first click on initial conditioning frames
98 | multimask_output_in_sam: true
99 | # SAM heads
100 | iou_prediction_use_sigmoid: True
101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102 | use_obj_ptrs_in_encoder: true
103 | add_tpos_enc_to_obj_ptrs: false
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/sam2_configs/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: model.segment_anything_2.sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: model.segment_anything_2.sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: model.segment_anything_2.sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: model.segment_anything_2.sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: model.segment_anything_2.sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: model.segment_anything_2.sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: model.segment_anything_2.sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | # HieraT does not currently support compilation, should always be set to False
118 | compile_image_encoder: False
119 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/segment_anything_2/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import find_packages, setup
8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
9 |
10 | def get_extensions():
11 | srcs = ["sam2/csrc/connected_components.cu"]
12 | compile_args = {
13 | "cxx": [],
14 | "nvcc": [
15 | "-DCUDA_HAS_FP16=1",
16 | "-D__CUDA_NO_HALF_OPERATORS__",
17 | "-D__CUDA_NO_HALF_CONVERSIONS__",
18 | "-D__CUDA_NO_HALF2_OPERATORS__",
19 | ],
20 | }
21 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
22 | return ext_modules
23 |
24 |
25 | # Setup configuration
26 | setup(
27 | ext_modules=get_extensions(),
28 | cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
29 | )
30 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/videogpt_plus/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import VideoGPTPlusPhi3ForCausalLM
2 | from .model import VideoGPTPlusLlamaForCausalLM
3 |
--------------------------------------------------------------------------------
/VideoGLaMM/model/videogpt_plus/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from distutils.util import strtobool
3 |
4 | # Configuration Constants
5 | # TODO: Change the chunk size if you use any other video encoder accordingly
6 | CHUNK_SIZE = 4 # Video chunk size for InternVideo2-Stage2_1B-224p-f4 which is trained using 4 frames per video
7 | NUM_FRAMES = int(os.environ.get("NUM_FRAMES", 16)) # Number of video frames (if using video)
8 | NUM_CONTEXT_IMAGES = int(os.environ.get("NUM_CONTEXT_IMAGES", 16)) # Number of context images for video
9 |
10 | # Model Constants
11 | IGNORE_INDEX = -100
12 | IMAGE_TOKEN_INDEX = -200
13 | DEFAULT_IMAGE_TOKEN = ""
14 | DEFAULT_VIDEO_TOKEN = "