├── LICENSE
├── README.md
├── app.py
├── assets
├── example_image1.jpg
├── example_image2.jpg
├── example_video1.mp4
├── example_video2.mp4
├── method1.jpg
└── method2.jpg
├── environment_setup.sh
├── inference.py
├── pyproject.toml
├── scripts
├── eval
│ └── lmms.sh
├── train
│ ├── pretrain.sh
│ └── sft.sh
└── zero2.json
└── vila_u
├── __init__.py
├── cli
└── eval.py
├── constants.py
├── conversation.py
├── data
├── __init__.py
├── dataset.py
├── datasets_mixture.py
└── simple_vila_webdataset.py
├── entry.py
├── eval
├── __init__.py
├── lmms
│ ├── models
│ │ ├── __init__.py
│ │ └── vila_u.py
│ └── tasks
│ │ └── __init__.py
└── registry.yaml
├── media.py
├── mm_utils.py
├── model
├── __init__.py
├── builder.py
├── configuration_vila_u.py
├── language_model
│ ├── builder.py
│ └── vila_u_llama.py
├── multimodal_encoder
│ ├── builder.py
│ ├── rqvaesigliptransformer
│ │ ├── __init__.py
│ │ ├── configuration_rqvaesigliptransformer.py
│ │ ├── modeling_rqvaesigliptransformer.py
│ │ ├── rqtransformer
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── configuration_rqtransformer.py
│ │ │ └── modeling_rqtransformer.py
│ │ └── rqvaesiglip
│ │ │ ├── __init__.py
│ │ │ ├── configuration_rqvaesiglip.py
│ │ │ ├── modeling_rqvaesiglip.py
│ │ │ ├── modules.py
│ │ │ ├── quantizations.py
│ │ │ └── siglip
│ │ │ ├── __init__.py
│ │ │ ├── configuration_siglip.py
│ │ │ ├── convert_siglip_to_hf.py
│ │ │ ├── image_processing_siglip.py
│ │ │ ├── modeling_siglip.py
│ │ │ ├── processing_siglip.py
│ │ │ └── tokenization_siglip.py
│ └── rqvaesigliptransformer_encoder.py
├── multimodal_projector
│ ├── base_projector.py
│ └── builder.py
├── utils.py
└── vila_u_arch.py
├── train
├── args.py
├── callbacks
│ └── autoresume_callback.py
├── train.py
├── train_mem.py
├── transformer_normalize_monkey_patch.py
├── transformers_replace
│ ├── generation
│ │ └── utils.py
│ └── models
│ │ └── llama
│ │ ├── configuring_llama.py
│ │ ├── modeling_llama.py
│ │ └── tokenization_llama.py
├── utils.py
└── vila_u_trainer.py
├── utils
├── __init__.py
├── distributed.py
├── io.py
├── logging.py
├── media.py
├── tokenizer.py
└── utils.py
└── wids
├── __init__.py
├── wids.py
├── wids_bench.py
├── wids_cleanup.py
├── wids_dir.py
├── wids_dl.py
├── wids_index.py
├── wids_lru.py
├── wids_mmtar.py
├── wids_specs.py
└── wids_tar.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 MIT HAN Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VILA-U: a Unified Foundation Model Integrating Visual Understanding and Generation
2 |
3 | \[[Online Demo](https://vila-u.hanlab.ai)\] \[[Paper](https://arxiv.org/abs/2409.04429#)\] \[[Project](https://hanlab.mit.edu/projects/vila-u)\] \[[Models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6)\]
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Figure 1: Multi-token in, Multi-token out Training and Inference.
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | Figure 2: Unified Foundation Vision Tower.
20 |
21 |
22 |
23 | ## News
24 |
25 | - \[2025/01\] 🎉 VILA-U has been accepted to ICLR2025!
26 | - \[2024/10\] Online demo of VILA-U is available: [https://vila-u.hanlab.ai](https://vila-u.hanlab.ai). Have a try!
27 | - \[2024/10\] We release the code and [models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6) for VILA-U!
28 |
29 | ## Abstract
30 |
31 | **VILA-U** is a **U**nified foundation model that integrates **V**ideo, **I**mage, **La**nguage understanding and generation. Traditional visual language models (VLMs) use separate modules for understanding and generating visual content, which can lead to misalignment and increased complexity. In contrast, VILA-U employs a single autoregressive next-token prediction framework for both tasks, eliminating the need for additional components like diffusion models. This approach not only simplifies the model but also achieves near state-of-the-art performance in visual language understanding and generation. The success of VILA-U is attributed to two main factors: the unified vision tower that aligns discrete visual tokens with textual inputs during pretraining, which enhances visual perception, and autoregressive image generation can achieve similar quality as diffusion models with high-quality dataset. This allows VILA-U to perform comparably to more complex models using a fully token-based autoregressive framework.
32 |
33 | ## Preparation
34 |
35 | ### Environment Setup
36 |
37 | ```bash
38 | git clone https://github.com/mit-han-lab/vila-u
39 | cd vila-u
40 | ./environment_setup.sh vila-u
41 | ```
42 |
43 | ### Download Models
44 |
45 | Please download our [models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6) from HuggingFace.
46 |
47 | ```bash
48 | git lfs install
49 | git clone https://huggingface.co/mit-han-lab/vila-u-7b-256
50 | ```
51 |
52 | ## Usage
53 |
54 | ### Gradio Demo
55 |
56 | Run the following command to launch a local gradio demo:
57 | ```bash
58 | CUDA_VISIBLE_DEVICES=0 python app.py --model_path path/to/your_downloaded_model
59 | ```
60 |
61 | ### Command Line Inference
62 |
63 | ```bash
64 | # Image Understanding
65 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --image_path assets/example_image1.jpg --query "Can you describe what is happening?"
66 | ```
67 |
68 | ```bash
69 | # Video Understanding
70 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --video_path assets/example_video1.mp4 --query "Elaborate on the visual and narrative elements of the video in detail."
71 | ```
72 |
73 | ```bash
74 | # Image Generation
75 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --prompt "A snowy mountain." --save_path path/to/save_images --generation_nums 8
76 | ```
77 |
78 | ```bash
79 | # Video Generation
80 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --prompt "Fireworks in the air." --video_generation True --save_path path/to/save_videos
81 | ```
82 |
83 | ### Evaluation
84 |
85 | Evaluate VILA-U on visual language benchmarks with the following command:
86 | ```bash
87 | vila_u-eval -m path/to/model -c vicuna_v1 -ti local
88 | ```
89 | Please refer to `vila_u/cli/eval.py` for more argument details.
90 |
91 | ### Training
92 |
93 | Note: Please prepare data before training. Data preparation details are in the file `vila_u/data/datasets_mixture.py`.
94 |
95 | ```bash
96 | # Pretrain
97 | srun -p your_slurm_partition -N 8 -t 04:00:00 -A your_slurm_account -J vila-u:pretrain --gpus-per-node 8 --exclusive --dependency singleton bash scripts/train/pretrain.sh &
98 | ```
99 |
100 | ```bash
101 | # SFT
102 | srun -p your_slurm_partition -N 8 -t 04:00:00 -A your_slurm_account -J vila-u:sft --gpus-per-node 8 --exclusive --dependency singleton bash scripts/train/sft.sh &
103 | ```
104 |
105 | ## Acknowledgment
106 |
107 | We thank Zhijian Liu from NVIDIA for his assistance with the evaluation setup.
108 |
109 | ## Citation
110 |
111 | If you find VILA-U useful or relevant to your project and research, please kindly cite our paper:
112 |
113 | ```bibtex
114 | @article{wu2024vila,
115 | title={Vila-u: a unified foundation model integrating visual understanding and generation},
116 | author={Wu, Yecheng and Zhang, Zhuoyang and Chen, Junyu and Tang, Haotian and Li, Dacheng and Fang, Yunhao and Zhu, Ligeng and Xie, Enze and Yin, Hongxu and Yi, Li and others},
117 | journal={arXiv preprint arXiv:2409.04429},
118 | year={2024}
119 | }
120 | ```
121 |
122 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import gradio as gr
4 | import imghdr
5 | import numpy as np
6 | import os
7 | import shutil
8 | import signal
9 | import sys
10 | import torch
11 | import uuid
12 | import vila_u
13 |
14 | CFG = 3.0
15 | TEMPERATURE = 0.9
16 | TOP_P = 0.6
17 |
18 |
19 | def is_image_file(filepath):
20 | return imghdr.what(filepath) is not None
21 |
22 |
23 | def generate_response(image, video, query, chat_history):
24 | if query is not None and image is None and video is None:
25 | response = model.generate_image_content(prompt=query, cfg=CFG)[0]
26 | out = response.permute(1, 2, 0)
27 | out = out.cpu().numpy().astype(np.uint8)
28 | out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
29 | image_filename = f"{uuid.uuid4()}.png"
30 | image_path = os.path.join(temp_dir, image_filename)
31 | cv2.imwrite(image_path, out)
32 |
33 | return chat_history + [(query, "Here is the image generated:"), (None, (image_path,))]
34 | elif image is not None:
35 | generation_config = model.default_generation_config
36 | generation_config.temperature = TEMPERATURE
37 | generation_config.top_p = TOP_P
38 | answer = model.generate_content([vila_u.Image(image), query], generation_config)
39 | media_display = image
40 | elif video is not None:
41 | generation_config = model.default_generation_config
42 | generation_config.temperature = TEMPERATURE
43 | generation_config.top_p = TOP_P
44 | answer = model.generate_content([vila_u.Video(video), query], generation_config)
45 | media_display = video
46 | else:
47 | return chat_history + [(None, "No input!")]
48 |
49 | return chat_history + [((media_display,), None), (query, answer)]
50 |
51 |
52 | def clear_chat():
53 | return None, None, None, []
54 |
55 |
56 | def regenerate_last_answer(chat_history):
57 | if len(chat_history) < 1:
58 | return chat_history
59 |
60 | last_query, last_answer = chat_history[-1]
61 | if last_query is None:
62 | if last_answer == "No input!":
63 | return chat_history
64 | else:
65 | return generate_response(None, None, chat_history[-2][0], chat_history[:-2])
66 | else:
67 | last_media = chat_history[-2][0][0]
68 | if is_image_file(last_media):
69 | return generate_response(last_media, None, last_query, chat_history[:-2])
70 | else:
71 | return generate_response(None, last_media, last_query, chat_history[:-2])
72 |
73 |
74 | def cleanup():
75 | if os.path.exists(temp_dir):
76 | shutil.rmtree(temp_dir)
77 |
78 |
79 | if __name__ == "__main__":
80 | parser = argparse.ArgumentParser()
81 | parser.add_argument("--model_path", type=str)
82 | args = parser.parse_args()
83 |
84 | if torch.cuda.is_available():
85 | device = 'cuda'
86 | else:
87 | raise ValueError("CUDA is not available on this machine. Please use a CUDA-enabled machine to run this demo.")
88 | model = vila_u.load(args.model_path).to(device)
89 |
90 | temp_dir = 'temp/'
91 | os.makedirs(temp_dir, exist_ok=True)
92 |
93 | signal.signal(signal.SIGINT, lambda s, f: (cleanup(), sys.exit()))
94 |
95 | with gr.Blocks(title='VILA-U') as demo:
96 | gr.Markdown("# VILA-U: a Unified Foundation Model Integrating Visual Understanding and Generation")
97 | websites = (
98 | """
99 | [[Paper](https://arxiv.org/abs/2409.04429)]
100 | [[Project](https://hanlab.mit.edu/projects/vila-u)]
101 | [[GitHub](https://github.com/mit-han-lab/vila-u)]
102 | [[Models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6)]
103 | """
104 | )
105 | gr.Markdown(websites)
106 |
107 | with gr.Row():
108 | with gr.Column(scale=2):
109 | image_input = gr.Image(label="Upload Image", type="filepath")
110 | video_input = gr.Video(label="Upload Video", type="filepath")
111 |
112 | with gr.Column(scale=4):
113 | output_container = gr.Chatbot(
114 | label="VILA-U Chatbot",
115 | height=400,
116 | layout="panel",
117 | )
118 |
119 | with gr.Row():
120 | question_input = gr.Textbox(show_label=False, \
121 | placeholder="Submit a question along with visual input, or provide an image generation prompt alone.", container=False, scale=6)
122 |
123 | submit_button = gr.Button("Submit", variant="primary", scale=1)
124 | clear_button = gr.Button(value="🗑️ Clear", scale=1)
125 | retry_button = gr.Button(value="🔄 Retry", scale=1)
126 |
127 | with gr.Row():
128 | gr.Examples(examples=[
129 | ["assets/example_image1.jpg", "Can you describe what is happening?"],
130 | ["assets/example_image2.jpg", "What is the brand of the silver car in the image?"],
131 | ], inputs=[image_input, question_input], cache_examples=False, label="Image Understanding Examples.")
132 |
133 | gr.Examples(examples=[
134 | ["assets/example_video1.mp4", "Elaborate on the visual and narrative elements of the video in detail."],
135 | ["assets/example_video2.mp4", "What is the man putting on the plate?"],
136 | ], inputs=[video_input, question_input], cache_examples=False, label="Video Understanding Examples.")
137 |
138 | gr.Examples(examples=[
139 | ["An elephant walking in the water."],
140 | ["A melting apple."],
141 | ["An astronaut riding a horse on the moon, oil painting by Van Gogh."],
142 | ["New England fall with leaves, house and river."],
143 | ["An old man with white beard."],
144 | ["A crystal tree shimmering under a starry sky."],
145 | ["A deep forest clearing with a mirrored pond reflecting a galaxyfilled night sky."],
146 | ["Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers."]
147 | ], inputs=[question_input], cache_examples=False, label="Image Generation Examples.")
148 |
149 | submit_button.click(generate_response, inputs=[image_input, video_input, question_input, output_container], outputs=output_container)
150 | clear_button.click(clear_chat, outputs=[image_input, video_input, question_input, output_container])
151 | retry_button.click(regenerate_last_answer, inputs=output_container, outputs=output_container)
152 |
153 | try:
154 | demo.launch(share=True)
155 | finally:
156 | cleanup()
--------------------------------------------------------------------------------
/assets/example_image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_image1.jpg
--------------------------------------------------------------------------------
/assets/example_image2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_image2.jpg
--------------------------------------------------------------------------------
/assets/example_video1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_video1.mp4
--------------------------------------------------------------------------------
/assets/example_video2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_video2.mp4
--------------------------------------------------------------------------------
/assets/method1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/method1.jpg
--------------------------------------------------------------------------------
/assets/method2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/method2.jpg
--------------------------------------------------------------------------------
/environment_setup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | CONDA_ENV=${1:-""}
5 | if [ -n "$CONDA_ENV" ]; then
6 | # This is required to activate conda environment
7 | eval "$(conda shell.bash hook)"
8 |
9 | conda create -n $CONDA_ENV python=3.10.14 -y
10 | conda activate $CONDA_ENV
11 | # This is optional if you prefer to use built-in nvcc
12 | conda install -c nvidia cuda-toolkit -y
13 | else
14 | echo "Skipping conda environment creation. Make sure you have the correct environment activated."
15 | fi
16 |
17 | # This is required to enable PEP 660 support
18 | pip install --upgrade pip setuptools
19 |
20 | # Install FlashAttention2
21 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
22 |
23 | # Install VILA
24 | pip install -e ".[train,eval]"
25 |
26 | pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
27 |
28 | pip install git+https://github.com/huggingface/transformers@v4.36.2
29 |
30 | # Replace transformers and deepspeed files
31 | site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
32 | cp -rv ./vila_u/train/transformers_replace/* $site_pkg_path/transformers/
33 | # Avoid confused warning
34 | rm -rf $site_pkg_path/lmms_eval/models/mplug_owl_video/modeling_mplug_owl.py
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import os
5 | import vila_u
6 |
7 |
8 | def save_image(response, path):
9 | os.makedirs(path, exist_ok=True)
10 | for i in range(response.shape[0]):
11 | image = response[i].permute(1, 2, 0)
12 | image = image.cpu().numpy().astype(np.uint8)
13 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
14 | cv2.imwrite(os.path.join(path, f"image_{i}.png"), image)
15 |
16 |
17 | def save_video(response, path):
18 | os.makedirs(path, exist_ok=True)
19 | for i in range(response.shape[0]):
20 | video = response[i].permute(0, 2, 3, 1)
21 | video = video.cpu().numpy().astype(np.uint8)
22 | video = np.concatenate(video, axis=1)
23 | video = cv2.cvtColor(video, cv2.COLOR_RGB2BGR)
24 | cv2.imwrite(os.path.join(path, f"video_{i}.png"), video)
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--model_path", type=str, required=True)
30 | ### image/video understanding arguments
31 | parser.add_argument("--image_path", type=str, default=None)
32 | parser.add_argument("--video_path", type=str, default=None)
33 | parser.add_argument("--query", type=str, default=None)
34 | parser.add_argument("--temperature", type=float, default=0.9, help="The value of temperature for text generation.")
35 | parser.add_argument("--top_p", type=float, default=0.6, help="The value of top-p for text generation.")
36 | ### image and video generation arguments
37 | parser.add_argument("--prompt", type=str, default=None)
38 | parser.add_argument("--video_generation", type=bool, default=False)
39 | parser.add_argument("--cfg", type=float, default=3.0, help="The value of the classifier free guidance for image generation.")
40 | parser.add_argument("--save_path", type=str, default="generated_images/")
41 | parser.add_argument("--generation_nums", type=int, default=1)
42 | args = parser.parse_args()
43 |
44 | if args.model_path is not None:
45 | model = vila_u.load(args.model_path)
46 | else:
47 | raise ValueError("No model path provided!")
48 |
49 | if args.query is not None:
50 | generation_config = model.default_generation_config
51 | generation_config.temperature = args.temperature
52 | generation_config.top_p = args.top_p
53 | if args.image_path is not None:
54 | image = vila_u.Image(args.image_path)
55 | response = model.generate_content([image, args.query])
56 | print("\033[1;32mResponse:\033[0m", response)
57 | exit()
58 | elif args.video_path is not None:
59 | video = vila_u.Video(args.video_path)
60 | response = model.generate_content([video, args.query])
61 | print("\033[1;32mResponse:\033[0m", response)
62 | exit()
63 | else:
64 | raise ValueError("No visual content input!")
65 | elif args.prompt is not None:
66 | if args.video_generation:
67 | response = model.generate_video_content(args.prompt, args.cfg, args.generation_nums)
68 | save_video(response, args.save_path)
69 | exit()
70 | else:
71 | response = model.generate_image_content(args.prompt, args.cfg, args.generation_nums)
72 | save_image(response, args.save_path)
73 | exit()
74 | else:
75 | raise ValueError("No query or prompt provided!")
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "vila-u"
7 | version = "1.0.0"
8 | description = "VILA-U: a Unified Foundation Model Integrating Visual Understanding and Generation"
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.3.0", "torchvision==0.18.0",
17 | "reka-api", "google-generativeai", "anthropic",
18 | "tokenizers>=0.15.2", "sentencepiece==0.1.99", "shortuuid",
19 | "accelerate==0.34.2", "peft>=0.9.0", "bitsandbytes==0.41.0",
20 | "pydantic<2,>=1", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2",
21 | "gradio==3.35.2", "gradio_client==0.2.9",
22 | "requests", "httpx", "uvicorn", "fastapi", "fire",
23 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.12",
24 | "openpyxl==3.1.2", "pytorchvideo==0.1.5", "decord==0.6.0",
25 | "datasets==2.16.1", "openai==1.8.0", "webdataset==0.2.86",
26 | "nltk==3.3", "pywsd==1.2.4", "opencv-python-headless==4.8.0.76",
27 | "tyro", "pytest", "pre-commit", "loguru", "hydra-core"
28 | ]
29 |
30 | [project.scripts]
31 | vila_u-eval = "vila_u.cli.eval:main"
32 |
33 | [project.optional-dependencies]
34 | train = ["deepspeed==0.9.5", "ninja", "wandb"]
35 | eval = ["mmengine", "word2number", "Levenshtein", "nltk", "pywsd"]
36 |
37 | [project.urls]
38 | "Homepage" = "https://github.com/mit-han-lab/vila-u"
39 |
40 | [tool.setuptools.packages.find]
41 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
42 |
43 | [tool.wheel]
44 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
45 |
--------------------------------------------------------------------------------
/scripts/eval/lmms.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | TASK=$1
6 | MODEL_PATH=$2
7 | CONV_MODE=$3
8 | if [[ "$TASK" =~ videomme ]]; then
9 | NUM_VIDEO_FRAMES=$(echo "$TASK" | cut -d'-' -f2-)
10 | IFS='-' read -ra segments <<< "$TASK"
11 | unset segments[${#segments[@]}-1]
12 | TASK=$(IFS=-; echo "${segments[*]}")
13 | else
14 | NUM_VIDEO_FRAMES=8
15 | fi
16 |
17 | MODEL_NAME=$(basename $MODEL_PATH)
18 | OUTPUT_DIR=${OUTPUT_DIR:-"runs/eval/$MODEL_NAME/lmms-$TASK"}
19 |
20 | NPROC_PER_NODE=${NPROC_PER_NODE:-$(nvidia-smi -L | wc -l)}
21 |
22 | export LMMS_EVAL_PLUGINS=vila_u.eval.lmms
23 | export HF_HOME=$HOME/.cache/huggingface
24 | export CACHE_DIR=$OUTPUT_DIR/cache
25 |
26 | torchrun --nproc_per_node=$NPROC_PER_NODE \
27 | -m lmms_eval \
28 | --model vila_u \
29 | --model_args model_path=$MODEL_PATH,conv_mode=$CONV_MODE,num_video_frames=$NUM_VIDEO_FRAMES \
30 | --tasks $TASK \
31 | --log_samples \
32 | --output_path $OUTPUT_DIR
33 |
34 | mv $OUTPUT_DIR/*_$MODEL_NAME/*_results.json $OUTPUT_DIR/results.json || true
35 | mv $OUTPUT_DIR/*_$MODEL_NAME/*_samples_*.jsonl $OUTPUT_DIR/samples.jsonl || true
36 | mv $OUTPUT_DIR/*_$MODEL_NAME/* $OUTPUT_DIR || true
37 | rm -r $OUTPUT_DIR/*_$MODEL_NAME || true
38 |
39 | mv $OUTPUT_DIR/*_vila_u_*/* $OUTPUT_DIR || true
40 | rm -r $OUTPUT_DIR/*_vila_u_* || true
41 | rm -r $OUTPUT_DIR/rank*_metric_eval_done.txt || true
42 |
--------------------------------------------------------------------------------
/scripts/train/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_IB_SL=1
4 | export CUDA_DEVICE_MAX_CONNECTIONS=1
5 | export NCCL_ASYNC_ERROR_HANDLING=1
6 |
7 | source activate vila-u
8 |
9 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
10 | export MASTER_ADDR=${master_addr:-"127.0.0.1"}
11 | export CURRENT_RANK=${SLURM_PROCID:-"0"}
12 | worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ')
13 | n_node=${SLURM_JOB_NUM_NODES:-1}
14 |
15 | echo "MASTER_ADDR="$MASTER_ADDR
16 | echo "JobID: $SLURM_JOB_ID | Full list: $worker_list"
17 |
18 | global_bs=${BATCH_SIZE:-512}
19 | acc_step=${ACC_STEP:-1}
20 | bs=$((global_bs / n_node / acc_step))
21 |
22 | echo "PER_DEVICE_TRAIN_BATCH_SIZE="$bs
23 |
24 | torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \
25 | --master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \
26 | vila_u/train/train_mem.py \
27 | --deepspeed ./scripts/zero2.json \
28 | --model_name_or_path meta-llama/Llama-2-7b \
29 | --version v1 \
30 | --data_mixture sharegpt4v_pretrain+mmc4core+internal_generation+openvid_generation \
31 | --vision_tower path/to/prepared_vision_tower \
32 | --chunk_sampler True \
33 | --mm_projector mlp2x_gelu \
34 | --tune_mm_projector True \
35 | --tune_language_model True \
36 | --mm_vision_select_layer -2 \
37 | --mm_use_im_start_end True \
38 | --mm_use_vi_start_end True \
39 | --mm_use_im_patch_token False \
40 | --image_aspect_ratio resize \
41 | --bf16 True \
42 | --output_dir ./checkpoints/vila-u-pretrain \
43 | --num_train_epochs 1 \
44 | --per_device_train_batch_size $bs \
45 | --per_device_eval_batch_size 4 \
46 | --gradient_accumulation_steps $acc_step \
47 | --evaluation_strategy "no" \
48 | --save_strategy "steps" \
49 | --save_steps 100 \
50 | --save_total_limit 1 \
51 | --learning_rate 5e-5 \
52 | --weight_decay 0. \
53 | --warmup_ratio 0.03 \
54 | --lr_scheduler_type "cosine" \
55 | --logging_steps 1 \
56 | --tf32 True \
57 | --model_max_length 8192 \
58 | --gradient_checkpointing True \
59 | --dataloader_num_workers 8 \
60 | --lazy_preprocess True \
61 | --report_to wandb \
--------------------------------------------------------------------------------
/scripts/train/sft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_IB_SL=1
4 | export CUDA_DEVICE_MAX_CONNECTIONS=1
5 | export NCCL_ASYNC_ERROR_HANDLING=1
6 |
7 | source activate vila-u
8 |
9 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
10 | export MASTER_ADDR=${master_addr:-"127.0.0.1"}
11 | export CURRENT_RANK=${SLURM_PROCID:-"0"}
12 | worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ')
13 | n_node=${SLURM_JOB_NUM_NODES:-1}
14 |
15 | echo "MASTER_ADDR="$MASTER_ADDR
16 | echo "JobID: $SLURM_JOB_ID | Full list: $worker_list"
17 |
18 | global_bs=${BATCH_SIZE:-512}
19 | acc_step=${ACC_STEP:-1}
20 | bs=$((global_bs / n_node / acc_step))
21 |
22 | echo "PER_DEVICE_TRAIN_BATCH_SIZE="$bs
23 |
24 | torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \
25 | --master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \
26 | vila_u/train/train_mem.py \
27 | --deepspeed ./scripts/zero2.json \
28 | --model_name_or_path path/to/vila-u-pretrain \
29 | --version v1 \
30 | --data_mixture sharegpt4v_sft+vflan+shot2story_shotonly+video_chatgpt+youcook2+vatex+sharegpt_video+scienceqa+wit_subset+sherlock+internal_generation+openvid_generation \
31 | --chunk_sampler True \
32 | --mm_projector mlp2x_gelu \
33 | --tune_mm_projector True \
34 | --tune_language_model True \
35 | --mm_vision_select_layer -2 \
36 | --mm_use_im_start_end True \
37 | --mm_use_vi_start_end True \
38 | --mm_use_im_patch_token False \
39 | --image_aspect_ratio resize \
40 | --bf16 True \
41 | --output_dir ./checkpoints/vila-u-sft \
42 | --num_train_epochs 1 \
43 | --per_device_train_batch_size $bs \
44 | --per_device_eval_batch_size 4 \
45 | --gradient_accumulation_steps $acc_step \
46 | --evaluation_strategy "no" \
47 | --save_strategy "steps" \
48 | --save_steps 100 \
49 | --save_total_limit 1 \
50 | --learning_rate 1e-4 \
51 | --weight_decay 0. \
52 | --warmup_ratio 0.03 \
53 | --lr_scheduler_type "cosine" \
54 | --logging_steps 1 \
55 | --tf32 True \
56 | --model_max_length 8192 \
57 | --gradient_checkpointing True \
58 | --dataloader_num_workers 4 \
59 | --lazy_preprocess True \
60 | --vflan_no_system_prompt True \
61 | --report_to wandb \
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/vila_u/__init__.py:
--------------------------------------------------------------------------------
1 | from .entry import *
2 | from .media import *
--------------------------------------------------------------------------------
/vila_u/cli/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import subprocess
4 | import time
5 |
6 | from argparse import ArgumentParser
7 | from collections import deque
8 | from tabulate import tabulate
9 | from typing import Dict, List, Optional
10 |
11 | from vila_u.eval import EVAL_ROOT, TASKS
12 | from vila_u.utils import io
13 | from vila_u.utils.logging import logger
14 |
15 |
16 | def lstr(s: Optional[str]) -> Optional[List[str]]:
17 | if s is not None:
18 | s = s.split(",") if "," in s else [s]
19 | return s
20 |
21 |
22 | def _load_results(output_dir: str, task: str) -> Optional[Dict]:
23 | for fname in ["results.json", "metrics.json"]:
24 | if os.path.exists(os.path.join(output_dir, task, fname)):
25 | return io.load(os.path.join(output_dir, task, fname))
26 | return None
27 |
28 |
29 | def main() -> None:
30 | parser = ArgumentParser()
31 | parser.add_argument("--model-path", "-m", type=str, required=True)
32 | parser.add_argument("--model-name", type=str, default=None)
33 | parser.add_argument("--conv-mode", "-c", type=str, required=True)
34 | parser.add_argument("--nproc-per-node", "-n", type=int, default=8)
35 | parser.add_argument("--tasks", "-t", type=lstr)
36 | parser.add_argument("--tags-include", "-ti", type=lstr)
37 | parser.add_argument("--tags-exclude", "-te", type=lstr)
38 | parser.add_argument("--num_video_frames", "-nf", type=str, default="8")
39 | parser.add_argument("--output-dir", type=str, default=None)
40 | parser.add_argument("--report-to", "-r", choices=["wandb", None], default=None)
41 | args = parser.parse_args()
42 |
43 | # Get the model name and output directory
44 | model_name = os.path.basename(os.path.normpath(args.model_path)).lower()
45 | if args.model_name is not None:
46 | model_name = args.model_name
47 | output_dir = os.path.join("runs", "eval", model_name)
48 | num_video_frames = args.num_video_frames
49 | if args.output_dir is not None:
50 | output_dir = osp.expanduser(args.output_dir)
51 |
52 | # Filter tasks based on name and tags
53 | tasks = []
54 | for task, metainfo in TASKS.items():
55 | tags = set(metainfo.get("tags", []))
56 | if args.tasks is not None and task not in args.tasks:
57 | continue
58 | if args.tags_include is not None and tags.isdisjoint(args.tags_include):
59 | continue
60 | if args.tags_exclude is not None and tags.intersection(args.tags_exclude):
61 | continue
62 | tasks.append(task)
63 | logger.info(f"Running evaluation for '{model_name}' on {len(tasks)} tasks: {tasks}")
64 |
65 | # Prepare the evaluation commands
66 | cmds = {}
67 | for task in tasks:
68 | if _load_results(output_dir, task=task):
69 | logger.warning(f"Skipping '{task}' as it has already been evaluated.")
70 | continue
71 |
72 | cmd = []
73 | if task.startswith("lmms-"):
74 | cmd += [f"{EVAL_ROOT}/lmms.sh", task.replace("lmms-", "")]
75 | elif "_" in task:
76 | name, split = task.split("_")
77 | cmd += [f"{EVAL_ROOT}/{name}.sh", split]
78 | else:
79 | cmd += [f"{EVAL_ROOT}/{task}.sh"]
80 | cmd += [args.model_path, args.conv_mode]
81 |
82 | concurrency = 1
83 | final_cmd = cmd
84 | cmds[task] = " ".join(final_cmd)
85 |
86 | # Prepare the environment variables
87 | env = os.environ.copy()
88 | env["NPROC_PER_NODE"] = str(args.nproc_per_node)
89 |
90 | # Run the commands with the specified concurrency
91 | remaining = deque(cmds.keys())
92 | processes, returncodes = {}, {}
93 | try:
94 | while remaining or processes:
95 | while remaining and len(processes) < concurrency:
96 | task = remaining.popleft()
97 | logger.info(f"Running '{cmds[task]}'")
98 | processes[task] = subprocess.Popen(
99 | cmds[task],
100 | stdout=subprocess.DEVNULL if concurrency > 1 else None,
101 | stderr=subprocess.DEVNULL if concurrency > 1 else None,
102 | shell=True,
103 | env=env,
104 | )
105 |
106 | for task, process in processes.items():
107 | if process.poll() is not None:
108 | returncodes[task] = process.returncode
109 | processes.pop(task)
110 | break
111 |
112 | time.sleep(1)
113 | except KeyboardInterrupt:
114 | logger.warning("Terminating all processes...")
115 | for _, process in processes.items():
116 | process.terminate()
117 | for _, process in processes.items():
118 | process.wait()
119 |
120 | final_return_code = 0
121 | # Check the return codes
122 | for task, returncode in returncodes.items():
123 | if returncode != 0:
124 | logger.error(f"Error running '{task}' evaluation (return code: {returncode})")
125 | final_return_code = returncode
126 |
127 | if args.report_to == "wandb":
128 | import wandb
129 |
130 | wandb_project = os.environ.get("WANDB_PROJECT", "vila-u-eval")
131 | wandb_entity = os.environ.get("WANDB_ENTITY", None)
132 | wandb_name = os.environ.get("WANDB_NAME", model_name)
133 | logger.info(f"initiating wandb run for '{wandb_project}/{wandb_name}'")
134 | wandb.init(
135 | project=wandb_project,
136 | entity=wandb_entity,
137 | name=wandb_name,
138 | config={
139 | "model_path": args.model_path,
140 | "conv_mode": args.conv_mode,
141 | },
142 | )
143 |
144 | # Collect the results and save them
145 | metrics = {}
146 | for task in tasks:
147 | results = _load_results(output_dir, task=task)
148 | if results is None:
149 | continue
150 | for name, path in TASKS[task].get("metrics", {}).items():
151 | val = results
152 | for key in path.split("/") if "/" in path else [path]:
153 | val = val[key]
154 | metrics[f"{task}/{name}"] = val
155 |
156 | if args.report_to == "wandb":
157 | logger.info(f"Logging '{task}/{name}' to wandb")
158 | wandb.log({f"{task}/{name}": val})
159 | io.save(os.path.join(output_dir, "metrics.json"), metrics, indent=4)
160 | logger.info(f"Saved all metrics to '{output_dir}/metrics.json'")
161 | if args.report_to == "wandb":
162 | logger.info(f"Saved wandb url to '{output_dir}/wandb.txt'")
163 | io.save(os.path.join(output_dir, "wandb.txt"), wandb.run.get_url())
164 |
165 | # Print the metrics in a tabular format
166 | logger.info("Results:\n" + tabulate(metrics.items(), tablefmt="simple_outline", headers=["Metric", "Value"]))
167 |
168 | return final_return_code
169 |
170 |
171 | if __name__ == "__main__":
172 | main()
--------------------------------------------------------------------------------
/vila_u/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | IGNORE_INDEX = -100
7 | IMAGE_TOKEN_INDEX = -200
8 | DEFAULT_IMAGE_TOKEN = ""
9 | DEFAULT_IMAGE_PATCH_TOKEN = ""
10 | DEFAULT_IM_START_TOKEN = ""
11 | DEFAULT_IM_END_TOKEN = ""
12 | DEFAULT_VI_START_TOKEN = ""
13 | DEFAULT_VI_END_TOKEN = ""
14 | IMAGE_PLACEHOLDER = ""
15 |
16 | SENTINEL_TOKEN = ""
--------------------------------------------------------------------------------
/vila_u/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | from enum import auto, Enum
4 | from typing import List, Tuple
5 |
6 | from vila_u.utils.logging import logger
7 |
8 |
9 | class SeparatorStyle(Enum):
10 | """Different separator style."""
11 | SINGLE = auto()
12 | TWO = auto()
13 |
14 |
15 | @dataclasses.dataclass
16 | class Conversation:
17 | """A class that keeps all conversation history."""
18 | system: str
19 | roles: List[str]
20 | messages: List[List[str]]
21 | offset: int
22 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23 | sep: str = "###"
24 | sep2: str = None
25 | version: str = "Unknown"
26 |
27 | skip_next: bool = False
28 |
29 | def get_prompt(self):
30 | messages = self.messages
31 | if len(messages) > 0 and type(messages[0][1]) is tuple:
32 | messages = self.messages.copy()
33 | init_role, init_msg = messages[0].copy()
34 | init_msg = init_msg[0].replace("", "").strip()
35 | messages[0] = (init_role, "\n" + init_msg)
36 |
37 | if self.sep_style == SeparatorStyle.SINGLE:
38 | ret = self.system + self.sep
39 | for role, message in messages:
40 | if message:
41 | if type(message) is tuple:
42 | message, _, _ = message
43 | ret += role + ": " + message + self.sep
44 | else:
45 | ret += role + ":"
46 | elif self.sep_style == SeparatorStyle.TWO:
47 | seps = [self.sep, self.sep2]
48 | ret = self.system + seps[0]
49 | for i, (role, message) in enumerate(messages):
50 | if message:
51 | if type(message) is tuple:
52 | message, _, _ = message
53 | ret += role + ": " + message + seps[i % 2]
54 | else:
55 | ret += role + ":"
56 | else:
57 | raise ValueError(f"Invalid style: {self.sep_style}")
58 |
59 | return ret
60 |
61 | def append_message(self, role, message):
62 | self.messages.append([role, message])
63 |
64 | def copy(self):
65 | return Conversation(
66 | system=self.system,
67 | roles=self.roles,
68 | messages=[[x, y] for x, y in self.messages],
69 | offset=self.offset,
70 | sep_style=self.sep_style,
71 | sep=self.sep,
72 | sep2=self.sep2,
73 | version=self.version)
74 |
75 | conv_vicuna_v0 = Conversation(
76 | system="A chat between a curious human and an artificial intelligence assistant. "
77 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
78 | roles=("Human", "Assistant"),
79 | messages=(
80 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
81 | ("Assistant",
82 | "Renewable energy sources are those that can be replenished naturally in a relatively "
83 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
84 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
85 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
86 | "renewable and non-renewable energy sources:\n"
87 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
88 | "energy sources are finite and will eventually run out.\n"
89 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
90 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
91 | "and other negative effects.\n"
92 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
93 | "have lower operational costs than non-renewable sources.\n"
94 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
95 | "locations than non-renewable sources.\n"
96 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
97 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
98 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
99 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
100 | ),
101 | offset=2,
102 | sep_style=SeparatorStyle.SINGLE,
103 | sep="###",
104 | )
105 |
106 | conv_vicuna_v1 = Conversation(
107 | system="A chat between a curious user and an artificial intelligence assistant. "
108 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
109 | roles=("USER", "ASSISTANT"),
110 | version="v1",
111 | messages=(),
112 | offset=0,
113 | sep_style=SeparatorStyle.TWO,
114 | sep=" ",
115 | sep2="",
116 | )
117 |
118 | default_conversation = conv_vicuna_v1
119 |
120 | conv_templates = {
121 | "v0": conv_vicuna_v0,
122 | "v1": conv_vicuna_v1,
123 | "vicuna_v1": conv_vicuna_v1,
124 | }
125 |
126 | CONVERSATION_MODE_MAPPING = {
127 | "vila-u-7b-256": "vicuna_v1",
128 | "vila-u-7b-384": "vicuna_v1",
129 | }
130 |
131 | def auto_set_conversation_mode(model_name_or_path: str) -> str:
132 | global default_conversation
133 | for k, v in CONVERSATION_MODE_MAPPING.items():
134 | if k in model_name_or_path.lower():
135 | logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
136 | default_conversation = conv_templates[v]
137 | return
138 |
--------------------------------------------------------------------------------
/vila_u/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import *
2 | from .datasets_mixture import *
3 | from .simple_vila_webdataset import VILAWebDataset
4 |
--------------------------------------------------------------------------------
/vila_u/data/datasets_mixture.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from dataclasses import dataclass, field
4 |
5 |
6 | @dataclass
7 | class Dataset:
8 | dataset_name: str
9 | dataset_type: str = field(default="torch")
10 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
11 | meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."})
12 | image_path: str = field(default=None, metadata={"help": "Path to the training image data."})
13 | description: str = field(
14 | default=None,
15 | metadata={
16 | "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset."
17 | },
18 | )
19 | test_script: str = (None,)
20 | maintainer: str = (None,)
21 |
22 |
23 | DATASETS = {}
24 |
25 |
26 | def add_dataset(dataset):
27 | if dataset.dataset_name in DATASETS:
28 | warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.")
29 | assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'."
30 |
31 | DATASETS.update({dataset.dataset_name: dataset})
32 |
33 |
34 | def register_datasets_mixtures():
35 | internal_generation = Dataset(
36 | dataset_name="internal_generation",
37 | dataset_type="internal-generation",
38 | data_path="",
39 | meta_path="",
40 | )
41 | add_dataset(internal_generation)
42 |
43 |
44 | # Please download data from https://github.com/NJU-PCALab/OpenVid-1M to prepare openvid_generation.
45 | openvid_generation = Dataset(
46 | dataset_name="openvid_generation",
47 | dataset_type="openvid-generation",
48 | data_path="",
49 | )
50 | add_dataset(openvid_generation)
51 |
52 |
53 | # Please download data from https://sharegpt4v.github.io/ to prepare sharegpt4v.
54 | sharegpt4v_pretrain = Dataset(
55 | dataset_name="sharegpt4v_pretrain",
56 | dataset_type="torch",
57 | data_path="",
58 | image_path="",
59 | description="Original data source: https://sharegpt4v.github.io/ ~1M long Image - Text pair generated by ShareGPT4V captioner.",
60 | )
61 | add_dataset(sharegpt4v_pretrain)
62 |
63 |
64 | sharegpt4v_sft = Dataset(
65 | dataset_name="sharegpt4v_sft",
66 | dataset_type="torch",
67 | data_path="",
68 | image_path="",
69 | description="Original data source: https://sharegpt4v.github.io/ 655K llava_1_5_sft data relablled w/ ShareGPT4V captioner.",
70 | )
71 | add_dataset(sharegpt4v_sft)
72 |
73 |
74 | # Please refer to https://github.com/NVlabs/VILA/tree/main/data_prepare to prepare the following datasets.
75 | mmc4core = Dataset(
76 | dataset_name="mmc4core",
77 | dataset_type="mmc4",
78 | data_path="",
79 | )
80 | add_dataset(mmc4core)
81 |
82 |
83 | vflan = Dataset(
84 | dataset_name="vflan",
85 | dataset_type="vflan",
86 | data_path="",
87 | )
88 | add_dataset(vflan)
89 |
90 |
91 | shot2story_shotonly = Dataset(
92 | dataset_name="shot2story_shotonly",
93 | dataset_type="torch",
94 | data_path="",
95 | image_path="",
96 | )
97 | add_dataset(shot2story_shotonly)
98 |
99 |
100 | video_chatgpt = Dataset(
101 | dataset_name="video_chatgpt",
102 | dataset_type="torch",
103 | data_path="",
104 | image_path="",
105 | )
106 | add_dataset(video_chatgpt)
107 |
108 |
109 | youcook2 = Dataset(
110 | dataset_name="youcook2",
111 | dataset_type="torch",
112 | data_path="",
113 | image_path="",
114 | )
115 | add_dataset(youcook2)
116 |
117 |
118 | vatex = Dataset(
119 | dataset_name="vatex",
120 | dataset_type="torch",
121 | data_path="",
122 | image_path="",
123 | )
124 | add_dataset(vatex)
125 |
126 |
127 | sharegpt_video = Dataset(
128 | dataset_name="sharegpt_video",
129 | dataset_type="torch",
130 | data_path="",
131 | image_path="",
132 | )
133 | add_dataset(sharegpt_video)
134 |
135 |
136 | scienceqa = Dataset(
137 | dataset_name="scienceqa",
138 | dataset_type="torch",
139 | data_path="",
140 | image_path="",
141 | )
142 | add_dataset(scienceqa)
143 |
144 |
145 | wit_subset = Dataset(
146 | dataset_name="wit_subset",
147 | dataset_type="torch",
148 | data_path="",
149 | image_path=""
150 | )
151 | add_dataset(wit_subset)
152 |
153 |
154 | sherlock = Dataset(
155 | dataset_name="sherlock",
156 | dataset_type="torch",
157 | data_path="",
158 | image_path="",
159 | )
160 | add_dataset(sherlock)
--------------------------------------------------------------------------------
/vila_u/data/simple_vila_webdataset.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import hashlib
3 | import os.path as osp
4 |
5 | from functools import reduce
6 | from vila_u.wids import ShardListDataset
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class VILAWebDataset(Dataset):
11 | def __init__(
12 | self,
13 | data_path=None,
14 | meta_path=None,
15 | cache_dir=None,
16 | max_shards_to_load=None,
17 | ):
18 | self.data_path = osp.expanduser(data_path)
19 | self.meta_path = osp.expanduser(meta_path) if meta_path is not None else None
20 |
21 | _local_meta_path = osp.join(self.data_path, "wids-meta.json")
22 | if meta_path is None and osp.exists(_local_meta_path):
23 | print(f"loading from {_local_meta_path}")
24 | self.meta_path = meta_path = _local_meta_path
25 |
26 | if meta_path is None:
27 | self.meta_path = osp.join(
28 | osp.expanduser(cache_dir),
29 | self.data_path.replace("/", "--") + f".max_shards:{max_shards_to_load}" + ".wdsmeta.json",
30 | )
31 |
32 | assert osp.exists(
33 | self.meta_path
34 | ), f"meta path not found in [{self.meta_path}] or [{_local_meta_path}]"
35 | print(f"[SimplyCoyo] Loading meta infomation {self.meta_path}", flush=True)
36 |
37 | uuid = hashlib.sha256(self.meta_path.encode()).hexdigest()[:8]
38 | self.dataset = ShardListDataset(
39 | self.meta_path,
40 | cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"),
41 | )
42 |
43 | def __getitem__(self, idx):
44 | return self.dataset[idx]
45 |
46 | def __len__(self):
47 | return len(self.dataset)
48 |
49 | @staticmethod
50 | def simple_collate(batch):
51 | batched_data = {}
52 | for data in batch:
53 | for k, v in data.items():
54 | if k not in batched_data:
55 | batched_data[k] = []
56 | batched_data[k].append(v)
57 | return dict(batched_data)
58 |
59 | @staticmethod
60 | def custom_collate(batch):
61 | def transform2list(a: dict):
62 | for k, v in a.items():
63 | if isinstance(v, dict):
64 | a[k] = transform2list(v)
65 | else:
66 | a[k] = [
67 | v,
68 | ]
69 | return a
70 |
71 | def merge(a: dict, b: dict, path=[], strict=False):
72 | c = {}
73 | keys = set(a.keys()).union(b.keys())
74 | for key in keys:
75 | if key in a and key in b:
76 | if isinstance(a[key], dict) and isinstance(b[key], dict):
77 | c[key] = merge(a[key], b[key], path + [str(key)], strict=strict)
78 | else:
79 | c[key] = a[key] + b[key]
80 | else:
81 | if strict:
82 | raise Exception("Conflict at " + ".".join(path + [str(key)]))
83 | c[key] = a[key] if key in a else b[key]
84 | return c
85 |
86 | tasks = (transform2list(_) for _ in batch)
87 | return reduce(merge, tasks)
--------------------------------------------------------------------------------
/vila_u/entry.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import typing
4 |
5 | from typing import List, Optional
6 |
7 | if typing.TYPE_CHECKING:
8 | from transformers import PreTrainedModel
9 | else:
10 | PreTrainedModel = None
11 |
12 | from vila_u.conversation import auto_set_conversation_mode
13 | from vila_u.model.builder import load_pretrained_model
14 |
15 | __all__ = ["load"]
16 |
17 |
18 | def load(
19 | model_path: str,
20 | devices: Optional[List[int]] = None,
21 | **kwargs,
22 | ) -> PreTrainedModel:
23 | auto_set_conversation_mode(model_path)
24 |
25 | model_path = os.path.expanduser(model_path)
26 | if os.path.exists(os.path.join(model_path, "model")):
27 | model_path = os.path.join(model_path, "model")
28 |
29 | # Set `max_memory` to constrain which GPUs to use
30 | if devices is not None:
31 | assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set"
32 | kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices})
33 |
34 | model = load_pretrained_model(model_path, **kwargs)[1]
35 |
36 | return model
--------------------------------------------------------------------------------
/vila_u/eval/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from vila_u.utils import io
4 |
5 | __all__ = ["EVAL_ROOT", "TASKS"]
6 |
7 |
8 | EVAL_ROOT = "scripts/eval"
9 | TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry.yaml"))
10 |
--------------------------------------------------------------------------------
/vila_u/eval/lmms/models/__init__.py:
--------------------------------------------------------------------------------
1 | AVAILABLE_MODELS = {
2 | "vila_u": "VILA_U",
3 | }
4 |
--------------------------------------------------------------------------------
/vila_u/eval/lmms/models/vila_u.py:
--------------------------------------------------------------------------------
1 | import accelerate
2 | import os
3 | import requests
4 | import torch
5 |
6 | from lmms_eval.api.instance import Instance
7 | from lmms_eval.api.model import lmms
8 | from lmms_eval.api.registry import register_model
9 | from tqdm import tqdm
10 | from typing import List, Tuple
11 |
12 | import vila_u
13 | from vila_u import conversation as conversation_lib
14 | from vila_u.media import Video
15 | from vila_u.utils import distributed as dist
16 | from vila_u.utils import io
17 |
18 |
19 | @register_model("vila_u")
20 | class VILA_U(lmms):
21 | def __init__(
22 | self, model_path: str, conv_mode: str, num_video_frames: int = 8, batch_size: int = 1
23 | ) -> None:
24 | super().__init__()
25 | assert batch_size == 1, "VILA-U only supports batch size of 1 at the moment."
26 | self._update_gpt_eval_model()
27 |
28 | devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size())
29 | torch.cuda.set_device(devices[0])
30 |
31 | self.model = vila_u.load(model_path, devices=devices)
32 | self.model.config.num_video_frames = num_video_frames
33 | context_length = num_video_frames * 512
34 | self.model.config.model_max_length = context_length
35 | self.model.config.tokenizer_model_max_length = context_length
36 | self.model.llm.config.model_max_length = context_length
37 | self.model.llm.config.tokenizer_model_max_length = context_length
38 | self.model.tokenizer.model_max_length = context_length
39 |
40 | conversation_lib.default_conversation = conversation_lib.conv_templates[conv_mode].copy()
41 |
42 | self.accelerator = accelerate.Accelerator()
43 | self.device = torch.device(f"cuda:{devices[0]}")
44 | self._world_size = dist.size()
45 | self._rank = dist.rank()
46 |
47 | def _update_gpt_eval_model(self) -> None:
48 | _unpatched_post = requests.post
49 |
50 | def _patched_post(url, json, **kwargs):
51 | if json is not None and "model" in json:
52 | if json["model"] == "gpt-3.5-turbo-0613":
53 | json["model"] = "gpt-4o-mini"
54 | return _unpatched_post(url, json=json, **kwargs)
55 |
56 | requests.post = _patched_post
57 |
58 | def generate_until(self, requests: List[Instance]) -> List[str]:
59 | responses = []
60 | for request in tqdm(requests, disable=not dist.is_main()):
61 | prompt, generation_kwargs, doc_to_visual, doc_id, task, split = self._patch(request.args)
62 | doc = self.task_dict[task][split][doc_id]
63 |
64 | # Generate multimodal prompt
65 | medias = []
66 | for media in doc_to_visual(doc):
67 | if isinstance(media, str):
68 | if any(media.endswith(ext) for ext in [".mp4", ".mkv", ".webm"]):
69 | media = Video(media)
70 | else:
71 | raise NotImplementedError(f"Unsupported media type: {media}")
72 | medias.append(media)
73 | prompt = medias + [prompt]
74 |
75 | # Override generation config
76 | generation_config = self.model.default_generation_config
77 | generation_config.update(**generation_kwargs)
78 |
79 | # Generate and cache response
80 | cache_path = None
81 | if "CACHE_DIR" in os.environ:
82 | cache_path = os.path.join(os.environ["CACHE_DIR"], f"{task}_{split}_{doc_id}.txt")
83 |
84 | if cache_path is not None and os.path.exists(cache_path):
85 | response = io.load(cache_path)
86 | else:
87 | response = self.model.generate_content(prompt, generation_config=generation_config)
88 | if cache_path is not None:
89 | io.save(cache_path, response)
90 | responses.append(response)
91 |
92 | print("Prompt:", prompt)
93 | print("Response:", response)
94 | return responses
95 |
96 | def _patch(self, args: Tuple) -> Tuple:
97 | prompt, generation_kwargs, doc_to_visual, doc_id, task, split = args
98 | doc = self.task_dict[task][split][doc_id]
99 |
100 | return prompt, generation_kwargs, doc_to_visual, doc_id, task, split
101 |
102 | def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
103 | raise NotImplementedError
104 |
105 | def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
106 | raise NotImplementedError
107 |
--------------------------------------------------------------------------------
/vila_u/eval/lmms/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/vila_u/eval/lmms/tasks/__init__.py
--------------------------------------------------------------------------------
/vila_u/eval/registry.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | lmms-gqa:
3 | tags:
4 | - core
5 | - local
6 | - regression
7 | metrics:
8 | accuracy: results/gqa/exact_match,none
9 |
10 |
11 | lmms-mme:
12 | tags:
13 | - core
14 | - local
15 | - regression
16 | metrics:
17 | cognition: results/mme/mme_cognition_score,none
18 | perception: results/mme/mme_percetion_score,none
19 |
20 |
21 | lmms-mmvet:
22 | tags:
23 | - core
24 | - openai
25 |
26 |
27 | lmms-pope:
28 | tags:
29 | - core
30 | - local
31 | - regression
32 | metrics:
33 | accuracy: results/pope/pope_accuracy,none
34 | precision: results/pope/pope_precision,none
35 | recall: results/pope/pope_recall,none
36 | f1: results/pope/pope_f1_score,none
37 |
38 |
39 | lmms-seedbench:
40 | tags:
41 | - core
42 | - local
43 | - regression
44 | metrics:
45 | all: results/seedbench/seed_all,none
46 | image: results/seedbench/seed_image,none
47 | video: results/seedbench/seed_video,none
48 |
49 |
50 | lmms-textvqa_test:
51 | tags:
52 | - submission
53 |
54 |
55 | lmms-vqav2_test:
56 | tags:
57 | - core
58 | - submission
59 |
60 |
61 | lmms-activitynetqa:
62 | tags:
63 | - openai
64 | metrics:
65 | accuracy: results/activitynetqa/gpt_eval_accuracy,none
66 | score: results/activitynetqa/gpt_eval_score,none
67 |
--------------------------------------------------------------------------------
/vila_u/media.py:
--------------------------------------------------------------------------------
1 | __all__ = ["Media", "File", "Image", "Video"]
2 |
3 |
4 | class Media:
5 | pass
6 |
7 |
8 | class File(Media):
9 | def __init__(self, path: str) -> None:
10 | self.path = path
11 |
12 |
13 | class Image(File):
14 | pass
15 |
16 |
17 | class Video(File):
18 | pass
--------------------------------------------------------------------------------
/vila_u/mm_utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import cv2
3 | import numpy as np
4 | import os
5 | import re
6 | import torch
7 | import tempfile
8 |
9 | from io import BytesIO
10 | from PIL import Image
11 | from torchvision.transforms import CenterCrop
12 | from transformers import StoppingCriteria
13 |
14 | from vila_u.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
15 |
16 |
17 | def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None):
18 | if fps == None or frame_count == None:
19 | fps = vidcap.get(cv2.CAP_PROP_FPS)
20 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
21 |
22 | if fps == 0 or frame_count == 0:
23 | print("Video file not found. return empty images.")
24 | return [
25 | Image.new("RGB", (720, 720)),
26 | ] * num_frames
27 |
28 | duration = frame_count / fps
29 | frame_interval = frame_count // num_frames
30 |
31 | if frame_interval == 0 and frame_count <= 1:
32 | print("frame_interval is equal to 0. return empty image.")
33 | return [
34 | Image.new("RGB", (720, 720)),
35 | ] * num_frames
36 |
37 | images = []
38 | count = 0
39 | success = True
40 | frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int)
41 |
42 | while success:
43 | if frame_count >= num_frames:
44 | success, frame = vidcap.read()
45 | if success:
46 | if count in frame_indices:
47 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
48 | im_pil = Image.fromarray(img)
49 | images.append(im_pil)
50 | if len(images) >= num_frames:
51 | return images
52 | count += 1
53 | else:
54 | break
55 | else:
56 | success, frame = vidcap.read()
57 | if success:
58 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
59 | im_pil = Image.fromarray(img)
60 | images.append(im_pil)
61 | count += 1
62 | elif count >= 1:
63 | width, height = images[-1].size
64 | images = [Image.new("RGB", (width, height))] * (num_frames - len(images)) + images
65 | print("padding frames:", (num_frames - len(images)))
66 | return images
67 | else:
68 | break
69 |
70 | print("fail")
71 | images = [Image.new("RGB", (720, 720))] * num_frames
72 |
73 | return images
74 |
75 |
76 | def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None):
77 | """
78 | Extract frames from a video using OpenCV.
79 |
80 | Args:
81 | vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
82 | frames (int): Number of frames to extract from the video.
83 |
84 | Returns:
85 | list: List of PIL Images extracted from the video.
86 |
87 | Raises:
88 | NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
89 | """
90 |
91 | if isinstance(vpath_or_bytesio, str):
92 | vidcap = cv2.VideoCapture(vpath_or_bytesio)
93 | return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
94 | elif isinstance(vpath_or_bytesio, (BytesIO,)):
95 | with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
96 | temp_video.write(vpath_or_bytesio.read())
97 | temp_video_name = temp_video.name
98 | vidcap = cv2.VideoCapture(temp_video_name)
99 | return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
100 | else:
101 | raise NotImplementedError(type(vpath_or_bytesio))
102 |
103 |
104 | def load_image_from_base64(image):
105 | return Image.open(BytesIO(base64.b64decode(image)))
106 |
107 |
108 | def expand2square(pil_img, background_color):
109 | """
110 | Expand the given PIL image to a square shape by adding padding.
111 |
112 | Parameters:
113 | - pil_img: The PIL image to be expanded.
114 | - background_color: The color of the padding to be added.
115 |
116 | Returns:
117 | - The expanded PIL image.
118 |
119 | If the image is already square, it is returned as is.
120 | If the image is wider than it is tall, padding is added to the top and bottom.
121 | If the image is taller than it is wide, padding is added to the left and right.
122 | """
123 | width, height = pil_img.size
124 | if pil_img.mode == 'L':
125 | background_color = background_color[0]
126 | if width == height:
127 | return pil_img
128 | elif width > height:
129 | result = Image.new(pil_img.mode, (width, width), background_color)
130 | result.paste(pil_img, (0, (width - height) // 2))
131 | return result
132 | else:
133 | result = Image.new(pil_img.mode, (height, height), background_color)
134 | result.paste(pil_img, ((height - width) // 2, 0))
135 | return result
136 |
137 |
138 | def process_image(image_file, data_args, image_folder, generation_mode=False):
139 | processor = data_args.image_processor
140 | if isinstance(image_file, str):
141 | if image_folder is not None:
142 | image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
143 | else:
144 | image = Image.open(image_file).convert("RGB")
145 | elif isinstance(image_file, BytesIO):
146 | image = Image.open(image_file).convert("RGB")
147 | else:
148 | image = image_file
149 |
150 | if generation_mode:
151 | if image.size[0] < image.size[1]:
152 | image = image.crop((0, 0, min(image.size), min(image.size)))
153 | else:
154 | ccrop = CenterCrop(min(image.size))
155 | image = ccrop(image)
156 | elif data_args.image_aspect_ratio == "resize":
157 | if hasattr(data_args.image_processor, "crop_size"):
158 | crop_size = data_args.image_processor.crop_size
159 | else:
160 | assert hasattr(data_args.image_processor, "size")
161 | crop_size = data_args.image_processor.size
162 | image = image.resize((crop_size["height"], crop_size["width"]))
163 | elif data_args.image_aspect_ratio == "pad":
164 | image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
165 | else:
166 | raise NotImplementedError()
167 |
168 | image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
169 |
170 | return image
171 |
172 |
173 | def process_images(images, image_processor, model_cfg):
174 |
175 | model_cfg.image_processor = image_processor
176 | new_images = [process_image(image, model_cfg, None) for image in images]
177 |
178 | if all(x.shape == new_images[0].shape for x in new_images):
179 | new_images = torch.stack(new_images, dim=0)
180 | return new_images
181 |
182 |
183 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
184 | prompt_chunks = re.split(f"({DEFAULT_IMAGE_TOKEN})", prompt)
185 | input_ids = [tokenizer.bos_token_id]
186 | for chunk in prompt_chunks:
187 | if chunk == DEFAULT_IMAGE_TOKEN:
188 | input_ids.append(image_token_index)
189 | else:
190 | input_ids.extend(tokenizer(chunk).input_ids[1:])
191 |
192 | if return_tensors is not None:
193 | if return_tensors == "pt":
194 | return torch.tensor(input_ids, dtype=torch.long)
195 | raise ValueError(f"Unsupported tensor type: {return_tensors}")
196 |
197 | return input_ids
198 |
199 |
200 | def get_model_name_from_path(model_path):
201 | model_path = model_path.strip("/")
202 | model_paths = model_path.split("/")
203 | if model_paths[-1].startswith("checkpoint-"):
204 | return model_paths[-2] + "_" + model_paths[-1]
205 | else:
206 | return model_paths[-1]
207 |
208 |
209 | class KeywordsStoppingCriteria(StoppingCriteria):
210 | def __init__(self, keywords, tokenizer, input_ids):
211 | self.keywords = keywords
212 | self.keyword_ids = []
213 | self.max_keyword_len = 0
214 | for keyword in keywords:
215 | cur_keyword_ids = tokenizer(keyword).input_ids
216 | if (
217 | len(cur_keyword_ids) > 1
218 | and cur_keyword_ids[0] == tokenizer.bos_token_id
219 | ):
220 | cur_keyword_ids = cur_keyword_ids[1:]
221 | if len(cur_keyword_ids) > self.max_keyword_len:
222 | self.max_keyword_len = len(cur_keyword_ids)
223 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
224 | self.tokenizer = tokenizer
225 | self.start_len = input_ids.shape[1]
226 |
227 | def call_for_batch(
228 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
229 | ) -> bool:
230 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
231 | self.keyword_ids = [
232 | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
233 | ]
234 | for keyword_id in self.keyword_ids:
235 | if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
236 | return True
237 | outputs = self.tokenizer.batch_decode(
238 | output_ids[:, -offset:], skip_special_tokens=True
239 | )[0]
240 | for keyword in self.keywords:
241 | if keyword in outputs:
242 | return True
243 | return False
244 |
245 | def __call__(
246 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
247 | ) -> bool:
248 | outputs = []
249 | for i in range(output_ids.shape[0]):
250 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
251 | return all(outputs)
252 |
--------------------------------------------------------------------------------
/vila_u/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.vila_u_llama import VILAULlamaModel, VILAULlamaConfig
--------------------------------------------------------------------------------
/vila_u/model/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoConfig
3 |
4 | from vila_u.model import VILAULlamaModel
5 | from vila_u.constants import (
6 | DEFAULT_IMAGE_PATCH_TOKEN,
7 | DEFAULT_IM_START_TOKEN,
8 | DEFAULT_IM_END_TOKEN,
9 | DEFAULT_VI_START_TOKEN,
10 | DEFAULT_VI_END_TOKEN,
11 | )
12 |
13 |
14 | def load_pretrained_model(
15 | model_path,
16 | model_dtype=torch.bfloat16,
17 | device_map="auto",
18 | device="cuda",
19 | **kwargs,
20 | ):
21 | kwargs = {"device_map": device_map, **kwargs}
22 |
23 | if device != "cuda":
24 | kwargs["device_map"] = {"": device}
25 |
26 | config = AutoConfig.from_pretrained(model_path)
27 | config.resume_path = model_path
28 | config.model_dtype = model_dtype.__str__()
29 |
30 | model = VILAULlamaModel(
31 | config=config,
32 | low_cpu_mem_usage=True,
33 | **kwargs
34 | )
35 | tokenizer = model.tokenizer
36 |
37 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
38 | mm_use_vi_start_end = getattr(model.config, "mm_use_vi_start_end", False)
39 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
40 | if mm_use_im_patch_token:
41 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
42 | if mm_use_im_start_end and mm_use_vi_start_end:
43 | tokenizer.add_tokens(
44 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VI_START_TOKEN, DEFAULT_VI_END_TOKEN], special_tokens=True
45 | )
46 | elif mm_use_im_start_end:
47 | tokenizer.add_tokens(
48 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
49 | )
50 | model.resize_token_embeddings(len(tokenizer))
51 | model.eval()
52 |
53 | vision_tower = model.get_vision_tower()
54 | vision_tower.to(device=device, dtype=model_dtype)
55 |
56 | mm_projector = model.get_mm_projector()
57 | mm_projector.to(device=device, dtype=model_dtype)
58 |
59 | image_processor = vision_tower.image_processor
60 |
61 | if hasattr(model.llm.config, "max_sequence_length"):
62 | context_len = model.config.max_sequence_length
63 | else:
64 | context_len = 2048
65 |
66 | return tokenizer, model, image_processor, context_len
--------------------------------------------------------------------------------
/vila_u/model/configuration_vila_u.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 |
4 | class VILAUConfig(PretrainedConfig):
5 | model_type = "vila_u"
6 |
7 | def __init__(
8 | self,
9 | llm_cfg=None,
10 | vision_tower_cfg=None,
11 | mm_projector_cfg=None,
12 | architectures=None,
13 | resume_path=None,
14 | hidden_size=None,
15 | mm_hidden_size=None,
16 | image_aspect_ratio=None,
17 | num_video_frames=None,
18 | mm_use_im_start_end=False,
19 | mm_use_vi_start_end=False,
20 | mm_use_im_patch_token=True,
21 | **kwargs
22 | ):
23 | super().__init__()
24 |
25 | self.llm_cfg = llm_cfg
26 | self.vision_tower_cfg = vision_tower_cfg
27 | self.mm_projector_cfg = mm_projector_cfg
28 | self.architectures = architectures
29 | self.resume_path = resume_path
30 | self.hidden_size = hidden_size
31 | self.mm_hidden_size = mm_hidden_size
32 | self.image_aspect_ratio = image_aspect_ratio
33 | self.num_video_frames = num_video_frames
34 | self.mm_use_im_start_end = mm_use_im_start_end
35 | self.mm_use_vi_start_end = mm_use_vi_start_end
36 | self.mm_use_im_patch_token = mm_use_im_patch_token
--------------------------------------------------------------------------------
/vila_u/model/language_model/builder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from transformers import (
5 | AutoTokenizer,
6 | AutoModelForCausalLM,
7 | AutoConfig,
8 | PretrainedConfig,
9 | PreTrainedModel,
10 | )
11 |
12 |
13 | def context_length_extension(config):
14 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
15 | model_max_length = getattr(config, "model_max_length", None)
16 | if orig_ctx_len and model_max_length > orig_ctx_len:
17 | print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
18 | scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
19 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
20 | return config
21 |
22 |
23 | def build_llm_and_tokenizer(
24 | model_name_or_path: str,
25 | config: PretrainedConfig,
26 | attn_implementation=None,
27 | model_max_length=None,
28 | *args,
29 | **kwargs,
30 | ) -> PreTrainedModel:
31 | llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
32 | llm_cfg._attn_implementation = attn_implementation
33 | llm_cfg.model_max_length = model_max_length
34 | if model_max_length is not None:
35 | context_length_extension(llm_cfg)
36 |
37 | llm = AutoModelForCausalLM.from_pretrained(
38 | model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
39 | )
40 |
41 | tokenizer = AutoTokenizer.from_pretrained(
42 | model_name_or_path,
43 | model_max_length=llm_cfg.model_max_length,
44 | padding_side="right",
45 | use_fast=False,
46 | legacy=False,
47 | )
48 |
49 | config.hidden_size = llm.config.hidden_size
50 | return llm, tokenizer
--------------------------------------------------------------------------------
/vila_u/model/language_model/vila_u_llama.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from torch.nn import CrossEntropyLoss
5 | from typing import List, Optional, Tuple, Union
6 | from transformers import (
7 | AutoConfig,
8 | AutoModel,
9 | PreTrainedModel,
10 | PretrainedConfig,
11 | )
12 | from transformers.modeling_outputs import CausalLMOutputWithPast
13 |
14 | from ..configuration_vila_u import VILAUConfig
15 | from ..vila_u_arch import VILAUMetaModel, VILAUMetaForCausalLM
16 |
17 |
18 | class VILAULlamaConfig(VILAUConfig):
19 | model_type = "vila_u_llama"
20 |
21 |
22 | class VILAULlamaModel(VILAUMetaModel, VILAUMetaForCausalLM, PreTrainedModel):
23 | config_class = VILAULlamaConfig
24 | main_input_name = "input_embeds"
25 | supports_gradient_checkpointing = True
26 |
27 | def __init__(self, config: VILAULlamaConfig = None, *args, **kwargs) -> None:
28 | super().__init__(config)
29 |
30 | return self.init_vlm(config=config, *args, **kwargs)
31 |
32 | @classmethod
33 | def from_pretrained(
34 | cls,
35 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
36 | *model_args,
37 | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
38 | cache_dir: Optional[Union[str, os.PathLike]] = None,
39 | ignore_mismatched_sizes: bool = False,
40 | force_download: bool = False,
41 | local_files_only: bool = False,
42 | token: Optional[Union[str, bool]] = None,
43 | revision: str = "main",
44 | use_safetensors: bool = None,
45 | **kwargs,
46 | ):
47 | if hasattr(cls, "load_pretrained"):
48 | return cls.load_pretrained(
49 | pretrained_model_name_or_path,
50 | *model_args,
51 | config=config,
52 | cache_dir=cache_dir,
53 | ignore_mismatched_sizes=ignore_mismatched_sizes,
54 | force_download=force_download,
55 | local_files_only=local_files_only,
56 | token=token,
57 | revision=revision,
58 | use_safetensors=use_safetensors,
59 | **kwargs,
60 | )
61 |
62 | return super(VILAULlamaModel).from_pretrained(
63 | pretrained_model_name_or_path,
64 | *model_args,
65 | config=config,
66 | cache_dir=cache_dir,
67 | ignore_mismatched_sizes=ignore_mismatched_sizes,
68 | force_download=force_download,
69 | local_files_only=local_files_only,
70 | token=token,
71 | revision=revision,
72 | use_safetensors=use_safetensors,
73 | **kwargs,
74 | )
75 |
76 | def forward(
77 | self,
78 | input_ids: torch.LongTensor = None,
79 | images: Optional[torch.FloatTensor] = None,
80 | attention_mask: Optional[torch.Tensor] = None,
81 | position_ids: Optional[torch.LongTensor] = None,
82 | past_key_values: Optional[List[torch.FloatTensor]] = None,
83 | inputs_embeds: Optional[torch.FloatTensor] = None,
84 | labels: Optional[torch.LongTensor] = None,
85 | use_cache: Optional[bool] = None,
86 | output_attentions: Optional[bool] = None,
87 | output_hidden_states: Optional[bool] = None,
88 | return_dict: Optional[bool] = None,
89 | ) -> Union[Tuple, CausalLMOutputWithPast]:
90 | if inputs_embeds is None:
91 | (
92 | input_ids,
93 | position_ids,
94 | attention_mask,
95 | past_key_values,
96 | inputs_embeds,
97 | labels,
98 | ) = self.prepare_inputs_labels_for_multimodal(
99 | input_ids,
100 | position_ids,
101 | attention_mask,
102 | past_key_values,
103 | labels,
104 | images,
105 | )
106 |
107 | if self.training:
108 | (
109 | _,
110 | new_position_ids,
111 | new_attention_mask,
112 | _,
113 | new_inputs_embeds,
114 | new_labels,
115 | sorted_seqlens_in_batch,
116 | ) = self.repack_multimodal_data(
117 | input_ids,
118 | position_ids,
119 | attention_mask,
120 | past_key_values,
121 | inputs_embeds,
122 | labels,
123 | )
124 | new_input_ids = None
125 | past_key_values = None
126 | else:
127 | new_attention_mask = attention_mask
128 | new_position_ids = position_ids
129 | new_inputs_embeds = inputs_embeds
130 | new_labels = labels
131 | sorted_seqlens_in_batch = attention_mask.sum(-1).int()
132 | new_input_ids = input_ids
133 |
134 | output_attentions = output_attentions if output_attentions is not None else self.llm.config.output_attentions
135 | output_hidden_states = (
136 | output_hidden_states if output_hidden_states is not None else self.llm.config.output_hidden_states
137 | )
138 | return_dict = return_dict if return_dict is not None else self.llm.config.use_return_dict
139 |
140 | outputs = self.llm.model(
141 | input_ids=new_input_ids,
142 | attention_mask=new_attention_mask,
143 | position_ids=new_position_ids,
144 | past_key_values=past_key_values,
145 | inputs_embeds=new_inputs_embeds,
146 | use_cache=use_cache,
147 | output_attentions=output_attentions,
148 | output_hidden_states=output_hidden_states,
149 | return_dict=return_dict,
150 | seqlens_in_batch=sorted_seqlens_in_batch,
151 | )
152 |
153 | hidden_states = outputs[0]
154 |
155 | image_hidden_states = []
156 | image_labels = []
157 | noimage_labels = []
158 |
159 | for i in range(hidden_states.shape[0]):
160 | label = new_labels[i]
161 | hidden_state = hidden_states[i]
162 | label_zero = label[:, 0].clone()
163 |
164 | if self.config.mm_use_vi_start_end:
165 | image_start_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 4)).squeeze(1)
166 | image_end_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 3)).squeeze(1)
167 | video_start_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 2)).squeeze(1)
168 | video_end_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 1)).squeeze(1)
169 | image_start_index = torch.cat([image_start_index, video_start_index])
170 | image_end_index = torch.cat([image_end_index, video_end_index])
171 | else:
172 | image_start_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 2)).squeeze(1)
173 | image_end_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 1)).squeeze(1)
174 |
175 | assert len(image_start_index) == len(image_end_index), f"length of image_start_index is {len(image_start_index)}, length of image_end_index is {len(image_end_index)}"
176 |
177 | if len(image_start_index) > 0:
178 | for start_idx, end_idx in zip(image_start_index, image_end_index):
179 | image_label = label[start_idx+1:end_idx, :]
180 | image_labels.append(image_label)
181 | image_hidden_state = hidden_state[start_idx:end_idx-1, :]
182 | image_hidden_states.append(image_hidden_state)
183 | label_zero[start_idx+1:end_idx] = -100
184 |
185 | noimage_labels.append(label_zero)
186 |
187 | # For video
188 | image_hidden_states_aux = []
189 | image_labels_aux = []
190 | image_hidden_states_length = [img.shape[0] for img in image_hidden_states]
191 | image_hidden_states_length_relative = [img // min(image_hidden_states_length) for img in image_hidden_states_length]
192 | for l in range(len(image_hidden_states_length_relative)):
193 | if image_hidden_states_length_relative[l] > 1:
194 | image_hidden_states_aux += torch.split(image_hidden_states[l], min(image_hidden_states_length), dim=0)
195 | image_labels_aux += torch.split(image_labels[l], min(image_hidden_states_length), dim=0)
196 | else:
197 | image_hidden_states_aux.append(image_hidden_states[l])
198 | image_labels_aux.append(image_labels[l])
199 |
200 | if len(image_hidden_states_aux) > 0:
201 | image_hidden_states = torch.stack(image_hidden_states_aux, 0)
202 | image_labels = torch.stack(image_labels_aux, 0)
203 |
204 | noimage_labels = torch.stack(noimage_labels, 0)
205 |
206 | logits = self.llm.lm_head(hidden_states)
207 |
208 | loss_fct = CrossEntropyLoss()
209 |
210 | image_loss = None
211 | if torch.is_tensor(image_hidden_states):
212 | if hasattr(self.vision_tower.vision_tower, "rqvaesiglip"):
213 | outs = self.vision_tower.vision_tower.rqtransformer(image_hidden_states, image_labels - self.llm.vocab_size, self.vision_tower.vision_tower.rqvaesiglip)
214 | else:
215 | raise NotImplementedError()
216 | B, seq_len, D, C = outs.shape
217 | image_logits = outs.reshape(B*seq_len*D, C).contiguous()
218 | image_labels = image_labels.reshape(B*seq_len*D).contiguous() - self.llm.vocab_size
219 | image_loss = loss_fct(image_logits, image_labels)
220 |
221 | loss = None
222 | shift_logits = logits[..., :-1, :].contiguous()
223 | shift_labels = noimage_labels[..., 1:].contiguous()
224 | shift_logits = shift_logits.view(-1, self.llm.config.vocab_size)
225 | shift_labels = shift_labels.view(-1)
226 | shift_labels = shift_labels.to(shift_logits.device)
227 | loss = loss_fct(shift_logits, shift_labels)
228 |
229 | if image_loss is not None:
230 | loss = loss + image_loss
231 |
232 | return CausalLMOutputWithPast(
233 | loss=loss,
234 | logits=logits,
235 | past_key_values=outputs.past_key_values,
236 | hidden_states=outputs.hidden_states,
237 | attentions=outputs.attentions,
238 | )
239 |
240 |
241 | AutoConfig.register("vila_u_llama", VILAULlamaConfig)
242 | AutoModel.register(VILAULlamaConfig, VILAULlamaModel)
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
4 |
5 | from .rqvaesigliptransformer_encoder import RQVAESIGLIPTransformerVisionTower
6 |
7 |
8 | def build_vision_tower(
9 | model_name_or_path: str, config: PretrainedConfig
10 | ) -> PreTrainedModel:
11 | if model_name_or_path is None:
12 | return None
13 |
14 | vision_tower_arch = None
15 | if config.resume_path:
16 | assert os.path.exists(
17 | model_name_or_path
18 | ), f"Resume vision tower path {model_name_or_path} does not exist!"
19 | vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
20 | vision_tower_arch = vision_tower_cfg.architectures[0].lower()
21 | vision_tower_name = (
22 | vision_tower_arch if vision_tower_arch is not None else model_name_or_path
23 | )
24 |
25 | vision_tower = RQVAESIGLIPTransformerVisionTower(model_name_or_path, config)
26 |
27 | config.mm_hidden_size = vision_tower.config.hidden_size
28 |
29 | return vision_tower
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_rqvaesigliptransformer import RQVAESIGLIPTransformerConfig
2 | from .modeling_rqvaesigliptransformer import RQVAESIGLIPTransformer
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/configuration_rqvaesigliptransformer.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 |
4 | class RQVAESIGLIPTransformerConfig(PretrainedConfig):
5 | model_type = "rqvaesigliptransformer_model"
6 | def __init__(
7 | self,
8 | rqvaesiglip=None,
9 | rqtransformer=None,
10 | hidden_size=None,
11 | architectures=None,
12 | **kwargs,
13 | ):
14 | super().__init__()
15 |
16 | self.rqvaesiglip = rqvaesiglip
17 | self.rqtransformer = rqtransformer
18 | self.hidden_size = hidden_size
19 | self.architectures = architectures
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/modeling_rqvaesigliptransformer.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedModel, AutoConfig, AutoModel
2 |
3 | from .configuration_rqvaesigliptransformer import RQVAESIGLIPTransformerConfig
4 | from .rqvaesiglip import RQVAESiglipModel
5 | from .rqtransformer import RQTransformer
6 |
7 |
8 | class RQVAESIGLIPTransformer(PreTrainedModel):
9 | config_class = RQVAESIGLIPTransformerConfig
10 | def __init__(self, config: RQVAESIGLIPTransformerConfig):
11 | super().__init__(config)
12 |
13 | rqvaesiglip_config = RQVAESiglipModel.config_class.from_dict(config.rqvaesiglip)
14 | rqtransformer_config = RQTransformer.config_class.from_dict(config.rqtransformer)
15 |
16 | self.rqvaesiglip = RQVAESiglipModel._from_config(rqvaesiglip_config)
17 | self.rqtransformer = RQTransformer._from_config(rqtransformer_config)
18 |
19 |
20 | AutoConfig.register("rqvaesigliptransformer_model", RQVAESIGLIPTransformerConfig)
21 | AutoModel.register(RQVAESIGLIPTransformerConfig, RQVAESIGLIPTransformer)
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_rqtransformer import RQTransformer
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqtransformer/attentions.py.
3 | """
4 |
5 | import math
6 | import torch
7 |
8 | from torch import nn
9 | from torch.nn import functional as F
10 | from typing import Iterable
11 |
12 | from .configuration_rqtransformer import AttentionBlockConfig, AttentionStackConfig
13 |
14 |
15 | class MultiSelfAttention(nn.Module):
16 | """
17 | Optimized by batched matmul operations
18 | """
19 |
20 | def __init__(self, config: AttentionBlockConfig, mask=True):
21 | super().__init__()
22 | assert config.embed_dim % config.n_head == 0
23 |
24 | self.key = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias)
25 | self.query = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias)
26 | self.value = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias)
27 |
28 | self.attn_drop = nn.Dropout(config.attn_pdrop, inplace=False)
29 | self.resid_drop = nn.Dropout(config.resid_pdrop, inplace=True)
30 |
31 | self.proj = nn.Linear(config.embed_dim, config.embed_dim, config.attn_bias)
32 |
33 | self.n_head = config.n_head
34 | self.mask = mask
35 |
36 | def forward(self, x, caching=False, past_kv=None):
37 | (B, T, C) = x.shape
38 |
39 | if not caching:
40 | assert past_kv is None
41 |
42 | x = x.transpose(0, 1).contiguous()
43 |
44 | k = self.key(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1)
45 | q = self.query(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1)
46 | v = self.value(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1)
47 |
48 | if past_kv is not None:
49 | past_key, past_value = past_kv
50 | k = torch.cat([past_key, k], dim=-2)
51 | v = torch.cat([past_value, v], dim=-2)
52 | T_past = past_key.shape[1]
53 | else:
54 | T_past = 0
55 |
56 | if caching:
57 | present = torch.stack([k, v])
58 | else:
59 | present = None
60 |
61 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
62 | if self.mask:
63 | mask = torch.tril(torch.ones(T_past+T, T_past+T, device=x.device, dtype=torch.bool))
64 | mask = mask.view(1, T_past+T, T_past+T)
65 | att = att.masked_fill(~mask[:, T_past:T_past+T, :T_past+T], float('-inf'))
66 | att = F.softmax(att, dim=-1)
67 | att = self.attn_drop(att)
68 |
69 | y = torch.bmm(att, v)
70 | y = y.transpose(0, 1).contiguous().view(T, B, C)
71 |
72 | y = self.resid_drop(self.proj(y))
73 |
74 | if caching:
75 | return y.transpose(0, 1).contiguous(), present
76 | else:
77 | return y.transpose(0, 1).contiguous()
78 |
79 |
80 | class AttentionBlock(nn.Module):
81 | """ an unassuming Transformer block """
82 |
83 | def __init__(self, config: AttentionBlockConfig):
84 | super().__init__()
85 |
86 | self.ln1 = nn.LayerNorm(config.embed_dim)
87 | self.ln2 = nn.LayerNorm(config.embed_dim)
88 |
89 | self.attn = MultiSelfAttention(config, mask=True)
90 | self.mlp = nn.Sequential(
91 | nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.mlp_bias),
92 | nn.GELU(),
93 | nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.mlp_bias),
94 | nn.Dropout(config.resid_pdrop, inplace=True),
95 | )
96 | self._cache = None
97 |
98 | def forward(self, x):
99 |
100 | attn = self.attn(self.ln1(x))
101 |
102 | x = x + attn
103 | x = x + self.mlp(self.ln2(x))
104 |
105 | return x
106 |
107 | def cached_forward(self, x_present):
108 |
109 | attn, present = self.attn(self.ln1(x_present), caching=True, past_kv=self._cache['past_kv'])
110 | self._cache['past_kv'] = present
111 |
112 | x_present = x_present + attn
113 | x_present = x_present + self.mlp(self.ln2(x_present))
114 |
115 | return x_present
116 |
117 | def init_cache(self):
118 | self._cache = {'past_kv': None}
119 |
120 |
121 | class AttentionStack(nn.Module):
122 |
123 | blocks: Iterable[AttentionBlock]
124 |
125 | def __init__(self, config: AttentionStackConfig):
126 | super().__init__()
127 |
128 | self.blocks = nn.ModuleList([AttentionBlock(config.block) for _ in range(config.n_layer)])
129 |
130 | def forward(self, x):
131 | for block in self.blocks:
132 | x = block(x)
133 |
134 | return x
135 |
136 | def cached_forward(self, x_present):
137 | for block in self.blocks:
138 | x_present = block.cached_forward(x_present)
139 |
140 | return x_present
141 |
142 | def init_cache(self):
143 | for block in self.blocks:
144 | block.init_cache()
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/configuration_rqtransformer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from transformers import PretrainedConfig
3 |
4 |
5 | @dataclass
6 | class AttentionBlockConfig:
7 | embed_dim: int = 2560
8 | n_head: int = 40
9 | mlp_bias: bool = True
10 | attn_bias: bool = True
11 | attn_pdrop: float = 0.0
12 | resid_pdrop: float = 0.1
13 |
14 |
15 | @dataclass
16 | class AttentionStackConfig:
17 | n_layer: int = 6
18 | block: AttentionBlockConfig = AttentionBlockConfig()
19 |
20 |
21 | class RQTransformerConfig(PretrainedConfig):
22 | model_type = "rqtransformer_model"
23 | def __init__(
24 | self,
25 | block_size=None,
26 | input_embed_dim_1=None,
27 | input_embed_dim_2=None,
28 | embed_dim=None,
29 | vocab_size=None,
30 | head=None,
31 | architectures=None,
32 | **kwargs,
33 | ):
34 | super().__init__()
35 |
36 | self.block_size = block_size
37 | self.input_embed_dim_1 = input_embed_dim_1
38 | self.input_embed_dim_2 = input_embed_dim_2
39 | self.embed_dim = embed_dim
40 | self.vocab_size = vocab_size
41 | self.head = head
42 | self.architectures = architectures
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/modeling_rqtransformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqtransformer/transformers.py.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from collections import OrderedDict
9 | from torch.nn import functional as F
10 | from transformers import PreTrainedModel, AutoConfig, AutoModel
11 |
12 | from .attention import AttentionStack
13 | from .configuration_rqtransformer import RQTransformerConfig, AttentionStackConfig, AttentionBlockConfig
14 |
15 |
16 | def top_k_logits(logits, k):
17 | v, ix = torch.topk(logits, k)
18 | out = logits.clone()
19 | out[out < v[:, [-1]]] = -float('Inf')
20 |
21 | return out
22 |
23 |
24 | def top_p_probs(probs, p):
25 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
26 | cum_probs = torch.cumsum(sorted_probs, dim=-1)
27 |
28 | sorted_idx_remove_cond = cum_probs >= p
29 |
30 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
31 | sorted_idx_remove_cond[..., 0] = 0
32 |
33 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
34 | probs = probs.masked_fill(indices_to_remove, 0.0)
35 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
36 |
37 | return norm_probs
38 |
39 |
40 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None):
41 | """Take a 2-dim tensor, apply softmax along each row, and sample from
42 | each multinomial distribution defined by the rows.
43 |
44 | Args:
45 | logits: 2-dim tensor of shape (n_samples, logit_dim)
46 | temperature (float): softmax temperature
47 | top_k (Optional[int]): if given, sample only using `top_k` logits
48 | top_p (Optional[float]): if given, sample only using `top_p` logits
49 |
50 | Returns:
51 | samples: 1-dim integer tensor of shape (n_samples,)
52 | """
53 |
54 | logits = logits.to(dtype=torch.float32)
55 | logits = logits / temperature
56 |
57 | if top_k is not None:
58 | logits = top_k_logits(logits, top_k)
59 |
60 | if torch.sum(torch.isnan(logits)):
61 | print('WARNING... NaN observed')
62 | logits[torch.isnan(logits)] = -float('Inf')
63 |
64 | probs = F.softmax(logits, dim=-1)
65 |
66 | if top_p is not None:
67 | probs = top_p_probs(probs, top_p)
68 |
69 | try:
70 | samples = torch.multinomial(probs, num_samples=1)
71 | except:
72 | raise RuntimeError
73 |
74 | return samples.view(-1)
75 |
76 |
77 | class RQTransformer(PreTrainedModel):
78 | config_class = RQTransformerConfig
79 | def __init__(self, config: RQTransformerConfig):
80 | super().__init__(config)
81 | self.in_mlp_1 = nn.Linear(config.input_embed_dim_1, config.embed_dim)
82 | self.in_mlp_2 = nn.Linear(config.input_embed_dim_2, config.embed_dim)
83 |
84 | blockconfig = AttentionBlockConfig(embed_dim=config.embed_dim, n_head=config.head["block"]["n_head"])
85 | stackconfig = AttentionStackConfig(n_layer=config.head["n_layer"], block=blockconfig)
86 | self.head_transformer = AttentionStack(stackconfig)
87 |
88 | self.pos_emb_d = nn.Parameter(torch.zeros(1, config.block_size[2], config.embed_dim))
89 | self.pos_emb_d.data.normal_(mean=0.0, std=0.02)
90 |
91 | self.classifier_mlp = nn.Sequential(OrderedDict([
92 | ('layer_norm', nn.LayerNorm(config.embed_dim)),
93 | ('linear', nn.Linear(config.embed_dim, config.vocab_size)),
94 | ]))
95 |
96 | def embed_with_model_aux(self, code, model_aux):
97 | xs_emb, _ = model_aux.get_code_emb_with_depth(code)
98 | return xs_emb
99 |
100 | def forward(self, embed_from_body, code, model_aux=None):
101 | B, seq_len, D = code.shape
102 |
103 | depth_ctx = self.embed_with_model_aux(code, model_aux)
104 | depth_ctx = torch.cumsum(depth_ctx, dim=-2)
105 | depth_ctx = self.in_mlp_1(depth_ctx)
106 |
107 | embed_from_body = self.in_mlp_2(embed_from_body)
108 |
109 | depth_ctx_full = torch.cat(
110 | [
111 | embed_from_body.view(B, seq_len, 1, -1),
112 | depth_ctx[:, :, :-1, :],
113 | ],
114 | dim=-2,
115 | )
116 |
117 | depth_ctx_full = depth_ctx_full.reshape(B * seq_len, D, -1)
118 | depth_ctx_full = depth_ctx_full + self.pos_emb_d[:, :D, :]
119 |
120 | head_outputs = self.head_transformer(depth_ctx_full)
121 | head_outputs = head_outputs.reshape(B, seq_len, D, -1)
122 | head_outputs = self.classifier_mlp(head_outputs)
123 |
124 | return head_outputs
125 |
126 | def generate(self, embed_from_body, model_aux=None, cfg=3.0):
127 | generate_idx = 1
128 | B, seq_len, _ = embed_from_body.shape
129 |
130 | embed_from_body = self.in_mlp_2(embed_from_body)
131 |
132 | depth_ctx_full = embed_from_body.view(B, seq_len, 1, -1)
133 | depth_ctx_full = depth_ctx_full.reshape(B * seq_len, generate_idx, -1)
134 | depth_ctx_full = depth_ctx_full + self.pos_emb_d[:, :generate_idx, :]
135 |
136 | head_outputs = self.head_transformer(depth_ctx_full)
137 | head_outputs = head_outputs.reshape(B, -1)
138 |
139 | logits = self.classifier_mlp(head_outputs)
140 |
141 | logits = logits[B//2:, :] + cfg * (logits[:B//2, :] - logits[B//2:, :])
142 | code = sample_from_logits(logits, temperature=1.0, top_p=0.96, top_k=900)
143 | code = code.reshape(B//2, seq_len, 1).repeat(2, 1, self.pos_emb_d.shape[1])
144 |
145 | for i in range(self.pos_emb_d.shape[1]-1):
146 | generate_idx += 1
147 | depth_ctx = self.embed_with_model_aux(code, model_aux)
148 | depth_ctx = torch.cumsum(depth_ctx, dim=-2)[:, :, :i+1, :]
149 | if len(depth_ctx.shape) == 3:
150 | depth_ctx = depth_ctx.unsqueeze(2)
151 | depth_ctx = self.in_mlp_1(depth_ctx)
152 |
153 | depth_ctx_full = torch.cat(
154 | [
155 | embed_from_body.view(B, seq_len, 1, -1),
156 | depth_ctx,
157 | ],
158 | dim=-2,
159 | )
160 |
161 | depth_ctx_full = depth_ctx_full.reshape(B * seq_len, generate_idx, -1)
162 | depth_ctx_full = depth_ctx_full + self.pos_emb_d[:, :generate_idx, :]
163 |
164 | head_outputs = self.head_transformer(depth_ctx_full)
165 | head_outputs = head_outputs[:, -1, :]
166 |
167 | logits = self.classifier_mlp(head_outputs)
168 |
169 | logits = logits[B//2:, :] + cfg * (logits[:B//2, :] - logits[B//2:, :])
170 | code_generate = sample_from_logits(logits, temperature=1.0, top_p=0.96, top_k=900)
171 | code_generate = code_generate.reshape(B//2, seq_len).repeat(2, 1)
172 | code[:, :, i+1] = code_generate
173 |
174 | out_features = self.embed_with_model_aux(code, model_aux)
175 | out_features = torch.cumsum(out_features, dim=-2)[:, :, -1, :]
176 |
177 | return out_features, code
178 |
179 |
180 | AutoConfig.register("rqtransformer_model", RQTransformerConfig)
181 | AutoModel.register(RQTransformerConfig, RQTransformer)
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_rqvaesiglip import RQVAESiglipModel
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/configuration_rqvaesiglip.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 |
4 | class RQVAESiglipConfig(PretrainedConfig):
5 | model_type = "rqvaesiglip_model"
6 | def __init__(
7 | self,
8 | embed_dim=None,
9 | n_embed=None,
10 | latent_shape=None,
11 | code_shape=None,
12 | shared_codebook=None,
13 | restart_unused_codes=None,
14 | ddconfig=None,
15 | decay=0.99,
16 | latent_loss_weight=0.25,
17 | architectures=None,
18 | decoder_latent_shape=None,
19 | pretrained_model="google/siglip-large-patch16-256",
20 | **kwargs,
21 | ):
22 | super().__init__()
23 |
24 | self.embed_dim = embed_dim
25 | self.n_embed = n_embed
26 | self.latent_shape = latent_shape
27 | self.code_shape = code_shape
28 | self.shared_codebook = shared_codebook
29 | self.restart_unused_codes = restart_unused_codes
30 | self.ddconfig = ddconfig
31 | self.decay = decay
32 | self.latent_loss_weight = latent_loss_weight
33 | self.architectures = architectures
34 | self.decoder_latent_shape = decoder_latent_shape
35 | self.pretrained_model = pretrained_model
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/modeling_rqvaesiglip.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from transformers import PreTrainedModel, AutoConfig, AutoModel
6 | from typing import Optional
7 |
8 | from .configuration_rqvaesiglip import RQVAESiglipConfig
9 | from .modules import Decoder
10 | from .quantizations import RQBottleneck
11 | from .siglip import SiglipModel
12 |
13 |
14 | class RQVAESiglipModel(PreTrainedModel):
15 | config_class = RQVAESiglipConfig
16 | def __init__(self, config: RQVAESiglipConfig):
17 | super().__init__(config)
18 |
19 | siglip_config = SiglipModel.config_class.from_pretrained(config.pretrained_model)
20 | self.siglip_model = SiglipModel._from_config(siglip_config)
21 |
22 | self.quantizer = RQBottleneck(
23 | latent_shape=config.latent_shape,
24 | code_shape=config.code_shape,
25 | n_embed=config.n_embed,
26 | decay=config.decay,
27 | shared_codebook=config.shared_codebook,
28 | restart_unused_codes=config.restart_unused_codes,
29 | )
30 | self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.ddconfig["z_channels"], 1)
31 |
32 | self.decoder = Decoder(**config.ddconfig)
33 |
34 | try:
35 | self.decoder_latent_shape = config.decoder_latent_shape
36 | except:
37 | self.decoder_latent_shape = None
38 |
39 | self.logit_scale = self.siglip_model.logit_scale
40 | self.logit_bias = self.siglip_model.logit_bias
41 |
42 | def encode_image(self, image):
43 | vision_model = self.siglip_model.vision_model
44 | hidden_states = vision_model.embeddings(image)
45 |
46 | attention_mask = None
47 | output_attentions = None
48 | for i, encoder_layer in enumerate(vision_model.encoder.layers):
49 | if vision_model.encoder.gradient_checkpointing and vision_model.encoder.training:
50 | layer_outputs = vision_model.encoder._gradient_checkpointing_func(
51 | encoder_layer.__call__,
52 | hidden_states,
53 | attention_mask,
54 | output_attentions,
55 | )
56 | else:
57 | layer_outputs = encoder_layer(
58 | hidden_states,
59 | attention_mask,
60 | output_attentions=output_attentions,
61 | )
62 | hidden_states = layer_outputs[0]
63 | if i == len(vision_model.encoder.layers) - 2:
64 | B, L, C = hidden_states.shape
65 | hidden_states = hidden_states.reshape(B, int(L**0.5), int(L**0.5), C)
66 | z_q, quant_loss, code = self.quantizer(hidden_states)
67 |
68 | return code, z_q
69 |
70 | def decode(self, z_q):
71 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
72 |
73 | if self.decoder_latent_shape is not None:
74 | z_q = F.interpolate(z_q.to(torch.float32), size=tuple(self.decoder_latent_shape), mode='bilinear').to(torch.bfloat16)
75 |
76 | z_q = self.post_quant_conv(z_q)
77 | out = self.decoder(z_q)
78 |
79 | return out
80 |
81 | @torch.no_grad()
82 | def get_code_emb_with_depth(self, code):
83 | return self.quantizer.embed_code_with_depth(code)
84 |
85 |
86 | AutoConfig.register("rqvaesiglip_model", RQVAESiglipConfig)
87 | AutoModel.register(RQVAESiglipConfig, RQVAESiglipModel)
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/modules.py.
3 | """
4 |
5 | import torch
6 |
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.utils.checkpoint import checkpoint
10 |
11 |
12 | def nonlinearity(x):
13 | return F.silu(x, inplace=True)
14 |
15 |
16 | def Normalize(in_channels):
17 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
18 |
19 |
20 | class Upsample(nn.Module):
21 | def __init__(self, in_channels, with_conv):
22 | super().__init__()
23 | self.with_conv = with_conv
24 | if self.with_conv:
25 | self.conv = torch.nn.Conv2d(in_channels,
26 | in_channels,
27 | kernel_size=3,
28 | stride=1,
29 | padding=1)
30 |
31 | def forward(self, x):
32 | x = torch.nn.functional.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(torch.bfloat16)
33 | if self.with_conv:
34 | x = self.conv(x)
35 | return x
36 |
37 |
38 | class ResnetBlock(nn.Module):
39 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
40 | dropout, temb_channels=512):
41 | super().__init__()
42 | self.in_channels = in_channels
43 | out_channels = in_channels if out_channels is None else out_channels
44 | self.out_channels = out_channels
45 | self.use_conv_shortcut = conv_shortcut
46 | self.checkpointing = False
47 |
48 | self.norm1 = Normalize(in_channels)
49 | self.conv1 = torch.nn.Conv2d(in_channels,
50 | out_channels,
51 | kernel_size=3,
52 | stride=1,
53 | padding=1)
54 | if temb_channels > 0:
55 | self.temb_proj = torch.nn.Linear(temb_channels,
56 | out_channels)
57 | self.norm2 = Normalize(out_channels)
58 | self.dropout = torch.nn.Dropout(dropout, inplace=True)
59 | self.conv2 = torch.nn.Conv2d(out_channels,
60 | out_channels,
61 | kernel_size=3,
62 | stride=1,
63 | padding=1)
64 | if self.in_channels != self.out_channels:
65 | if self.use_conv_shortcut:
66 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
67 | out_channels,
68 | kernel_size=3,
69 | stride=1,
70 | padding=1)
71 | else:
72 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
73 | out_channels,
74 | kernel_size=1,
75 | stride=1,
76 | padding=0)
77 |
78 | def _forward(self, x, temb):
79 | h = x
80 | h = self.norm1(h)
81 | h = nonlinearity(h)
82 | h = self.conv1(h)
83 |
84 | if temb is not None:
85 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
86 |
87 | h = self.norm2(h)
88 | h = nonlinearity(h)
89 | h = self.dropout(h)
90 | h = self.conv2(h)
91 |
92 | if self.in_channels != self.out_channels:
93 | if self.use_conv_shortcut:
94 | x = self.conv_shortcut(x)
95 | else:
96 | x = self.nin_shortcut(x)
97 |
98 | return x+h
99 |
100 | def forward(self, x, temb):
101 | if self.checkpointing and self.training:
102 | out = checkpoint(self._forward, x, temb)
103 | else:
104 | out = self._forward(x, temb)
105 | return out
106 |
107 |
108 | class AttnBlock(nn.Module):
109 | def __init__(self, in_channels):
110 | super().__init__()
111 | self.in_channels = in_channels
112 |
113 | self.norm = Normalize(in_channels)
114 | self.q = torch.nn.Conv2d(in_channels,
115 | in_channels,
116 | kernel_size=1,
117 | stride=1,
118 | padding=0)
119 | self.k = torch.nn.Conv2d(in_channels,
120 | in_channels,
121 | kernel_size=1,
122 | stride=1,
123 | padding=0)
124 | self.v = torch.nn.Conv2d(in_channels,
125 | in_channels,
126 | kernel_size=1,
127 | stride=1,
128 | padding=0)
129 | self.proj_out = torch.nn.Conv2d(in_channels,
130 | in_channels,
131 | kernel_size=1,
132 | stride=1,
133 | padding=0)
134 |
135 |
136 | def forward(self, x):
137 | h_ = x
138 | h_ = self.norm(h_)
139 | q = self.q(h_)
140 | k = self.k(h_)
141 | v = self.v(h_)
142 |
143 | b,c,h,w = q.shape
144 | q = q.reshape(b,c,h*w)
145 | q = q.permute(0,2,1)
146 | k = k.reshape(b,c,h*w)
147 | w_ = torch.bmm(q,k)
148 | w_ = w_ * (int(c)**(-0.5))
149 | w_ = torch.nn.functional.softmax(w_, dim=2)
150 |
151 | v = v.reshape(b,c,h*w)
152 | w_ = w_.permute(0,2,1)
153 | h_ = torch.bmm(v,w_)
154 | h_ = h_.reshape(b,c,h,w)
155 |
156 | h_ = self.proj_out(h_)
157 |
158 | return x+h_
159 |
160 |
161 | class Decoder(nn.Module):
162 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
163 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
164 | resolution, z_channels, give_pre_end=False, **ignorekwargs):
165 | super().__init__()
166 | self.ch = ch
167 | self.temb_ch = 0
168 | self.num_resolutions = len(ch_mult)
169 | self.num_res_blocks = num_res_blocks
170 | self.resolution = resolution
171 | self.in_channels = in_channels
172 | self.give_pre_end = give_pre_end
173 |
174 | in_ch_mult = (1,)+tuple(ch_mult)
175 | block_in = ch*ch_mult[self.num_resolutions-1]
176 | curr_res = resolution // 2**(self.num_resolutions-1)
177 | self.z_shape = (1, z_channels, curr_res, curr_res)
178 |
179 | self.conv_in = torch.nn.Conv2d(z_channels,
180 | block_in,
181 | kernel_size=3,
182 | stride=1,
183 | padding=1)
184 |
185 | self.mid = nn.Module()
186 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
187 | out_channels=block_in,
188 | temb_channels=self.temb_ch,
189 | dropout=dropout)
190 | self.mid.attn_1 = AttnBlock(block_in)
191 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
192 | out_channels=block_in,
193 | temb_channels=self.temb_ch,
194 | dropout=dropout)
195 |
196 | self.up = nn.ModuleList()
197 | for i_level in reversed(range(self.num_resolutions)):
198 | block = nn.ModuleList()
199 | attn = nn.ModuleList()
200 | block_out = ch*ch_mult[i_level]
201 | for i_block in range(self.num_res_blocks+1):
202 | block.append(ResnetBlock(in_channels=block_in,
203 | out_channels=block_out,
204 | temb_channels=self.temb_ch,
205 | dropout=dropout))
206 | block_in = block_out
207 | if curr_res in attn_resolutions:
208 | attn.append(AttnBlock(block_in))
209 | up = nn.Module()
210 | up.block = block
211 | up.attn = attn
212 | if i_level != 0:
213 | up.upsample = Upsample(block_in, resamp_with_conv)
214 | curr_res = curr_res * 2
215 | self.up.insert(0, up)
216 |
217 | self.norm_out = Normalize(block_in)
218 | self.conv_out = torch.nn.Conv2d(block_in,
219 | out_ch,
220 | kernel_size=3,
221 | stride=1,
222 | padding=1)
223 |
224 | def forward(self, z):
225 | self.last_z_shape = z.shape
226 |
227 | temb = None
228 |
229 | h = self.conv_in(z)
230 |
231 | h = self.mid.block_1(h, temb)
232 | h = self.mid.attn_1(h)
233 | h = self.mid.block_2(h, temb)
234 |
235 | for i_level in reversed(range(self.num_resolutions)):
236 | for i_block in range(self.num_res_blocks+1):
237 | h = self.up[i_level].block[i_block](h, temb)
238 | if len(self.up[i_level].attn) > 0:
239 | h = self.up[i_level].attn[i_block](h)
240 | if i_level != 0:
241 | h = self.up[i_level].upsample(h)
242 |
243 | if self.give_pre_end:
244 | return h
245 |
246 | h = self.norm_out(h)
247 | h = nonlinearity(h)
248 | h = self.conv_out(h)
249 | return h
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/quantizations.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py.
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributed as dist
8 |
9 | from torch import nn
10 | from torch.nn import functional as F
11 | from typing import Iterable
12 |
13 |
14 | class VQEmbedding(nn.Embedding):
15 | """VQ embedding module with ema update."""
16 |
17 | def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5):
18 | super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed)
19 |
20 | self.ema = ema
21 | self.decay = decay
22 | self.eps = eps
23 | self.restart_unused_codes = restart_unused_codes
24 | self.n_embed = n_embed
25 |
26 | if self.ema:
27 | _ = [p.requires_grad_(False) for p in self.parameters()]
28 |
29 | self.register_buffer('cluster_size_ema', torch.zeros(n_embed))
30 | self.register_buffer('embed_ema', self.weight[:-1, :].detach().clone())
31 |
32 | @torch.no_grad()
33 | def compute_distances(self, inputs):
34 | codebook_t = self.weight[:-1, :].t()
35 |
36 | (embed_dim, _) = codebook_t.shape
37 | inputs_shape = inputs.shape
38 | assert inputs_shape[-1] == embed_dim
39 |
40 | inputs_flat = inputs.reshape(-1, embed_dim)
41 |
42 | inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True)
43 | codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True)
44 | distances = torch.addmm(
45 | inputs_norm_sq + codebook_t_norm_sq,
46 | inputs_flat,
47 | codebook_t,
48 | alpha=-2.0,
49 | )
50 | distances = distances.reshape(*inputs_shape[:-1], -1)
51 | return distances
52 |
53 | @torch.no_grad()
54 | def find_nearest_embedding(self, inputs):
55 | distances = self.compute_distances(inputs)
56 | embed_idxs = distances.argmin(dim=-1)
57 |
58 | return embed_idxs
59 |
60 | @torch.no_grad()
61 | def _tile_with_noise(self, x, target_n):
62 | B, embed_dim = x.shape
63 | n_repeats = (target_n + B -1) // B
64 | std = x.new_ones(embed_dim) * 0.01 / np.sqrt(embed_dim)
65 | x = x.repeat(n_repeats, 1)
66 | x = x + torch.rand_like(x) * std
67 | return x
68 |
69 | @torch.no_grad()
70 | def _update_buffers(self, vectors, idxs):
71 |
72 | n_embed, embed_dim = self.weight.shape[0]-1, self.weight.shape[-1]
73 |
74 | vectors = vectors.reshape(-1, embed_dim)
75 | idxs = idxs.reshape(-1)
76 |
77 | n_vectors = vectors.shape[0]
78 | n_total_embed = n_embed
79 |
80 | one_hot_idxs = vectors.new_zeros(n_total_embed, n_vectors)
81 | one_hot_idxs.scatter_(dim=0,
82 | index=idxs.unsqueeze(0),
83 | src=vectors.new_ones(1, n_vectors)
84 | )
85 |
86 | cluster_size = one_hot_idxs.sum(dim=1)
87 | vectors_sum_per_cluster = one_hot_idxs @ vectors
88 |
89 | if dist.is_initialized():
90 | dist.all_reduce(vectors_sum_per_cluster, op=dist.ReduceOp.SUM)
91 | dist.all_reduce(cluster_size, op=dist.ReduceOp.SUM)
92 |
93 | self.cluster_size_ema.mul_(self.decay).add_(cluster_size, alpha=1 - self.decay)
94 | self.embed_ema.mul_(self.decay).add_(vectors_sum_per_cluster, alpha=1 - self.decay)
95 |
96 | if self.restart_unused_codes:
97 | if n_vectors < n_embed:
98 | vectors = self._tile_with_noise(vectors, n_embed)
99 | n_vectors = vectors.shape[0]
100 | _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed]
101 |
102 | if dist.is_initialized():
103 | dist.broadcast(_vectors_random, 0)
104 |
105 | usage = (self.cluster_size_ema.view(-1, 1) >= 1).float()
106 | self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage))
107 | self.cluster_size_ema.mul_(usage.view(-1))
108 | self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1))
109 |
110 | @torch.no_grad()
111 | def _update_embedding(self):
112 |
113 | n_embed = self.weight.shape[0] - 1
114 | n = self.cluster_size_ema.sum()
115 | normalized_cluster_size = (
116 | n * (self.cluster_size_ema + self.eps) / (n + n_embed * self.eps)
117 | )
118 | self.weight[:-1, :] = self.embed_ema / normalized_cluster_size.reshape(-1, 1)
119 |
120 | def forward(self, inputs):
121 | embed_idxs = self.find_nearest_embedding(inputs)
122 | if self.training:
123 | if self.ema:
124 | self._update_buffers(inputs, embed_idxs)
125 |
126 | embeds = self.embed(embed_idxs)
127 |
128 | if self.ema and self.training:
129 | self._update_embedding()
130 |
131 | return embeds, embed_idxs
132 |
133 | def embed(self, idxs):
134 | embeds = super().forward(idxs)
135 | return embeds
136 |
137 |
138 | class RQBottleneck(nn.Module):
139 | """
140 | Quantization bottleneck via Residual Quantization.
141 |
142 | Arguments:
143 | latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D)
144 | code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d)
145 | n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook)
146 | If isinstance(n_embed, int), the sizes of all codebooks are same.
147 | shared_codebook (bool): If True, codebooks are shared in all location. If False,
148 | uses separate codebooks along the ``depth'' dimension. (default: False)
149 | restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch
150 | as the new embedding of unused codes in training. (default: True)
151 | """
152 |
153 | def __init__(self,
154 | latent_shape,
155 | code_shape,
156 | n_embed,
157 | decay=0.99,
158 | shared_codebook=False,
159 | restart_unused_codes=True,
160 | commitment_loss='cumsum'
161 | ):
162 | super().__init__()
163 |
164 | if not len(code_shape) == len(latent_shape) == 3:
165 | raise ValueError("incompatible code shape or latent shape")
166 | if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]):
167 | raise ValueError("incompatible code shape or latent shape")
168 |
169 | embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2]
170 |
171 | self.latent_shape = torch.Size(latent_shape)
172 | self.code_shape = torch.Size(code_shape)
173 | self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))])
174 |
175 | self.shared_codebook = shared_codebook
176 | if self.shared_codebook:
177 | if isinstance(n_embed, Iterable) or isinstance(decay, Iterable):
178 | raise ValueError("Shared codebooks are incompatible \
179 | with list types of momentums or sizes: Change it into int")
180 |
181 | self.restart_unused_codes = restart_unused_codes
182 | self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])]
183 | self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])]
184 | assert len(self.n_embed) == self.code_shape[-1]
185 | assert len(self.decay) == self.code_shape[-1]
186 |
187 | if self.shared_codebook:
188 | codebook0 = VQEmbedding(self.n_embed[0],
189 | embed_dim,
190 | decay=self.decay[0],
191 | restart_unused_codes=restart_unused_codes,
192 | )
193 | self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])])
194 | else:
195 | codebooks = [VQEmbedding(self.n_embed[idx],
196 | embed_dim,
197 | decay=self.decay[idx],
198 | restart_unused_codes=restart_unused_codes,
199 | ) for idx in range(self.code_shape[-1])]
200 | self.codebooks = nn.ModuleList(codebooks)
201 |
202 | self.commitment_loss = commitment_loss
203 |
204 | def to_code_shape(self, x):
205 | (B, H, W, D) = x.shape
206 | (rH, rW, _) = self.shape_divisor
207 |
208 | x = x.reshape(B, H//rH, rH, W//rW, rW, D)
209 | x = x.permute(0, 1, 3, 2, 4, 5)
210 | x = x.reshape(B, H//rH, W//rW, -1)
211 |
212 | return x
213 |
214 | def to_latent_shape(self, x):
215 | (B, h, w, _) = x.shape
216 | (_, _, D) = self.latent_shape
217 | (rH, rW, _) = self.shape_divisor
218 |
219 | x = x.reshape(B, h, w, rH, rW, D)
220 | x = x.permute(0, 1, 3, 2, 4, 5)
221 | x = x.reshape(B, h*rH, w*rW, D)
222 |
223 | return x
224 |
225 | def quantize(self, x):
226 | r"""
227 | Return list of quantized features and the selected codewords by the residual quantization.
228 | The code is selected by the residuals between x and quantized features by the previous codebooks.
229 |
230 | Arguments:
231 | x (Tensor): bottleneck feature maps to quantize.
232 |
233 | Returns:
234 | quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks.
235 | codes (LongTensor): codewords index, corresponding to quants.
236 |
237 | Shape:
238 | - x: (B, h, w, embed_dim)
239 | - quant_list[i]: (B, h, w, embed_dim)
240 | - codes: (B, h, w, d)
241 | """
242 | B, h, w, embed_dim = x.shape
243 |
244 | residual_feature = x.detach().clone()
245 |
246 | quant_list = []
247 | code_list = []
248 | aggregated_quants = torch.zeros_like(x)
249 | for i in range(self.code_shape[-1]):
250 | quant, code = self.codebooks[i](residual_feature)
251 |
252 | residual_feature.sub_(quant)
253 | aggregated_quants.add_(quant)
254 |
255 | quant_list.append(aggregated_quants.clone())
256 | code_list.append(code.unsqueeze(-1))
257 |
258 | codes = torch.cat(code_list, dim=-1)
259 | return quant_list, codes
260 |
261 | def forward(self, x):
262 | x_reshaped = self.to_code_shape(x)
263 | quant_list, codes = self.quantize(x_reshaped)
264 |
265 | commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list)
266 | quants_trunc = self.to_latent_shape(quant_list[-1])
267 | quants_trunc = x + (quants_trunc - x).detach()
268 |
269 | return quants_trunc, commitment_loss, codes
270 |
271 | def compute_commitment_loss(self, x, quant_list):
272 | r"""
273 | Compute the commitment loss for the residual quantization.
274 | The loss is iteratively computed by aggregating quantized features.
275 | """
276 | loss_list = []
277 |
278 | for idx, quant in enumerate(quant_list):
279 | partial_loss = (x-quant.detach()).pow(2.0).mean()
280 | loss_list.append(partial_loss)
281 |
282 | commitment_loss = torch.mean(torch.stack(loss_list))
283 | return commitment_loss
284 |
285 | @torch.no_grad()
286 | def embed_code(self, code):
287 | assert code.shape[1:] == self.code_shape
288 |
289 | code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1)
290 |
291 | if self.shared_codebook:
292 | embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)]
293 | else:
294 | embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)]
295 |
296 | embeds = torch.cat(embeds, dim=-2).sum(-2)
297 | embeds = self.to_latent_shape(embeds)
298 |
299 | return embeds
300 |
301 | @torch.no_grad()
302 | def embed_code_with_depth(self, code, to_latent_shape=False):
303 | assert code.shape[-1] == self.code_shape[-1]
304 |
305 | code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1)
306 |
307 | if self.shared_codebook:
308 | embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)]
309 | else:
310 | embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)]
311 |
312 | if to_latent_shape:
313 | embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds]
314 | embeds = torch.cat(embeds, dim=-2)
315 |
316 | return embeds, None
317 |
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/siglip/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import (
17 | OptionalDependencyNotAvailable,
18 | _LazyModule,
19 | is_torch_available,
20 | is_vision_available,
21 | )
22 |
23 |
24 | _import_structure = {
25 | "configuration_siglip": [
26 | "SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
27 | "SiglipConfig",
28 | "SiglipTextConfig",
29 | "SiglipVisionConfig",
30 | ],
31 | "processing_siglip": ["SiglipProcessor"],
32 | "tokenization_siglip": ["SiglipTokenizer"],
33 | }
34 |
35 | try:
36 | if not is_vision_available():
37 | raise OptionalDependencyNotAvailable()
38 | except OptionalDependencyNotAvailable:
39 | pass
40 | else:
41 | _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
42 |
43 | try:
44 | if not is_torch_available():
45 | raise OptionalDependencyNotAvailable()
46 | except OptionalDependencyNotAvailable:
47 | pass
48 | else:
49 | _import_structure["modeling_siglip"] = [
50 | "SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
51 | "SiglipModel",
52 | "SiglipPreTrainedModel",
53 | "SiglipTextModel",
54 | "SiglipVisionModel",
55 | ]
56 |
57 |
58 | if TYPE_CHECKING:
59 | from .configuration_siglip import (
60 | SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
61 | SiglipConfig,
62 | SiglipTextConfig,
63 | SiglipVisionConfig,
64 | )
65 | from .processing_siglip import SiglipProcessor
66 | from .tokenization_siglip import SiglipTokenizer
67 |
68 | try:
69 | if not is_vision_available():
70 | raise OptionalDependencyNotAvailable()
71 | except OptionalDependencyNotAvailable:
72 | pass
73 | else:
74 | from .image_processing_siglip import SiglipImageProcessor
75 |
76 | try:
77 | if not is_torch_available():
78 | raise OptionalDependencyNotAvailable()
79 | except OptionalDependencyNotAvailable:
80 | pass
81 | else:
82 | from .modeling_siglip import (
83 | SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
84 | SiglipModel,
85 | SiglipPreTrainedModel,
86 | SiglipTextModel,
87 | SiglipVisionModel,
88 | )
89 |
90 |
91 | else:
92 | import sys
93 |
94 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
95 |
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/siglip/processing_siglip.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Image/Text processor class for SigLIP.
17 | """
18 |
19 | from typing import List, Optional, Union
20 |
21 | from transformers.feature_extraction_utils import BatchFeature
22 | from transformers.image_utils import ImageInput
23 | from transformers.processing_utils import ProcessorMixin
24 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25 | from transformers.utils import TensorType
26 |
27 |
28 | class SiglipProcessor(ProcessorMixin):
29 | r"""
30 | Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
31 |
32 | [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
33 | [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
34 |
35 | Args:
36 | image_processor ([`SiglipImageProcessor`]):
37 | The image processor is a required input.
38 | tokenizer ([`SiglipTokenizer`]):
39 | The tokenizer is a required input.
40 | """
41 |
42 | attributes = ["image_processor", "tokenizer"]
43 | image_processor_class = "SiglipImageProcessor"
44 | tokenizer_class = "SiglipTokenizer"
45 |
46 | def __init__(self, image_processor, tokenizer):
47 | super().__init__(image_processor, tokenizer)
48 |
49 | def __call__(
50 | self,
51 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
52 | images: ImageInput = None,
53 | padding: Union[bool, str, PaddingStrategy] = "max_length",
54 | truncation: Union[bool, str, TruncationStrategy] = None,
55 | max_length=None,
56 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
57 | ) -> BatchFeature:
58 | """
59 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
60 | and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
61 | the text. To prepare the image(s), this method forwards the `images` argument to
62 | SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
63 | of the above two methods for more information.
64 |
65 | Args:
66 | text (`str`, `List[str]`, `List[List[str]]`):
67 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
68 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
69 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
70 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
71 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
72 | tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
73 | number of channels, H and W are image height and width.
74 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`):
75 | Select a strategy to pad the returned sequences (according to the model's padding side and padding
76 | index) among:
77 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
78 | sequence if provided).
79 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
80 | acceptable input length for the model if that argument is not provided.
81 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
82 | lengths).
83 | max_length (`int`, *optional*):
84 | Maximum length of the returned list and optionally padding length (see above).
85 | truncation (`bool`, *optional*):
86 | Activates truncation to cut input sequences longer than `max_length` to `max_length`.
87 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
88 | If set, will return tensors of a particular framework. Acceptable values are:
89 |
90 | - `'tf'`: Return TensorFlow `tf.constant` objects.
91 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
92 | - `'np'`: Return NumPy `np.ndarray` objects.
93 | - `'jax'`: Return JAX `jnp.ndarray` objects.
94 |
95 | Returns:
96 | [`BatchFeature`]: A [`BatchFeature`] with the following fields:
97 |
98 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
99 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
100 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
101 | `None`).
102 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
103 | """
104 |
105 | if text is None and images is None:
106 | raise ValueError("You have to specify either text or images. Both cannot be none.")
107 |
108 | if text is not None:
109 | encoding = self.tokenizer(
110 | text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
111 | )
112 |
113 | if images is not None:
114 | image_features = self.image_processor(images, return_tensors=return_tensors)
115 |
116 | if text is not None and images is not None:
117 | encoding["pixel_values"] = image_features.pixel_values
118 | return encoding
119 | elif text is not None:
120 | return encoding
121 | else:
122 | return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
123 |
124 | def decode(self, *args, **kwargs):
125 | """
126 | This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
127 | the docstring of this method for more information.
128 | """
129 | return self.tokenizer.decode(*args, **kwargs)
130 |
131 | def batch_decode(self, *args, **kwargs):
132 | """
133 | This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
134 | refer to the docstring of this method for more information.
135 | """
136 | return self.tokenizer.batch_decode(*args, **kwargs)
137 |
138 | @property
139 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
140 | def model_input_names(self):
141 | tokenizer_input_names = self.tokenizer.model_input_names
142 | image_processor_input_names = self.image_processor.model_input_names
143 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
144 |
--------------------------------------------------------------------------------
/vila_u/model/multimodal_encoder/rqvaesigliptransformer_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPImageProcessor, PreTrainedModel, PretrainedConfig
5 | from transformers.image_processing_utils import BaseImageProcessor
6 |
7 | from .rqvaesigliptransformer import RQVAESIGLIPTransformerConfig, RQVAESIGLIPTransformer
8 |
9 |
10 | class RQVAESIGLIPTransformerVisionTower(nn.Module):
11 | def __init__(self, model_name_or_path, config: PretrainedConfig):
12 | super().__init__()
13 | self.config = RQVAESIGLIPTransformerConfig.from_pretrained(model_name_or_path)
14 | self.vision_tower = RQVAESIGLIPTransformer.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
15 | self.is_loaded = True
16 |
17 | if self.config.hidden_size == 1152:
18 | self.image_processor = CLIPImageProcessor(
19 | size={"height": 384, "width": 384},
20 | crop_size={"height": 384, "width": 384},
21 | image_mean=[0.5, 0.5, 0.5],
22 | image_std=[0.5, 0.5, 0.5]
23 | )
24 | self.image_tokens = 729
25 | elif self.config.hidden_size == 1024:
26 | self.image_processor = CLIPImageProcessor(
27 | size={"height": 256, "width": 256},
28 | crop_size={"height": 256, "width": 256},
29 | image_mean=[0.5, 0.5, 0.5],
30 | image_std=[0.5, 0.5, 0.5]
31 | )
32 | self.image_tokens = 256
33 | else:
34 | raise NotImplementedError()
35 |
36 | def forward(self, images, text_vocab_size):
37 | output = self.vision_tower.rqvaesiglip.encode_image(images)
38 | image_features, tokens = output[-1], output[-2]
39 |
40 | bs, patch_size, _, dim = image_features.shape
41 | image_features = torch.reshape(image_features, [bs, patch_size**2, dim])
42 | tokens = torch.add(torch.reshape(tokens, [bs, patch_size**2, -1]), text_vocab_size)
43 |
44 | return image_features, tokens
--------------------------------------------------------------------------------
/vila_u/model/multimodal_projector/base_projector.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch.nn as nn
3 | import torch
4 |
5 | from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
6 |
7 |
8 | class IdentityMap(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, *args, **kwargs):
13 | return x
14 |
15 | @property
16 | def config(self):
17 | return {"mm_projector_type": "identity"}
18 |
19 |
20 | class MultimodalProjectorConfig(PretrainedConfig):
21 | model_type = "v2l_projector"
22 |
23 | def __init__(self, mm_projector_type: str=None, **kwargs):
24 | super().__init__()
25 |
26 | self.mm_projector_type = mm_projector_type
27 |
28 |
29 | class MultimodalProjector(PreTrainedModel):
30 | config_class = MultimodalProjectorConfig
31 |
32 | def __init__(
33 | self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig
34 | ):
35 | super().__init__(mm_projector_cfg)
36 | mm_projector_type = mm_projector_cfg.mm_projector_type
37 | if mm_projector_type == "identity":
38 | self.layers = IdentityMap()
39 | elif mm_projector_type == "linear":
40 | self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
41 | else:
42 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
43 | if mlp_gelu_match:
44 | mlp_depth = int(mlp_gelu_match.group(1))
45 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
46 | for _ in range(1, mlp_depth):
47 | modules.append(nn.GELU())
48 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
49 | self.layers = nn.Sequential(*modules)
50 | else:
51 | raise ValueError(f"Unknown projector type: {mm_projector_type}")
52 |
53 | def forward(self, x, *args, **kwargs):
54 | return self.layers(x)
55 |
56 |
57 | AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
58 | AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
--------------------------------------------------------------------------------
/vila_u/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from transformers import PretrainedConfig, PreTrainedModel
5 |
6 | from .base_projector import MultimodalProjectorConfig, MultimodalProjector
7 |
8 |
9 | def build_mm_projector(
10 | model_type_or_path: str, config: PretrainedConfig
11 | ) -> PreTrainedModel:
12 | if model_type_or_path is None:
13 | return None
14 |
15 | if config.resume_path:
16 | assert os.path.exists(
17 | model_type_or_path
18 | ), f"Resume mm projector path {model_type_or_path} does not exist!"
19 | return MultimodalProjector.from_pretrained(
20 | model_type_or_path, config, torch_dtype=eval(config.model_dtype)
21 | )
22 | else:
23 | mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
24 | mm_projector = MultimodalProjector(mm_projector_cfg, config).to(
25 | eval(config.model_dtype)
26 | )
27 | return mm_projector
--------------------------------------------------------------------------------
/vila_u/model/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | from huggingface_hub import snapshot_download, repo_exists
5 | from huggingface_hub.utils import HFValidationError
6 | from transformers import PretrainedConfig
7 |
8 |
9 | def get_model_config(config):
10 | default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
11 |
12 | if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
13 | root_path = config._name_or_path
14 | else:
15 | root_path = config.resume_path
16 |
17 | if root_path is not None and not osp.exists(root_path):
18 | try:
19 | valid_hf_repo = repo_exists(root_path)
20 | except HFValidationError as e:
21 | valid_hf_repo = False
22 | if valid_hf_repo:
23 | root_path = snapshot_download(root_path)
24 |
25 | return_list = []
26 | for key in default_keys:
27 | cfg = getattr(config, key, None)
28 | if isinstance(cfg, dict):
29 | try:
30 | return_list.append(os.path.join(root_path, key[:-4]))
31 | except:
32 | raise ValueError(f"Cannot find resume path in config for {key}!")
33 | elif isinstance(cfg, PretrainedConfig):
34 | return_list.append(os.path.join(root_path, key[:-4]))
35 | elif isinstance(cfg, str):
36 | return_list.append(cfg)
37 |
38 | return return_list
--------------------------------------------------------------------------------
/vila_u/train/args.py:
--------------------------------------------------------------------------------
1 | import transformers
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Dict, Optional, Sequence, List
5 |
6 |
7 | @dataclass
8 | class DataArguments:
9 | data_mixture: str = "llava_1_5_mm_align"
10 | image_aspect_ratio: str = "square"
11 | lazy_preprocess: bool = False
12 | vflan_no_system_prompt: bool = False
13 | num_video_frames: int = 8
14 |
15 |
16 | @dataclass
17 | class ModelArguments:
18 | version: Optional[str] = field(default="v0")
19 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
20 | vision_tower: Optional[str] = field(default="google/siglip-so400m-patch14-384")
21 | mm_projector: Optional[str] = field(default="mlp2x_gelu")
22 | mm_use_im_start_end: bool = field(default=False)
23 | mm_use_vi_start_end: bool = field(default=False)
24 | mm_use_im_patch_token: bool = field(default=True)
25 | mm_vision_select_layer: Optional[int] = field(default=-1)
26 | interpolate_mode: Optional[str] = field(default="linear")
27 | drop_path_rate: Optional[float] = field(default=0.)
28 |
29 |
30 | @dataclass
31 | class TrainingArguments(transformers.TrainingArguments):
32 | cache_dir: Optional[str] = field(default=None)
33 | remove_unused_columns: bool = field(default=False)
34 | tune_vision_tower: bool = field(default=False)
35 | tune_language_model: bool = field(default=False)
36 | tune_mm_projector: bool = field(default=False)
37 | chunk_sampler: bool = field(default=False)
38 | model_dtype: str = field(default="torch.bfloat16")
39 | model_max_length: int = field(
40 | default=512,
41 | metadata={
42 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
43 | },
44 | )
45 | mm_projector_lr: Optional[float] = None
46 | group_by_modality_length: bool = field(default=False)
47 | total_time_limit: int = field(
48 | default=-1, metadata={"help": "Timeout limit for this job (in minutes)."}
49 | )
50 | pre_terminate_time: int = field(
51 | default=10,
52 | metadata={
53 | "help": "Time to terminate the task inadvance (minutes), saveing checkpoints needs time."
54 | },
55 | )
--------------------------------------------------------------------------------
/vila_u/train/callbacks/autoresume_callback.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import transformers
5 |
6 | from transformers.utils import logging
7 |
8 | logger = logging.get_logger("transformers")
9 |
10 |
11 | def rank_print(*s):
12 | if not torch.distributed.is_initialized():
13 | rank = 0
14 | else:
15 | rank =torch.distributed.get_rank()
16 | print(rank, *s)
17 |
18 | sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
19 | try:
20 | logger.info("Importing AutoResume lib...")
21 | from userlib.auto_resume import AutoResume
22 |
23 | AutoResume.init()
24 | logger.info("Found AutoResume SDK!")
25 | except:
26 | logger.warn("Did not find AutoResume SDK!")
27 | AutoResume = None
28 |
29 |
30 | class AutoResumeCallback(transformers.TrainerCallback):
31 | """
32 | A [`TrainerCallback`] that handles autoresume.
33 |
34 | Args:
35 | interval: interval (in number of iterations) between checks as to
36 | whether to suspend.
37 | """
38 |
39 | def __init__(self, interval: int = 50):
40 | self.interval = interval
41 |
42 | def on_step_end(self, args, state, control, **kwargs):
43 | if state.global_step % self.interval == 0:
44 | rank_print("AutoResumeHook: Checking whether to suspend...")
45 |
46 | # Check whether to suspend the job.
47 | should_preempt = AutoResume is not None and AutoResume.termination_requested()
48 |
49 | if should_preempt:
50 | if state.is_local_process_zero:
51 | logger.warn(f"AutoResumeHook: Request resume...")
52 | if AutoResume is not None:
53 | AutoResume.request_resume()
54 | control.should_training_stop = True
55 | control.should_save = True
56 |
--------------------------------------------------------------------------------
/vila_u/train/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | import transformers
5 |
6 | from torch.utils.data import Dataset
7 | from transformers import HfArgumentParser, AutoConfig
8 | from transformers import set_seed
9 | from typing import Dict, Tuple, cast
10 |
11 | from vila_u import conversation as conversation_lib
12 | from vila_u.data import make_supervised_data_module
13 | from vila_u.model import VILAULlamaModel, VILAULlamaConfig
14 | from vila_u.model.multimodal_encoder.rqvaesigliptransformer_encoder import RQVAESIGLIPTransformerVisionTower
15 | from vila_u.train.vila_u_trainer import VILAUTrainer
16 | from vila_u.train.args import TrainingArguments, ModelArguments, DataArguments
17 | from vila_u.train.callbacks.autoresume_callback import AutoResumeCallback
18 | from vila_u.train.utils import (
19 | get_checkpoint_path,
20 | prepare_config_for_training,
21 | mprint,
22 | )
23 |
24 | local_rank = None
25 |
26 | if "WANDB_PROJECT" not in os.environ:
27 | os.environ["WANDB_PROJECT"] = "VILA-U"
28 |
29 |
30 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
31 | """Collects the state dict and dump to disk."""
32 | if trainer.deepspeed:
33 | torch.cuda.synchronize()
34 | trainer.save_model(output_dir, _internal_call=True)
35 | return
36 |
37 | state_dict = trainer.model.state_dict()
38 | if trainer.args.should_save:
39 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
40 | del state_dict
41 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
42 |
43 |
44 | def smart_tokenizer_and_embedding_resize(
45 | special_tokens_dict: Dict,
46 | tokenizer: transformers.PreTrainedTokenizer,
47 | model: transformers.PreTrainedModel,
48 | ):
49 | """Resize tokenizer and embedding.
50 |
51 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
52 | """
53 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
54 | model.resize_token_embeddings(len(tokenizer))
55 |
56 | if num_new_tokens > 0:
57 | input_embeddings = model.get_input_embeddings().weight.data
58 | output_embeddings = model.get_output_embeddings().weight.data
59 |
60 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
61 | dim=0, keepdim=True
62 | )
63 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
64 | dim=0, keepdim=True
65 | )
66 |
67 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
68 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
69 |
70 |
71 | def train():
72 | global local_rank
73 |
74 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
75 | model_args, data_args, training_args = cast(Tuple[ModelArguments, DataArguments, TrainingArguments], parser.parse_args_into_dataclasses())
76 | training_args.run_name = training_args.output_dir.split("/")[-1]
77 | local_rank = training_args.local_rank
78 | compute_dtype = (
79 | torch.float16
80 | if training_args.fp16
81 | else (torch.bfloat16 if training_args.bf16 else torch.float32)
82 | )
83 |
84 | set_seed(training_args.seed)
85 |
86 | resume_path, continue_training = get_checkpoint_path(training_args.output_dir)
87 |
88 | if not continue_training:
89 | print(f"Models has been ready under {training_args.output_dir}. Skipp training")
90 | exit(0)
91 |
92 | if resume_path:
93 | resume_from_checkpoint = True
94 | config = AutoConfig.from_pretrained(resume_path, trust_remote_code=True)
95 | config.resume_path = resume_path
96 | model_cls = eval(config.architectures[0])
97 | else:
98 | resume_from_checkpoint = False
99 | model_cls = VILAULlamaModel
100 | config = VILAULlamaConfig.from_pretrained(
101 | model_args.model_name_or_path,
102 | resume=resume_from_checkpoint
103 | )
104 | if getattr(config, "resume_path", None) is not None:
105 | config.resume_path = model_args.model_name_or_path
106 |
107 | prepare_config_for_training(config, model_args, training_args, data_args)
108 |
109 | model = model_cls(
110 | config=config,
111 | attn_implementation="flash_attention_2",
112 | model_max_length=training_args.model_max_length,
113 | cache_dir=training_args.cache_dir,
114 | )
115 |
116 | mprint(model)
117 |
118 | model.llm.config.use_cache = False
119 | model.get_llm().requires_grad_(training_args.tune_language_model)
120 | mprint(f"Tunable parameters:\nlanguage model {training_args.tune_language_model}")
121 |
122 | if model.get_vision_tower():
123 | model.get_vision_tower().requires_grad_(training_args.tune_vision_tower)
124 | model.get_mm_projector().requires_grad_(training_args.tune_mm_projector)
125 | if isinstance(model.get_vision_tower(), RQVAESIGLIPTransformerVisionTower):
126 | model.get_vision_tower().vision_tower.rqvaesiglip.eval()
127 | model.get_vision_tower().vision_tower.rqtransformer.requires_grad_(True)
128 | else:
129 | raise NotImplementedError()
130 | print(f"vision tower {training_args.tune_vision_tower}")
131 | print(f"mm projector {training_args.tune_mm_projector}")
132 |
133 | if not any([training_args.tune_language_model, training_args.tune_vision_tower, training_args.tune_mm_projector]):
134 | logging.warning(
135 | "You are not tuning any part of the model. Please check if this is intended."
136 | )
137 |
138 | def need_to_modify_do_sample(generation_config):
139 | if generation_config.do_sample is False:
140 | if (
141 | generation_config.temperature is not None
142 | and generation_config.temperature != 1.0
143 | ):
144 | return True
145 | if generation_config.top_p is not None and generation_config.top_p != 1.0:
146 | return True
147 | return False
148 |
149 | if need_to_modify_do_sample(model.llm.generation_config):
150 | model.llm.generation_config.do_sample = True
151 |
152 | if training_args.gradient_checkpointing:
153 | if hasattr(model.llm, "enable_input_require_grads"):
154 | model.llm.enable_input_require_grads()
155 | else:
156 |
157 | def make_inputs_require_grad(module, input, output):
158 | output.requires_grad_(True)
159 |
160 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
161 |
162 | tokenizer = model.tokenizer
163 | if model_args.version == "v0":
164 | if tokenizer.pad_token is None:
165 | smart_tokenizer_and_embedding_resize(
166 | special_tokens_dict=dict(pad_token="[PAD]"),
167 | tokenizer=tokenizer,
168 | model=model.llm,
169 | )
170 | else:
171 | tokenizer.pad_token = tokenizer.unk_token
172 | if tokenizer.pad_token is None:
173 | smart_tokenizer_and_embedding_resize(
174 | special_tokens_dict=dict(pad_token="[PAD]"),
175 | tokenizer=tokenizer,
176 | model=model.llm,
177 | )
178 | if model_args.version in conversation_lib.conv_templates:
179 | conversation_lib.default_conversation = conversation_lib.conv_templates[
180 | model_args.version
181 | ]
182 | else:
183 | conversation_lib.default_conversation = conversation_lib.conv_templates[
184 | "vicuna_v1"
185 | ]
186 |
187 | model.llm.pad_token_id = tokenizer.pad_token_id
188 | model.llm.config.tokenizer_padding_side = tokenizer.padding_side
189 | model.llm.config.tokenizer_model_max_length = tokenizer.model_max_length
190 |
191 | vision_tower = model.get_vision_tower()
192 | if vision_tower is not None:
193 | data_args.image_processor = vision_tower.image_processor
194 | data_args.is_multimodal = True
195 |
196 | model.config.num_video_frames = data_args.num_video_frames
197 | model.config.image_aspect_ratio = data_args.image_aspect_ratio
198 | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = (
199 | model_args.mm_use_im_start_end
200 | )
201 | model.config.mm_use_vi_start_end = data_args.mm_use_vi_start_end = (
202 | model_args.mm_use_vi_start_end
203 | )
204 | model.config.mm_projector_lr = training_args.mm_projector_lr
205 | training_args.use_im_start_end = model_args.mm_use_im_start_end
206 | training_args.use_vi_start_end = model_args.mm_use_vi_start_end
207 | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
208 | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
209 |
210 | data_module = make_supervised_data_module(
211 | tokenizer=tokenizer,
212 | data_args=data_args,
213 | training_args=training_args,
214 | )
215 | callbacks = [AutoResumeCallback()]
216 | trainer = VILAUTrainer(
217 | model=model,
218 | tokenizer=tokenizer,
219 | args=training_args,
220 | callbacks=callbacks,
221 | **data_module,
222 | )
223 |
224 | print(
225 | "length of dataloader:",
226 | len(trainer.get_train_dataloader()),
227 | len(trainer.train_dataset),
228 | flush=True,
229 | )
230 | print(
231 | "[GPU memory] before trainer",
232 | torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
233 | flush=True,
234 | )
235 |
236 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
237 | trainer.save_state()
238 |
239 | model.llm.config.use_cache = True
240 | model.config.resume_path = model.config._name_or_path = training_args.output_dir
241 | safe_save_model_for_hf_trainer(
242 | trainer=trainer, output_dir=training_args.output_dir
243 | )
244 |
245 | if __name__ == "__main__":
246 | train()
--------------------------------------------------------------------------------
/vila_u/train/train_mem.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 | from vila_u.train.train import train
3 | from vila_u.train.transformer_normalize_monkey_patch import patched_normalize
4 |
5 |
6 | def __len__(self):
7 | return len(self.batch_sampler)
8 |
9 |
10 | def __iter__(self):
11 | return self.batch_sampler.__iter__()
12 |
13 | if __name__ == "__main__":
14 | with (
15 | mock.patch('transformers.image_processing_utils.normalize', new=patched_normalize),
16 | mock.patch('accelerate.data_loader.BatchSamplerShard.__len__', new=__len__),
17 | mock.patch('accelerate.data_loader.BatchSamplerShard.__iter__', new=__iter__)
18 | ):
19 | train()
20 |
--------------------------------------------------------------------------------
/vila_u/train/transformer_normalize_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | from transformers.image_transforms import np, Union, Iterable, Optional, ChannelDimension, \
3 | infer_channel_dimension_format, get_channel_dimension_axis, to_channel_dimension_format
4 |
5 | def patched_normalize(
6 | image: np.ndarray,
7 | mean: Union[float, Iterable[float]],
8 | std: Union[float, Iterable[float]],
9 | data_format: Optional[ChannelDimension] = None,
10 | input_data_format: Optional[Union[str, ChannelDimension]] = None,
11 | ) -> np.ndarray:
12 | """
13 | Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
14 |
15 | image = (image - mean) / std
16 |
17 | Args:
18 | image (`np.ndarray`):
19 | The image to normalize.
20 | mean (`float` or `Iterable[float]`):
21 | The mean to use for normalization.
22 | std (`float` or `Iterable[float]`):
23 | The standard deviation to use for normalization.
24 | data_format (`ChannelDimension`, *optional*):
25 | The channel dimension format of the output image. If unset, will use the inferred format from the input.
26 | """
27 | if not isinstance(image, np.ndarray):
28 | raise ValueError("image must be a numpy array")
29 |
30 | input_data_format = infer_channel_dimension_format(image)
31 | channel_axis = get_channel_dimension_axis(image)
32 | num_channels = image.shape[channel_axis]
33 |
34 | if isinstance(mean, Iterable):
35 | if len(mean) != num_channels:
36 | if num_channels == 1:
37 | num_channels = 3
38 | image = np.concatenate([image, image, image], axis=channel_axis)
39 | else:
40 | raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
41 | else:
42 | mean = [mean] * num_channels
43 | mean = np.array(mean, dtype=image.dtype)
44 |
45 | if isinstance(std, Iterable):
46 | if len(std) != num_channels:
47 | raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
48 | else:
49 | std = [std] * num_channels
50 | std = np.array(std, dtype=image.dtype)
51 |
52 | if input_data_format == ChannelDimension.LAST:
53 | image = (image - mean) / std
54 | else:
55 | image = ((image.T - mean) / std).T
56 |
57 | image = to_channel_dimension_format(image, data_format) if data_format is not None else image
58 | return image
59 |
60 |
61 | def patch_normalize_preprocess():
62 | transformers.image_transforms.normalize = patched_normalize
--------------------------------------------------------------------------------
/vila_u/train/transformers_replace/models/llama/configuring_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ LLaMA model configuration"""
21 |
22 | from ...configuration_utils import PretrainedConfig
23 | from ...utils import logging
24 |
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29 |
30 |
31 | class LlamaConfig(PretrainedConfig):
32 | r"""
33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35 | defaults will yield a similar configuration to that of the LLaMA-7B.
36 |
37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38 | documentation from [`PretrainedConfig`] for more information.
39 |
40 |
41 | Args:
42 | vocab_size (`int`, *optional*, defaults to 32000):
43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
44 | `inputs_ids` passed when calling [`LlamaModel`]
45 | hidden_size (`int`, *optional*, defaults to 4096):
46 | Dimension of the hidden representations.
47 | intermediate_size (`int`, *optional*, defaults to 11008):
48 | Dimension of the MLP representations.
49 | num_hidden_layers (`int`, *optional*, defaults to 32):
50 | Number of hidden layers in the Transformer decoder.
51 | num_attention_heads (`int`, *optional*, defaults to 32):
52 | Number of attention heads for each attention layer in the Transformer decoder.
53 | num_key_value_heads (`int`, *optional*):
54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58 | by meanpooling all the original heads within that group. For more details checkout [this
59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60 | `num_attention_heads`.
61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62 | The non-linear activation function (function or string) in the decoder.
63 | max_position_embeddings (`int`, *optional*, defaults to 2048):
64 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
65 | Llama 2 up to 4096, CodeLlama up to 16384.
66 | initializer_range (`float`, *optional*, defaults to 0.02):
67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69 | The epsilon used by the rms normalization layers.
70 | use_cache (`bool`, *optional*, defaults to `True`):
71 | Whether or not the model should return the last key/values attentions (not used by all models). Only
72 | relevant if `config.is_decoder=True`.
73 | pad_token_id (`int`, *optional*):
74 | Padding token id.
75 | bos_token_id (`int`, *optional*, defaults to 1):
76 | Beginning of stream token id.
77 | eos_token_id (`int`, *optional*, defaults to 2):
78 | End of stream token id.
79 | pretraining_tp (`int`, *optional*, defaults to 1):
80 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
81 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
82 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
83 | issue](https://github.com/pytorch/pytorch/issues/76232).
84 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
85 | Whether to tie weight embeddings
86 | rope_theta (`float`, *optional*, defaults to 10000.0):
87 | The base period of the RoPE embeddings.
88 | rope_scaling (`Dict`, *optional*):
89 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
90 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
91 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
92 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
93 | these scaling strategies behave:
94 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
95 | experimental feature, subject to breaking API changes in future versions.
96 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
97 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
98 | attention_dropout (`float`, *optional*, defaults to 0.0):
99 | The dropout ratio for the attention probabilities.
100 |
101 | ```python
102 | >>> from transformers import LlamaModel, LlamaConfig
103 |
104 | >>> # Initializing a LLaMA llama-7b style configuration
105 | >>> configuration = LlamaConfig()
106 |
107 | >>> # Initializing a model from the llama-7b style configuration
108 | >>> model = LlamaModel(configuration)
109 |
110 | >>> # Accessing the model configuration
111 | >>> configuration = model.config
112 | ```"""
113 | model_type = "llama"
114 | keys_to_ignore_at_inference = ["past_key_values"]
115 |
116 | def __init__(
117 | self,
118 | vocab_size=32000,
119 | hidden_size=4096,
120 | intermediate_size=11008,
121 | num_hidden_layers=32,
122 | num_attention_heads=32,
123 | num_key_value_heads=None,
124 | hidden_act="silu",
125 | max_position_embeddings=2048,
126 | initializer_range=0.02,
127 | rms_norm_eps=1e-6,
128 | use_cache=True,
129 | pad_token_id=None,
130 | bos_token_id=1,
131 | eos_token_id=2,
132 | pretraining_tp=1,
133 | tie_word_embeddings=False,
134 | rope_theta=10000.0,
135 | rope_scaling=None,
136 | attention_bias=False,
137 | attention_dropout=0.0,
138 | **kwargs,
139 | ):
140 | self.vocab_size = vocab_size
141 | self.max_position_embeddings = max_position_embeddings
142 | self.hidden_size = hidden_size
143 | self.intermediate_size = intermediate_size
144 | self.num_hidden_layers = num_hidden_layers
145 | self.num_attention_heads = num_attention_heads
146 |
147 | # for backward compatibility
148 | if num_key_value_heads is None:
149 | num_key_value_heads = num_attention_heads
150 |
151 | self.num_key_value_heads = num_key_value_heads
152 | self.hidden_act = hidden_act
153 | self.initializer_range = initializer_range
154 | self.rms_norm_eps = rms_norm_eps
155 | self.pretraining_tp = pretraining_tp
156 | self.use_cache = use_cache
157 | self.rope_theta = rope_theta
158 | self.rope_scaling = rope_scaling
159 | self._rope_scaling_validation()
160 | self.attention_bias = attention_bias
161 | self.attention_dropout = attention_dropout
162 |
163 | super().__init__(
164 | pad_token_id=pad_token_id,
165 | bos_token_id=bos_token_id,
166 | eos_token_id=eos_token_id,
167 | tie_word_embeddings=tie_word_embeddings,
168 | **kwargs,
169 | )
170 |
171 | def _rope_scaling_validation(self):
172 | """
173 | Validate the `rope_scaling` configuration.
174 | """
175 | if self.rope_scaling is None:
176 | return
177 |
178 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
179 | raise ValueError(
180 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
181 | f"got {self.rope_scaling}"
182 | )
183 | rope_scaling_type = self.rope_scaling.get("type", None)
184 | rope_scaling_factor = self.rope_scaling.get("factor", None)
185 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
186 | raise ValueError(
187 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
188 | )
189 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
190 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
191 |
--------------------------------------------------------------------------------
/vila_u/train/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import re
4 | import torch
5 |
6 | from dataclasses import dataclass
7 | from transformers import PretrainedConfig
8 |
9 |
10 | def rprint(*args, **kwargs):
11 | rank = int(os.environ.get("RANK", 0))
12 | world_size = int(os.environ.get("WORLD_SIZE", 1))
13 | if world_size > 1:
14 | return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs)
15 | else:
16 | return print(*args, **kwargs)
17 |
18 |
19 | def mprint(*args, **kwargs):
20 | rank = int(os.environ.get("RANK", 0))
21 | world_size = int(os.environ.get("WORLD_SIZE", 1))
22 | if world_size > 1:
23 | if rank == 0:
24 | return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs)
25 | else:
26 | return
27 | else:
28 | return print(*args, **kwargs)
29 |
30 |
31 | def is_local(model_name_or_path: str) -> bool:
32 | return os.path.isdir(model_name_or_path)
33 |
34 |
35 | def get_checkpoint_path(
36 | output_dir: str, checkpoint_prefix: str = "checkpoint"
37 | ) -> str | None:
38 | output_dir = os.path.abspath(output_dir)
39 | pathlib_dir = pathlib.Path(output_dir)
40 |
41 | if list(pathlib_dir.glob("config.json")):
42 | return output_dir, False
43 | else:
44 | try:
45 | ordering_and_checkpoint_path = []
46 | glob_checkpoints = [
47 | str(x)
48 | for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*")
49 | if os.path.isdir(x)
50 | ]
51 | for path in glob_checkpoints:
52 | regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
53 | if regex_match is not None and regex_match.groups() is not None:
54 | ordering_and_checkpoint_path.append(
55 | (int(regex_match.groups()[0]), path)
56 | )
57 | checkpoints_sorted = sorted(ordering_and_checkpoint_path)
58 | return checkpoints_sorted[-1][1], True
59 | except:
60 | return None, True
61 |
62 |
63 | def prepare_config_for_training(
64 | config: PretrainedConfig,
65 | model_args: dataclass,
66 | training_args: dataclass,
67 | data_args: dataclass,
68 | ) -> None:
69 | assert model_args.vision_tower is not None, "requires vision tower"
70 |
71 | if getattr(config, "llm_cfg", None) is None:
72 | config.llm_cfg = model_args.model_name_or_path
73 | if getattr(config, "vision_tower_cfg", None) is None:
74 | config.vision_tower_cfg = model_args.vision_tower
75 | if getattr(config, "mm_projector_cfg", None) is None:
76 | config.mm_projector_cfg = model_args.mm_projector
77 |
78 | config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16
79 | config.model_dtype = config.model_dtype.__str__()
80 |
81 | config.tune_language_model = training_args.tune_language_model
82 | config.tune_vision_tower = training_args.tune_vision_tower
83 | config.tune_mm_projector = training_args.tune_mm_projector
84 |
85 | config.image_aspect_ratio = data_args.image_aspect_ratio
86 |
87 | if getattr(config, "vision_tower_cfg", None) is not None:
88 | config.mm_vision_select_layer = model_args.mm_vision_select_layer
89 | config.interpolate_mode = model_args.interpolate_mode
90 | config.drop_path_rate = model_args.drop_path_rate
--------------------------------------------------------------------------------
/vila_u/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
--------------------------------------------------------------------------------
/vila_u/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | from torch import distributed as dist
5 | from typing import Any, List, Optional
6 |
7 |
8 | __all__ = [
9 | "init",
10 | "is_initialized",
11 | "size",
12 | "rank",
13 | "local_size",
14 | "local_rank",
15 | "is_main",
16 | "barrier",
17 | "gather",
18 | "all_gather",
19 | ]
20 |
21 |
22 | def init() -> None:
23 | if "RANK" not in os.environ:
24 | warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
25 | return
26 | dist.init_process_group(backend="nccl", init_method="env://")
27 |
28 |
29 | def is_initialized() -> bool:
30 | return dist.is_initialized()
31 |
32 |
33 | def size() -> int:
34 | return int(os.environ.get("WORLD_SIZE", 1))
35 |
36 |
37 | def rank() -> int:
38 | return int(os.environ.get("RANK", 0))
39 |
40 |
41 | def local_size() -> int:
42 | return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
43 |
44 |
45 | def local_rank() -> int:
46 | return int(os.environ.get("LOCAL_RANK", 0))
47 |
48 |
49 | def is_main() -> bool:
50 | return rank() == 0
51 |
52 |
53 | def barrier() -> None:
54 | dist.barrier()
55 |
56 |
57 | def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
58 | if is_main():
59 | objs = [None for _ in range(size())]
60 | dist.gather_object(obj, objs, dst=dst)
61 | return objs
62 | else:
63 | dist.gather_object(obj, dst=dst)
64 | return None
65 |
66 |
67 | def all_gather(obj: Any) -> List[Any]:
68 | objs = [None for _ in range(size())]
69 | dist.all_gather_object(objs, obj)
70 | return objs
71 |
--------------------------------------------------------------------------------
/vila_u/utils/io.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 | import pickle
5 | import torch
6 | import yaml
7 |
8 | from contextlib import contextmanager
9 | from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, TextIO, Union
10 |
11 |
12 | __all__ = [
13 | "load",
14 | "save",
15 | "load_json",
16 | "save_json",
17 | "load_jsonl",
18 | "save_jsonl",
19 | "load_mat",
20 | "save_mat",
21 | "load_npy",
22 | "save_npy",
23 | "load_npz",
24 | "save_npz",
25 | "load_pt",
26 | "save_pt",
27 | "load_yaml",
28 | "save_yaml",
29 | ]
30 |
31 |
32 | @contextmanager
33 | def file_descriptor(f: Union[str, IO], mode: str = "r") -> Iterator[IO]:
34 | opened = False
35 | try:
36 | if isinstance(f, str):
37 | f = open(f, mode)
38 | opened = True
39 | yield f
40 | finally:
41 | if opened:
42 | f.close()
43 |
44 |
45 | def load_json(f: Union[str, TextIO], **kwargs) -> Any:
46 | with file_descriptor(f, mode="r") as fd:
47 | return json.load(fd, **kwargs)
48 |
49 |
50 | def save_json(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
51 | with file_descriptor(f, mode="w") as fd:
52 | json.dump(obj, fd, **kwargs)
53 |
54 |
55 | def load_jsonl(f: Union[str, TextIO], **kwargs) -> Any:
56 | with file_descriptor(f, mode="r") as fd:
57 | return [json.loads(datum, **kwargs) for datum in fd.readlines()]
58 |
59 |
60 | def save_jsonl(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
61 | with file_descriptor(f, mode="w") as fd:
62 | fd.write("\n".join(json.dumps(datum, **kwargs) for datum in obj))
63 |
64 |
65 | def load_mat(f: Union[str, BinaryIO], **kwargs) -> Any:
66 | import scipy.io
67 |
68 | return scipy.io.loadmat(f, **kwargs)
69 |
70 |
71 | def save_mat(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
72 | import scipy.io
73 |
74 | scipy.io.savemat(f, obj, **kwargs)
75 |
76 |
77 | def load_npy(f: Union[str, BinaryIO], **kwargs) -> Any:
78 | return np.load(f, **kwargs)
79 |
80 |
81 | def save_npy(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
82 | np.save(f, obj, **kwargs)
83 |
84 |
85 | def load_npz(f: Union[str, BinaryIO], **kwargs) -> Any:
86 | return np.load(f, **kwargs)
87 |
88 |
89 | def save_npz(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
90 | np.savez(f, obj, **kwargs)
91 |
92 |
93 | def load_pkl(f: Union[str, BinaryIO], **kwargs) -> Any:
94 | with file_descriptor(f, mode="rb") as fd:
95 | try:
96 | return pickle.load(fd, **kwargs)
97 | except UnicodeDecodeError:
98 | if "encoding" in kwargs:
99 | raise
100 | fd.seek(0)
101 | return pickle.load(fd, encoding="latin1", **kwargs)
102 |
103 |
104 | def save_pkl(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
105 | with file_descriptor(f, mode="wb") as fd:
106 | pickle.dump(obj, fd, **kwargs)
107 |
108 |
109 | def load_pt(f: Union[str, BinaryIO], **kwargs) -> Any:
110 | return torch.load(f, **kwargs)
111 |
112 |
113 | def save_pt(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None:
114 | torch.save(obj, f, **kwargs)
115 |
116 |
117 | def load_yaml(f: Union[str, TextIO]) -> Any:
118 | with file_descriptor(f, mode="r") as fd:
119 | return yaml.safe_load(fd)
120 |
121 |
122 | def save_yaml(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
123 | with file_descriptor(f, mode="w") as fd:
124 | yaml.safe_dump(obj, fd, **kwargs)
125 |
126 |
127 | def load_txt(f: Union[str, TextIO]) -> Any:
128 | with file_descriptor(f, mode="r") as fd:
129 | return fd.read()
130 |
131 |
132 | def save_txt(f: Union[str, TextIO], obj: Any, **kwargs) -> None:
133 | with file_descriptor(f, mode="w") as fd:
134 | fd.write(obj)
135 |
136 |
137 | __io_registry: Dict[str, Dict[str, Callable]] = {
138 | ".txt": {"load": load_txt, "save": save_txt},
139 | ".json": {"load": load_json, "save": save_json},
140 | ".jsonl": {"load": load_jsonl, "save": save_jsonl},
141 | ".mat": {"load": load_mat, "save": save_mat},
142 | ".npy": {"load": load_npy, "save": save_npy},
143 | ".npz": {"load": load_npz, "save": save_npz},
144 | ".pkl": {"load": load_pkl, "save": save_pkl},
145 | ".pt": {"load": load_pt, "save": save_pt},
146 | ".pth": {"load": load_pt, "save": save_pt},
147 | ".pth.tar": {"load": load_pt, "save": save_pt},
148 | ".yaml": {"load": load_yaml, "save": save_yaml},
149 | ".yml": {"load": load_yaml, "save": save_yaml},
150 | }
151 |
152 |
153 | def load(fpath: str, **kwargs) -> Any:
154 | assert isinstance(fpath, str), type(fpath)
155 |
156 | for extension in sorted(__io_registry.keys(), key=len, reverse=True):
157 | if fpath.endswith(extension) and "load" in __io_registry[extension]:
158 | return __io_registry[extension]["load"](fpath, **kwargs)
159 |
160 | raise NotImplementedError(f'"{fpath}" cannot be loaded.')
161 |
162 |
163 | def save(fpath: str, obj: Any, **kwargs) -> None:
164 | assert isinstance(fpath, str), type(fpath)
165 | os.makedirs(os.path.dirname(fpath), exist_ok=True)
166 |
167 | for extension in sorted(__io_registry.keys(), key=len, reverse=True):
168 | if fpath.endswith(extension) and "save" in __io_registry[extension]:
169 | __io_registry[extension]["save"](fpath, obj, **kwargs)
170 | return
171 |
172 | raise NotImplementedError(f'"{fpath}" cannot be saved.')
173 |
--------------------------------------------------------------------------------
/vila_u/utils/logging.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | if typing.TYPE_CHECKING:
4 | from loguru import Logger
5 | else:
6 | Logger = None
7 |
8 | __all__ = ["logger"]
9 |
10 |
11 | def __get_logger() -> Logger:
12 | from loguru import logger
13 |
14 | return logger
15 |
16 |
17 | logger = __get_logger()
--------------------------------------------------------------------------------
/vila_u/utils/media.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import numpy as np
4 | import os
5 | import requests
6 | import PIL
7 | import PIL.Image
8 |
9 | from collections import defaultdict
10 | from transformers import PretrainedConfig
11 | from typing import Any, Dict, List, Optional, Union
12 |
13 | from vila_u.constants import DEFAULT_IMAGE_TOKEN
14 | from vila_u.media import Image, Video
15 | from vila_u.utils import make_list
16 | from vila_u.utils.logging import logger
17 |
18 | __all__ = ["extract_media"]
19 |
20 |
21 | def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
22 | if isinstance(image, Image):
23 | if image.path.startswith("http://") or image.path.startswith("https://"):
24 | image = PIL.Image.open(requests.get(image.path, stream=True).raw)
25 | else:
26 | image = PIL.Image.open(image.path)
27 | return image
28 |
29 |
30 | def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
31 | # Load video frames from a directory
32 | if os.path.isdir(video_path):
33 | frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
34 | indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
35 | return [PIL.Image.open(frame_paths[index]) for index in indices]
36 |
37 | # Load video frames from a video file
38 | vidcap = cv2.VideoCapture(video_path)
39 |
40 | # Find the last frame as frame count might not be accurate
41 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
42 | while frame_count > 0:
43 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
44 | if vidcap.grab():
45 | break
46 | frame_count -= 1
47 | else:
48 | raise ValueError(f"Video '{video_path}' has no frames.")
49 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, 0)
50 |
51 | # Extract frames uniformly
52 | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
53 | frames = {}
54 | for index in range(frame_count):
55 | success = vidcap.grab()
56 | if not success:
57 | raise ValueError(f"Failed to grab frame {index} from video '{video_path}'.")
58 | if index not in indices:
59 | continue
60 | success, frame = vidcap.retrieve()
61 | if not success:
62 | logger.warning(f"Failed to retrieve frame {index} from video '{video_path}'. Skipped.")
63 | continue
64 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65 | frames[index] = PIL.Image.fromarray(frame)
66 | return [frames[index] for index in indices if index in frames]
67 |
68 |
69 | def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
70 | num_frames = config.num_video_frames
71 |
72 | frames = _load_video(video.path, num_frames=num_frames)
73 | return frames
74 |
75 |
76 | def extract_media(
77 | messages: List[Dict[str, Any]],
78 | config: Optional[PretrainedConfig] = None,
79 | draft: bool = False,
80 | ) -> Dict[str, List[Any]]:
81 | media = defaultdict(list)
82 | for message in messages:
83 | text = ""
84 | for part in make_list(message["value"]):
85 | if isinstance(part, str):
86 | text += part
87 | elif isinstance(part, (Image, PIL.Image.Image)):
88 | if draft:
89 | media["image"].append(part)
90 | else:
91 | image = _extract_image(part)
92 | text += DEFAULT_IMAGE_TOKEN + "\n"
93 | media["image"].append(image)
94 | elif isinstance(part, Video):
95 | if draft:
96 | media["video"].append(part)
97 | else:
98 | video = _extract_video(part, config)
99 | text += (DEFAULT_IMAGE_TOKEN + "\n") * len(video)
100 | media["image"].extend(video)
101 | else:
102 | raise ValueError(f"Unsupported prompt part type: {type(part)}")
103 | message["value"] = text
104 |
105 | return media
--------------------------------------------------------------------------------
/vila_u/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 |
4 | from typing import Any, Dict, List, Optional, Sequence
5 |
6 | from vila_u import conversation as conversation_lib
7 | from vila_u.constants import IGNORE_INDEX, SENTINEL_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_VI_START_TOKEN
8 | from vila_u.mm_utils import tokenizer_image_token
9 | from vila_u.utils.logging import logger
10 |
11 | __all__ = [
12 | "tokenize_conversation",
13 | ]
14 |
15 | DUMMY_CONVERSATION = [
16 | {"from": "human", "value": "question"},
17 | {"from": "gpt", "value": "answer"},
18 | ] * 10
19 |
20 | def tokenize_conversation(
21 | messages: Sequence[Dict[str, str]],
22 | tokenizer: transformers.PreTrainedTokenizer,
23 | add_generation_prompt: bool = False,
24 | overrides: Optional[Dict[str, str]] = None,
25 | no_system_prompt: bool = False,
26 | image_generation: bool = False,
27 | video_generation: bool = False,
28 | ) -> torch.Tensor:
29 | for message in messages:
30 | message["value"] = message["value"].strip()
31 |
32 | conv = conversation_lib.default_conversation.copy()
33 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
34 |
35 | if no_system_prompt:
36 | conv.system = ""
37 |
38 | # Skip the first message if it is not from human
39 | if messages[0]["from"] != "human":
40 | messages = messages[1:]
41 |
42 | # Add a generation prompt if needed
43 | if add_generation_prompt:
44 | messages.append({"from": "gpt", "value": None})
45 |
46 | conv.messages = []
47 | for turn, message in enumerate(messages):
48 | role = roles[message["from"]]
49 | assert role == conv.roles[turn % 2]
50 | if overrides is not None and message["from"] in overrides:
51 | conv.append_message(role, overrides[message["from"]])
52 | else:
53 | conv.append_message(role, message["value"])
54 |
55 | prompt = conv.get_prompt()
56 | if image_generation:
57 | prompt += f" {DEFAULT_IM_START_TOKEN}"
58 | elif video_generation:
59 | prompt += f" {DEFAULT_VI_START_TOKEN}"
60 | else:
61 | pass
62 |
63 | return tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
64 |
65 | def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
66 | if not hasattr(tokenizer, "sentinel_token"):
67 | tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
68 | tokenizer.sentinel_token = SENTINEL_TOKEN
69 | tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
70 |
71 | def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
72 | _maybe_add_sentinel_token(tokenizer)
73 | template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
74 |
75 | stop_tokens = {tokenizer.eos_token}
76 | for k in range(template.size(0) - 1):
77 | if template[k] == tokenizer.sentinel_token_id:
78 | stop_token = tokenizer.decode(template[k + 1])
79 | stop_tokens.add(stop_token)
80 | return list(stop_tokens)
81 |
--------------------------------------------------------------------------------
/vila_u/utils/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List
2 |
3 | __all__ = ["make_list", "disable_torch_init"]
4 |
5 |
6 | def make_list(obj: Any) -> List:
7 | return obj if isinstance(obj, list) else [obj]
8 |
9 |
10 | def disable_torch_init():
11 | """
12 | Disable the redundant torch default initialization to accelerate model creation.
13 | """
14 | import torch
15 |
16 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
17 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
--------------------------------------------------------------------------------
/vila_u/wids/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
2 | # This file is part of the WebDataset library.
3 | # See the LICENSE file for licensing terms (BSD-style).
4 | #
5 | # flake8: noqa
6 |
7 | from .wids import (
8 | ChunkedSampler,
9 | DistributedChunkedSampler,
10 | ShardedSampler,
11 | ShardListDataset,
12 | DistributedLocalSampler,
13 | )
14 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_bench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | from . import wids
5 | from .compat import WebDataset
6 |
7 |
8 | def main_wids(args):
9 | desc = json.load(open(args.dataset))
10 | files = desc["files"]
11 | dataset = wids.ShardListDataset(files, cache_size=4)
12 | print(len(dataset))
13 | for i in range(len(dataset)):
14 | print(i, dataset[i]["__key__"])
15 | dataset.close()
16 |
17 |
18 | def main_wds(args):
19 | desc = json.load(open(args.dataset))
20 | files = desc["files"]
21 | urls = [f["url"] for f in files]
22 | dataset = WebDataset(urls)
23 | for i, sample in enumerate(dataset):
24 | print(i, sample["__key__"])
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser()
29 | # there are two subcommands: wids and wds
30 | subparsers = parser.add_subparsers(dest="command")
31 | wids_parser = subparsers.add_parser("wids")
32 | wds_parser = subparsers.add_parser("wds")
33 |
34 | # wids subcommand
35 | wids_parser.add_argument("dataset", help="dataset name")
36 |
37 | # wds subcommand
38 | wds_parser.add_argument("dataset", help="dataset name")
39 |
40 | args = parser.parse_args()
41 |
42 | if args.command == "wids":
43 | main_wids(args)
44 | elif args.command == "wds":
45 | main_wds(args)
46 | else:
47 | raise ValueError(f"Unknown command: {args.command}")
48 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_cleanup.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides utilities for managing files in a directory.
3 |
4 | It includes a function `keep_most_recent_files` that keeps the most recent
5 | files in a directory, deleting the rest based on the maximum size of the directory
6 | in bytes and the maximum number of files to keep.
7 |
8 | The cleanup job can be run in the background using `create_cleanup_background_process`.
9 | """
10 |
11 | import fcntl
12 | import glob
13 | import os
14 | import time
15 |
16 | import numpy as np
17 |
18 |
19 | def keep_most_recent_files(pattern, maxsize=int(1e12), maxfiles=1000, debug=False):
20 | """Keep the most recent files in a directory, deleting the rest.
21 |
22 | The maxsize is the maximum size of the directory in bytes. The maxfiles is
23 | the maximum number of files to keep. The files are sorted by modification
24 | time, and the most recent files are kept. If the directory is already
25 | smaller than maxsize, then no files are deleted. If there are fewer than
26 | maxfiles, then no files are deleted."""
27 |
28 | # get the list of files in the directory
29 | fnames = glob.glob(pattern)
30 | # compute a list of (mtime, fname, size) triples
31 | files = []
32 | for fname in fnames:
33 | try:
34 | s = os.stat(fname)
35 | except FileNotFoundError:
36 | continue
37 | files.append((s.st_mtime, fname, s.st_size))
38 | # sort the list by mtime, most recent first
39 | files.sort(reverse=True)
40 | # compute an accumulated total of the file sizes in order using np.cumsum
41 | sizes = np.cumsum([size for mtime, fname, size in files])
42 | # compute a cutoff index based on maxsize
43 | cutoff = np.searchsorted(sizes, maxsize)
44 | # compute a cutoff index based on maxfiles
45 | cutoff = min(cutoff, maxfiles)
46 | # delete the files above the cutoff in reverse order
47 | for mtime, fname, size in files[cutoff:][::-1]:
48 | try:
49 | os.unlink(fname)
50 | except FileNotFoundError:
51 | pass
52 |
53 |
54 | class ExclusiveLock:
55 | """A simple non-blocking exclusive lock using fcntl."""
56 |
57 | def __init__(self, lockfile):
58 | self.lockfile = lockfile
59 |
60 | def try_lock(self):
61 | try:
62 | self.lock = open(self.lockfile, "w")
63 | fcntl.flock(self.lock.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
64 | return True
65 | except OSError as e:
66 | if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
67 | return False
68 | else:
69 | raise
70 |
71 | def release_lock(self):
72 | self.lock.close()
73 | os.unlink(self.lockfile)
74 |
75 |
76 | def create_cleanup_background_process(
77 | pattern, maxsize=int(1e12), maxfiles=1000, every=60
78 | ):
79 | """Create a background process that keeps a directory below a certain size."""
80 |
81 | def cleanup_worker(every):
82 | # use a lock file to ensure that only one cleanup worker is running
83 | lockfile = os.path.join(os.path.dirname(pattern), ".cleanup.lock")
84 | lock = ExclusiveLock(lockfile)
85 | if not lock.try_lock():
86 | return
87 | while True:
88 | keep_most_recent_files(pattern, maxsize=maxsize, maxfiles=maxfiles)
89 | time.sleep(every)
90 |
91 | import multiprocessing
92 |
93 | p = multiprocessing.Process(target=cleanup_worker, args=(every,))
94 | p.start()
95 | return p
96 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_dir.py:
--------------------------------------------------------------------------------
1 | """
2 | # dynamically create a shard index
3 | class DirectoryDataset(ShardListDataset):
4 | def __init__(self, directory):
5 | pass
6 | """
7 |
8 | """
9 | # randomly choose shards from a directory
10 | class DirectoryQueueDataset(IterableDataset):
11 | def __init__(self, directory, strategy="replace", choice="random", downloader=None, transformations="PIL"):
12 | pass
13 | def add_transform(self, transform):
14 | pass
15 | def __iter__(self):
16 | # pick file according to strategy
17 | # rename file to .active
18 | # randomly yield samples from file
19 | # rename file back to its original name or unlink it, according to strategy
20 | pass
21 | """
22 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_dl.py:
--------------------------------------------------------------------------------
1 | import fcntl
2 | import os
3 | import shutil
4 | import sys
5 | import time
6 | from collections import deque
7 | from datetime import datetime
8 | from urllib.parse import urlparse
9 |
10 | recent_downloads = deque(maxlen=1000)
11 |
12 | open_objects = {}
13 | max_open_objects = 100
14 |
15 |
16 | class ULockFile:
17 | """A simple locking class. We don't need any of the third
18 | party libraries since we rely on POSIX semantics for linking
19 | below anyway."""
20 |
21 | def __init__(self, path):
22 | self.lockfile_path = path
23 | self.lockfile = None
24 |
25 | def __enter__(self):
26 | self.lockfile = open(self.lockfile_path, "w")
27 | fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX)
28 | return self
29 |
30 | def __exit__(self, exc_type, exc_val, exc_tb):
31 | fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
32 | self.lockfile.close()
33 | self.lockfile = None
34 | try:
35 | os.unlink(self.lockfile_path)
36 | except FileNotFoundError:
37 | pass
38 |
39 |
40 | def pipe_download(remote, local):
41 | """Perform a download for a pipe: url."""
42 | assert remote.startswith("pipe:")
43 | cmd = remote[5:]
44 | cmd = cmd.format(local=local)
45 | assert os.system(cmd) == 0, "Command failed: %s" % cmd
46 |
47 |
48 | def copy_file(remote, local):
49 | remote = urlparse(remote)
50 | assert remote.scheme in ["file", ""]
51 | # use absolute path
52 | remote = os.path.abspath(remote.path)
53 | local = urlparse(local)
54 | assert local.scheme in ["file", ""]
55 | local = os.path.abspath(local.path)
56 | if remote == local:
57 | return
58 | # check if the local file exists
59 | shutil.copyfile(remote, local)
60 |
61 |
62 | verbose_cmd = int(os.environ.get("WIDS_VERBOSE_CMD", "0"))
63 |
64 |
65 | def vcmd(flag, verbose_flag=""):
66 | return verbose_flag if verbose_cmd else flag
67 |
68 |
69 | default_cmds = {
70 | "posixpath": copy_file,
71 | "file": copy_file,
72 | "pipe": pipe_download,
73 | "http": "curl " + vcmd("-s") + " -L {url} -o {local}",
74 | "https": "curl " + vcmd("-s") + " -L {url} -o {local}",
75 | "ftp": "curl " + vcmd("-s") + " -L {url} -o {local}",
76 | "ftps": "curl " + vcmd("-s") + " -L {url} -o {local}",
77 | "gs": "gsutil " + vcmd("-q") + " cp {url} {local}",
78 | "s3": "aws s3 cp {url} {local}",
79 | }
80 |
81 | #TODO(ligeng): change HTTPS download to python requests library
82 |
83 | def download_file_no_log(remote, local, handlers=default_cmds):
84 | """Download a file from a remote url to a local path.
85 | The remote url can be a pipe: url, in which case the remainder of
86 | the url is treated as a command template that is executed to perform the download.
87 | """
88 |
89 | if remote.startswith("pipe:"):
90 | schema = "pipe"
91 | else:
92 | schema = urlparse(remote).scheme
93 | if schema is None or schema == "":
94 | schema = "posixpath"
95 | # get the handler
96 | handler = handlers.get(schema)
97 | if handler is None:
98 | raise ValueError("Unknown schema: %s" % schema)
99 | # call the handler
100 | if callable(handler):
101 | handler(remote, local)
102 | else:
103 | assert isinstance(handler, str)
104 | cmd = handler.format(url=remote, local=local)
105 | assert os.system(cmd) == 0, "Command failed: %s" % cmd
106 | return local
107 |
108 |
109 | def download_file(remote, local, handlers=default_cmds, verbose=False):
110 | start = time.time()
111 | try:
112 | return download_file_no_log(remote, local, handlers=handlers)
113 | finally:
114 | recent_downloads.append((remote, local, time.time(), time.time() - start))
115 | if verbose:
116 | print(
117 | "downloaded",
118 | remote,
119 | "to",
120 | local,
121 | "in",
122 | time.time() - start,
123 | "seconds",
124 | file=sys.stderr,
125 | )
126 |
127 |
128 | def download_and_open(remote, local, mode="rb", handlers=default_cmds, verbose=False):
129 | with ULockFile(local + ".lock"):
130 | if os.path.exists(remote):
131 | # print("enter1", remote, local, mode)
132 | result = open(remote, mode)
133 | else:
134 | # print("enter2", remote, local, mode)
135 | if not os.path.exists(local):
136 | if verbose:
137 | print("downloading", remote, "to", local, file=sys.stderr)
138 | download_file(remote, local, handlers=handlers)
139 | else:
140 | if verbose:
141 | print("using cached", local, file=sys.stderr)
142 | result = open(local, mode)
143 |
144 | # input()
145 |
146 | if open_objects is not None:
147 | for k, v in list(open_objects.items()):
148 | if v.closed:
149 | del open_objects[k]
150 | if len(open_objects) > max_open_objects:
151 | raise RuntimeError("Too many open objects")
152 | current_time = datetime.now().strftime("%Y%m%d%H%M%S")
153 | key = tuple(str(x) for x in [remote, local, mode, current_time])
154 | open_objects[key] = result
155 | return result
156 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_lru.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 |
4 | class LRUCache:
5 | def __init__(self, capacity: int, release_handler=None):
6 | """Initialize a new LRU cache with the given capacity."""
7 | self.capacity = capacity
8 | self.cache = OrderedDict()
9 | self.release_handler = release_handler
10 |
11 | def __getitem__(self, key):
12 | """Return the value associated with the given key, or None."""
13 | if key not in self.cache:
14 | return None
15 | self.cache.move_to_end(key)
16 | return self.cache[key]
17 |
18 | def __setitem__(self, key, value):
19 | """Associate the given value with the given key."""
20 | if key in self.cache:
21 | self.cache.move_to_end(key)
22 | self.cache[key] = value
23 | if len(self.cache) > self.capacity:
24 | key, value = self.cache.popitem(last=False)
25 | if self.release_handler is not None:
26 | self.release_handler(key, value)
27 |
28 | def __delitem__(self, key):
29 | """Remove the given key from the cache."""
30 | if key in self.cache:
31 | if self.release_handler is not None:
32 | value = self.cache[key]
33 | self.release_handler(key, value)
34 | del self.cache[key]
35 |
36 | def __len__(self):
37 | """Return the number of entries in the cache."""
38 | return len(self.cache)
39 |
40 | def __contains__(self, key):
41 | """Return whether the cache contains the given key."""
42 | return key in self.cache
43 |
44 | def items(self):
45 | """Return an iterator over the keys of the cache."""
46 | return self.cache.items()
47 |
48 | def keys(self):
49 | """Return an iterator over the keys of the cache."""
50 | return self.cache.keys()
51 |
52 | def values(self):
53 | """Return an iterator over the values of the cache."""
54 | return self.cache.values()
55 |
56 | def clear(self):
57 | for key in list(self.keys()):
58 | value = self.cache[key]
59 | if self.release_handler is not None:
60 | self.release_handler(key, value)
61 | del self[key]
62 |
63 | def __del__(self):
64 | self.clear()
65 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_mmtar.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import fcntl
3 | import io
4 | import mmap
5 | import os
6 | import struct
7 |
8 | TarHeader = collections.namedtuple(
9 | "TarHeader",
10 | [
11 | "name",
12 | "mode",
13 | "uid",
14 | "gid",
15 | "size",
16 | "mtime",
17 | "chksum",
18 | "typeflag",
19 | "linkname",
20 | "magic",
21 | "version",
22 | "uname",
23 | "gname",
24 | "devmajor",
25 | "devminor",
26 | "prefix",
27 | ],
28 | )
29 |
30 |
31 | def parse_tar_header(header_bytes):
32 | header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
33 | return TarHeader(*header)
34 |
35 |
36 | def next_header(offset, header):
37 | block_size = 512
38 | size = header.size.decode("utf-8").strip("\x00")
39 | if size == "":
40 | return -1
41 | size = int(size, 8)
42 | # compute the file size rounded up to the next block size if it is a partial block
43 | padded_file_size = (size + block_size - 1) // block_size * block_size
44 | return offset + block_size + padded_file_size
45 |
46 |
47 | # TODO(ligeng): support gzip stream
48 | class MMIndexedTar:
49 | def __init__(self, fname, index_file=None, verbose=True, cleanup_callback=None):
50 | self.verbose = verbose
51 | self.cleanup_callback = cleanup_callback
52 | if isinstance(fname, str):
53 | self.stream = open(fname, "rb")
54 | self.fname = fname
55 | elif isinstance(fname, io.IOBase):
56 | self.stream = fname
57 | self.fname = None
58 | self.mmapped_file = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
59 | if cleanup_callback:
60 | cleanup_callback(fname, self.stream.fileno(), "start")
61 | self._build_index()
62 |
63 | def close(self, dispose=False):
64 | if self.cleanup_callback:
65 | self.cleanup_callback(self.fname, self.stream.fileno(), "end")
66 | self.mmapped_file.close()
67 | self.stream.close()
68 |
69 | def _build_index(self):
70 | self.by_name = {}
71 | self.by_index = []
72 | offset = 0
73 | while offset >= 0 and offset < len(self.mmapped_file):
74 | header = parse_tar_header(self.mmapped_file[offset : offset + 500])
75 | name = header.name.decode("utf-8").strip("\x00")
76 | typeflag = header.typeflag.decode("utf-8").strip("\x00")
77 | if name != "" and name != "././@PaxHeader" and typeflag in ["0", ""]:
78 | try:
79 | size = int(header.size.decode("utf-8")[:-1], 8)
80 | except ValueError as exn:
81 | print(header)
82 | raise exn
83 | self.by_name[name] = offset
84 | self.by_index.append((name, offset, size))
85 | offset = next_header(offset, header)
86 |
87 | def names(self):
88 | return self.by_name.keys()
89 |
90 | def get_at_offset(self, offset):
91 | header = parse_tar_header(self.mmapped_file[offset : offset + 500])
92 | name = header.name.decode("utf-8").strip("\x00")
93 | start = offset + 512
94 | end = start + int(header.size.decode("utf-8")[:-1], 8)
95 | return name, self.mmapped_file[start:end]
96 |
97 | def get_at_index(self, index):
98 | name, offset, size = self.by_index[index]
99 | return self.get_at_offset(offset)
100 |
101 | def get_by_name(self, name):
102 | offset = self.by_name[name]
103 | return self.get_at_offset(offset)
104 |
105 | def __iter__(self):
106 | for name, offset, size in self.by_index:
107 | yield name, self.mmapped_file[offset + 512 : offset + 512 + size]
108 |
109 | def __getitem__(self, key):
110 | if isinstance(key, int):
111 | return self.get_at_index(key)
112 | else:
113 | return self.get_by_name(key)
114 |
115 | def __len__(self):
116 | return len(self.by_index)
117 |
118 | def get_file(self, i):
119 | fname, data = self.get_at_index(i)
120 | return fname, io.BytesIO(data)
121 |
122 |
123 | def keep_while_reading(fname, fd, phase, delay=0.0):
124 | """This is a possible cleanup callback for cleanup_callback of MIndexedTar.
125 |
126 | It assumes that as long as there are some readers for a file,
127 | more readers may be trying to open it.
128 |
129 | Note that on Linux, unlinking the file doesn't matter after
130 | it has been mmapped. The contents will only be deleted when
131 | all readers close the file. The unlinking merely makes the file
132 | unavailable to new readers, since the downloader checks first
133 | whether the file exists.
134 | """
135 | assert delay == 0.0, "delay not implemented"
136 | if fd < 0 or fname is None:
137 | return
138 | if phase == "start":
139 | fcntl.flock(fd, fcntl.LOCK_SH)
140 | elif phase == "end":
141 | try:
142 | fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
143 | os.unlink(fname)
144 | except FileNotFoundError:
145 | # someone else deleted it already
146 | pass
147 | except BlockingIOError:
148 | # we couldn't get an exclusive lock, so someone else is still reading
149 | pass
150 | else:
151 | raise ValueError(f"Unknown phase {phase}")
152 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_specs.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | import os
4 | import tempfile
5 | from urllib.parse import urlparse, urlunparse
6 |
7 | from .wids_dl import download_and_open
8 |
9 |
10 | def urldir(url):
11 | """Return the directory part of a url."""
12 | parsed_url = urlparse(url)
13 | path = parsed_url.path
14 | directory = os.path.dirname(path)
15 | return parsed_url._replace(path=directory).geturl()
16 |
17 |
18 | def urlmerge(base, url):
19 | """Merge a base URL and a relative URL.
20 |
21 | The function fills in any missing part of the url from the base,
22 | except for params, query, and fragment, which are taken only from the 'url'.
23 | For the pathname component, it merges the paths like os.path.join:
24 | an absolute path in 'url' overrides the base path, otherwise the paths are merged.
25 |
26 | Parameters:
27 | base (str): The base URL.
28 | url (str): The URL to merge with the base.
29 |
30 | Returns:
31 | str: The merged URL.
32 | """
33 | # Parse the base and the relative URL
34 | parsed_base = urlparse(base)
35 | parsed_url = urlparse(url)
36 |
37 | # Merge paths using os.path.join
38 | # If the url path is absolute, it overrides the base path
39 | if parsed_url.path.startswith("/"):
40 | merged_path = parsed_url.path
41 | else:
42 | merged_path = os.path.normpath(os.path.join(parsed_base.path, parsed_url.path))
43 |
44 | # Construct the merged URL
45 | merged_url = urlunparse(
46 | (
47 | parsed_url.scheme or parsed_base.scheme,
48 | parsed_url.netloc or parsed_base.netloc,
49 | merged_path,
50 | parsed_url.params, # Use params from the url only
51 | parsed_url.query, # Use query from the url only
52 | parsed_url.fragment, # Use fragment from the url only
53 | )
54 | )
55 |
56 | return merged_url
57 |
58 |
59 | def check_shards(l):
60 | """Check that a list of shards is well-formed.
61 |
62 | This checks that the list is a list of dictionaries, and that
63 | each dictionary has a "url" and a "nsamples" key.
64 | """
65 | assert isinstance(l, list)
66 | for shard in l:
67 | assert isinstance(shard, dict)
68 | assert "url" in shard
69 | assert "nsamples" in shard
70 | return l
71 |
72 |
73 | def set_all(l, k, v):
74 | """Set a key to a value in a list of dictionaries."""
75 | if v is None:
76 | return
77 | for x in l:
78 | if k not in x:
79 | x[k] = v
80 |
81 |
82 | def load_remote_dsdesc_raw(source):
83 | """Load a remote or local dataset description in JSON format."""
84 | if isinstance(source, str):
85 | with tempfile.TemporaryDirectory() as tmpdir:
86 | dlname = os.path.join(tmpdir, "dataset.json")
87 | with download_and_open(source, dlname) as f:
88 | dsdesc = json.load(f)
89 | elif isinstance(source, io.IOBase):
90 | dsdesc = json.load(source)
91 | else:
92 | # FIXME: use gopen
93 | import requests
94 |
95 | jsondata = requests.get(source).text
96 | dsdesc = json.loads(jsondata)
97 | return dsdesc
98 |
99 |
100 | def rebase_shardlist(shardlist, base):
101 | """Rebase the URLs in a shardlist."""
102 | if base is None:
103 | return shardlist
104 | for shard in shardlist:
105 | shard["url"] = urlmerge(base, shard["url"])
106 | return shardlist
107 |
108 |
109 | def resolve_dsdesc(dsdesc, *, options=None, base=None):
110 | """Resolve a dataset description.
111 |
112 | This rebases the shards as necessary and loads any remote references.
113 |
114 | Dataset descriptions are JSON files. They must have the following format;
115 |
116 | {
117 | "wids_version": 1,
118 | # optional immediate shardlist
119 | "shardlist": [
120 | {"url": "http://example.com/file.tar", "nsamples": 1000},
121 | ...
122 | ],
123 | # sub-datasets
124 | "datasets": [
125 | {"source_url": "http://example.com/dataset.json"},
126 | {"shardlist": [
127 | {"url": "http://example.com/file.tar", "nsamples": 1000},
128 | ...
129 | ]}
130 | ...
131 | ]
132 | }
133 | """
134 | if options is None:
135 | options = {}
136 | assert isinstance(dsdesc, dict)
137 | dsdesc = dict(dsdesc, **options)
138 | shardlist = rebase_shardlist(dsdesc.get("shardlist", []), base)
139 | assert shardlist is not None
140 | set_all(shardlist, "weight", dsdesc.get("weight"))
141 | set_all(shardlist, "name", dsdesc.get("name"))
142 | check_shards(shardlist)
143 | assert "wids_version" in dsdesc, "No wids_version in dataset description"
144 | assert dsdesc["wids_version"] == 1, "Unknown wids_version"
145 | for component in dsdesc.get("datasets", []):
146 | # we use the weight from the reference to the dataset,
147 | # regardless of remote loading
148 | weight = component.get("weight")
149 | # follow any source_url dsdescs through remote loading
150 | source_url = None
151 | if "source_url" in component:
152 | source_url = component["source_url"]
153 | component = load_remote_dsdesc_raw(source_url)
154 | assert (
155 | "source_url" not in component
156 | ), "double indirection in dataset description"
157 | assert "shardlist" in component, "no shardlist in dataset description"
158 | # if the component has a base, use it to rebase the shardlist
159 | # otherwise use the base from the source_url, if any
160 | subbase = component.get("base", urldir(source_url) if source_url else None)
161 | if subbase is not None:
162 | rebase_shardlist(component["shardlist"], subbase)
163 | l = check_shards(component["shardlist"])
164 | set_all(l, "weight", weight)
165 | set_all(l, "source_url", source_url)
166 | set_all(l, "dataset", component.get("name"))
167 | shardlist.extend(l)
168 | assert len(shardlist) > 0, "No shards found"
169 | dsdesc["shardlist"] = shardlist
170 | return dsdesc
171 |
172 |
173 | def load_dsdesc_and_resolve(source, *, options=None, base=None):
174 | if options is None:
175 | options = {}
176 | dsdesc = load_remote_dsdesc_raw(source)
177 | return resolve_dsdesc(dsdesc, base=base, options=options)
178 |
--------------------------------------------------------------------------------
/vila_u/wids/wids_tar.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import os.path
4 | import pickle
5 | import re
6 | import tarfile
7 |
8 | import numpy as np
9 |
10 |
11 | def find_index_file(file):
12 | prefix, last_ext = os.path.splitext(file)
13 | if re.match("._[0-9]+_$", last_ext):
14 | return prefix + ".index"
15 | else:
16 | return file + ".index"
17 |
18 |
19 | class TarFileReader:
20 | def __init__(self, file, index_file=find_index_file, verbose=True):
21 | self.verbose = verbose
22 | if callable(index_file):
23 | index_file = index_file(file)
24 | self.index_file = index_file
25 |
26 | # Open the tar file and keep it open
27 | if isinstance(file, str):
28 | self.tar_file = tarfile.open(file, "r")
29 | else:
30 | self.tar_file = tarfile.open(fileobj=file, mode="r")
31 |
32 | # Create the index
33 | self._create_tar_index()
34 |
35 | def _create_tar_index(self):
36 | if self.index_file is not None and os.path.exists(self.index_file):
37 | if self.verbose:
38 | print("Loading tar index from", self.index_file)
39 | with open(self.index_file, "rb") as stream:
40 | self.fnames, self.index = pickle.load(stream)
41 | return
42 | # Create an empty list for the index
43 | self.fnames = []
44 | self.index = []
45 |
46 | if self.verbose:
47 | print("Creating tar index for", self.tar_file.name, "at", self.index_file)
48 | # Iterate over the members of the tar file
49 | for member in self.tar_file:
50 | # If the member is a file, add it to the index
51 | if member.isfile():
52 | # Get the file's offset
53 | offset = self.tar_file.fileobj.tell()
54 | self.fnames.append(member.name)
55 | self.index.append([offset, member.size])
56 | if self.verbose:
57 | print(
58 | "Done creating tar index for", self.tar_file.name, "at", self.index_file
59 | )
60 | self.index = np.array(self.index)
61 | if self.index_file is not None:
62 | if os.path.exists(self.index_file + ".temp"):
63 | os.unlink(self.index_file + ".temp")
64 | with open(self.index_file + ".temp", "wb") as stream:
65 | pickle.dump((self.fnames, self.index), stream)
66 | os.rename(self.index_file + ".temp", self.index_file)
67 |
68 | def names(self):
69 | return self.fnames
70 |
71 | def __len__(self):
72 | return len(self.index)
73 |
74 | def get_file(self, i):
75 | name = self.fnames[i]
76 | offset, size = self.index[i]
77 | self.tar_file.fileobj.seek(offset)
78 | file_bytes = self.tar_file.fileobj.read(size)
79 | return name, io.BytesIO(file_bytes)
80 |
81 | def close(self):
82 | # Close the tar file
83 | self.tar_file.close()
84 |
--------------------------------------------------------------------------------