├── LICENSE ├── README.md ├── app.py ├── assets ├── example_image1.jpg ├── example_image2.jpg ├── example_video1.mp4 ├── example_video2.mp4 ├── method1.jpg └── method2.jpg ├── environment_setup.sh ├── inference.py ├── pyproject.toml ├── scripts ├── eval │ └── lmms.sh ├── train │ ├── pretrain.sh │ └── sft.sh └── zero2.json └── vila_u ├── __init__.py ├── cli └── eval.py ├── constants.py ├── conversation.py ├── data ├── __init__.py ├── dataset.py ├── datasets_mixture.py └── simple_vila_webdataset.py ├── entry.py ├── eval ├── __init__.py ├── lmms │ ├── models │ │ ├── __init__.py │ │ └── vila_u.py │ └── tasks │ │ └── __init__.py └── registry.yaml ├── media.py ├── mm_utils.py ├── model ├── __init__.py ├── builder.py ├── configuration_vila_u.py ├── language_model │ ├── builder.py │ └── vila_u_llama.py ├── multimodal_encoder │ ├── builder.py │ ├── rqvaesigliptransformer │ │ ├── __init__.py │ │ ├── configuration_rqvaesigliptransformer.py │ │ ├── modeling_rqvaesigliptransformer.py │ │ ├── rqtransformer │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── configuration_rqtransformer.py │ │ │ └── modeling_rqtransformer.py │ │ └── rqvaesiglip │ │ │ ├── __init__.py │ │ │ ├── configuration_rqvaesiglip.py │ │ │ ├── modeling_rqvaesiglip.py │ │ │ ├── modules.py │ │ │ ├── quantizations.py │ │ │ └── siglip │ │ │ ├── __init__.py │ │ │ ├── configuration_siglip.py │ │ │ ├── convert_siglip_to_hf.py │ │ │ ├── image_processing_siglip.py │ │ │ ├── modeling_siglip.py │ │ │ ├── processing_siglip.py │ │ │ └── tokenization_siglip.py │ └── rqvaesigliptransformer_encoder.py ├── multimodal_projector │ ├── base_projector.py │ └── builder.py ├── utils.py └── vila_u_arch.py ├── train ├── args.py ├── callbacks │ └── autoresume_callback.py ├── train.py ├── train_mem.py ├── transformer_normalize_monkey_patch.py ├── transformers_replace │ ├── generation │ │ └── utils.py │ └── models │ │ └── llama │ │ ├── configuring_llama.py │ │ ├── modeling_llama.py │ │ └── tokenization_llama.py ├── utils.py └── vila_u_trainer.py ├── utils ├── __init__.py ├── distributed.py ├── io.py ├── logging.py ├── media.py ├── tokenizer.py └── utils.py └── wids ├── __init__.py ├── wids.py ├── wids_bench.py ├── wids_cleanup.py ├── wids_dir.py ├── wids_dl.py ├── wids_index.py ├── wids_lru.py ├── wids_mmtar.py ├── wids_specs.py └── wids_tar.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VILA-U: a Unified Foundation Model Integrating Visual Understanding and Generation 2 | 3 | \[[Online Demo](https://vila-u.hanlab.ai)\] \[[Paper](https://arxiv.org/abs/2409.04429#)\] \[[Project](https://hanlab.mit.edu/projects/vila-u)\] \[[Models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6)\] 4 | 5 | 6 |

7 | 8 |

9 | 10 |

11 | Figure 1: Multi-token in, Multi-token out Training and Inference. 12 |

13 | 14 |

15 | 16 |

17 | 18 |

19 | Figure 2: Unified Foundation Vision Tower. 20 |

21 | 22 | 23 | ## News 24 | 25 | - \[2025/01\] 🎉 VILA-U has been accepted to ICLR2025! 26 | - \[2024/10\] Online demo of VILA-U is available: [https://vila-u.hanlab.ai](https://vila-u.hanlab.ai). Have a try! 27 | - \[2024/10\] We release the code and [models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6) for VILA-U! 28 | 29 | ## Abstract 30 | 31 | **VILA-U** is a **U**nified foundation model that integrates **V**ideo, **I**mage, **La**nguage understanding and generation. Traditional visual language models (VLMs) use separate modules for understanding and generating visual content, which can lead to misalignment and increased complexity. In contrast, VILA-U employs a single autoregressive next-token prediction framework for both tasks, eliminating the need for additional components like diffusion models. This approach not only simplifies the model but also achieves near state-of-the-art performance in visual language understanding and generation. The success of VILA-U is attributed to two main factors: the unified vision tower that aligns discrete visual tokens with textual inputs during pretraining, which enhances visual perception, and autoregressive image generation can achieve similar quality as diffusion models with high-quality dataset. This allows VILA-U to perform comparably to more complex models using a fully token-based autoregressive framework. 32 | 33 | ## Preparation 34 | 35 | ### Environment Setup 36 | 37 | ```bash 38 | git clone https://github.com/mit-han-lab/vila-u 39 | cd vila-u 40 | ./environment_setup.sh vila-u 41 | ``` 42 | 43 | ### Download Models 44 | 45 | Please download our [models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6) from HuggingFace. 46 | 47 | ```bash 48 | git lfs install 49 | git clone https://huggingface.co/mit-han-lab/vila-u-7b-256 50 | ``` 51 | 52 | ## Usage 53 | 54 | ### Gradio Demo 55 | 56 | Run the following command to launch a local gradio demo: 57 | ```bash 58 | CUDA_VISIBLE_DEVICES=0 python app.py --model_path path/to/your_downloaded_model 59 | ``` 60 | 61 | ### Command Line Inference 62 | 63 | ```bash 64 | # Image Understanding 65 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --image_path assets/example_image1.jpg --query "Can you describe what is happening?" 66 | ``` 67 | 68 | ```bash 69 | # Video Understanding 70 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --video_path assets/example_video1.mp4 --query "Elaborate on the visual and narrative elements of the video in detail." 71 | ``` 72 | 73 | ```bash 74 | # Image Generation 75 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --prompt "A snowy mountain." --save_path path/to/save_images --generation_nums 8 76 | ``` 77 | 78 | ```bash 79 | # Video Generation 80 | CUDA_VISIBLE_DEVICES=0 python inference.py --model_path path/to/your_downloaded_model --prompt "Fireworks in the air." --video_generation True --save_path path/to/save_videos 81 | ``` 82 | 83 | ### Evaluation 84 | 85 | Evaluate VILA-U on visual language benchmarks with the following command: 86 | ```bash 87 | vila_u-eval -m path/to/model -c vicuna_v1 -ti local 88 | ``` 89 | Please refer to `vila_u/cli/eval.py` for more argument details. 90 | 91 | ### Training 92 | 93 | Note: Please prepare data before training. Data preparation details are in the file `vila_u/data/datasets_mixture.py`. 94 | 95 | ```bash 96 | # Pretrain 97 | srun -p your_slurm_partition -N 8 -t 04:00:00 -A your_slurm_account -J vila-u:pretrain --gpus-per-node 8 --exclusive --dependency singleton bash scripts/train/pretrain.sh & 98 | ``` 99 | 100 | ```bash 101 | # SFT 102 | srun -p your_slurm_partition -N 8 -t 04:00:00 -A your_slurm_account -J vila-u:sft --gpus-per-node 8 --exclusive --dependency singleton bash scripts/train/sft.sh & 103 | ``` 104 | 105 | ## Acknowledgment 106 | 107 | We thank Zhijian Liu from NVIDIA for his assistance with the evaluation setup. 108 | 109 | ## Citation 110 | 111 | If you find VILA-U useful or relevant to your project and research, please kindly cite our paper: 112 | 113 | ```bibtex 114 | @article{wu2024vila, 115 | title={Vila-u: a unified foundation model integrating visual understanding and generation}, 116 | author={Wu, Yecheng and Zhang, Zhuoyang and Chen, Junyu and Tang, Haotian and Li, Dacheng and Fang, Yunhao and Zhu, Ligeng and Xie, Enze and Yin, Hongxu and Yi, Li and others}, 117 | journal={arXiv preprint arXiv:2409.04429}, 118 | year={2024} 119 | } 120 | ``` 121 | 122 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import gradio as gr 4 | import imghdr 5 | import numpy as np 6 | import os 7 | import shutil 8 | import signal 9 | import sys 10 | import torch 11 | import uuid 12 | import vila_u 13 | 14 | CFG = 3.0 15 | TEMPERATURE = 0.9 16 | TOP_P = 0.6 17 | 18 | 19 | def is_image_file(filepath): 20 | return imghdr.what(filepath) is not None 21 | 22 | 23 | def generate_response(image, video, query, chat_history): 24 | if query is not None and image is None and video is None: 25 | response = model.generate_image_content(prompt=query, cfg=CFG)[0] 26 | out = response.permute(1, 2, 0) 27 | out = out.cpu().numpy().astype(np.uint8) 28 | out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) 29 | image_filename = f"{uuid.uuid4()}.png" 30 | image_path = os.path.join(temp_dir, image_filename) 31 | cv2.imwrite(image_path, out) 32 | 33 | return chat_history + [(query, "Here is the image generated:"), (None, (image_path,))] 34 | elif image is not None: 35 | generation_config = model.default_generation_config 36 | generation_config.temperature = TEMPERATURE 37 | generation_config.top_p = TOP_P 38 | answer = model.generate_content([vila_u.Image(image), query], generation_config) 39 | media_display = image 40 | elif video is not None: 41 | generation_config = model.default_generation_config 42 | generation_config.temperature = TEMPERATURE 43 | generation_config.top_p = TOP_P 44 | answer = model.generate_content([vila_u.Video(video), query], generation_config) 45 | media_display = video 46 | else: 47 | return chat_history + [(None, "No input!")] 48 | 49 | return chat_history + [((media_display,), None), (query, answer)] 50 | 51 | 52 | def clear_chat(): 53 | return None, None, None, [] 54 | 55 | 56 | def regenerate_last_answer(chat_history): 57 | if len(chat_history) < 1: 58 | return chat_history 59 | 60 | last_query, last_answer = chat_history[-1] 61 | if last_query is None: 62 | if last_answer == "No input!": 63 | return chat_history 64 | else: 65 | return generate_response(None, None, chat_history[-2][0], chat_history[:-2]) 66 | else: 67 | last_media = chat_history[-2][0][0] 68 | if is_image_file(last_media): 69 | return generate_response(last_media, None, last_query, chat_history[:-2]) 70 | else: 71 | return generate_response(None, last_media, last_query, chat_history[:-2]) 72 | 73 | 74 | def cleanup(): 75 | if os.path.exists(temp_dir): 76 | shutil.rmtree(temp_dir) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--model_path", type=str) 82 | args = parser.parse_args() 83 | 84 | if torch.cuda.is_available(): 85 | device = 'cuda' 86 | else: 87 | raise ValueError("CUDA is not available on this machine. Please use a CUDA-enabled machine to run this demo.") 88 | model = vila_u.load(args.model_path).to(device) 89 | 90 | temp_dir = 'temp/' 91 | os.makedirs(temp_dir, exist_ok=True) 92 | 93 | signal.signal(signal.SIGINT, lambda s, f: (cleanup(), sys.exit())) 94 | 95 | with gr.Blocks(title='VILA-U') as demo: 96 | gr.Markdown("# VILA-U: a Unified Foundation Model Integrating Visual Understanding and Generation") 97 | websites = ( 98 | """ 99 | [[Paper](https://arxiv.org/abs/2409.04429)] 100 | [[Project](https://hanlab.mit.edu/projects/vila-u)] 101 | [[GitHub](https://github.com/mit-han-lab/vila-u)] 102 | [[Models](https://huggingface.co/collections/mit-han-lab/vila-u-7b-6716f7dd5331e4bdf944ffa6)] 103 | """ 104 | ) 105 | gr.Markdown(websites) 106 | 107 | with gr.Row(): 108 | with gr.Column(scale=2): 109 | image_input = gr.Image(label="Upload Image", type="filepath") 110 | video_input = gr.Video(label="Upload Video", type="filepath") 111 | 112 | with gr.Column(scale=4): 113 | output_container = gr.Chatbot( 114 | label="VILA-U Chatbot", 115 | height=400, 116 | layout="panel", 117 | ) 118 | 119 | with gr.Row(): 120 | question_input = gr.Textbox(show_label=False, \ 121 | placeholder="Submit a question along with visual input, or provide an image generation prompt alone.", container=False, scale=6) 122 | 123 | submit_button = gr.Button("Submit", variant="primary", scale=1) 124 | clear_button = gr.Button(value="🗑️ Clear", scale=1) 125 | retry_button = gr.Button(value="🔄 Retry", scale=1) 126 | 127 | with gr.Row(): 128 | gr.Examples(examples=[ 129 | ["assets/example_image1.jpg", "Can you describe what is happening?"], 130 | ["assets/example_image2.jpg", "What is the brand of the silver car in the image?"], 131 | ], inputs=[image_input, question_input], cache_examples=False, label="Image Understanding Examples.") 132 | 133 | gr.Examples(examples=[ 134 | ["assets/example_video1.mp4", "Elaborate on the visual and narrative elements of the video in detail."], 135 | ["assets/example_video2.mp4", "What is the man putting on the plate?"], 136 | ], inputs=[video_input, question_input], cache_examples=False, label="Video Understanding Examples.") 137 | 138 | gr.Examples(examples=[ 139 | ["An elephant walking in the water."], 140 | ["A melting apple."], 141 | ["An astronaut riding a horse on the moon, oil painting by Van Gogh."], 142 | ["New England fall with leaves, house and river."], 143 | ["An old man with white beard."], 144 | ["A crystal tree shimmering under a starry sky."], 145 | ["A deep forest clearing with a mirrored pond reflecting a galaxyfilled night sky."], 146 | ["Happy dreamy owl monster sitting on a tree branch, colorful glittering particles, forest background, detailed feathers."] 147 | ], inputs=[question_input], cache_examples=False, label="Image Generation Examples.") 148 | 149 | submit_button.click(generate_response, inputs=[image_input, video_input, question_input, output_container], outputs=output_container) 150 | clear_button.click(clear_chat, outputs=[image_input, video_input, question_input, output_container]) 151 | retry_button.click(regenerate_last_answer, inputs=output_container, outputs=output_container) 152 | 153 | try: 154 | demo.launch(share=True) 155 | finally: 156 | cleanup() -------------------------------------------------------------------------------- /assets/example_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_image1.jpg -------------------------------------------------------------------------------- /assets/example_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_image2.jpg -------------------------------------------------------------------------------- /assets/example_video1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_video1.mp4 -------------------------------------------------------------------------------- /assets/example_video2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/example_video2.mp4 -------------------------------------------------------------------------------- /assets/method1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/method1.jpg -------------------------------------------------------------------------------- /assets/method2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/assets/method2.jpg -------------------------------------------------------------------------------- /environment_setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | CONDA_ENV=${1:-""} 5 | if [ -n "$CONDA_ENV" ]; then 6 | # This is required to activate conda environment 7 | eval "$(conda shell.bash hook)" 8 | 9 | conda create -n $CONDA_ENV python=3.10.14 -y 10 | conda activate $CONDA_ENV 11 | # This is optional if you prefer to use built-in nvcc 12 | conda install -c nvidia cuda-toolkit -y 13 | else 14 | echo "Skipping conda environment creation. Make sure you have the correct environment activated." 15 | fi 16 | 17 | # This is required to enable PEP 660 support 18 | pip install --upgrade pip setuptools 19 | 20 | # Install FlashAttention2 21 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 22 | 23 | # Install VILA 24 | pip install -e ".[train,eval]" 25 | 26 | pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git 27 | 28 | pip install git+https://github.com/huggingface/transformers@v4.36.2 29 | 30 | # Replace transformers and deepspeed files 31 | site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])') 32 | cp -rv ./vila_u/train/transformers_replace/* $site_pkg_path/transformers/ 33 | # Avoid confused warning 34 | rm -rf $site_pkg_path/lmms_eval/models/mplug_owl_video/modeling_mplug_owl.py -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import vila_u 6 | 7 | 8 | def save_image(response, path): 9 | os.makedirs(path, exist_ok=True) 10 | for i in range(response.shape[0]): 11 | image = response[i].permute(1, 2, 0) 12 | image = image.cpu().numpy().astype(np.uint8) 13 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 14 | cv2.imwrite(os.path.join(path, f"image_{i}.png"), image) 15 | 16 | 17 | def save_video(response, path): 18 | os.makedirs(path, exist_ok=True) 19 | for i in range(response.shape[0]): 20 | video = response[i].permute(0, 2, 3, 1) 21 | video = video.cpu().numpy().astype(np.uint8) 22 | video = np.concatenate(video, axis=1) 23 | video = cv2.cvtColor(video, cv2.COLOR_RGB2BGR) 24 | cv2.imwrite(os.path.join(path, f"video_{i}.png"), video) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--model_path", type=str, required=True) 30 | ### image/video understanding arguments 31 | parser.add_argument("--image_path", type=str, default=None) 32 | parser.add_argument("--video_path", type=str, default=None) 33 | parser.add_argument("--query", type=str, default=None) 34 | parser.add_argument("--temperature", type=float, default=0.9, help="The value of temperature for text generation.") 35 | parser.add_argument("--top_p", type=float, default=0.6, help="The value of top-p for text generation.") 36 | ### image and video generation arguments 37 | parser.add_argument("--prompt", type=str, default=None) 38 | parser.add_argument("--video_generation", type=bool, default=False) 39 | parser.add_argument("--cfg", type=float, default=3.0, help="The value of the classifier free guidance for image generation.") 40 | parser.add_argument("--save_path", type=str, default="generated_images/") 41 | parser.add_argument("--generation_nums", type=int, default=1) 42 | args = parser.parse_args() 43 | 44 | if args.model_path is not None: 45 | model = vila_u.load(args.model_path) 46 | else: 47 | raise ValueError("No model path provided!") 48 | 49 | if args.query is not None: 50 | generation_config = model.default_generation_config 51 | generation_config.temperature = args.temperature 52 | generation_config.top_p = args.top_p 53 | if args.image_path is not None: 54 | image = vila_u.Image(args.image_path) 55 | response = model.generate_content([image, args.query]) 56 | print("\033[1;32mResponse:\033[0m", response) 57 | exit() 58 | elif args.video_path is not None: 59 | video = vila_u.Video(args.video_path) 60 | response = model.generate_content([video, args.query]) 61 | print("\033[1;32mResponse:\033[0m", response) 62 | exit() 63 | else: 64 | raise ValueError("No visual content input!") 65 | elif args.prompt is not None: 66 | if args.video_generation: 67 | response = model.generate_video_content(args.prompt, args.cfg, args.generation_nums) 68 | save_video(response, args.save_path) 69 | exit() 70 | else: 71 | response = model.generate_image_content(args.prompt, args.cfg, args.generation_nums) 72 | save_image(response, args.save_path) 73 | exit() 74 | else: 75 | raise ValueError("No query or prompt provided!") -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vila-u" 7 | version = "1.0.0" 8 | description = "VILA-U: a Unified Foundation Model Integrating Visual Understanding and Generation" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.3.0", "torchvision==0.18.0", 17 | "reka-api", "google-generativeai", "anthropic", 18 | "tokenizers>=0.15.2", "sentencepiece==0.1.99", "shortuuid", 19 | "accelerate==0.34.2", "peft>=0.9.0", "bitsandbytes==0.41.0", 20 | "pydantic<2,>=1", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2", 21 | "gradio==3.35.2", "gradio_client==0.2.9", 22 | "requests", "httpx", "uvicorn", "fastapi", "fire", 23 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.12", 24 | "openpyxl==3.1.2", "pytorchvideo==0.1.5", "decord==0.6.0", 25 | "datasets==2.16.1", "openai==1.8.0", "webdataset==0.2.86", 26 | "nltk==3.3", "pywsd==1.2.4", "opencv-python-headless==4.8.0.76", 27 | "tyro", "pytest", "pre-commit", "loguru", "hydra-core" 28 | ] 29 | 30 | [project.scripts] 31 | vila_u-eval = "vila_u.cli.eval:main" 32 | 33 | [project.optional-dependencies] 34 | train = ["deepspeed==0.9.5", "ninja", "wandb"] 35 | eval = ["mmengine", "word2number", "Levenshtein", "nltk", "pywsd"] 36 | 37 | [project.urls] 38 | "Homepage" = "https://github.com/mit-han-lab/vila-u" 39 | 40 | [tool.setuptools.packages.find] 41 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 42 | 43 | [tool.wheel] 44 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 45 | -------------------------------------------------------------------------------- /scripts/eval/lmms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | TASK=$1 6 | MODEL_PATH=$2 7 | CONV_MODE=$3 8 | if [[ "$TASK" =~ videomme ]]; then 9 | NUM_VIDEO_FRAMES=$(echo "$TASK" | cut -d'-' -f2-) 10 | IFS='-' read -ra segments <<< "$TASK" 11 | unset segments[${#segments[@]}-1] 12 | TASK=$(IFS=-; echo "${segments[*]}") 13 | else 14 | NUM_VIDEO_FRAMES=8 15 | fi 16 | 17 | MODEL_NAME=$(basename $MODEL_PATH) 18 | OUTPUT_DIR=${OUTPUT_DIR:-"runs/eval/$MODEL_NAME/lmms-$TASK"} 19 | 20 | NPROC_PER_NODE=${NPROC_PER_NODE:-$(nvidia-smi -L | wc -l)} 21 | 22 | export LMMS_EVAL_PLUGINS=vila_u.eval.lmms 23 | export HF_HOME=$HOME/.cache/huggingface 24 | export CACHE_DIR=$OUTPUT_DIR/cache 25 | 26 | torchrun --nproc_per_node=$NPROC_PER_NODE \ 27 | -m lmms_eval \ 28 | --model vila_u \ 29 | --model_args model_path=$MODEL_PATH,conv_mode=$CONV_MODE,num_video_frames=$NUM_VIDEO_FRAMES \ 30 | --tasks $TASK \ 31 | --log_samples \ 32 | --output_path $OUTPUT_DIR 33 | 34 | mv $OUTPUT_DIR/*_$MODEL_NAME/*_results.json $OUTPUT_DIR/results.json || true 35 | mv $OUTPUT_DIR/*_$MODEL_NAME/*_samples_*.jsonl $OUTPUT_DIR/samples.jsonl || true 36 | mv $OUTPUT_DIR/*_$MODEL_NAME/* $OUTPUT_DIR || true 37 | rm -r $OUTPUT_DIR/*_$MODEL_NAME || true 38 | 39 | mv $OUTPUT_DIR/*_vila_u_*/* $OUTPUT_DIR || true 40 | rm -r $OUTPUT_DIR/*_vila_u_* || true 41 | rm -r $OUTPUT_DIR/rank*_metric_eval_done.txt || true 42 | -------------------------------------------------------------------------------- /scripts/train/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_IB_SL=1 4 | export CUDA_DEVICE_MAX_CONNECTIONS=1 5 | export NCCL_ASYNC_ERROR_HANDLING=1 6 | 7 | source activate vila-u 8 | 9 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | export MASTER_ADDR=${master_addr:-"127.0.0.1"} 11 | export CURRENT_RANK=${SLURM_PROCID:-"0"} 12 | worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') 13 | n_node=${SLURM_JOB_NUM_NODES:-1} 14 | 15 | echo "MASTER_ADDR="$MASTER_ADDR 16 | echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" 17 | 18 | global_bs=${BATCH_SIZE:-512} 19 | acc_step=${ACC_STEP:-1} 20 | bs=$((global_bs / n_node / acc_step)) 21 | 22 | echo "PER_DEVICE_TRAIN_BATCH_SIZE="$bs 23 | 24 | torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ 25 | --master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \ 26 | vila_u/train/train_mem.py \ 27 | --deepspeed ./scripts/zero2.json \ 28 | --model_name_or_path meta-llama/Llama-2-7b \ 29 | --version v1 \ 30 | --data_mixture sharegpt4v_pretrain+mmc4core+internal_generation+openvid_generation \ 31 | --vision_tower path/to/prepared_vision_tower \ 32 | --chunk_sampler True \ 33 | --mm_projector mlp2x_gelu \ 34 | --tune_mm_projector True \ 35 | --tune_language_model True \ 36 | --mm_vision_select_layer -2 \ 37 | --mm_use_im_start_end True \ 38 | --mm_use_vi_start_end True \ 39 | --mm_use_im_patch_token False \ 40 | --image_aspect_ratio resize \ 41 | --bf16 True \ 42 | --output_dir ./checkpoints/vila-u-pretrain \ 43 | --num_train_epochs 1 \ 44 | --per_device_train_batch_size $bs \ 45 | --per_device_eval_batch_size 4 \ 46 | --gradient_accumulation_steps $acc_step \ 47 | --evaluation_strategy "no" \ 48 | --save_strategy "steps" \ 49 | --save_steps 100 \ 50 | --save_total_limit 1 \ 51 | --learning_rate 5e-5 \ 52 | --weight_decay 0. \ 53 | --warmup_ratio 0.03 \ 54 | --lr_scheduler_type "cosine" \ 55 | --logging_steps 1 \ 56 | --tf32 True \ 57 | --model_max_length 8192 \ 58 | --gradient_checkpointing True \ 59 | --dataloader_num_workers 8 \ 60 | --lazy_preprocess True \ 61 | --report_to wandb \ -------------------------------------------------------------------------------- /scripts/train/sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NCCL_IB_SL=1 4 | export CUDA_DEVICE_MAX_CONNECTIONS=1 5 | export NCCL_ASYNC_ERROR_HANDLING=1 6 | 7 | source activate vila-u 8 | 9 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | export MASTER_ADDR=${master_addr:-"127.0.0.1"} 11 | export CURRENT_RANK=${SLURM_PROCID:-"0"} 12 | worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') 13 | n_node=${SLURM_JOB_NUM_NODES:-1} 14 | 15 | echo "MASTER_ADDR="$MASTER_ADDR 16 | echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" 17 | 18 | global_bs=${BATCH_SIZE:-512} 19 | acc_step=${ACC_STEP:-1} 20 | bs=$((global_bs / n_node / acc_step)) 21 | 22 | echo "PER_DEVICE_TRAIN_BATCH_SIZE="$bs 23 | 24 | torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ 25 | --master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \ 26 | vila_u/train/train_mem.py \ 27 | --deepspeed ./scripts/zero2.json \ 28 | --model_name_or_path path/to/vila-u-pretrain \ 29 | --version v1 \ 30 | --data_mixture sharegpt4v_sft+vflan+shot2story_shotonly+video_chatgpt+youcook2+vatex+sharegpt_video+scienceqa+wit_subset+sherlock+internal_generation+openvid_generation \ 31 | --chunk_sampler True \ 32 | --mm_projector mlp2x_gelu \ 33 | --tune_mm_projector True \ 34 | --tune_language_model True \ 35 | --mm_vision_select_layer -2 \ 36 | --mm_use_im_start_end True \ 37 | --mm_use_vi_start_end True \ 38 | --mm_use_im_patch_token False \ 39 | --image_aspect_ratio resize \ 40 | --bf16 True \ 41 | --output_dir ./checkpoints/vila-u-sft \ 42 | --num_train_epochs 1 \ 43 | --per_device_train_batch_size $bs \ 44 | --per_device_eval_batch_size 4 \ 45 | --gradient_accumulation_steps $acc_step \ 46 | --evaluation_strategy "no" \ 47 | --save_strategy "steps" \ 48 | --save_steps 100 \ 49 | --save_total_limit 1 \ 50 | --learning_rate 1e-4 \ 51 | --weight_decay 0. \ 52 | --warmup_ratio 0.03 \ 53 | --lr_scheduler_type "cosine" \ 54 | --logging_steps 1 \ 55 | --tf32 True \ 56 | --model_max_length 8192 \ 57 | --gradient_checkpointing True \ 58 | --dataloader_num_workers 4 \ 59 | --lazy_preprocess True \ 60 | --vflan_no_system_prompt True \ 61 | --report_to wandb \ -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /vila_u/__init__.py: -------------------------------------------------------------------------------- 1 | from .entry import * 2 | from .media import * -------------------------------------------------------------------------------- /vila_u/cli/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import subprocess 4 | import time 5 | 6 | from argparse import ArgumentParser 7 | from collections import deque 8 | from tabulate import tabulate 9 | from typing import Dict, List, Optional 10 | 11 | from vila_u.eval import EVAL_ROOT, TASKS 12 | from vila_u.utils import io 13 | from vila_u.utils.logging import logger 14 | 15 | 16 | def lstr(s: Optional[str]) -> Optional[List[str]]: 17 | if s is not None: 18 | s = s.split(",") if "," in s else [s] 19 | return s 20 | 21 | 22 | def _load_results(output_dir: str, task: str) -> Optional[Dict]: 23 | for fname in ["results.json", "metrics.json"]: 24 | if os.path.exists(os.path.join(output_dir, task, fname)): 25 | return io.load(os.path.join(output_dir, task, fname)) 26 | return None 27 | 28 | 29 | def main() -> None: 30 | parser = ArgumentParser() 31 | parser.add_argument("--model-path", "-m", type=str, required=True) 32 | parser.add_argument("--model-name", type=str, default=None) 33 | parser.add_argument("--conv-mode", "-c", type=str, required=True) 34 | parser.add_argument("--nproc-per-node", "-n", type=int, default=8) 35 | parser.add_argument("--tasks", "-t", type=lstr) 36 | parser.add_argument("--tags-include", "-ti", type=lstr) 37 | parser.add_argument("--tags-exclude", "-te", type=lstr) 38 | parser.add_argument("--num_video_frames", "-nf", type=str, default="8") 39 | parser.add_argument("--output-dir", type=str, default=None) 40 | parser.add_argument("--report-to", "-r", choices=["wandb", None], default=None) 41 | args = parser.parse_args() 42 | 43 | # Get the model name and output directory 44 | model_name = os.path.basename(os.path.normpath(args.model_path)).lower() 45 | if args.model_name is not None: 46 | model_name = args.model_name 47 | output_dir = os.path.join("runs", "eval", model_name) 48 | num_video_frames = args.num_video_frames 49 | if args.output_dir is not None: 50 | output_dir = osp.expanduser(args.output_dir) 51 | 52 | # Filter tasks based on name and tags 53 | tasks = [] 54 | for task, metainfo in TASKS.items(): 55 | tags = set(metainfo.get("tags", [])) 56 | if args.tasks is not None and task not in args.tasks: 57 | continue 58 | if args.tags_include is not None and tags.isdisjoint(args.tags_include): 59 | continue 60 | if args.tags_exclude is not None and tags.intersection(args.tags_exclude): 61 | continue 62 | tasks.append(task) 63 | logger.info(f"Running evaluation for '{model_name}' on {len(tasks)} tasks: {tasks}") 64 | 65 | # Prepare the evaluation commands 66 | cmds = {} 67 | for task in tasks: 68 | if _load_results(output_dir, task=task): 69 | logger.warning(f"Skipping '{task}' as it has already been evaluated.") 70 | continue 71 | 72 | cmd = [] 73 | if task.startswith("lmms-"): 74 | cmd += [f"{EVAL_ROOT}/lmms.sh", task.replace("lmms-", "")] 75 | elif "_" in task: 76 | name, split = task.split("_") 77 | cmd += [f"{EVAL_ROOT}/{name}.sh", split] 78 | else: 79 | cmd += [f"{EVAL_ROOT}/{task}.sh"] 80 | cmd += [args.model_path, args.conv_mode] 81 | 82 | concurrency = 1 83 | final_cmd = cmd 84 | cmds[task] = " ".join(final_cmd) 85 | 86 | # Prepare the environment variables 87 | env = os.environ.copy() 88 | env["NPROC_PER_NODE"] = str(args.nproc_per_node) 89 | 90 | # Run the commands with the specified concurrency 91 | remaining = deque(cmds.keys()) 92 | processes, returncodes = {}, {} 93 | try: 94 | while remaining or processes: 95 | while remaining and len(processes) < concurrency: 96 | task = remaining.popleft() 97 | logger.info(f"Running '{cmds[task]}'") 98 | processes[task] = subprocess.Popen( 99 | cmds[task], 100 | stdout=subprocess.DEVNULL if concurrency > 1 else None, 101 | stderr=subprocess.DEVNULL if concurrency > 1 else None, 102 | shell=True, 103 | env=env, 104 | ) 105 | 106 | for task, process in processes.items(): 107 | if process.poll() is not None: 108 | returncodes[task] = process.returncode 109 | processes.pop(task) 110 | break 111 | 112 | time.sleep(1) 113 | except KeyboardInterrupt: 114 | logger.warning("Terminating all processes...") 115 | for _, process in processes.items(): 116 | process.terminate() 117 | for _, process in processes.items(): 118 | process.wait() 119 | 120 | final_return_code = 0 121 | # Check the return codes 122 | for task, returncode in returncodes.items(): 123 | if returncode != 0: 124 | logger.error(f"Error running '{task}' evaluation (return code: {returncode})") 125 | final_return_code = returncode 126 | 127 | if args.report_to == "wandb": 128 | import wandb 129 | 130 | wandb_project = os.environ.get("WANDB_PROJECT", "vila-u-eval") 131 | wandb_entity = os.environ.get("WANDB_ENTITY", None) 132 | wandb_name = os.environ.get("WANDB_NAME", model_name) 133 | logger.info(f"initiating wandb run for '{wandb_project}/{wandb_name}'") 134 | wandb.init( 135 | project=wandb_project, 136 | entity=wandb_entity, 137 | name=wandb_name, 138 | config={ 139 | "model_path": args.model_path, 140 | "conv_mode": args.conv_mode, 141 | }, 142 | ) 143 | 144 | # Collect the results and save them 145 | metrics = {} 146 | for task in tasks: 147 | results = _load_results(output_dir, task=task) 148 | if results is None: 149 | continue 150 | for name, path in TASKS[task].get("metrics", {}).items(): 151 | val = results 152 | for key in path.split("/") if "/" in path else [path]: 153 | val = val[key] 154 | metrics[f"{task}/{name}"] = val 155 | 156 | if args.report_to == "wandb": 157 | logger.info(f"Logging '{task}/{name}' to wandb") 158 | wandb.log({f"{task}/{name}": val}) 159 | io.save(os.path.join(output_dir, "metrics.json"), metrics, indent=4) 160 | logger.info(f"Saved all metrics to '{output_dir}/metrics.json'") 161 | if args.report_to == "wandb": 162 | logger.info(f"Saved wandb url to '{output_dir}/wandb.txt'") 163 | io.save(os.path.join(output_dir, "wandb.txt"), wandb.run.get_url()) 164 | 165 | # Print the metrics in a tabular format 166 | logger.info("Results:\n" + tabulate(metrics.items(), tablefmt="simple_outline", headers=["Metric", "Value"])) 167 | 168 | return final_return_code 169 | 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /vila_u/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | IGNORE_INDEX = -100 7 | IMAGE_TOKEN_INDEX = -200 8 | DEFAULT_IMAGE_TOKEN = "" 9 | DEFAULT_IMAGE_PATCH_TOKEN = "" 10 | DEFAULT_IM_START_TOKEN = "" 11 | DEFAULT_IM_END_TOKEN = "" 12 | DEFAULT_VI_START_TOKEN = "" 13 | DEFAULT_VI_END_TOKEN = "" 14 | IMAGE_PLACEHOLDER = "" 15 | 16 | SENTINEL_TOKEN = "" -------------------------------------------------------------------------------- /vila_u/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from enum import auto, Enum 4 | from typing import List, Tuple 5 | 6 | from vila_u.utils.logging import logger 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | SINGLE = auto() 12 | TWO = auto() 13 | 14 | 15 | @dataclasses.dataclass 16 | class Conversation: 17 | """A class that keeps all conversation history.""" 18 | system: str 19 | roles: List[str] 20 | messages: List[List[str]] 21 | offset: int 22 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 23 | sep: str = "###" 24 | sep2: str = None 25 | version: str = "Unknown" 26 | 27 | skip_next: bool = False 28 | 29 | def get_prompt(self): 30 | messages = self.messages 31 | if len(messages) > 0 and type(messages[0][1]) is tuple: 32 | messages = self.messages.copy() 33 | init_role, init_msg = messages[0].copy() 34 | init_msg = init_msg[0].replace("", "").strip() 35 | messages[0] = (init_role, "\n" + init_msg) 36 | 37 | if self.sep_style == SeparatorStyle.SINGLE: 38 | ret = self.system + self.sep 39 | for role, message in messages: 40 | if message: 41 | if type(message) is tuple: 42 | message, _, _ = message 43 | ret += role + ": " + message + self.sep 44 | else: 45 | ret += role + ":" 46 | elif self.sep_style == SeparatorStyle.TWO: 47 | seps = [self.sep, self.sep2] 48 | ret = self.system + seps[0] 49 | for i, (role, message) in enumerate(messages): 50 | if message: 51 | if type(message) is tuple: 52 | message, _, _ = message 53 | ret += role + ": " + message + seps[i % 2] 54 | else: 55 | ret += role + ":" 56 | else: 57 | raise ValueError(f"Invalid style: {self.sep_style}") 58 | 59 | return ret 60 | 61 | def append_message(self, role, message): 62 | self.messages.append([role, message]) 63 | 64 | def copy(self): 65 | return Conversation( 66 | system=self.system, 67 | roles=self.roles, 68 | messages=[[x, y] for x, y in self.messages], 69 | offset=self.offset, 70 | sep_style=self.sep_style, 71 | sep=self.sep, 72 | sep2=self.sep2, 73 | version=self.version) 74 | 75 | conv_vicuna_v0 = Conversation( 76 | system="A chat between a curious human and an artificial intelligence assistant. " 77 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 78 | roles=("Human", "Assistant"), 79 | messages=( 80 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 81 | ("Assistant", 82 | "Renewable energy sources are those that can be replenished naturally in a relatively " 83 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 84 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 85 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 86 | "renewable and non-renewable energy sources:\n" 87 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 88 | "energy sources are finite and will eventually run out.\n" 89 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 90 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 91 | "and other negative effects.\n" 92 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 93 | "have lower operational costs than non-renewable sources.\n" 94 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 95 | "locations than non-renewable sources.\n" 96 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 97 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 98 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 99 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 100 | ), 101 | offset=2, 102 | sep_style=SeparatorStyle.SINGLE, 103 | sep="###", 104 | ) 105 | 106 | conv_vicuna_v1 = Conversation( 107 | system="A chat between a curious user and an artificial intelligence assistant. " 108 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 109 | roles=("USER", "ASSISTANT"), 110 | version="v1", 111 | messages=(), 112 | offset=0, 113 | sep_style=SeparatorStyle.TWO, 114 | sep=" ", 115 | sep2="", 116 | ) 117 | 118 | default_conversation = conv_vicuna_v1 119 | 120 | conv_templates = { 121 | "v0": conv_vicuna_v0, 122 | "v1": conv_vicuna_v1, 123 | "vicuna_v1": conv_vicuna_v1, 124 | } 125 | 126 | CONVERSATION_MODE_MAPPING = { 127 | "vila-u-7b-256": "vicuna_v1", 128 | "vila-u-7b-384": "vicuna_v1", 129 | } 130 | 131 | def auto_set_conversation_mode(model_name_or_path: str) -> str: 132 | global default_conversation 133 | for k, v in CONVERSATION_MODE_MAPPING.items(): 134 | if k in model_name_or_path.lower(): 135 | logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.") 136 | default_conversation = conv_templates[v] 137 | return 138 | -------------------------------------------------------------------------------- /vila_u/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .datasets_mixture import * 3 | from .simple_vila_webdataset import VILAWebDataset 4 | -------------------------------------------------------------------------------- /vila_u/data/datasets_mixture.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from dataclasses import dataclass, field 4 | 5 | 6 | @dataclass 7 | class Dataset: 8 | dataset_name: str 9 | dataset_type: str = field(default="torch") 10 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 11 | meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."}) 12 | image_path: str = field(default=None, metadata={"help": "Path to the training image data."}) 13 | description: str = field( 14 | default=None, 15 | metadata={ 16 | "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset." 17 | }, 18 | ) 19 | test_script: str = (None,) 20 | maintainer: str = (None,) 21 | 22 | 23 | DATASETS = {} 24 | 25 | 26 | def add_dataset(dataset): 27 | if dataset.dataset_name in DATASETS: 28 | warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.") 29 | assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'." 30 | 31 | DATASETS.update({dataset.dataset_name: dataset}) 32 | 33 | 34 | def register_datasets_mixtures(): 35 | internal_generation = Dataset( 36 | dataset_name="internal_generation", 37 | dataset_type="internal-generation", 38 | data_path="", 39 | meta_path="", 40 | ) 41 | add_dataset(internal_generation) 42 | 43 | 44 | # Please download data from https://github.com/NJU-PCALab/OpenVid-1M to prepare openvid_generation. 45 | openvid_generation = Dataset( 46 | dataset_name="openvid_generation", 47 | dataset_type="openvid-generation", 48 | data_path="", 49 | ) 50 | add_dataset(openvid_generation) 51 | 52 | 53 | # Please download data from https://sharegpt4v.github.io/ to prepare sharegpt4v. 54 | sharegpt4v_pretrain = Dataset( 55 | dataset_name="sharegpt4v_pretrain", 56 | dataset_type="torch", 57 | data_path="", 58 | image_path="", 59 | description="Original data source: https://sharegpt4v.github.io/ ~1M long Image - Text pair generated by ShareGPT4V captioner.", 60 | ) 61 | add_dataset(sharegpt4v_pretrain) 62 | 63 | 64 | sharegpt4v_sft = Dataset( 65 | dataset_name="sharegpt4v_sft", 66 | dataset_type="torch", 67 | data_path="", 68 | image_path="", 69 | description="Original data source: https://sharegpt4v.github.io/ 655K llava_1_5_sft data relablled w/ ShareGPT4V captioner.", 70 | ) 71 | add_dataset(sharegpt4v_sft) 72 | 73 | 74 | # Please refer to https://github.com/NVlabs/VILA/tree/main/data_prepare to prepare the following datasets. 75 | mmc4core = Dataset( 76 | dataset_name="mmc4core", 77 | dataset_type="mmc4", 78 | data_path="", 79 | ) 80 | add_dataset(mmc4core) 81 | 82 | 83 | vflan = Dataset( 84 | dataset_name="vflan", 85 | dataset_type="vflan", 86 | data_path="", 87 | ) 88 | add_dataset(vflan) 89 | 90 | 91 | shot2story_shotonly = Dataset( 92 | dataset_name="shot2story_shotonly", 93 | dataset_type="torch", 94 | data_path="", 95 | image_path="", 96 | ) 97 | add_dataset(shot2story_shotonly) 98 | 99 | 100 | video_chatgpt = Dataset( 101 | dataset_name="video_chatgpt", 102 | dataset_type="torch", 103 | data_path="", 104 | image_path="", 105 | ) 106 | add_dataset(video_chatgpt) 107 | 108 | 109 | youcook2 = Dataset( 110 | dataset_name="youcook2", 111 | dataset_type="torch", 112 | data_path="", 113 | image_path="", 114 | ) 115 | add_dataset(youcook2) 116 | 117 | 118 | vatex = Dataset( 119 | dataset_name="vatex", 120 | dataset_type="torch", 121 | data_path="", 122 | image_path="", 123 | ) 124 | add_dataset(vatex) 125 | 126 | 127 | sharegpt_video = Dataset( 128 | dataset_name="sharegpt_video", 129 | dataset_type="torch", 130 | data_path="", 131 | image_path="", 132 | ) 133 | add_dataset(sharegpt_video) 134 | 135 | 136 | scienceqa = Dataset( 137 | dataset_name="scienceqa", 138 | dataset_type="torch", 139 | data_path="", 140 | image_path="", 141 | ) 142 | add_dataset(scienceqa) 143 | 144 | 145 | wit_subset = Dataset( 146 | dataset_name="wit_subset", 147 | dataset_type="torch", 148 | data_path="", 149 | image_path="" 150 | ) 151 | add_dataset(wit_subset) 152 | 153 | 154 | sherlock = Dataset( 155 | dataset_name="sherlock", 156 | dataset_type="torch", 157 | data_path="", 158 | image_path="", 159 | ) 160 | add_dataset(sherlock) -------------------------------------------------------------------------------- /vila_u/data/simple_vila_webdataset.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import hashlib 3 | import os.path as osp 4 | 5 | from functools import reduce 6 | from vila_u.wids import ShardListDataset 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class VILAWebDataset(Dataset): 11 | def __init__( 12 | self, 13 | data_path=None, 14 | meta_path=None, 15 | cache_dir=None, 16 | max_shards_to_load=None, 17 | ): 18 | self.data_path = osp.expanduser(data_path) 19 | self.meta_path = osp.expanduser(meta_path) if meta_path is not None else None 20 | 21 | _local_meta_path = osp.join(self.data_path, "wids-meta.json") 22 | if meta_path is None and osp.exists(_local_meta_path): 23 | print(f"loading from {_local_meta_path}") 24 | self.meta_path = meta_path = _local_meta_path 25 | 26 | if meta_path is None: 27 | self.meta_path = osp.join( 28 | osp.expanduser(cache_dir), 29 | self.data_path.replace("/", "--") + f".max_shards:{max_shards_to_load}" + ".wdsmeta.json", 30 | ) 31 | 32 | assert osp.exists( 33 | self.meta_path 34 | ), f"meta path not found in [{self.meta_path}] or [{_local_meta_path}]" 35 | print(f"[SimplyCoyo] Loading meta infomation {self.meta_path}", flush=True) 36 | 37 | uuid = hashlib.sha256(self.meta_path.encode()).hexdigest()[:8] 38 | self.dataset = ShardListDataset( 39 | self.meta_path, 40 | cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"), 41 | ) 42 | 43 | def __getitem__(self, idx): 44 | return self.dataset[idx] 45 | 46 | def __len__(self): 47 | return len(self.dataset) 48 | 49 | @staticmethod 50 | def simple_collate(batch): 51 | batched_data = {} 52 | for data in batch: 53 | for k, v in data.items(): 54 | if k not in batched_data: 55 | batched_data[k] = [] 56 | batched_data[k].append(v) 57 | return dict(batched_data) 58 | 59 | @staticmethod 60 | def custom_collate(batch): 61 | def transform2list(a: dict): 62 | for k, v in a.items(): 63 | if isinstance(v, dict): 64 | a[k] = transform2list(v) 65 | else: 66 | a[k] = [ 67 | v, 68 | ] 69 | return a 70 | 71 | def merge(a: dict, b: dict, path=[], strict=False): 72 | c = {} 73 | keys = set(a.keys()).union(b.keys()) 74 | for key in keys: 75 | if key in a and key in b: 76 | if isinstance(a[key], dict) and isinstance(b[key], dict): 77 | c[key] = merge(a[key], b[key], path + [str(key)], strict=strict) 78 | else: 79 | c[key] = a[key] + b[key] 80 | else: 81 | if strict: 82 | raise Exception("Conflict at " + ".".join(path + [str(key)])) 83 | c[key] = a[key] if key in a else b[key] 84 | return c 85 | 86 | tasks = (transform2list(_) for _ in batch) 87 | return reduce(merge, tasks) -------------------------------------------------------------------------------- /vila_u/entry.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import typing 4 | 5 | from typing import List, Optional 6 | 7 | if typing.TYPE_CHECKING: 8 | from transformers import PreTrainedModel 9 | else: 10 | PreTrainedModel = None 11 | 12 | from vila_u.conversation import auto_set_conversation_mode 13 | from vila_u.model.builder import load_pretrained_model 14 | 15 | __all__ = ["load"] 16 | 17 | 18 | def load( 19 | model_path: str, 20 | devices: Optional[List[int]] = None, 21 | **kwargs, 22 | ) -> PreTrainedModel: 23 | auto_set_conversation_mode(model_path) 24 | 25 | model_path = os.path.expanduser(model_path) 26 | if os.path.exists(os.path.join(model_path, "model")): 27 | model_path = os.path.join(model_path, "model") 28 | 29 | # Set `max_memory` to constrain which GPUs to use 30 | if devices is not None: 31 | assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set" 32 | kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices}) 33 | 34 | model = load_pretrained_model(model_path, **kwargs)[1] 35 | 36 | return model -------------------------------------------------------------------------------- /vila_u/eval/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from vila_u.utils import io 4 | 5 | __all__ = ["EVAL_ROOT", "TASKS"] 6 | 7 | 8 | EVAL_ROOT = "scripts/eval" 9 | TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry.yaml")) 10 | -------------------------------------------------------------------------------- /vila_u/eval/lmms/models/__init__.py: -------------------------------------------------------------------------------- 1 | AVAILABLE_MODELS = { 2 | "vila_u": "VILA_U", 3 | } 4 | -------------------------------------------------------------------------------- /vila_u/eval/lmms/models/vila_u.py: -------------------------------------------------------------------------------- 1 | import accelerate 2 | import os 3 | import requests 4 | import torch 5 | 6 | from lmms_eval.api.instance import Instance 7 | from lmms_eval.api.model import lmms 8 | from lmms_eval.api.registry import register_model 9 | from tqdm import tqdm 10 | from typing import List, Tuple 11 | 12 | import vila_u 13 | from vila_u import conversation as conversation_lib 14 | from vila_u.media import Video 15 | from vila_u.utils import distributed as dist 16 | from vila_u.utils import io 17 | 18 | 19 | @register_model("vila_u") 20 | class VILA_U(lmms): 21 | def __init__( 22 | self, model_path: str, conv_mode: str, num_video_frames: int = 8, batch_size: int = 1 23 | ) -> None: 24 | super().__init__() 25 | assert batch_size == 1, "VILA-U only supports batch size of 1 at the moment." 26 | self._update_gpt_eval_model() 27 | 28 | devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size()) 29 | torch.cuda.set_device(devices[0]) 30 | 31 | self.model = vila_u.load(model_path, devices=devices) 32 | self.model.config.num_video_frames = num_video_frames 33 | context_length = num_video_frames * 512 34 | self.model.config.model_max_length = context_length 35 | self.model.config.tokenizer_model_max_length = context_length 36 | self.model.llm.config.model_max_length = context_length 37 | self.model.llm.config.tokenizer_model_max_length = context_length 38 | self.model.tokenizer.model_max_length = context_length 39 | 40 | conversation_lib.default_conversation = conversation_lib.conv_templates[conv_mode].copy() 41 | 42 | self.accelerator = accelerate.Accelerator() 43 | self.device = torch.device(f"cuda:{devices[0]}") 44 | self._world_size = dist.size() 45 | self._rank = dist.rank() 46 | 47 | def _update_gpt_eval_model(self) -> None: 48 | _unpatched_post = requests.post 49 | 50 | def _patched_post(url, json, **kwargs): 51 | if json is not None and "model" in json: 52 | if json["model"] == "gpt-3.5-turbo-0613": 53 | json["model"] = "gpt-4o-mini" 54 | return _unpatched_post(url, json=json, **kwargs) 55 | 56 | requests.post = _patched_post 57 | 58 | def generate_until(self, requests: List[Instance]) -> List[str]: 59 | responses = [] 60 | for request in tqdm(requests, disable=not dist.is_main()): 61 | prompt, generation_kwargs, doc_to_visual, doc_id, task, split = self._patch(request.args) 62 | doc = self.task_dict[task][split][doc_id] 63 | 64 | # Generate multimodal prompt 65 | medias = [] 66 | for media in doc_to_visual(doc): 67 | if isinstance(media, str): 68 | if any(media.endswith(ext) for ext in [".mp4", ".mkv", ".webm"]): 69 | media = Video(media) 70 | else: 71 | raise NotImplementedError(f"Unsupported media type: {media}") 72 | medias.append(media) 73 | prompt = medias + [prompt] 74 | 75 | # Override generation config 76 | generation_config = self.model.default_generation_config 77 | generation_config.update(**generation_kwargs) 78 | 79 | # Generate and cache response 80 | cache_path = None 81 | if "CACHE_DIR" in os.environ: 82 | cache_path = os.path.join(os.environ["CACHE_DIR"], f"{task}_{split}_{doc_id}.txt") 83 | 84 | if cache_path is not None and os.path.exists(cache_path): 85 | response = io.load(cache_path) 86 | else: 87 | response = self.model.generate_content(prompt, generation_config=generation_config) 88 | if cache_path is not None: 89 | io.save(cache_path, response) 90 | responses.append(response) 91 | 92 | print("Prompt:", prompt) 93 | print("Response:", response) 94 | return responses 95 | 96 | def _patch(self, args: Tuple) -> Tuple: 97 | prompt, generation_kwargs, doc_to_visual, doc_id, task, split = args 98 | doc = self.task_dict[task][split][doc_id] 99 | 100 | return prompt, generation_kwargs, doc_to_visual, doc_id, task, split 101 | 102 | def generate_until_multi_round(self, requests: List[Instance]) -> List[str]: 103 | raise NotImplementedError 104 | 105 | def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: 106 | raise NotImplementedError 107 | -------------------------------------------------------------------------------- /vila_u/eval/lmms/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/vila-u/4ea42ead2ce22b035ee74c7966eed245bca3e927/vila_u/eval/lmms/tasks/__init__.py -------------------------------------------------------------------------------- /vila_u/eval/registry.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | lmms-gqa: 3 | tags: 4 | - core 5 | - local 6 | - regression 7 | metrics: 8 | accuracy: results/gqa/exact_match,none 9 | 10 | 11 | lmms-mme: 12 | tags: 13 | - core 14 | - local 15 | - regression 16 | metrics: 17 | cognition: results/mme/mme_cognition_score,none 18 | perception: results/mme/mme_percetion_score,none 19 | 20 | 21 | lmms-mmvet: 22 | tags: 23 | - core 24 | - openai 25 | 26 | 27 | lmms-pope: 28 | tags: 29 | - core 30 | - local 31 | - regression 32 | metrics: 33 | accuracy: results/pope/pope_accuracy,none 34 | precision: results/pope/pope_precision,none 35 | recall: results/pope/pope_recall,none 36 | f1: results/pope/pope_f1_score,none 37 | 38 | 39 | lmms-seedbench: 40 | tags: 41 | - core 42 | - local 43 | - regression 44 | metrics: 45 | all: results/seedbench/seed_all,none 46 | image: results/seedbench/seed_image,none 47 | video: results/seedbench/seed_video,none 48 | 49 | 50 | lmms-textvqa_test: 51 | tags: 52 | - submission 53 | 54 | 55 | lmms-vqav2_test: 56 | tags: 57 | - core 58 | - submission 59 | 60 | 61 | lmms-activitynetqa: 62 | tags: 63 | - openai 64 | metrics: 65 | accuracy: results/activitynetqa/gpt_eval_accuracy,none 66 | score: results/activitynetqa/gpt_eval_score,none 67 | -------------------------------------------------------------------------------- /vila_u/media.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Media", "File", "Image", "Video"] 2 | 3 | 4 | class Media: 5 | pass 6 | 7 | 8 | class File(Media): 9 | def __init__(self, path: str) -> None: 10 | self.path = path 11 | 12 | 13 | class Image(File): 14 | pass 15 | 16 | 17 | class Video(File): 18 | pass -------------------------------------------------------------------------------- /vila_u/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import cv2 3 | import numpy as np 4 | import os 5 | import re 6 | import torch 7 | import tempfile 8 | 9 | from io import BytesIO 10 | from PIL import Image 11 | from torchvision.transforms import CenterCrop 12 | from transformers import StoppingCriteria 13 | 14 | from vila_u.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 15 | 16 | 17 | def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None): 18 | if fps == None or frame_count == None: 19 | fps = vidcap.get(cv2.CAP_PROP_FPS) 20 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 21 | 22 | if fps == 0 or frame_count == 0: 23 | print("Video file not found. return empty images.") 24 | return [ 25 | Image.new("RGB", (720, 720)), 26 | ] * num_frames 27 | 28 | duration = frame_count / fps 29 | frame_interval = frame_count // num_frames 30 | 31 | if frame_interval == 0 and frame_count <= 1: 32 | print("frame_interval is equal to 0. return empty image.") 33 | return [ 34 | Image.new("RGB", (720, 720)), 35 | ] * num_frames 36 | 37 | images = [] 38 | count = 0 39 | success = True 40 | frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int) 41 | 42 | while success: 43 | if frame_count >= num_frames: 44 | success, frame = vidcap.read() 45 | if success: 46 | if count in frame_indices: 47 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 48 | im_pil = Image.fromarray(img) 49 | images.append(im_pil) 50 | if len(images) >= num_frames: 51 | return images 52 | count += 1 53 | else: 54 | break 55 | else: 56 | success, frame = vidcap.read() 57 | if success: 58 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 59 | im_pil = Image.fromarray(img) 60 | images.append(im_pil) 61 | count += 1 62 | elif count >= 1: 63 | width, height = images[-1].size 64 | images = [Image.new("RGB", (width, height))] * (num_frames - len(images)) + images 65 | print("padding frames:", (num_frames - len(images))) 66 | return images 67 | else: 68 | break 69 | 70 | print("fail") 71 | images = [Image.new("RGB", (720, 720))] * num_frames 72 | 73 | return images 74 | 75 | 76 | def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None): 77 | """ 78 | Extract frames from a video using OpenCV. 79 | 80 | Args: 81 | vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. 82 | frames (int): Number of frames to extract from the video. 83 | 84 | Returns: 85 | list: List of PIL Images extracted from the video. 86 | 87 | Raises: 88 | NotImplementedError: If the type of `vpath_or_bytesio` is not supported. 89 | """ 90 | 91 | if isinstance(vpath_or_bytesio, str): 92 | vidcap = cv2.VideoCapture(vpath_or_bytesio) 93 | return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) 94 | elif isinstance(vpath_or_bytesio, (BytesIO,)): 95 | with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: 96 | temp_video.write(vpath_or_bytesio.read()) 97 | temp_video_name = temp_video.name 98 | vidcap = cv2.VideoCapture(temp_video_name) 99 | return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) 100 | else: 101 | raise NotImplementedError(type(vpath_or_bytesio)) 102 | 103 | 104 | def load_image_from_base64(image): 105 | return Image.open(BytesIO(base64.b64decode(image))) 106 | 107 | 108 | def expand2square(pil_img, background_color): 109 | """ 110 | Expand the given PIL image to a square shape by adding padding. 111 | 112 | Parameters: 113 | - pil_img: The PIL image to be expanded. 114 | - background_color: The color of the padding to be added. 115 | 116 | Returns: 117 | - The expanded PIL image. 118 | 119 | If the image is already square, it is returned as is. 120 | If the image is wider than it is tall, padding is added to the top and bottom. 121 | If the image is taller than it is wide, padding is added to the left and right. 122 | """ 123 | width, height = pil_img.size 124 | if pil_img.mode == 'L': 125 | background_color = background_color[0] 126 | if width == height: 127 | return pil_img 128 | elif width > height: 129 | result = Image.new(pil_img.mode, (width, width), background_color) 130 | result.paste(pil_img, (0, (width - height) // 2)) 131 | return result 132 | else: 133 | result = Image.new(pil_img.mode, (height, height), background_color) 134 | result.paste(pil_img, ((height - width) // 2, 0)) 135 | return result 136 | 137 | 138 | def process_image(image_file, data_args, image_folder, generation_mode=False): 139 | processor = data_args.image_processor 140 | if isinstance(image_file, str): 141 | if image_folder is not None: 142 | image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") 143 | else: 144 | image = Image.open(image_file).convert("RGB") 145 | elif isinstance(image_file, BytesIO): 146 | image = Image.open(image_file).convert("RGB") 147 | else: 148 | image = image_file 149 | 150 | if generation_mode: 151 | if image.size[0] < image.size[1]: 152 | image = image.crop((0, 0, min(image.size), min(image.size))) 153 | else: 154 | ccrop = CenterCrop(min(image.size)) 155 | image = ccrop(image) 156 | elif data_args.image_aspect_ratio == "resize": 157 | if hasattr(data_args.image_processor, "crop_size"): 158 | crop_size = data_args.image_processor.crop_size 159 | else: 160 | assert hasattr(data_args.image_processor, "size") 161 | crop_size = data_args.image_processor.size 162 | image = image.resize((crop_size["height"], crop_size["width"])) 163 | elif data_args.image_aspect_ratio == "pad": 164 | image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) 165 | else: 166 | raise NotImplementedError() 167 | 168 | image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] 169 | 170 | return image 171 | 172 | 173 | def process_images(images, image_processor, model_cfg): 174 | 175 | model_cfg.image_processor = image_processor 176 | new_images = [process_image(image, model_cfg, None) for image in images] 177 | 178 | if all(x.shape == new_images[0].shape for x in new_images): 179 | new_images = torch.stack(new_images, dim=0) 180 | return new_images 181 | 182 | 183 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 184 | prompt_chunks = re.split(f"({DEFAULT_IMAGE_TOKEN})", prompt) 185 | input_ids = [tokenizer.bos_token_id] 186 | for chunk in prompt_chunks: 187 | if chunk == DEFAULT_IMAGE_TOKEN: 188 | input_ids.append(image_token_index) 189 | else: 190 | input_ids.extend(tokenizer(chunk).input_ids[1:]) 191 | 192 | if return_tensors is not None: 193 | if return_tensors == "pt": 194 | return torch.tensor(input_ids, dtype=torch.long) 195 | raise ValueError(f"Unsupported tensor type: {return_tensors}") 196 | 197 | return input_ids 198 | 199 | 200 | def get_model_name_from_path(model_path): 201 | model_path = model_path.strip("/") 202 | model_paths = model_path.split("/") 203 | if model_paths[-1].startswith("checkpoint-"): 204 | return model_paths[-2] + "_" + model_paths[-1] 205 | else: 206 | return model_paths[-1] 207 | 208 | 209 | class KeywordsStoppingCriteria(StoppingCriteria): 210 | def __init__(self, keywords, tokenizer, input_ids): 211 | self.keywords = keywords 212 | self.keyword_ids = [] 213 | self.max_keyword_len = 0 214 | for keyword in keywords: 215 | cur_keyword_ids = tokenizer(keyword).input_ids 216 | if ( 217 | len(cur_keyword_ids) > 1 218 | and cur_keyword_ids[0] == tokenizer.bos_token_id 219 | ): 220 | cur_keyword_ids = cur_keyword_ids[1:] 221 | if len(cur_keyword_ids) > self.max_keyword_len: 222 | self.max_keyword_len = len(cur_keyword_ids) 223 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 224 | self.tokenizer = tokenizer 225 | self.start_len = input_ids.shape[1] 226 | 227 | def call_for_batch( 228 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 229 | ) -> bool: 230 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 231 | self.keyword_ids = [ 232 | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids 233 | ] 234 | for keyword_id in self.keyword_ids: 235 | if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): 236 | return True 237 | outputs = self.tokenizer.batch_decode( 238 | output_ids[:, -offset:], skip_special_tokens=True 239 | )[0] 240 | for keyword in self.keywords: 241 | if keyword in outputs: 242 | return True 243 | return False 244 | 245 | def __call__( 246 | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 247 | ) -> bool: 248 | outputs = [] 249 | for i in range(output_ids.shape[0]): 250 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 251 | return all(outputs) 252 | -------------------------------------------------------------------------------- /vila_u/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.vila_u_llama import VILAULlamaModel, VILAULlamaConfig -------------------------------------------------------------------------------- /vila_u/model/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig 3 | 4 | from vila_u.model import VILAULlamaModel 5 | from vila_u.constants import ( 6 | DEFAULT_IMAGE_PATCH_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | DEFAULT_VI_START_TOKEN, 10 | DEFAULT_VI_END_TOKEN, 11 | ) 12 | 13 | 14 | def load_pretrained_model( 15 | model_path, 16 | model_dtype=torch.bfloat16, 17 | device_map="auto", 18 | device="cuda", 19 | **kwargs, 20 | ): 21 | kwargs = {"device_map": device_map, **kwargs} 22 | 23 | if device != "cuda": 24 | kwargs["device_map"] = {"": device} 25 | 26 | config = AutoConfig.from_pretrained(model_path) 27 | config.resume_path = model_path 28 | config.model_dtype = model_dtype.__str__() 29 | 30 | model = VILAULlamaModel( 31 | config=config, 32 | low_cpu_mem_usage=True, 33 | **kwargs 34 | ) 35 | tokenizer = model.tokenizer 36 | 37 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 38 | mm_use_vi_start_end = getattr(model.config, "mm_use_vi_start_end", False) 39 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 40 | if mm_use_im_patch_token: 41 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 42 | if mm_use_im_start_end and mm_use_vi_start_end: 43 | tokenizer.add_tokens( 44 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VI_START_TOKEN, DEFAULT_VI_END_TOKEN], special_tokens=True 45 | ) 46 | elif mm_use_im_start_end: 47 | tokenizer.add_tokens( 48 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True 49 | ) 50 | model.resize_token_embeddings(len(tokenizer)) 51 | model.eval() 52 | 53 | vision_tower = model.get_vision_tower() 54 | vision_tower.to(device=device, dtype=model_dtype) 55 | 56 | mm_projector = model.get_mm_projector() 57 | mm_projector.to(device=device, dtype=model_dtype) 58 | 59 | image_processor = vision_tower.image_processor 60 | 61 | if hasattr(model.llm.config, "max_sequence_length"): 62 | context_len = model.config.max_sequence_length 63 | else: 64 | context_len = 2048 65 | 66 | return tokenizer, model, image_processor, context_len -------------------------------------------------------------------------------- /vila_u/model/configuration_vila_u.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class VILAUConfig(PretrainedConfig): 5 | model_type = "vila_u" 6 | 7 | def __init__( 8 | self, 9 | llm_cfg=None, 10 | vision_tower_cfg=None, 11 | mm_projector_cfg=None, 12 | architectures=None, 13 | resume_path=None, 14 | hidden_size=None, 15 | mm_hidden_size=None, 16 | image_aspect_ratio=None, 17 | num_video_frames=None, 18 | mm_use_im_start_end=False, 19 | mm_use_vi_start_end=False, 20 | mm_use_im_patch_token=True, 21 | **kwargs 22 | ): 23 | super().__init__() 24 | 25 | self.llm_cfg = llm_cfg 26 | self.vision_tower_cfg = vision_tower_cfg 27 | self.mm_projector_cfg = mm_projector_cfg 28 | self.architectures = architectures 29 | self.resume_path = resume_path 30 | self.hidden_size = hidden_size 31 | self.mm_hidden_size = mm_hidden_size 32 | self.image_aspect_ratio = image_aspect_ratio 33 | self.num_video_frames = num_video_frames 34 | self.mm_use_im_start_end = mm_use_im_start_end 35 | self.mm_use_vi_start_end = mm_use_vi_start_end 36 | self.mm_use_im_patch_token = mm_use_im_patch_token -------------------------------------------------------------------------------- /vila_u/model/language_model/builder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from transformers import ( 5 | AutoTokenizer, 6 | AutoModelForCausalLM, 7 | AutoConfig, 8 | PretrainedConfig, 9 | PreTrainedModel, 10 | ) 11 | 12 | 13 | def context_length_extension(config): 14 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 15 | model_max_length = getattr(config, "model_max_length", None) 16 | if orig_ctx_len and model_max_length > orig_ctx_len: 17 | print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") 18 | scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) 19 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 20 | return config 21 | 22 | 23 | def build_llm_and_tokenizer( 24 | model_name_or_path: str, 25 | config: PretrainedConfig, 26 | attn_implementation=None, 27 | model_max_length=None, 28 | *args, 29 | **kwargs, 30 | ) -> PreTrainedModel: 31 | llm_cfg = AutoConfig.from_pretrained(model_name_or_path) 32 | llm_cfg._attn_implementation = attn_implementation 33 | llm_cfg.model_max_length = model_max_length 34 | if model_max_length is not None: 35 | context_length_extension(llm_cfg) 36 | 37 | llm = AutoModelForCausalLM.from_pretrained( 38 | model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs 39 | ) 40 | 41 | tokenizer = AutoTokenizer.from_pretrained( 42 | model_name_or_path, 43 | model_max_length=llm_cfg.model_max_length, 44 | padding_side="right", 45 | use_fast=False, 46 | legacy=False, 47 | ) 48 | 49 | config.hidden_size = llm.config.hidden_size 50 | return llm, tokenizer -------------------------------------------------------------------------------- /vila_u/model/language_model/vila_u_llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch.nn import CrossEntropyLoss 5 | from typing import List, Optional, Tuple, Union 6 | from transformers import ( 7 | AutoConfig, 8 | AutoModel, 9 | PreTrainedModel, 10 | PretrainedConfig, 11 | ) 12 | from transformers.modeling_outputs import CausalLMOutputWithPast 13 | 14 | from ..configuration_vila_u import VILAUConfig 15 | from ..vila_u_arch import VILAUMetaModel, VILAUMetaForCausalLM 16 | 17 | 18 | class VILAULlamaConfig(VILAUConfig): 19 | model_type = "vila_u_llama" 20 | 21 | 22 | class VILAULlamaModel(VILAUMetaModel, VILAUMetaForCausalLM, PreTrainedModel): 23 | config_class = VILAULlamaConfig 24 | main_input_name = "input_embeds" 25 | supports_gradient_checkpointing = True 26 | 27 | def __init__(self, config: VILAULlamaConfig = None, *args, **kwargs) -> None: 28 | super().__init__(config) 29 | 30 | return self.init_vlm(config=config, *args, **kwargs) 31 | 32 | @classmethod 33 | def from_pretrained( 34 | cls, 35 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], 36 | *model_args, 37 | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, 38 | cache_dir: Optional[Union[str, os.PathLike]] = None, 39 | ignore_mismatched_sizes: bool = False, 40 | force_download: bool = False, 41 | local_files_only: bool = False, 42 | token: Optional[Union[str, bool]] = None, 43 | revision: str = "main", 44 | use_safetensors: bool = None, 45 | **kwargs, 46 | ): 47 | if hasattr(cls, "load_pretrained"): 48 | return cls.load_pretrained( 49 | pretrained_model_name_or_path, 50 | *model_args, 51 | config=config, 52 | cache_dir=cache_dir, 53 | ignore_mismatched_sizes=ignore_mismatched_sizes, 54 | force_download=force_download, 55 | local_files_only=local_files_only, 56 | token=token, 57 | revision=revision, 58 | use_safetensors=use_safetensors, 59 | **kwargs, 60 | ) 61 | 62 | return super(VILAULlamaModel).from_pretrained( 63 | pretrained_model_name_or_path, 64 | *model_args, 65 | config=config, 66 | cache_dir=cache_dir, 67 | ignore_mismatched_sizes=ignore_mismatched_sizes, 68 | force_download=force_download, 69 | local_files_only=local_files_only, 70 | token=token, 71 | revision=revision, 72 | use_safetensors=use_safetensors, 73 | **kwargs, 74 | ) 75 | 76 | def forward( 77 | self, 78 | input_ids: torch.LongTensor = None, 79 | images: Optional[torch.FloatTensor] = None, 80 | attention_mask: Optional[torch.Tensor] = None, 81 | position_ids: Optional[torch.LongTensor] = None, 82 | past_key_values: Optional[List[torch.FloatTensor]] = None, 83 | inputs_embeds: Optional[torch.FloatTensor] = None, 84 | labels: Optional[torch.LongTensor] = None, 85 | use_cache: Optional[bool] = None, 86 | output_attentions: Optional[bool] = None, 87 | output_hidden_states: Optional[bool] = None, 88 | return_dict: Optional[bool] = None, 89 | ) -> Union[Tuple, CausalLMOutputWithPast]: 90 | if inputs_embeds is None: 91 | ( 92 | input_ids, 93 | position_ids, 94 | attention_mask, 95 | past_key_values, 96 | inputs_embeds, 97 | labels, 98 | ) = self.prepare_inputs_labels_for_multimodal( 99 | input_ids, 100 | position_ids, 101 | attention_mask, 102 | past_key_values, 103 | labels, 104 | images, 105 | ) 106 | 107 | if self.training: 108 | ( 109 | _, 110 | new_position_ids, 111 | new_attention_mask, 112 | _, 113 | new_inputs_embeds, 114 | new_labels, 115 | sorted_seqlens_in_batch, 116 | ) = self.repack_multimodal_data( 117 | input_ids, 118 | position_ids, 119 | attention_mask, 120 | past_key_values, 121 | inputs_embeds, 122 | labels, 123 | ) 124 | new_input_ids = None 125 | past_key_values = None 126 | else: 127 | new_attention_mask = attention_mask 128 | new_position_ids = position_ids 129 | new_inputs_embeds = inputs_embeds 130 | new_labels = labels 131 | sorted_seqlens_in_batch = attention_mask.sum(-1).int() 132 | new_input_ids = input_ids 133 | 134 | output_attentions = output_attentions if output_attentions is not None else self.llm.config.output_attentions 135 | output_hidden_states = ( 136 | output_hidden_states if output_hidden_states is not None else self.llm.config.output_hidden_states 137 | ) 138 | return_dict = return_dict if return_dict is not None else self.llm.config.use_return_dict 139 | 140 | outputs = self.llm.model( 141 | input_ids=new_input_ids, 142 | attention_mask=new_attention_mask, 143 | position_ids=new_position_ids, 144 | past_key_values=past_key_values, 145 | inputs_embeds=new_inputs_embeds, 146 | use_cache=use_cache, 147 | output_attentions=output_attentions, 148 | output_hidden_states=output_hidden_states, 149 | return_dict=return_dict, 150 | seqlens_in_batch=sorted_seqlens_in_batch, 151 | ) 152 | 153 | hidden_states = outputs[0] 154 | 155 | image_hidden_states = [] 156 | image_labels = [] 157 | noimage_labels = [] 158 | 159 | for i in range(hidden_states.shape[0]): 160 | label = new_labels[i] 161 | hidden_state = hidden_states[i] 162 | label_zero = label[:, 0].clone() 163 | 164 | if self.config.mm_use_vi_start_end: 165 | image_start_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 4)).squeeze(1) 166 | image_end_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 3)).squeeze(1) 167 | video_start_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 2)).squeeze(1) 168 | video_end_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 1)).squeeze(1) 169 | image_start_index = torch.cat([image_start_index, video_start_index]) 170 | image_end_index = torch.cat([image_end_index, video_end_index]) 171 | else: 172 | image_start_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 2)).squeeze(1) 173 | image_end_index = torch.nonzero(torch.eq(label_zero, self.llm.vocab_size - 1)).squeeze(1) 174 | 175 | assert len(image_start_index) == len(image_end_index), f"length of image_start_index is {len(image_start_index)}, length of image_end_index is {len(image_end_index)}" 176 | 177 | if len(image_start_index) > 0: 178 | for start_idx, end_idx in zip(image_start_index, image_end_index): 179 | image_label = label[start_idx+1:end_idx, :] 180 | image_labels.append(image_label) 181 | image_hidden_state = hidden_state[start_idx:end_idx-1, :] 182 | image_hidden_states.append(image_hidden_state) 183 | label_zero[start_idx+1:end_idx] = -100 184 | 185 | noimage_labels.append(label_zero) 186 | 187 | # For video 188 | image_hidden_states_aux = [] 189 | image_labels_aux = [] 190 | image_hidden_states_length = [img.shape[0] for img in image_hidden_states] 191 | image_hidden_states_length_relative = [img // min(image_hidden_states_length) for img in image_hidden_states_length] 192 | for l in range(len(image_hidden_states_length_relative)): 193 | if image_hidden_states_length_relative[l] > 1: 194 | image_hidden_states_aux += torch.split(image_hidden_states[l], min(image_hidden_states_length), dim=0) 195 | image_labels_aux += torch.split(image_labels[l], min(image_hidden_states_length), dim=0) 196 | else: 197 | image_hidden_states_aux.append(image_hidden_states[l]) 198 | image_labels_aux.append(image_labels[l]) 199 | 200 | if len(image_hidden_states_aux) > 0: 201 | image_hidden_states = torch.stack(image_hidden_states_aux, 0) 202 | image_labels = torch.stack(image_labels_aux, 0) 203 | 204 | noimage_labels = torch.stack(noimage_labels, 0) 205 | 206 | logits = self.llm.lm_head(hidden_states) 207 | 208 | loss_fct = CrossEntropyLoss() 209 | 210 | image_loss = None 211 | if torch.is_tensor(image_hidden_states): 212 | if hasattr(self.vision_tower.vision_tower, "rqvaesiglip"): 213 | outs = self.vision_tower.vision_tower.rqtransformer(image_hidden_states, image_labels - self.llm.vocab_size, self.vision_tower.vision_tower.rqvaesiglip) 214 | else: 215 | raise NotImplementedError() 216 | B, seq_len, D, C = outs.shape 217 | image_logits = outs.reshape(B*seq_len*D, C).contiguous() 218 | image_labels = image_labels.reshape(B*seq_len*D).contiguous() - self.llm.vocab_size 219 | image_loss = loss_fct(image_logits, image_labels) 220 | 221 | loss = None 222 | shift_logits = logits[..., :-1, :].contiguous() 223 | shift_labels = noimage_labels[..., 1:].contiguous() 224 | shift_logits = shift_logits.view(-1, self.llm.config.vocab_size) 225 | shift_labels = shift_labels.view(-1) 226 | shift_labels = shift_labels.to(shift_logits.device) 227 | loss = loss_fct(shift_logits, shift_labels) 228 | 229 | if image_loss is not None: 230 | loss = loss + image_loss 231 | 232 | return CausalLMOutputWithPast( 233 | loss=loss, 234 | logits=logits, 235 | past_key_values=outputs.past_key_values, 236 | hidden_states=outputs.hidden_states, 237 | attentions=outputs.attentions, 238 | ) 239 | 240 | 241 | AutoConfig.register("vila_u_llama", VILAULlamaConfig) 242 | AutoModel.register(VILAULlamaConfig, VILAULlamaModel) -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import AutoConfig, PretrainedConfig, PreTrainedModel 4 | 5 | from .rqvaesigliptransformer_encoder import RQVAESIGLIPTransformerVisionTower 6 | 7 | 8 | def build_vision_tower( 9 | model_name_or_path: str, config: PretrainedConfig 10 | ) -> PreTrainedModel: 11 | if model_name_or_path is None: 12 | return None 13 | 14 | vision_tower_arch = None 15 | if config.resume_path: 16 | assert os.path.exists( 17 | model_name_or_path 18 | ), f"Resume vision tower path {model_name_or_path} does not exist!" 19 | vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) 20 | vision_tower_arch = vision_tower_cfg.architectures[0].lower() 21 | vision_tower_name = ( 22 | vision_tower_arch if vision_tower_arch is not None else model_name_or_path 23 | ) 24 | 25 | vision_tower = RQVAESIGLIPTransformerVisionTower(model_name_or_path, config) 26 | 27 | config.mm_hidden_size = vision_tower.config.hidden_size 28 | 29 | return vision_tower -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_rqvaesigliptransformer import RQVAESIGLIPTransformerConfig 2 | from .modeling_rqvaesigliptransformer import RQVAESIGLIPTransformer -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/configuration_rqvaesigliptransformer.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class RQVAESIGLIPTransformerConfig(PretrainedConfig): 5 | model_type = "rqvaesigliptransformer_model" 6 | def __init__( 7 | self, 8 | rqvaesiglip=None, 9 | rqtransformer=None, 10 | hidden_size=None, 11 | architectures=None, 12 | **kwargs, 13 | ): 14 | super().__init__() 15 | 16 | self.rqvaesiglip = rqvaesiglip 17 | self.rqtransformer = rqtransformer 18 | self.hidden_size = hidden_size 19 | self.architectures = architectures -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/modeling_rqvaesigliptransformer.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, AutoConfig, AutoModel 2 | 3 | from .configuration_rqvaesigliptransformer import RQVAESIGLIPTransformerConfig 4 | from .rqvaesiglip import RQVAESiglipModel 5 | from .rqtransformer import RQTransformer 6 | 7 | 8 | class RQVAESIGLIPTransformer(PreTrainedModel): 9 | config_class = RQVAESIGLIPTransformerConfig 10 | def __init__(self, config: RQVAESIGLIPTransformerConfig): 11 | super().__init__(config) 12 | 13 | rqvaesiglip_config = RQVAESiglipModel.config_class.from_dict(config.rqvaesiglip) 14 | rqtransformer_config = RQTransformer.config_class.from_dict(config.rqtransformer) 15 | 16 | self.rqvaesiglip = RQVAESiglipModel._from_config(rqvaesiglip_config) 17 | self.rqtransformer = RQTransformer._from_config(rqtransformer_config) 18 | 19 | 20 | AutoConfig.register("rqvaesigliptransformer_model", RQVAESIGLIPTransformerConfig) 21 | AutoModel.register(RQVAESIGLIPTransformerConfig, RQVAESIGLIPTransformer) -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_rqtransformer import RQTransformer -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqtransformer/attentions.py. 3 | """ 4 | 5 | import math 6 | import torch 7 | 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from typing import Iterable 11 | 12 | from .configuration_rqtransformer import AttentionBlockConfig, AttentionStackConfig 13 | 14 | 15 | class MultiSelfAttention(nn.Module): 16 | """ 17 | Optimized by batched matmul operations 18 | """ 19 | 20 | def __init__(self, config: AttentionBlockConfig, mask=True): 21 | super().__init__() 22 | assert config.embed_dim % config.n_head == 0 23 | 24 | self.key = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) 25 | self.query = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) 26 | self.value = nn.Linear(config.embed_dim, config.embed_dim, bias=config.attn_bias) 27 | 28 | self.attn_drop = nn.Dropout(config.attn_pdrop, inplace=False) 29 | self.resid_drop = nn.Dropout(config.resid_pdrop, inplace=True) 30 | 31 | self.proj = nn.Linear(config.embed_dim, config.embed_dim, config.attn_bias) 32 | 33 | self.n_head = config.n_head 34 | self.mask = mask 35 | 36 | def forward(self, x, caching=False, past_kv=None): 37 | (B, T, C) = x.shape 38 | 39 | if not caching: 40 | assert past_kv is None 41 | 42 | x = x.transpose(0, 1).contiguous() 43 | 44 | k = self.key(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) 45 | q = self.query(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) 46 | v = self.value(x).view(T, B*self.n_head, C//self.n_head).transpose(0, 1) 47 | 48 | if past_kv is not None: 49 | past_key, past_value = past_kv 50 | k = torch.cat([past_key, k], dim=-2) 51 | v = torch.cat([past_value, v], dim=-2) 52 | T_past = past_key.shape[1] 53 | else: 54 | T_past = 0 55 | 56 | if caching: 57 | present = torch.stack([k, v]) 58 | else: 59 | present = None 60 | 61 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 62 | if self.mask: 63 | mask = torch.tril(torch.ones(T_past+T, T_past+T, device=x.device, dtype=torch.bool)) 64 | mask = mask.view(1, T_past+T, T_past+T) 65 | att = att.masked_fill(~mask[:, T_past:T_past+T, :T_past+T], float('-inf')) 66 | att = F.softmax(att, dim=-1) 67 | att = self.attn_drop(att) 68 | 69 | y = torch.bmm(att, v) 70 | y = y.transpose(0, 1).contiguous().view(T, B, C) 71 | 72 | y = self.resid_drop(self.proj(y)) 73 | 74 | if caching: 75 | return y.transpose(0, 1).contiguous(), present 76 | else: 77 | return y.transpose(0, 1).contiguous() 78 | 79 | 80 | class AttentionBlock(nn.Module): 81 | """ an unassuming Transformer block """ 82 | 83 | def __init__(self, config: AttentionBlockConfig): 84 | super().__init__() 85 | 86 | self.ln1 = nn.LayerNorm(config.embed_dim) 87 | self.ln2 = nn.LayerNorm(config.embed_dim) 88 | 89 | self.attn = MultiSelfAttention(config, mask=True) 90 | self.mlp = nn.Sequential( 91 | nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=config.mlp_bias), 92 | nn.GELU(), 93 | nn.Linear(4 * config.embed_dim, config.embed_dim, bias=config.mlp_bias), 94 | nn.Dropout(config.resid_pdrop, inplace=True), 95 | ) 96 | self._cache = None 97 | 98 | def forward(self, x): 99 | 100 | attn = self.attn(self.ln1(x)) 101 | 102 | x = x + attn 103 | x = x + self.mlp(self.ln2(x)) 104 | 105 | return x 106 | 107 | def cached_forward(self, x_present): 108 | 109 | attn, present = self.attn(self.ln1(x_present), caching=True, past_kv=self._cache['past_kv']) 110 | self._cache['past_kv'] = present 111 | 112 | x_present = x_present + attn 113 | x_present = x_present + self.mlp(self.ln2(x_present)) 114 | 115 | return x_present 116 | 117 | def init_cache(self): 118 | self._cache = {'past_kv': None} 119 | 120 | 121 | class AttentionStack(nn.Module): 122 | 123 | blocks: Iterable[AttentionBlock] 124 | 125 | def __init__(self, config: AttentionStackConfig): 126 | super().__init__() 127 | 128 | self.blocks = nn.ModuleList([AttentionBlock(config.block) for _ in range(config.n_layer)]) 129 | 130 | def forward(self, x): 131 | for block in self.blocks: 132 | x = block(x) 133 | 134 | return x 135 | 136 | def cached_forward(self, x_present): 137 | for block in self.blocks: 138 | x_present = block.cached_forward(x_present) 139 | 140 | return x_present 141 | 142 | def init_cache(self): 143 | for block in self.blocks: 144 | block.init_cache() -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/configuration_rqtransformer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from transformers import PretrainedConfig 3 | 4 | 5 | @dataclass 6 | class AttentionBlockConfig: 7 | embed_dim: int = 2560 8 | n_head: int = 40 9 | mlp_bias: bool = True 10 | attn_bias: bool = True 11 | attn_pdrop: float = 0.0 12 | resid_pdrop: float = 0.1 13 | 14 | 15 | @dataclass 16 | class AttentionStackConfig: 17 | n_layer: int = 6 18 | block: AttentionBlockConfig = AttentionBlockConfig() 19 | 20 | 21 | class RQTransformerConfig(PretrainedConfig): 22 | model_type = "rqtransformer_model" 23 | def __init__( 24 | self, 25 | block_size=None, 26 | input_embed_dim_1=None, 27 | input_embed_dim_2=None, 28 | embed_dim=None, 29 | vocab_size=None, 30 | head=None, 31 | architectures=None, 32 | **kwargs, 33 | ): 34 | super().__init__() 35 | 36 | self.block_size = block_size 37 | self.input_embed_dim_1 = input_embed_dim_1 38 | self.input_embed_dim_2 = input_embed_dim_2 39 | self.embed_dim = embed_dim 40 | self.vocab_size = vocab_size 41 | self.head = head 42 | self.architectures = architectures -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqtransformer/modeling_rqtransformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqtransformer/transformers.py. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from collections import OrderedDict 9 | from torch.nn import functional as F 10 | from transformers import PreTrainedModel, AutoConfig, AutoModel 11 | 12 | from .attention import AttentionStack 13 | from .configuration_rqtransformer import RQTransformerConfig, AttentionStackConfig, AttentionBlockConfig 14 | 15 | 16 | def top_k_logits(logits, k): 17 | v, ix = torch.topk(logits, k) 18 | out = logits.clone() 19 | out[out < v[:, [-1]]] = -float('Inf') 20 | 21 | return out 22 | 23 | 24 | def top_p_probs(probs, p): 25 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 26 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 27 | 28 | sorted_idx_remove_cond = cum_probs >= p 29 | 30 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 31 | sorted_idx_remove_cond[..., 0] = 0 32 | 33 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 34 | probs = probs.masked_fill(indices_to_remove, 0.0) 35 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) 36 | 37 | return norm_probs 38 | 39 | 40 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None): 41 | """Take a 2-dim tensor, apply softmax along each row, and sample from 42 | each multinomial distribution defined by the rows. 43 | 44 | Args: 45 | logits: 2-dim tensor of shape (n_samples, logit_dim) 46 | temperature (float): softmax temperature 47 | top_k (Optional[int]): if given, sample only using `top_k` logits 48 | top_p (Optional[float]): if given, sample only using `top_p` logits 49 | 50 | Returns: 51 | samples: 1-dim integer tensor of shape (n_samples,) 52 | """ 53 | 54 | logits = logits.to(dtype=torch.float32) 55 | logits = logits / temperature 56 | 57 | if top_k is not None: 58 | logits = top_k_logits(logits, top_k) 59 | 60 | if torch.sum(torch.isnan(logits)): 61 | print('WARNING... NaN observed') 62 | logits[torch.isnan(logits)] = -float('Inf') 63 | 64 | probs = F.softmax(logits, dim=-1) 65 | 66 | if top_p is not None: 67 | probs = top_p_probs(probs, top_p) 68 | 69 | try: 70 | samples = torch.multinomial(probs, num_samples=1) 71 | except: 72 | raise RuntimeError 73 | 74 | return samples.view(-1) 75 | 76 | 77 | class RQTransformer(PreTrainedModel): 78 | config_class = RQTransformerConfig 79 | def __init__(self, config: RQTransformerConfig): 80 | super().__init__(config) 81 | self.in_mlp_1 = nn.Linear(config.input_embed_dim_1, config.embed_dim) 82 | self.in_mlp_2 = nn.Linear(config.input_embed_dim_2, config.embed_dim) 83 | 84 | blockconfig = AttentionBlockConfig(embed_dim=config.embed_dim, n_head=config.head["block"]["n_head"]) 85 | stackconfig = AttentionStackConfig(n_layer=config.head["n_layer"], block=blockconfig) 86 | self.head_transformer = AttentionStack(stackconfig) 87 | 88 | self.pos_emb_d = nn.Parameter(torch.zeros(1, config.block_size[2], config.embed_dim)) 89 | self.pos_emb_d.data.normal_(mean=0.0, std=0.02) 90 | 91 | self.classifier_mlp = nn.Sequential(OrderedDict([ 92 | ('layer_norm', nn.LayerNorm(config.embed_dim)), 93 | ('linear', nn.Linear(config.embed_dim, config.vocab_size)), 94 | ])) 95 | 96 | def embed_with_model_aux(self, code, model_aux): 97 | xs_emb, _ = model_aux.get_code_emb_with_depth(code) 98 | return xs_emb 99 | 100 | def forward(self, embed_from_body, code, model_aux=None): 101 | B, seq_len, D = code.shape 102 | 103 | depth_ctx = self.embed_with_model_aux(code, model_aux) 104 | depth_ctx = torch.cumsum(depth_ctx, dim=-2) 105 | depth_ctx = self.in_mlp_1(depth_ctx) 106 | 107 | embed_from_body = self.in_mlp_2(embed_from_body) 108 | 109 | depth_ctx_full = torch.cat( 110 | [ 111 | embed_from_body.view(B, seq_len, 1, -1), 112 | depth_ctx[:, :, :-1, :], 113 | ], 114 | dim=-2, 115 | ) 116 | 117 | depth_ctx_full = depth_ctx_full.reshape(B * seq_len, D, -1) 118 | depth_ctx_full = depth_ctx_full + self.pos_emb_d[:, :D, :] 119 | 120 | head_outputs = self.head_transformer(depth_ctx_full) 121 | head_outputs = head_outputs.reshape(B, seq_len, D, -1) 122 | head_outputs = self.classifier_mlp(head_outputs) 123 | 124 | return head_outputs 125 | 126 | def generate(self, embed_from_body, model_aux=None, cfg=3.0): 127 | generate_idx = 1 128 | B, seq_len, _ = embed_from_body.shape 129 | 130 | embed_from_body = self.in_mlp_2(embed_from_body) 131 | 132 | depth_ctx_full = embed_from_body.view(B, seq_len, 1, -1) 133 | depth_ctx_full = depth_ctx_full.reshape(B * seq_len, generate_idx, -1) 134 | depth_ctx_full = depth_ctx_full + self.pos_emb_d[:, :generate_idx, :] 135 | 136 | head_outputs = self.head_transformer(depth_ctx_full) 137 | head_outputs = head_outputs.reshape(B, -1) 138 | 139 | logits = self.classifier_mlp(head_outputs) 140 | 141 | logits = logits[B//2:, :] + cfg * (logits[:B//2, :] - logits[B//2:, :]) 142 | code = sample_from_logits(logits, temperature=1.0, top_p=0.96, top_k=900) 143 | code = code.reshape(B//2, seq_len, 1).repeat(2, 1, self.pos_emb_d.shape[1]) 144 | 145 | for i in range(self.pos_emb_d.shape[1]-1): 146 | generate_idx += 1 147 | depth_ctx = self.embed_with_model_aux(code, model_aux) 148 | depth_ctx = torch.cumsum(depth_ctx, dim=-2)[:, :, :i+1, :] 149 | if len(depth_ctx.shape) == 3: 150 | depth_ctx = depth_ctx.unsqueeze(2) 151 | depth_ctx = self.in_mlp_1(depth_ctx) 152 | 153 | depth_ctx_full = torch.cat( 154 | [ 155 | embed_from_body.view(B, seq_len, 1, -1), 156 | depth_ctx, 157 | ], 158 | dim=-2, 159 | ) 160 | 161 | depth_ctx_full = depth_ctx_full.reshape(B * seq_len, generate_idx, -1) 162 | depth_ctx_full = depth_ctx_full + self.pos_emb_d[:, :generate_idx, :] 163 | 164 | head_outputs = self.head_transformer(depth_ctx_full) 165 | head_outputs = head_outputs[:, -1, :] 166 | 167 | logits = self.classifier_mlp(head_outputs) 168 | 169 | logits = logits[B//2:, :] + cfg * (logits[:B//2, :] - logits[B//2:, :]) 170 | code_generate = sample_from_logits(logits, temperature=1.0, top_p=0.96, top_k=900) 171 | code_generate = code_generate.reshape(B//2, seq_len).repeat(2, 1) 172 | code[:, :, i+1] = code_generate 173 | 174 | out_features = self.embed_with_model_aux(code, model_aux) 175 | out_features = torch.cumsum(out_features, dim=-2)[:, :, -1, :] 176 | 177 | return out_features, code 178 | 179 | 180 | AutoConfig.register("rqtransformer_model", RQTransformerConfig) 181 | AutoModel.register(RQTransformerConfig, RQTransformer) -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_rqvaesiglip import RQVAESiglipModel -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/configuration_rqvaesiglip.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class RQVAESiglipConfig(PretrainedConfig): 5 | model_type = "rqvaesiglip_model" 6 | def __init__( 7 | self, 8 | embed_dim=None, 9 | n_embed=None, 10 | latent_shape=None, 11 | code_shape=None, 12 | shared_codebook=None, 13 | restart_unused_codes=None, 14 | ddconfig=None, 15 | decay=0.99, 16 | latent_loss_weight=0.25, 17 | architectures=None, 18 | decoder_latent_shape=None, 19 | pretrained_model="google/siglip-large-patch16-256", 20 | **kwargs, 21 | ): 22 | super().__init__() 23 | 24 | self.embed_dim = embed_dim 25 | self.n_embed = n_embed 26 | self.latent_shape = latent_shape 27 | self.code_shape = code_shape 28 | self.shared_codebook = shared_codebook 29 | self.restart_unused_codes = restart_unused_codes 30 | self.ddconfig = ddconfig 31 | self.decay = decay 32 | self.latent_loss_weight = latent_loss_weight 33 | self.architectures = architectures 34 | self.decoder_latent_shape = decoder_latent_shape 35 | self.pretrained_model = pretrained_model -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/modeling_rqvaesiglip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from transformers import PreTrainedModel, AutoConfig, AutoModel 6 | from typing import Optional 7 | 8 | from .configuration_rqvaesiglip import RQVAESiglipConfig 9 | from .modules import Decoder 10 | from .quantizations import RQBottleneck 11 | from .siglip import SiglipModel 12 | 13 | 14 | class RQVAESiglipModel(PreTrainedModel): 15 | config_class = RQVAESiglipConfig 16 | def __init__(self, config: RQVAESiglipConfig): 17 | super().__init__(config) 18 | 19 | siglip_config = SiglipModel.config_class.from_pretrained(config.pretrained_model) 20 | self.siglip_model = SiglipModel._from_config(siglip_config) 21 | 22 | self.quantizer = RQBottleneck( 23 | latent_shape=config.latent_shape, 24 | code_shape=config.code_shape, 25 | n_embed=config.n_embed, 26 | decay=config.decay, 27 | shared_codebook=config.shared_codebook, 28 | restart_unused_codes=config.restart_unused_codes, 29 | ) 30 | self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.ddconfig["z_channels"], 1) 31 | 32 | self.decoder = Decoder(**config.ddconfig) 33 | 34 | try: 35 | self.decoder_latent_shape = config.decoder_latent_shape 36 | except: 37 | self.decoder_latent_shape = None 38 | 39 | self.logit_scale = self.siglip_model.logit_scale 40 | self.logit_bias = self.siglip_model.logit_bias 41 | 42 | def encode_image(self, image): 43 | vision_model = self.siglip_model.vision_model 44 | hidden_states = vision_model.embeddings(image) 45 | 46 | attention_mask = None 47 | output_attentions = None 48 | for i, encoder_layer in enumerate(vision_model.encoder.layers): 49 | if vision_model.encoder.gradient_checkpointing and vision_model.encoder.training: 50 | layer_outputs = vision_model.encoder._gradient_checkpointing_func( 51 | encoder_layer.__call__, 52 | hidden_states, 53 | attention_mask, 54 | output_attentions, 55 | ) 56 | else: 57 | layer_outputs = encoder_layer( 58 | hidden_states, 59 | attention_mask, 60 | output_attentions=output_attentions, 61 | ) 62 | hidden_states = layer_outputs[0] 63 | if i == len(vision_model.encoder.layers) - 2: 64 | B, L, C = hidden_states.shape 65 | hidden_states = hidden_states.reshape(B, int(L**0.5), int(L**0.5), C) 66 | z_q, quant_loss, code = self.quantizer(hidden_states) 67 | 68 | return code, z_q 69 | 70 | def decode(self, z_q): 71 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 72 | 73 | if self.decoder_latent_shape is not None: 74 | z_q = F.interpolate(z_q.to(torch.float32), size=tuple(self.decoder_latent_shape), mode='bilinear').to(torch.bfloat16) 75 | 76 | z_q = self.post_quant_conv(z_q) 77 | out = self.decoder(z_q) 78 | 79 | return out 80 | 81 | @torch.no_grad() 82 | def get_code_emb_with_depth(self, code): 83 | return self.quantizer.embed_code_with_depth(code) 84 | 85 | 86 | AutoConfig.register("rqvaesiglip_model", RQVAESiglipConfig) 87 | AutoModel.register(RQVAESiglipConfig, RQVAESiglipModel) -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/modules.py. 3 | """ 4 | 5 | import torch 6 | 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.utils.checkpoint import checkpoint 10 | 11 | 12 | def nonlinearity(x): 13 | return F.silu(x, inplace=True) 14 | 15 | 16 | def Normalize(in_channels): 17 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 18 | 19 | 20 | class Upsample(nn.Module): 21 | def __init__(self, in_channels, with_conv): 22 | super().__init__() 23 | self.with_conv = with_conv 24 | if self.with_conv: 25 | self.conv = torch.nn.Conv2d(in_channels, 26 | in_channels, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1) 30 | 31 | def forward(self, x): 32 | x = torch.nn.functional.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(torch.bfloat16) 33 | if self.with_conv: 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class ResnetBlock(nn.Module): 39 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 40 | dropout, temb_channels=512): 41 | super().__init__() 42 | self.in_channels = in_channels 43 | out_channels = in_channels if out_channels is None else out_channels 44 | self.out_channels = out_channels 45 | self.use_conv_shortcut = conv_shortcut 46 | self.checkpointing = False 47 | 48 | self.norm1 = Normalize(in_channels) 49 | self.conv1 = torch.nn.Conv2d(in_channels, 50 | out_channels, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1) 54 | if temb_channels > 0: 55 | self.temb_proj = torch.nn.Linear(temb_channels, 56 | out_channels) 57 | self.norm2 = Normalize(out_channels) 58 | self.dropout = torch.nn.Dropout(dropout, inplace=True) 59 | self.conv2 = torch.nn.Conv2d(out_channels, 60 | out_channels, 61 | kernel_size=3, 62 | stride=1, 63 | padding=1) 64 | if self.in_channels != self.out_channels: 65 | if self.use_conv_shortcut: 66 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 67 | out_channels, 68 | kernel_size=3, 69 | stride=1, 70 | padding=1) 71 | else: 72 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 73 | out_channels, 74 | kernel_size=1, 75 | stride=1, 76 | padding=0) 77 | 78 | def _forward(self, x, temb): 79 | h = x 80 | h = self.norm1(h) 81 | h = nonlinearity(h) 82 | h = self.conv1(h) 83 | 84 | if temb is not None: 85 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 86 | 87 | h = self.norm2(h) 88 | h = nonlinearity(h) 89 | h = self.dropout(h) 90 | h = self.conv2(h) 91 | 92 | if self.in_channels != self.out_channels: 93 | if self.use_conv_shortcut: 94 | x = self.conv_shortcut(x) 95 | else: 96 | x = self.nin_shortcut(x) 97 | 98 | return x+h 99 | 100 | def forward(self, x, temb): 101 | if self.checkpointing and self.training: 102 | out = checkpoint(self._forward, x, temb) 103 | else: 104 | out = self._forward(x, temb) 105 | return out 106 | 107 | 108 | class AttnBlock(nn.Module): 109 | def __init__(self, in_channels): 110 | super().__init__() 111 | self.in_channels = in_channels 112 | 113 | self.norm = Normalize(in_channels) 114 | self.q = torch.nn.Conv2d(in_channels, 115 | in_channels, 116 | kernel_size=1, 117 | stride=1, 118 | padding=0) 119 | self.k = torch.nn.Conv2d(in_channels, 120 | in_channels, 121 | kernel_size=1, 122 | stride=1, 123 | padding=0) 124 | self.v = torch.nn.Conv2d(in_channels, 125 | in_channels, 126 | kernel_size=1, 127 | stride=1, 128 | padding=0) 129 | self.proj_out = torch.nn.Conv2d(in_channels, 130 | in_channels, 131 | kernel_size=1, 132 | stride=1, 133 | padding=0) 134 | 135 | 136 | def forward(self, x): 137 | h_ = x 138 | h_ = self.norm(h_) 139 | q = self.q(h_) 140 | k = self.k(h_) 141 | v = self.v(h_) 142 | 143 | b,c,h,w = q.shape 144 | q = q.reshape(b,c,h*w) 145 | q = q.permute(0,2,1) 146 | k = k.reshape(b,c,h*w) 147 | w_ = torch.bmm(q,k) 148 | w_ = w_ * (int(c)**(-0.5)) 149 | w_ = torch.nn.functional.softmax(w_, dim=2) 150 | 151 | v = v.reshape(b,c,h*w) 152 | w_ = w_.permute(0,2,1) 153 | h_ = torch.bmm(v,w_) 154 | h_ = h_.reshape(b,c,h,w) 155 | 156 | h_ = self.proj_out(h_) 157 | 158 | return x+h_ 159 | 160 | 161 | class Decoder(nn.Module): 162 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 163 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 164 | resolution, z_channels, give_pre_end=False, **ignorekwargs): 165 | super().__init__() 166 | self.ch = ch 167 | self.temb_ch = 0 168 | self.num_resolutions = len(ch_mult) 169 | self.num_res_blocks = num_res_blocks 170 | self.resolution = resolution 171 | self.in_channels = in_channels 172 | self.give_pre_end = give_pre_end 173 | 174 | in_ch_mult = (1,)+tuple(ch_mult) 175 | block_in = ch*ch_mult[self.num_resolutions-1] 176 | curr_res = resolution // 2**(self.num_resolutions-1) 177 | self.z_shape = (1, z_channels, curr_res, curr_res) 178 | 179 | self.conv_in = torch.nn.Conv2d(z_channels, 180 | block_in, 181 | kernel_size=3, 182 | stride=1, 183 | padding=1) 184 | 185 | self.mid = nn.Module() 186 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 187 | out_channels=block_in, 188 | temb_channels=self.temb_ch, 189 | dropout=dropout) 190 | self.mid.attn_1 = AttnBlock(block_in) 191 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 192 | out_channels=block_in, 193 | temb_channels=self.temb_ch, 194 | dropout=dropout) 195 | 196 | self.up = nn.ModuleList() 197 | for i_level in reversed(range(self.num_resolutions)): 198 | block = nn.ModuleList() 199 | attn = nn.ModuleList() 200 | block_out = ch*ch_mult[i_level] 201 | for i_block in range(self.num_res_blocks+1): 202 | block.append(ResnetBlock(in_channels=block_in, 203 | out_channels=block_out, 204 | temb_channels=self.temb_ch, 205 | dropout=dropout)) 206 | block_in = block_out 207 | if curr_res in attn_resolutions: 208 | attn.append(AttnBlock(block_in)) 209 | up = nn.Module() 210 | up.block = block 211 | up.attn = attn 212 | if i_level != 0: 213 | up.upsample = Upsample(block_in, resamp_with_conv) 214 | curr_res = curr_res * 2 215 | self.up.insert(0, up) 216 | 217 | self.norm_out = Normalize(block_in) 218 | self.conv_out = torch.nn.Conv2d(block_in, 219 | out_ch, 220 | kernel_size=3, 221 | stride=1, 222 | padding=1) 223 | 224 | def forward(self, z): 225 | self.last_z_shape = z.shape 226 | 227 | temb = None 228 | 229 | h = self.conv_in(z) 230 | 231 | h = self.mid.block_1(h, temb) 232 | h = self.mid.attn_1(h) 233 | h = self.mid.block_2(h, temb) 234 | 235 | for i_level in reversed(range(self.num_resolutions)): 236 | for i_block in range(self.num_res_blocks+1): 237 | h = self.up[i_level].block[i_block](h, temb) 238 | if len(self.up[i_level].attn) > 0: 239 | h = self.up[i_level].attn[i_block](h) 240 | if i_level != 0: 241 | h = self.up[i_level].upsample(h) 242 | 243 | if self.give_pre_end: 244 | return h 245 | 246 | h = self.norm_out(h) 247 | h = nonlinearity(h) 248 | h = self.conv_out(h) 249 | return h -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/quantizations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from typing import Iterable 12 | 13 | 14 | class VQEmbedding(nn.Embedding): 15 | """VQ embedding module with ema update.""" 16 | 17 | def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5): 18 | super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed) 19 | 20 | self.ema = ema 21 | self.decay = decay 22 | self.eps = eps 23 | self.restart_unused_codes = restart_unused_codes 24 | self.n_embed = n_embed 25 | 26 | if self.ema: 27 | _ = [p.requires_grad_(False) for p in self.parameters()] 28 | 29 | self.register_buffer('cluster_size_ema', torch.zeros(n_embed)) 30 | self.register_buffer('embed_ema', self.weight[:-1, :].detach().clone()) 31 | 32 | @torch.no_grad() 33 | def compute_distances(self, inputs): 34 | codebook_t = self.weight[:-1, :].t() 35 | 36 | (embed_dim, _) = codebook_t.shape 37 | inputs_shape = inputs.shape 38 | assert inputs_shape[-1] == embed_dim 39 | 40 | inputs_flat = inputs.reshape(-1, embed_dim) 41 | 42 | inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) 43 | codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) 44 | distances = torch.addmm( 45 | inputs_norm_sq + codebook_t_norm_sq, 46 | inputs_flat, 47 | codebook_t, 48 | alpha=-2.0, 49 | ) 50 | distances = distances.reshape(*inputs_shape[:-1], -1) 51 | return distances 52 | 53 | @torch.no_grad() 54 | def find_nearest_embedding(self, inputs): 55 | distances = self.compute_distances(inputs) 56 | embed_idxs = distances.argmin(dim=-1) 57 | 58 | return embed_idxs 59 | 60 | @torch.no_grad() 61 | def _tile_with_noise(self, x, target_n): 62 | B, embed_dim = x.shape 63 | n_repeats = (target_n + B -1) // B 64 | std = x.new_ones(embed_dim) * 0.01 / np.sqrt(embed_dim) 65 | x = x.repeat(n_repeats, 1) 66 | x = x + torch.rand_like(x) * std 67 | return x 68 | 69 | @torch.no_grad() 70 | def _update_buffers(self, vectors, idxs): 71 | 72 | n_embed, embed_dim = self.weight.shape[0]-1, self.weight.shape[-1] 73 | 74 | vectors = vectors.reshape(-1, embed_dim) 75 | idxs = idxs.reshape(-1) 76 | 77 | n_vectors = vectors.shape[0] 78 | n_total_embed = n_embed 79 | 80 | one_hot_idxs = vectors.new_zeros(n_total_embed, n_vectors) 81 | one_hot_idxs.scatter_(dim=0, 82 | index=idxs.unsqueeze(0), 83 | src=vectors.new_ones(1, n_vectors) 84 | ) 85 | 86 | cluster_size = one_hot_idxs.sum(dim=1) 87 | vectors_sum_per_cluster = one_hot_idxs @ vectors 88 | 89 | if dist.is_initialized(): 90 | dist.all_reduce(vectors_sum_per_cluster, op=dist.ReduceOp.SUM) 91 | dist.all_reduce(cluster_size, op=dist.ReduceOp.SUM) 92 | 93 | self.cluster_size_ema.mul_(self.decay).add_(cluster_size, alpha=1 - self.decay) 94 | self.embed_ema.mul_(self.decay).add_(vectors_sum_per_cluster, alpha=1 - self.decay) 95 | 96 | if self.restart_unused_codes: 97 | if n_vectors < n_embed: 98 | vectors = self._tile_with_noise(vectors, n_embed) 99 | n_vectors = vectors.shape[0] 100 | _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed] 101 | 102 | if dist.is_initialized(): 103 | dist.broadcast(_vectors_random, 0) 104 | 105 | usage = (self.cluster_size_ema.view(-1, 1) >= 1).float() 106 | self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage)) 107 | self.cluster_size_ema.mul_(usage.view(-1)) 108 | self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1)) 109 | 110 | @torch.no_grad() 111 | def _update_embedding(self): 112 | 113 | n_embed = self.weight.shape[0] - 1 114 | n = self.cluster_size_ema.sum() 115 | normalized_cluster_size = ( 116 | n * (self.cluster_size_ema + self.eps) / (n + n_embed * self.eps) 117 | ) 118 | self.weight[:-1, :] = self.embed_ema / normalized_cluster_size.reshape(-1, 1) 119 | 120 | def forward(self, inputs): 121 | embed_idxs = self.find_nearest_embedding(inputs) 122 | if self.training: 123 | if self.ema: 124 | self._update_buffers(inputs, embed_idxs) 125 | 126 | embeds = self.embed(embed_idxs) 127 | 128 | if self.ema and self.training: 129 | self._update_embedding() 130 | 131 | return embeds, embed_idxs 132 | 133 | def embed(self, idxs): 134 | embeds = super().forward(idxs) 135 | return embeds 136 | 137 | 138 | class RQBottleneck(nn.Module): 139 | """ 140 | Quantization bottleneck via Residual Quantization. 141 | 142 | Arguments: 143 | latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D) 144 | code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d) 145 | n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook) 146 | If isinstance(n_embed, int), the sizes of all codebooks are same. 147 | shared_codebook (bool): If True, codebooks are shared in all location. If False, 148 | uses separate codebooks along the ``depth'' dimension. (default: False) 149 | restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch 150 | as the new embedding of unused codes in training. (default: True) 151 | """ 152 | 153 | def __init__(self, 154 | latent_shape, 155 | code_shape, 156 | n_embed, 157 | decay=0.99, 158 | shared_codebook=False, 159 | restart_unused_codes=True, 160 | commitment_loss='cumsum' 161 | ): 162 | super().__init__() 163 | 164 | if not len(code_shape) == len(latent_shape) == 3: 165 | raise ValueError("incompatible code shape or latent shape") 166 | if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]): 167 | raise ValueError("incompatible code shape or latent shape") 168 | 169 | embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2] 170 | 171 | self.latent_shape = torch.Size(latent_shape) 172 | self.code_shape = torch.Size(code_shape) 173 | self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))]) 174 | 175 | self.shared_codebook = shared_codebook 176 | if self.shared_codebook: 177 | if isinstance(n_embed, Iterable) or isinstance(decay, Iterable): 178 | raise ValueError("Shared codebooks are incompatible \ 179 | with list types of momentums or sizes: Change it into int") 180 | 181 | self.restart_unused_codes = restart_unused_codes 182 | self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])] 183 | self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])] 184 | assert len(self.n_embed) == self.code_shape[-1] 185 | assert len(self.decay) == self.code_shape[-1] 186 | 187 | if self.shared_codebook: 188 | codebook0 = VQEmbedding(self.n_embed[0], 189 | embed_dim, 190 | decay=self.decay[0], 191 | restart_unused_codes=restart_unused_codes, 192 | ) 193 | self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])]) 194 | else: 195 | codebooks = [VQEmbedding(self.n_embed[idx], 196 | embed_dim, 197 | decay=self.decay[idx], 198 | restart_unused_codes=restart_unused_codes, 199 | ) for idx in range(self.code_shape[-1])] 200 | self.codebooks = nn.ModuleList(codebooks) 201 | 202 | self.commitment_loss = commitment_loss 203 | 204 | def to_code_shape(self, x): 205 | (B, H, W, D) = x.shape 206 | (rH, rW, _) = self.shape_divisor 207 | 208 | x = x.reshape(B, H//rH, rH, W//rW, rW, D) 209 | x = x.permute(0, 1, 3, 2, 4, 5) 210 | x = x.reshape(B, H//rH, W//rW, -1) 211 | 212 | return x 213 | 214 | def to_latent_shape(self, x): 215 | (B, h, w, _) = x.shape 216 | (_, _, D) = self.latent_shape 217 | (rH, rW, _) = self.shape_divisor 218 | 219 | x = x.reshape(B, h, w, rH, rW, D) 220 | x = x.permute(0, 1, 3, 2, 4, 5) 221 | x = x.reshape(B, h*rH, w*rW, D) 222 | 223 | return x 224 | 225 | def quantize(self, x): 226 | r""" 227 | Return list of quantized features and the selected codewords by the residual quantization. 228 | The code is selected by the residuals between x and quantized features by the previous codebooks. 229 | 230 | Arguments: 231 | x (Tensor): bottleneck feature maps to quantize. 232 | 233 | Returns: 234 | quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks. 235 | codes (LongTensor): codewords index, corresponding to quants. 236 | 237 | Shape: 238 | - x: (B, h, w, embed_dim) 239 | - quant_list[i]: (B, h, w, embed_dim) 240 | - codes: (B, h, w, d) 241 | """ 242 | B, h, w, embed_dim = x.shape 243 | 244 | residual_feature = x.detach().clone() 245 | 246 | quant_list = [] 247 | code_list = [] 248 | aggregated_quants = torch.zeros_like(x) 249 | for i in range(self.code_shape[-1]): 250 | quant, code = self.codebooks[i](residual_feature) 251 | 252 | residual_feature.sub_(quant) 253 | aggregated_quants.add_(quant) 254 | 255 | quant_list.append(aggregated_quants.clone()) 256 | code_list.append(code.unsqueeze(-1)) 257 | 258 | codes = torch.cat(code_list, dim=-1) 259 | return quant_list, codes 260 | 261 | def forward(self, x): 262 | x_reshaped = self.to_code_shape(x) 263 | quant_list, codes = self.quantize(x_reshaped) 264 | 265 | commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list) 266 | quants_trunc = self.to_latent_shape(quant_list[-1]) 267 | quants_trunc = x + (quants_trunc - x).detach() 268 | 269 | return quants_trunc, commitment_loss, codes 270 | 271 | def compute_commitment_loss(self, x, quant_list): 272 | r""" 273 | Compute the commitment loss for the residual quantization. 274 | The loss is iteratively computed by aggregating quantized features. 275 | """ 276 | loss_list = [] 277 | 278 | for idx, quant in enumerate(quant_list): 279 | partial_loss = (x-quant.detach()).pow(2.0).mean() 280 | loss_list.append(partial_loss) 281 | 282 | commitment_loss = torch.mean(torch.stack(loss_list)) 283 | return commitment_loss 284 | 285 | @torch.no_grad() 286 | def embed_code(self, code): 287 | assert code.shape[1:] == self.code_shape 288 | 289 | code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) 290 | 291 | if self.shared_codebook: 292 | embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] 293 | else: 294 | embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] 295 | 296 | embeds = torch.cat(embeds, dim=-2).sum(-2) 297 | embeds = self.to_latent_shape(embeds) 298 | 299 | return embeds 300 | 301 | @torch.no_grad() 302 | def embed_code_with_depth(self, code, to_latent_shape=False): 303 | assert code.shape[-1] == self.code_shape[-1] 304 | 305 | code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) 306 | 307 | if self.shared_codebook: 308 | embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] 309 | else: 310 | embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] 311 | 312 | if to_latent_shape: 313 | embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds] 314 | embeds = torch.cat(embeds, dim=-2) 315 | 316 | return embeds, None 317 | -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/siglip/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_torch_available, 20 | is_vision_available, 21 | ) 22 | 23 | 24 | _import_structure = { 25 | "configuration_siglip": [ 26 | "SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", 27 | "SiglipConfig", 28 | "SiglipTextConfig", 29 | "SiglipVisionConfig", 30 | ], 31 | "processing_siglip": ["SiglipProcessor"], 32 | "tokenization_siglip": ["SiglipTokenizer"], 33 | } 34 | 35 | try: 36 | if not is_vision_available(): 37 | raise OptionalDependencyNotAvailable() 38 | except OptionalDependencyNotAvailable: 39 | pass 40 | else: 41 | _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"] 42 | 43 | try: 44 | if not is_torch_available(): 45 | raise OptionalDependencyNotAvailable() 46 | except OptionalDependencyNotAvailable: 47 | pass 48 | else: 49 | _import_structure["modeling_siglip"] = [ 50 | "SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST", 51 | "SiglipModel", 52 | "SiglipPreTrainedModel", 53 | "SiglipTextModel", 54 | "SiglipVisionModel", 55 | ] 56 | 57 | 58 | if TYPE_CHECKING: 59 | from .configuration_siglip import ( 60 | SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, 61 | SiglipConfig, 62 | SiglipTextConfig, 63 | SiglipVisionConfig, 64 | ) 65 | from .processing_siglip import SiglipProcessor 66 | from .tokenization_siglip import SiglipTokenizer 67 | 68 | try: 69 | if not is_vision_available(): 70 | raise OptionalDependencyNotAvailable() 71 | except OptionalDependencyNotAvailable: 72 | pass 73 | else: 74 | from .image_processing_siglip import SiglipImageProcessor 75 | 76 | try: 77 | if not is_torch_available(): 78 | raise OptionalDependencyNotAvailable() 79 | except OptionalDependencyNotAvailable: 80 | pass 81 | else: 82 | from .modeling_siglip import ( 83 | SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST, 84 | SiglipModel, 85 | SiglipPreTrainedModel, 86 | SiglipTextModel, 87 | SiglipVisionModel, 88 | ) 89 | 90 | 91 | else: 92 | import sys 93 | 94 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 95 | -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer/rqvaesiglip/siglip/processing_siglip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Image/Text processor class for SigLIP. 17 | """ 18 | 19 | from typing import List, Optional, Union 20 | 21 | from transformers.feature_extraction_utils import BatchFeature 22 | from transformers.image_utils import ImageInput 23 | from transformers.processing_utils import ProcessorMixin 24 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 25 | from transformers.utils import TensorType 26 | 27 | 28 | class SiglipProcessor(ProcessorMixin): 29 | r""" 30 | Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. 31 | 32 | [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the 33 | [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. 34 | 35 | Args: 36 | image_processor ([`SiglipImageProcessor`]): 37 | The image processor is a required input. 38 | tokenizer ([`SiglipTokenizer`]): 39 | The tokenizer is a required input. 40 | """ 41 | 42 | attributes = ["image_processor", "tokenizer"] 43 | image_processor_class = "SiglipImageProcessor" 44 | tokenizer_class = "SiglipTokenizer" 45 | 46 | def __init__(self, image_processor, tokenizer): 47 | super().__init__(image_processor, tokenizer) 48 | 49 | def __call__( 50 | self, 51 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 52 | images: ImageInput = None, 53 | padding: Union[bool, str, PaddingStrategy] = "max_length", 54 | truncation: Union[bool, str, TruncationStrategy] = None, 55 | max_length=None, 56 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 57 | ) -> BatchFeature: 58 | """ 59 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 60 | and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode 61 | the text. To prepare the image(s), this method forwards the `images` argument to 62 | SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 63 | of the above two methods for more information. 64 | 65 | Args: 66 | text (`str`, `List[str]`, `List[List[str]]`): 67 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 68 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 69 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 70 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 71 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 72 | tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a 73 | number of channels, H and W are image height and width. 74 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): 75 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 76 | index) among: 77 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 78 | sequence if provided). 79 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 80 | acceptable input length for the model if that argument is not provided. 81 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 82 | lengths). 83 | max_length (`int`, *optional*): 84 | Maximum length of the returned list and optionally padding length (see above). 85 | truncation (`bool`, *optional*): 86 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 87 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 88 | If set, will return tensors of a particular framework. Acceptable values are: 89 | 90 | - `'tf'`: Return TensorFlow `tf.constant` objects. 91 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 92 | - `'np'`: Return NumPy `np.ndarray` objects. 93 | - `'jax'`: Return JAX `jnp.ndarray` objects. 94 | 95 | Returns: 96 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 97 | 98 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 99 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 100 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 101 | `None`). 102 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 103 | """ 104 | 105 | if text is None and images is None: 106 | raise ValueError("You have to specify either text or images. Both cannot be none.") 107 | 108 | if text is not None: 109 | encoding = self.tokenizer( 110 | text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length 111 | ) 112 | 113 | if images is not None: 114 | image_features = self.image_processor(images, return_tensors=return_tensors) 115 | 116 | if text is not None and images is not None: 117 | encoding["pixel_values"] = image_features.pixel_values 118 | return encoding 119 | elif text is not None: 120 | return encoding 121 | else: 122 | return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) 123 | 124 | def decode(self, *args, **kwargs): 125 | """ 126 | This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to 127 | the docstring of this method for more information. 128 | """ 129 | return self.tokenizer.decode(*args, **kwargs) 130 | 131 | def batch_decode(self, *args, **kwargs): 132 | """ 133 | This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please 134 | refer to the docstring of this method for more information. 135 | """ 136 | return self.tokenizer.batch_decode(*args, **kwargs) 137 | 138 | @property 139 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip 140 | def model_input_names(self): 141 | tokenizer_input_names = self.tokenizer.model_input_names 142 | image_processor_input_names = self.image_processor.model_input_names 143 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 144 | -------------------------------------------------------------------------------- /vila_u/model/multimodal_encoder/rqvaesigliptransformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor, PreTrainedModel, PretrainedConfig 5 | from transformers.image_processing_utils import BaseImageProcessor 6 | 7 | from .rqvaesigliptransformer import RQVAESIGLIPTransformerConfig, RQVAESIGLIPTransformer 8 | 9 | 10 | class RQVAESIGLIPTransformerVisionTower(nn.Module): 11 | def __init__(self, model_name_or_path, config: PretrainedConfig): 12 | super().__init__() 13 | self.config = RQVAESIGLIPTransformerConfig.from_pretrained(model_name_or_path) 14 | self.vision_tower = RQVAESIGLIPTransformer.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype)) 15 | self.is_loaded = True 16 | 17 | if self.config.hidden_size == 1152: 18 | self.image_processor = CLIPImageProcessor( 19 | size={"height": 384, "width": 384}, 20 | crop_size={"height": 384, "width": 384}, 21 | image_mean=[0.5, 0.5, 0.5], 22 | image_std=[0.5, 0.5, 0.5] 23 | ) 24 | self.image_tokens = 729 25 | elif self.config.hidden_size == 1024: 26 | self.image_processor = CLIPImageProcessor( 27 | size={"height": 256, "width": 256}, 28 | crop_size={"height": 256, "width": 256}, 29 | image_mean=[0.5, 0.5, 0.5], 30 | image_std=[0.5, 0.5, 0.5] 31 | ) 32 | self.image_tokens = 256 33 | else: 34 | raise NotImplementedError() 35 | 36 | def forward(self, images, text_vocab_size): 37 | output = self.vision_tower.rqvaesiglip.encode_image(images) 38 | image_features, tokens = output[-1], output[-2] 39 | 40 | bs, patch_size, _, dim = image_features.shape 41 | image_features = torch.reshape(image_features, [bs, patch_size**2, dim]) 42 | tokens = torch.add(torch.reshape(tokens, [bs, patch_size**2, -1]), text_vocab_size) 43 | 44 | return image_features, tokens -------------------------------------------------------------------------------- /vila_u/model/multimodal_projector/base_projector.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | import torch 4 | 5 | from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class MultimodalProjectorConfig(PretrainedConfig): 21 | model_type = "v2l_projector" 22 | 23 | def __init__(self, mm_projector_type: str=None, **kwargs): 24 | super().__init__() 25 | 26 | self.mm_projector_type = mm_projector_type 27 | 28 | 29 | class MultimodalProjector(PreTrainedModel): 30 | config_class = MultimodalProjectorConfig 31 | 32 | def __init__( 33 | self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig 34 | ): 35 | super().__init__(mm_projector_cfg) 36 | mm_projector_type = mm_projector_cfg.mm_projector_type 37 | if mm_projector_type == "identity": 38 | self.layers = IdentityMap() 39 | elif mm_projector_type == "linear": 40 | self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size) 41 | else: 42 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type) 43 | if mlp_gelu_match: 44 | mlp_depth = int(mlp_gelu_match.group(1)) 45 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 46 | for _ in range(1, mlp_depth): 47 | modules.append(nn.GELU()) 48 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 49 | self.layers = nn.Sequential(*modules) 50 | else: 51 | raise ValueError(f"Unknown projector type: {mm_projector_type}") 52 | 53 | def forward(self, x, *args, **kwargs): 54 | return self.layers(x) 55 | 56 | 57 | AutoConfig.register("v2l_projector", MultimodalProjectorConfig) 58 | AutoModel.register(MultimodalProjectorConfig, MultimodalProjector) -------------------------------------------------------------------------------- /vila_u/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import PretrainedConfig, PreTrainedModel 5 | 6 | from .base_projector import MultimodalProjectorConfig, MultimodalProjector 7 | 8 | 9 | def build_mm_projector( 10 | model_type_or_path: str, config: PretrainedConfig 11 | ) -> PreTrainedModel: 12 | if model_type_or_path is None: 13 | return None 14 | 15 | if config.resume_path: 16 | assert os.path.exists( 17 | model_type_or_path 18 | ), f"Resume mm projector path {model_type_or_path} does not exist!" 19 | return MultimodalProjector.from_pretrained( 20 | model_type_or_path, config, torch_dtype=eval(config.model_dtype) 21 | ) 22 | else: 23 | mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) 24 | mm_projector = MultimodalProjector(mm_projector_cfg, config).to( 25 | eval(config.model_dtype) 26 | ) 27 | return mm_projector -------------------------------------------------------------------------------- /vila_u/model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from huggingface_hub import snapshot_download, repo_exists 5 | from huggingface_hub.utils import HFValidationError 6 | from transformers import PretrainedConfig 7 | 8 | 9 | def get_model_config(config): 10 | default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"] 11 | 12 | if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: 13 | root_path = config._name_or_path 14 | else: 15 | root_path = config.resume_path 16 | 17 | if root_path is not None and not osp.exists(root_path): 18 | try: 19 | valid_hf_repo = repo_exists(root_path) 20 | except HFValidationError as e: 21 | valid_hf_repo = False 22 | if valid_hf_repo: 23 | root_path = snapshot_download(root_path) 24 | 25 | return_list = [] 26 | for key in default_keys: 27 | cfg = getattr(config, key, None) 28 | if isinstance(cfg, dict): 29 | try: 30 | return_list.append(os.path.join(root_path, key[:-4])) 31 | except: 32 | raise ValueError(f"Cannot find resume path in config for {key}!") 33 | elif isinstance(cfg, PretrainedConfig): 34 | return_list.append(os.path.join(root_path, key[:-4])) 35 | elif isinstance(cfg, str): 36 | return_list.append(cfg) 37 | 38 | return return_list -------------------------------------------------------------------------------- /vila_u/train/args.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Dict, Optional, Sequence, List 5 | 6 | 7 | @dataclass 8 | class DataArguments: 9 | data_mixture: str = "llava_1_5_mm_align" 10 | image_aspect_ratio: str = "square" 11 | lazy_preprocess: bool = False 12 | vflan_no_system_prompt: bool = False 13 | num_video_frames: int = 8 14 | 15 | 16 | @dataclass 17 | class ModelArguments: 18 | version: Optional[str] = field(default="v0") 19 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 20 | vision_tower: Optional[str] = field(default="google/siglip-so400m-patch14-384") 21 | mm_projector: Optional[str] = field(default="mlp2x_gelu") 22 | mm_use_im_start_end: bool = field(default=False) 23 | mm_use_vi_start_end: bool = field(default=False) 24 | mm_use_im_patch_token: bool = field(default=True) 25 | mm_vision_select_layer: Optional[int] = field(default=-1) 26 | interpolate_mode: Optional[str] = field(default="linear") 27 | drop_path_rate: Optional[float] = field(default=0.) 28 | 29 | 30 | @dataclass 31 | class TrainingArguments(transformers.TrainingArguments): 32 | cache_dir: Optional[str] = field(default=None) 33 | remove_unused_columns: bool = field(default=False) 34 | tune_vision_tower: bool = field(default=False) 35 | tune_language_model: bool = field(default=False) 36 | tune_mm_projector: bool = field(default=False) 37 | chunk_sampler: bool = field(default=False) 38 | model_dtype: str = field(default="torch.bfloat16") 39 | model_max_length: int = field( 40 | default=512, 41 | metadata={ 42 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 43 | }, 44 | ) 45 | mm_projector_lr: Optional[float] = None 46 | group_by_modality_length: bool = field(default=False) 47 | total_time_limit: int = field( 48 | default=-1, metadata={"help": "Timeout limit for this job (in minutes)."} 49 | ) 50 | pre_terminate_time: int = field( 51 | default=10, 52 | metadata={ 53 | "help": "Time to terminate the task inadvance (minutes), saveing checkpoints needs time." 54 | }, 55 | ) -------------------------------------------------------------------------------- /vila_u/train/callbacks/autoresume_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import transformers 5 | 6 | from transformers.utils import logging 7 | 8 | logger = logging.get_logger("transformers") 9 | 10 | 11 | def rank_print(*s): 12 | if not torch.distributed.is_initialized(): 13 | rank = 0 14 | else: 15 | rank =torch.distributed.get_rank() 16 | print(rank, *s) 17 | 18 | sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) 19 | try: 20 | logger.info("Importing AutoResume lib...") 21 | from userlib.auto_resume import AutoResume 22 | 23 | AutoResume.init() 24 | logger.info("Found AutoResume SDK!") 25 | except: 26 | logger.warn("Did not find AutoResume SDK!") 27 | AutoResume = None 28 | 29 | 30 | class AutoResumeCallback(transformers.TrainerCallback): 31 | """ 32 | A [`TrainerCallback`] that handles autoresume. 33 | 34 | Args: 35 | interval: interval (in number of iterations) between checks as to 36 | whether to suspend. 37 | """ 38 | 39 | def __init__(self, interval: int = 50): 40 | self.interval = interval 41 | 42 | def on_step_end(self, args, state, control, **kwargs): 43 | if state.global_step % self.interval == 0: 44 | rank_print("AutoResumeHook: Checking whether to suspend...") 45 | 46 | # Check whether to suspend the job. 47 | should_preempt = AutoResume is not None and AutoResume.termination_requested() 48 | 49 | if should_preempt: 50 | if state.is_local_process_zero: 51 | logger.warn(f"AutoResumeHook: Request resume...") 52 | if AutoResume is not None: 53 | AutoResume.request_resume() 54 | control.should_training_stop = True 55 | control.should_save = True 56 | -------------------------------------------------------------------------------- /vila_u/train/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | import transformers 5 | 6 | from torch.utils.data import Dataset 7 | from transformers import HfArgumentParser, AutoConfig 8 | from transformers import set_seed 9 | from typing import Dict, Tuple, cast 10 | 11 | from vila_u import conversation as conversation_lib 12 | from vila_u.data import make_supervised_data_module 13 | from vila_u.model import VILAULlamaModel, VILAULlamaConfig 14 | from vila_u.model.multimodal_encoder.rqvaesigliptransformer_encoder import RQVAESIGLIPTransformerVisionTower 15 | from vila_u.train.vila_u_trainer import VILAUTrainer 16 | from vila_u.train.args import TrainingArguments, ModelArguments, DataArguments 17 | from vila_u.train.callbacks.autoresume_callback import AutoResumeCallback 18 | from vila_u.train.utils import ( 19 | get_checkpoint_path, 20 | prepare_config_for_training, 21 | mprint, 22 | ) 23 | 24 | local_rank = None 25 | 26 | if "WANDB_PROJECT" not in os.environ: 27 | os.environ["WANDB_PROJECT"] = "VILA-U" 28 | 29 | 30 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 31 | """Collects the state dict and dump to disk.""" 32 | if trainer.deepspeed: 33 | torch.cuda.synchronize() 34 | trainer.save_model(output_dir, _internal_call=True) 35 | return 36 | 37 | state_dict = trainer.model.state_dict() 38 | if trainer.args.should_save: 39 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 40 | del state_dict 41 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 42 | 43 | 44 | def smart_tokenizer_and_embedding_resize( 45 | special_tokens_dict: Dict, 46 | tokenizer: transformers.PreTrainedTokenizer, 47 | model: transformers.PreTrainedModel, 48 | ): 49 | """Resize tokenizer and embedding. 50 | 51 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 52 | """ 53 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 54 | model.resize_token_embeddings(len(tokenizer)) 55 | 56 | if num_new_tokens > 0: 57 | input_embeddings = model.get_input_embeddings().weight.data 58 | output_embeddings = model.get_output_embeddings().weight.data 59 | 60 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 61 | dim=0, keepdim=True 62 | ) 63 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 64 | dim=0, keepdim=True 65 | ) 66 | 67 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 68 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 69 | 70 | 71 | def train(): 72 | global local_rank 73 | 74 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 75 | model_args, data_args, training_args = cast(Tuple[ModelArguments, DataArguments, TrainingArguments], parser.parse_args_into_dataclasses()) 76 | training_args.run_name = training_args.output_dir.split("/")[-1] 77 | local_rank = training_args.local_rank 78 | compute_dtype = ( 79 | torch.float16 80 | if training_args.fp16 81 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 82 | ) 83 | 84 | set_seed(training_args.seed) 85 | 86 | resume_path, continue_training = get_checkpoint_path(training_args.output_dir) 87 | 88 | if not continue_training: 89 | print(f"Models has been ready under {training_args.output_dir}. Skipp training") 90 | exit(0) 91 | 92 | if resume_path: 93 | resume_from_checkpoint = True 94 | config = AutoConfig.from_pretrained(resume_path, trust_remote_code=True) 95 | config.resume_path = resume_path 96 | model_cls = eval(config.architectures[0]) 97 | else: 98 | resume_from_checkpoint = False 99 | model_cls = VILAULlamaModel 100 | config = VILAULlamaConfig.from_pretrained( 101 | model_args.model_name_or_path, 102 | resume=resume_from_checkpoint 103 | ) 104 | if getattr(config, "resume_path", None) is not None: 105 | config.resume_path = model_args.model_name_or_path 106 | 107 | prepare_config_for_training(config, model_args, training_args, data_args) 108 | 109 | model = model_cls( 110 | config=config, 111 | attn_implementation="flash_attention_2", 112 | model_max_length=training_args.model_max_length, 113 | cache_dir=training_args.cache_dir, 114 | ) 115 | 116 | mprint(model) 117 | 118 | model.llm.config.use_cache = False 119 | model.get_llm().requires_grad_(training_args.tune_language_model) 120 | mprint(f"Tunable parameters:\nlanguage model {training_args.tune_language_model}") 121 | 122 | if model.get_vision_tower(): 123 | model.get_vision_tower().requires_grad_(training_args.tune_vision_tower) 124 | model.get_mm_projector().requires_grad_(training_args.tune_mm_projector) 125 | if isinstance(model.get_vision_tower(), RQVAESIGLIPTransformerVisionTower): 126 | model.get_vision_tower().vision_tower.rqvaesiglip.eval() 127 | model.get_vision_tower().vision_tower.rqtransformer.requires_grad_(True) 128 | else: 129 | raise NotImplementedError() 130 | print(f"vision tower {training_args.tune_vision_tower}") 131 | print(f"mm projector {training_args.tune_mm_projector}") 132 | 133 | if not any([training_args.tune_language_model, training_args.tune_vision_tower, training_args.tune_mm_projector]): 134 | logging.warning( 135 | "You are not tuning any part of the model. Please check if this is intended." 136 | ) 137 | 138 | def need_to_modify_do_sample(generation_config): 139 | if generation_config.do_sample is False: 140 | if ( 141 | generation_config.temperature is not None 142 | and generation_config.temperature != 1.0 143 | ): 144 | return True 145 | if generation_config.top_p is not None and generation_config.top_p != 1.0: 146 | return True 147 | return False 148 | 149 | if need_to_modify_do_sample(model.llm.generation_config): 150 | model.llm.generation_config.do_sample = True 151 | 152 | if training_args.gradient_checkpointing: 153 | if hasattr(model.llm, "enable_input_require_grads"): 154 | model.llm.enable_input_require_grads() 155 | else: 156 | 157 | def make_inputs_require_grad(module, input, output): 158 | output.requires_grad_(True) 159 | 160 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 161 | 162 | tokenizer = model.tokenizer 163 | if model_args.version == "v0": 164 | if tokenizer.pad_token is None: 165 | smart_tokenizer_and_embedding_resize( 166 | special_tokens_dict=dict(pad_token="[PAD]"), 167 | tokenizer=tokenizer, 168 | model=model.llm, 169 | ) 170 | else: 171 | tokenizer.pad_token = tokenizer.unk_token 172 | if tokenizer.pad_token is None: 173 | smart_tokenizer_and_embedding_resize( 174 | special_tokens_dict=dict(pad_token="[PAD]"), 175 | tokenizer=tokenizer, 176 | model=model.llm, 177 | ) 178 | if model_args.version in conversation_lib.conv_templates: 179 | conversation_lib.default_conversation = conversation_lib.conv_templates[ 180 | model_args.version 181 | ] 182 | else: 183 | conversation_lib.default_conversation = conversation_lib.conv_templates[ 184 | "vicuna_v1" 185 | ] 186 | 187 | model.llm.pad_token_id = tokenizer.pad_token_id 188 | model.llm.config.tokenizer_padding_side = tokenizer.padding_side 189 | model.llm.config.tokenizer_model_max_length = tokenizer.model_max_length 190 | 191 | vision_tower = model.get_vision_tower() 192 | if vision_tower is not None: 193 | data_args.image_processor = vision_tower.image_processor 194 | data_args.is_multimodal = True 195 | 196 | model.config.num_video_frames = data_args.num_video_frames 197 | model.config.image_aspect_ratio = data_args.image_aspect_ratio 198 | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = ( 199 | model_args.mm_use_im_start_end 200 | ) 201 | model.config.mm_use_vi_start_end = data_args.mm_use_vi_start_end = ( 202 | model_args.mm_use_vi_start_end 203 | ) 204 | model.config.mm_projector_lr = training_args.mm_projector_lr 205 | training_args.use_im_start_end = model_args.mm_use_im_start_end 206 | training_args.use_vi_start_end = model_args.mm_use_vi_start_end 207 | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token 208 | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) 209 | 210 | data_module = make_supervised_data_module( 211 | tokenizer=tokenizer, 212 | data_args=data_args, 213 | training_args=training_args, 214 | ) 215 | callbacks = [AutoResumeCallback()] 216 | trainer = VILAUTrainer( 217 | model=model, 218 | tokenizer=tokenizer, 219 | args=training_args, 220 | callbacks=callbacks, 221 | **data_module, 222 | ) 223 | 224 | print( 225 | "length of dataloader:", 226 | len(trainer.get_train_dataloader()), 227 | len(trainer.train_dataset), 228 | flush=True, 229 | ) 230 | print( 231 | "[GPU memory] before trainer", 232 | torch.cuda.memory_allocated() / 1024 / 1024 / 1024, 233 | flush=True, 234 | ) 235 | 236 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 237 | trainer.save_state() 238 | 239 | model.llm.config.use_cache = True 240 | model.config.resume_path = model.config._name_or_path = training_args.output_dir 241 | safe_save_model_for_hf_trainer( 242 | trainer=trainer, output_dir=training_args.output_dir 243 | ) 244 | 245 | if __name__ == "__main__": 246 | train() -------------------------------------------------------------------------------- /vila_u/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | from vila_u.train.train import train 3 | from vila_u.train.transformer_normalize_monkey_patch import patched_normalize 4 | 5 | 6 | def __len__(self): 7 | return len(self.batch_sampler) 8 | 9 | 10 | def __iter__(self): 11 | return self.batch_sampler.__iter__() 12 | 13 | if __name__ == "__main__": 14 | with ( 15 | mock.patch('transformers.image_processing_utils.normalize', new=patched_normalize), 16 | mock.patch('accelerate.data_loader.BatchSamplerShard.__len__', new=__len__), 17 | mock.patch('accelerate.data_loader.BatchSamplerShard.__iter__', new=__iter__) 18 | ): 19 | train() 20 | -------------------------------------------------------------------------------- /vila_u/train/transformer_normalize_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers.image_transforms import np, Union, Iterable, Optional, ChannelDimension, \ 3 | infer_channel_dimension_format, get_channel_dimension_axis, to_channel_dimension_format 4 | 5 | def patched_normalize( 6 | image: np.ndarray, 7 | mean: Union[float, Iterable[float]], 8 | std: Union[float, Iterable[float]], 9 | data_format: Optional[ChannelDimension] = None, 10 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 11 | ) -> np.ndarray: 12 | """ 13 | Normalizes `image` using the mean and standard deviation specified by `mean` and `std`. 14 | 15 | image = (image - mean) / std 16 | 17 | Args: 18 | image (`np.ndarray`): 19 | The image to normalize. 20 | mean (`float` or `Iterable[float]`): 21 | The mean to use for normalization. 22 | std (`float` or `Iterable[float]`): 23 | The standard deviation to use for normalization. 24 | data_format (`ChannelDimension`, *optional*): 25 | The channel dimension format of the output image. If unset, will use the inferred format from the input. 26 | """ 27 | if not isinstance(image, np.ndarray): 28 | raise ValueError("image must be a numpy array") 29 | 30 | input_data_format = infer_channel_dimension_format(image) 31 | channel_axis = get_channel_dimension_axis(image) 32 | num_channels = image.shape[channel_axis] 33 | 34 | if isinstance(mean, Iterable): 35 | if len(mean) != num_channels: 36 | if num_channels == 1: 37 | num_channels = 3 38 | image = np.concatenate([image, image, image], axis=channel_axis) 39 | else: 40 | raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") 41 | else: 42 | mean = [mean] * num_channels 43 | mean = np.array(mean, dtype=image.dtype) 44 | 45 | if isinstance(std, Iterable): 46 | if len(std) != num_channels: 47 | raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") 48 | else: 49 | std = [std] * num_channels 50 | std = np.array(std, dtype=image.dtype) 51 | 52 | if input_data_format == ChannelDimension.LAST: 53 | image = (image - mean) / std 54 | else: 55 | image = ((image.T - mean) / std).T 56 | 57 | image = to_channel_dimension_format(image, data_format) if data_format is not None else image 58 | return image 59 | 60 | 61 | def patch_normalize_preprocess(): 62 | transformers.image_transforms.normalize = patched_normalize -------------------------------------------------------------------------------- /vila_u/train/transformers_replace/models/llama/configuring_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from ...configuration_utils import PretrainedConfig 23 | from ...utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LlamaConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`LlamaModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer decoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer decoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to 2048): 64 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 65 | Llama 2 up to 4096, CodeLlama up to 16384. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | Padding token id. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | Beginning of stream token id. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | End of stream token id. 79 | pretraining_tp (`int`, *optional*, defaults to 1): 80 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 81 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 82 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 83 | issue](https://github.com/pytorch/pytorch/issues/76232). 84 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 85 | Whether to tie weight embeddings 86 | rope_theta (`float`, *optional*, defaults to 10000.0): 87 | The base period of the RoPE embeddings. 88 | rope_scaling (`Dict`, *optional*): 89 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 90 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 91 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 92 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 93 | these scaling strategies behave: 94 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 95 | experimental feature, subject to breaking API changes in future versions. 96 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 97 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 98 | attention_dropout (`float`, *optional*, defaults to 0.0): 99 | The dropout ratio for the attention probabilities. 100 | 101 | ```python 102 | >>> from transformers import LlamaModel, LlamaConfig 103 | 104 | >>> # Initializing a LLaMA llama-7b style configuration 105 | >>> configuration = LlamaConfig() 106 | 107 | >>> # Initializing a model from the llama-7b style configuration 108 | >>> model = LlamaModel(configuration) 109 | 110 | >>> # Accessing the model configuration 111 | >>> configuration = model.config 112 | ```""" 113 | model_type = "llama" 114 | keys_to_ignore_at_inference = ["past_key_values"] 115 | 116 | def __init__( 117 | self, 118 | vocab_size=32000, 119 | hidden_size=4096, 120 | intermediate_size=11008, 121 | num_hidden_layers=32, 122 | num_attention_heads=32, 123 | num_key_value_heads=None, 124 | hidden_act="silu", 125 | max_position_embeddings=2048, 126 | initializer_range=0.02, 127 | rms_norm_eps=1e-6, 128 | use_cache=True, 129 | pad_token_id=None, 130 | bos_token_id=1, 131 | eos_token_id=2, 132 | pretraining_tp=1, 133 | tie_word_embeddings=False, 134 | rope_theta=10000.0, 135 | rope_scaling=None, 136 | attention_bias=False, 137 | attention_dropout=0.0, 138 | **kwargs, 139 | ): 140 | self.vocab_size = vocab_size 141 | self.max_position_embeddings = max_position_embeddings 142 | self.hidden_size = hidden_size 143 | self.intermediate_size = intermediate_size 144 | self.num_hidden_layers = num_hidden_layers 145 | self.num_attention_heads = num_attention_heads 146 | 147 | # for backward compatibility 148 | if num_key_value_heads is None: 149 | num_key_value_heads = num_attention_heads 150 | 151 | self.num_key_value_heads = num_key_value_heads 152 | self.hidden_act = hidden_act 153 | self.initializer_range = initializer_range 154 | self.rms_norm_eps = rms_norm_eps 155 | self.pretraining_tp = pretraining_tp 156 | self.use_cache = use_cache 157 | self.rope_theta = rope_theta 158 | self.rope_scaling = rope_scaling 159 | self._rope_scaling_validation() 160 | self.attention_bias = attention_bias 161 | self.attention_dropout = attention_dropout 162 | 163 | super().__init__( 164 | pad_token_id=pad_token_id, 165 | bos_token_id=bos_token_id, 166 | eos_token_id=eos_token_id, 167 | tie_word_embeddings=tie_word_embeddings, 168 | **kwargs, 169 | ) 170 | 171 | def _rope_scaling_validation(self): 172 | """ 173 | Validate the `rope_scaling` configuration. 174 | """ 175 | if self.rope_scaling is None: 176 | return 177 | 178 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 179 | raise ValueError( 180 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 181 | f"got {self.rope_scaling}" 182 | ) 183 | rope_scaling_type = self.rope_scaling.get("type", None) 184 | rope_scaling_factor = self.rope_scaling.get("factor", None) 185 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 186 | raise ValueError( 187 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 188 | ) 189 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 190 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 191 | -------------------------------------------------------------------------------- /vila_u/train/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import re 4 | import torch 5 | 6 | from dataclasses import dataclass 7 | from transformers import PretrainedConfig 8 | 9 | 10 | def rprint(*args, **kwargs): 11 | rank = int(os.environ.get("RANK", 0)) 12 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 13 | if world_size > 1: 14 | return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) 15 | else: 16 | return print(*args, **kwargs) 17 | 18 | 19 | def mprint(*args, **kwargs): 20 | rank = int(os.environ.get("RANK", 0)) 21 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 22 | if world_size > 1: 23 | if rank == 0: 24 | return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) 25 | else: 26 | return 27 | else: 28 | return print(*args, **kwargs) 29 | 30 | 31 | def is_local(model_name_or_path: str) -> bool: 32 | return os.path.isdir(model_name_or_path) 33 | 34 | 35 | def get_checkpoint_path( 36 | output_dir: str, checkpoint_prefix: str = "checkpoint" 37 | ) -> str | None: 38 | output_dir = os.path.abspath(output_dir) 39 | pathlib_dir = pathlib.Path(output_dir) 40 | 41 | if list(pathlib_dir.glob("config.json")): 42 | return output_dir, False 43 | else: 44 | try: 45 | ordering_and_checkpoint_path = [] 46 | glob_checkpoints = [ 47 | str(x) 48 | for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*") 49 | if os.path.isdir(x) 50 | ] 51 | for path in glob_checkpoints: 52 | regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) 53 | if regex_match is not None and regex_match.groups() is not None: 54 | ordering_and_checkpoint_path.append( 55 | (int(regex_match.groups()[0]), path) 56 | ) 57 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 58 | return checkpoints_sorted[-1][1], True 59 | except: 60 | return None, True 61 | 62 | 63 | def prepare_config_for_training( 64 | config: PretrainedConfig, 65 | model_args: dataclass, 66 | training_args: dataclass, 67 | data_args: dataclass, 68 | ) -> None: 69 | assert model_args.vision_tower is not None, "requires vision tower" 70 | 71 | if getattr(config, "llm_cfg", None) is None: 72 | config.llm_cfg = model_args.model_name_or_path 73 | if getattr(config, "vision_tower_cfg", None) is None: 74 | config.vision_tower_cfg = model_args.vision_tower 75 | if getattr(config, "mm_projector_cfg", None) is None: 76 | config.mm_projector_cfg = model_args.mm_projector 77 | 78 | config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16 79 | config.model_dtype = config.model_dtype.__str__() 80 | 81 | config.tune_language_model = training_args.tune_language_model 82 | config.tune_vision_tower = training_args.tune_vision_tower 83 | config.tune_mm_projector = training_args.tune_mm_projector 84 | 85 | config.image_aspect_ratio = data_args.image_aspect_ratio 86 | 87 | if getattr(config, "vision_tower_cfg", None) is not None: 88 | config.mm_vision_select_layer = model_args.mm_vision_select_layer 89 | config.interpolate_mode = model_args.interpolate_mode 90 | config.drop_path_rate = model_args.drop_path_rate -------------------------------------------------------------------------------- /vila_u/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /vila_u/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | from torch import distributed as dist 5 | from typing import Any, List, Optional 6 | 7 | 8 | __all__ = [ 9 | "init", 10 | "is_initialized", 11 | "size", 12 | "rank", 13 | "local_size", 14 | "local_rank", 15 | "is_main", 16 | "barrier", 17 | "gather", 18 | "all_gather", 19 | ] 20 | 21 | 22 | def init() -> None: 23 | if "RANK" not in os.environ: 24 | warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") 25 | return 26 | dist.init_process_group(backend="nccl", init_method="env://") 27 | 28 | 29 | def is_initialized() -> bool: 30 | return dist.is_initialized() 31 | 32 | 33 | def size() -> int: 34 | return int(os.environ.get("WORLD_SIZE", 1)) 35 | 36 | 37 | def rank() -> int: 38 | return int(os.environ.get("RANK", 0)) 39 | 40 | 41 | def local_size() -> int: 42 | return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) 43 | 44 | 45 | def local_rank() -> int: 46 | return int(os.environ.get("LOCAL_RANK", 0)) 47 | 48 | 49 | def is_main() -> bool: 50 | return rank() == 0 51 | 52 | 53 | def barrier() -> None: 54 | dist.barrier() 55 | 56 | 57 | def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: 58 | if is_main(): 59 | objs = [None for _ in range(size())] 60 | dist.gather_object(obj, objs, dst=dst) 61 | return objs 62 | else: 63 | dist.gather_object(obj, dst=dst) 64 | return None 65 | 66 | 67 | def all_gather(obj: Any) -> List[Any]: 68 | objs = [None for _ in range(size())] 69 | dist.all_gather_object(objs, obj) 70 | return objs 71 | -------------------------------------------------------------------------------- /vila_u/utils/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import pickle 5 | import torch 6 | import yaml 7 | 8 | from contextlib import contextmanager 9 | from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, TextIO, Union 10 | 11 | 12 | __all__ = [ 13 | "load", 14 | "save", 15 | "load_json", 16 | "save_json", 17 | "load_jsonl", 18 | "save_jsonl", 19 | "load_mat", 20 | "save_mat", 21 | "load_npy", 22 | "save_npy", 23 | "load_npz", 24 | "save_npz", 25 | "load_pt", 26 | "save_pt", 27 | "load_yaml", 28 | "save_yaml", 29 | ] 30 | 31 | 32 | @contextmanager 33 | def file_descriptor(f: Union[str, IO], mode: str = "r") -> Iterator[IO]: 34 | opened = False 35 | try: 36 | if isinstance(f, str): 37 | f = open(f, mode) 38 | opened = True 39 | yield f 40 | finally: 41 | if opened: 42 | f.close() 43 | 44 | 45 | def load_json(f: Union[str, TextIO], **kwargs) -> Any: 46 | with file_descriptor(f, mode="r") as fd: 47 | return json.load(fd, **kwargs) 48 | 49 | 50 | def save_json(f: Union[str, TextIO], obj: Any, **kwargs) -> None: 51 | with file_descriptor(f, mode="w") as fd: 52 | json.dump(obj, fd, **kwargs) 53 | 54 | 55 | def load_jsonl(f: Union[str, TextIO], **kwargs) -> Any: 56 | with file_descriptor(f, mode="r") as fd: 57 | return [json.loads(datum, **kwargs) for datum in fd.readlines()] 58 | 59 | 60 | def save_jsonl(f: Union[str, TextIO], obj: Any, **kwargs) -> None: 61 | with file_descriptor(f, mode="w") as fd: 62 | fd.write("\n".join(json.dumps(datum, **kwargs) for datum in obj)) 63 | 64 | 65 | def load_mat(f: Union[str, BinaryIO], **kwargs) -> Any: 66 | import scipy.io 67 | 68 | return scipy.io.loadmat(f, **kwargs) 69 | 70 | 71 | def save_mat(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: 72 | import scipy.io 73 | 74 | scipy.io.savemat(f, obj, **kwargs) 75 | 76 | 77 | def load_npy(f: Union[str, BinaryIO], **kwargs) -> Any: 78 | return np.load(f, **kwargs) 79 | 80 | 81 | def save_npy(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: 82 | np.save(f, obj, **kwargs) 83 | 84 | 85 | def load_npz(f: Union[str, BinaryIO], **kwargs) -> Any: 86 | return np.load(f, **kwargs) 87 | 88 | 89 | def save_npz(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: 90 | np.savez(f, obj, **kwargs) 91 | 92 | 93 | def load_pkl(f: Union[str, BinaryIO], **kwargs) -> Any: 94 | with file_descriptor(f, mode="rb") as fd: 95 | try: 96 | return pickle.load(fd, **kwargs) 97 | except UnicodeDecodeError: 98 | if "encoding" in kwargs: 99 | raise 100 | fd.seek(0) 101 | return pickle.load(fd, encoding="latin1", **kwargs) 102 | 103 | 104 | def save_pkl(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: 105 | with file_descriptor(f, mode="wb") as fd: 106 | pickle.dump(obj, fd, **kwargs) 107 | 108 | 109 | def load_pt(f: Union[str, BinaryIO], **kwargs) -> Any: 110 | return torch.load(f, **kwargs) 111 | 112 | 113 | def save_pt(f: Union[str, BinaryIO], obj: Any, **kwargs) -> None: 114 | torch.save(obj, f, **kwargs) 115 | 116 | 117 | def load_yaml(f: Union[str, TextIO]) -> Any: 118 | with file_descriptor(f, mode="r") as fd: 119 | return yaml.safe_load(fd) 120 | 121 | 122 | def save_yaml(f: Union[str, TextIO], obj: Any, **kwargs) -> None: 123 | with file_descriptor(f, mode="w") as fd: 124 | yaml.safe_dump(obj, fd, **kwargs) 125 | 126 | 127 | def load_txt(f: Union[str, TextIO]) -> Any: 128 | with file_descriptor(f, mode="r") as fd: 129 | return fd.read() 130 | 131 | 132 | def save_txt(f: Union[str, TextIO], obj: Any, **kwargs) -> None: 133 | with file_descriptor(f, mode="w") as fd: 134 | fd.write(obj) 135 | 136 | 137 | __io_registry: Dict[str, Dict[str, Callable]] = { 138 | ".txt": {"load": load_txt, "save": save_txt}, 139 | ".json": {"load": load_json, "save": save_json}, 140 | ".jsonl": {"load": load_jsonl, "save": save_jsonl}, 141 | ".mat": {"load": load_mat, "save": save_mat}, 142 | ".npy": {"load": load_npy, "save": save_npy}, 143 | ".npz": {"load": load_npz, "save": save_npz}, 144 | ".pkl": {"load": load_pkl, "save": save_pkl}, 145 | ".pt": {"load": load_pt, "save": save_pt}, 146 | ".pth": {"load": load_pt, "save": save_pt}, 147 | ".pth.tar": {"load": load_pt, "save": save_pt}, 148 | ".yaml": {"load": load_yaml, "save": save_yaml}, 149 | ".yml": {"load": load_yaml, "save": save_yaml}, 150 | } 151 | 152 | 153 | def load(fpath: str, **kwargs) -> Any: 154 | assert isinstance(fpath, str), type(fpath) 155 | 156 | for extension in sorted(__io_registry.keys(), key=len, reverse=True): 157 | if fpath.endswith(extension) and "load" in __io_registry[extension]: 158 | return __io_registry[extension]["load"](fpath, **kwargs) 159 | 160 | raise NotImplementedError(f'"{fpath}" cannot be loaded.') 161 | 162 | 163 | def save(fpath: str, obj: Any, **kwargs) -> None: 164 | assert isinstance(fpath, str), type(fpath) 165 | os.makedirs(os.path.dirname(fpath), exist_ok=True) 166 | 167 | for extension in sorted(__io_registry.keys(), key=len, reverse=True): 168 | if fpath.endswith(extension) and "save" in __io_registry[extension]: 169 | __io_registry[extension]["save"](fpath, obj, **kwargs) 170 | return 171 | 172 | raise NotImplementedError(f'"{fpath}" cannot be saved.') 173 | -------------------------------------------------------------------------------- /vila_u/utils/logging.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | if typing.TYPE_CHECKING: 4 | from loguru import Logger 5 | else: 6 | Logger = None 7 | 8 | __all__ = ["logger"] 9 | 10 | 11 | def __get_logger() -> Logger: 12 | from loguru import logger 13 | 14 | return logger 15 | 16 | 17 | logger = __get_logger() -------------------------------------------------------------------------------- /vila_u/utils/media.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import numpy as np 4 | import os 5 | import requests 6 | import PIL 7 | import PIL.Image 8 | 9 | from collections import defaultdict 10 | from transformers import PretrainedConfig 11 | from typing import Any, Dict, List, Optional, Union 12 | 13 | from vila_u.constants import DEFAULT_IMAGE_TOKEN 14 | from vila_u.media import Image, Video 15 | from vila_u.utils import make_list 16 | from vila_u.utils.logging import logger 17 | 18 | __all__ = ["extract_media"] 19 | 20 | 21 | def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: 22 | if isinstance(image, Image): 23 | if image.path.startswith("http://") or image.path.startswith("https://"): 24 | image = PIL.Image.open(requests.get(image.path, stream=True).raw) 25 | else: 26 | image = PIL.Image.open(image.path) 27 | return image 28 | 29 | 30 | def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: 31 | # Load video frames from a directory 32 | if os.path.isdir(video_path): 33 | frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) 34 | indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) 35 | return [PIL.Image.open(frame_paths[index]) for index in indices] 36 | 37 | # Load video frames from a video file 38 | vidcap = cv2.VideoCapture(video_path) 39 | 40 | # Find the last frame as frame count might not be accurate 41 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 42 | while frame_count > 0: 43 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) 44 | if vidcap.grab(): 45 | break 46 | frame_count -= 1 47 | else: 48 | raise ValueError(f"Video '{video_path}' has no frames.") 49 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, 0) 50 | 51 | # Extract frames uniformly 52 | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) 53 | frames = {} 54 | for index in range(frame_count): 55 | success = vidcap.grab() 56 | if not success: 57 | raise ValueError(f"Failed to grab frame {index} from video '{video_path}'.") 58 | if index not in indices: 59 | continue 60 | success, frame = vidcap.retrieve() 61 | if not success: 62 | logger.warning(f"Failed to retrieve frame {index} from video '{video_path}'. Skipped.") 63 | continue 64 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 65 | frames[index] = PIL.Image.fromarray(frame) 66 | return [frames[index] for index in indices if index in frames] 67 | 68 | 69 | def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: 70 | num_frames = config.num_video_frames 71 | 72 | frames = _load_video(video.path, num_frames=num_frames) 73 | return frames 74 | 75 | 76 | def extract_media( 77 | messages: List[Dict[str, Any]], 78 | config: Optional[PretrainedConfig] = None, 79 | draft: bool = False, 80 | ) -> Dict[str, List[Any]]: 81 | media = defaultdict(list) 82 | for message in messages: 83 | text = "" 84 | for part in make_list(message["value"]): 85 | if isinstance(part, str): 86 | text += part 87 | elif isinstance(part, (Image, PIL.Image.Image)): 88 | if draft: 89 | media["image"].append(part) 90 | else: 91 | image = _extract_image(part) 92 | text += DEFAULT_IMAGE_TOKEN + "\n" 93 | media["image"].append(image) 94 | elif isinstance(part, Video): 95 | if draft: 96 | media["video"].append(part) 97 | else: 98 | video = _extract_video(part, config) 99 | text += (DEFAULT_IMAGE_TOKEN + "\n") * len(video) 100 | media["image"].extend(video) 101 | else: 102 | raise ValueError(f"Unsupported prompt part type: {type(part)}") 103 | message["value"] = text 104 | 105 | return media -------------------------------------------------------------------------------- /vila_u/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | from typing import Any, Dict, List, Optional, Sequence 5 | 6 | from vila_u import conversation as conversation_lib 7 | from vila_u.constants import IGNORE_INDEX, SENTINEL_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_VI_START_TOKEN 8 | from vila_u.mm_utils import tokenizer_image_token 9 | from vila_u.utils.logging import logger 10 | 11 | __all__ = [ 12 | "tokenize_conversation", 13 | ] 14 | 15 | DUMMY_CONVERSATION = [ 16 | {"from": "human", "value": "question"}, 17 | {"from": "gpt", "value": "answer"}, 18 | ] * 10 19 | 20 | def tokenize_conversation( 21 | messages: Sequence[Dict[str, str]], 22 | tokenizer: transformers.PreTrainedTokenizer, 23 | add_generation_prompt: bool = False, 24 | overrides: Optional[Dict[str, str]] = None, 25 | no_system_prompt: bool = False, 26 | image_generation: bool = False, 27 | video_generation: bool = False, 28 | ) -> torch.Tensor: 29 | for message in messages: 30 | message["value"] = message["value"].strip() 31 | 32 | conv = conversation_lib.default_conversation.copy() 33 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 34 | 35 | if no_system_prompt: 36 | conv.system = "" 37 | 38 | # Skip the first message if it is not from human 39 | if messages[0]["from"] != "human": 40 | messages = messages[1:] 41 | 42 | # Add a generation prompt if needed 43 | if add_generation_prompt: 44 | messages.append({"from": "gpt", "value": None}) 45 | 46 | conv.messages = [] 47 | for turn, message in enumerate(messages): 48 | role = roles[message["from"]] 49 | assert role == conv.roles[turn % 2] 50 | if overrides is not None and message["from"] in overrides: 51 | conv.append_message(role, overrides[message["from"]]) 52 | else: 53 | conv.append_message(role, message["value"]) 54 | 55 | prompt = conv.get_prompt() 56 | if image_generation: 57 | prompt += f" {DEFAULT_IM_START_TOKEN}" 58 | elif video_generation: 59 | prompt += f" {DEFAULT_VI_START_TOKEN}" 60 | else: 61 | pass 62 | 63 | return tokenizer_image_token(prompt, tokenizer, return_tensors="pt") 64 | 65 | def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: 66 | if not hasattr(tokenizer, "sentinel_token"): 67 | tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) 68 | tokenizer.sentinel_token = SENTINEL_TOKEN 69 | tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) 70 | 71 | def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: 72 | _maybe_add_sentinel_token(tokenizer) 73 | template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) 74 | 75 | stop_tokens = {tokenizer.eos_token} 76 | for k in range(template.size(0) - 1): 77 | if template[k] == tokenizer.sentinel_token_id: 78 | stop_token = tokenizer.decode(template[k + 1]) 79 | stop_tokens.add(stop_token) 80 | return list(stop_tokens) 81 | -------------------------------------------------------------------------------- /vila_u/utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | __all__ = ["make_list", "disable_torch_init"] 4 | 5 | 6 | def make_list(obj: Any) -> List: 7 | return obj if isinstance(obj, list) else [obj] 8 | 9 | 10 | def disable_torch_init(): 11 | """ 12 | Disable the redundant torch default initialization to accelerate model creation. 13 | """ 14 | import torch 15 | 16 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 17 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) -------------------------------------------------------------------------------- /vila_u/wids/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. 2 | # This file is part of the WebDataset library. 3 | # See the LICENSE file for licensing terms (BSD-style). 4 | # 5 | # flake8: noqa 6 | 7 | from .wids import ( 8 | ChunkedSampler, 9 | DistributedChunkedSampler, 10 | ShardedSampler, 11 | ShardListDataset, 12 | DistributedLocalSampler, 13 | ) 14 | -------------------------------------------------------------------------------- /vila_u/wids/wids_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from . import wids 5 | from .compat import WebDataset 6 | 7 | 8 | def main_wids(args): 9 | desc = json.load(open(args.dataset)) 10 | files = desc["files"] 11 | dataset = wids.ShardListDataset(files, cache_size=4) 12 | print(len(dataset)) 13 | for i in range(len(dataset)): 14 | print(i, dataset[i]["__key__"]) 15 | dataset.close() 16 | 17 | 18 | def main_wds(args): 19 | desc = json.load(open(args.dataset)) 20 | files = desc["files"] 21 | urls = [f["url"] for f in files] 22 | dataset = WebDataset(urls) 23 | for i, sample in enumerate(dataset): 24 | print(i, sample["__key__"]) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | # there are two subcommands: wids and wds 30 | subparsers = parser.add_subparsers(dest="command") 31 | wids_parser = subparsers.add_parser("wids") 32 | wds_parser = subparsers.add_parser("wds") 33 | 34 | # wids subcommand 35 | wids_parser.add_argument("dataset", help="dataset name") 36 | 37 | # wds subcommand 38 | wds_parser.add_argument("dataset", help="dataset name") 39 | 40 | args = parser.parse_args() 41 | 42 | if args.command == "wids": 43 | main_wids(args) 44 | elif args.command == "wds": 45 | main_wds(args) 46 | else: 47 | raise ValueError(f"Unknown command: {args.command}") 48 | -------------------------------------------------------------------------------- /vila_u/wids/wids_cleanup.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides utilities for managing files in a directory. 3 | 4 | It includes a function `keep_most_recent_files` that keeps the most recent 5 | files in a directory, deleting the rest based on the maximum size of the directory 6 | in bytes and the maximum number of files to keep. 7 | 8 | The cleanup job can be run in the background using `create_cleanup_background_process`. 9 | """ 10 | 11 | import fcntl 12 | import glob 13 | import os 14 | import time 15 | 16 | import numpy as np 17 | 18 | 19 | def keep_most_recent_files(pattern, maxsize=int(1e12), maxfiles=1000, debug=False): 20 | """Keep the most recent files in a directory, deleting the rest. 21 | 22 | The maxsize is the maximum size of the directory in bytes. The maxfiles is 23 | the maximum number of files to keep. The files are sorted by modification 24 | time, and the most recent files are kept. If the directory is already 25 | smaller than maxsize, then no files are deleted. If there are fewer than 26 | maxfiles, then no files are deleted.""" 27 | 28 | # get the list of files in the directory 29 | fnames = glob.glob(pattern) 30 | # compute a list of (mtime, fname, size) triples 31 | files = [] 32 | for fname in fnames: 33 | try: 34 | s = os.stat(fname) 35 | except FileNotFoundError: 36 | continue 37 | files.append((s.st_mtime, fname, s.st_size)) 38 | # sort the list by mtime, most recent first 39 | files.sort(reverse=True) 40 | # compute an accumulated total of the file sizes in order using np.cumsum 41 | sizes = np.cumsum([size for mtime, fname, size in files]) 42 | # compute a cutoff index based on maxsize 43 | cutoff = np.searchsorted(sizes, maxsize) 44 | # compute a cutoff index based on maxfiles 45 | cutoff = min(cutoff, maxfiles) 46 | # delete the files above the cutoff in reverse order 47 | for mtime, fname, size in files[cutoff:][::-1]: 48 | try: 49 | os.unlink(fname) 50 | except FileNotFoundError: 51 | pass 52 | 53 | 54 | class ExclusiveLock: 55 | """A simple non-blocking exclusive lock using fcntl.""" 56 | 57 | def __init__(self, lockfile): 58 | self.lockfile = lockfile 59 | 60 | def try_lock(self): 61 | try: 62 | self.lock = open(self.lockfile, "w") 63 | fcntl.flock(self.lock.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) 64 | return True 65 | except OSError as e: 66 | if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK): 67 | return False 68 | else: 69 | raise 70 | 71 | def release_lock(self): 72 | self.lock.close() 73 | os.unlink(self.lockfile) 74 | 75 | 76 | def create_cleanup_background_process( 77 | pattern, maxsize=int(1e12), maxfiles=1000, every=60 78 | ): 79 | """Create a background process that keeps a directory below a certain size.""" 80 | 81 | def cleanup_worker(every): 82 | # use a lock file to ensure that only one cleanup worker is running 83 | lockfile = os.path.join(os.path.dirname(pattern), ".cleanup.lock") 84 | lock = ExclusiveLock(lockfile) 85 | if not lock.try_lock(): 86 | return 87 | while True: 88 | keep_most_recent_files(pattern, maxsize=maxsize, maxfiles=maxfiles) 89 | time.sleep(every) 90 | 91 | import multiprocessing 92 | 93 | p = multiprocessing.Process(target=cleanup_worker, args=(every,)) 94 | p.start() 95 | return p 96 | -------------------------------------------------------------------------------- /vila_u/wids/wids_dir.py: -------------------------------------------------------------------------------- 1 | """ 2 | # dynamically create a shard index 3 | class DirectoryDataset(ShardListDataset): 4 | def __init__(self, directory): 5 | pass 6 | """ 7 | 8 | """ 9 | # randomly choose shards from a directory 10 | class DirectoryQueueDataset(IterableDataset): 11 | def __init__(self, directory, strategy="replace", choice="random", downloader=None, transformations="PIL"): 12 | pass 13 | def add_transform(self, transform): 14 | pass 15 | def __iter__(self): 16 | # pick file according to strategy 17 | # rename file to .active 18 | # randomly yield samples from file 19 | # rename file back to its original name or unlink it, according to strategy 20 | pass 21 | """ 22 | -------------------------------------------------------------------------------- /vila_u/wids/wids_dl.py: -------------------------------------------------------------------------------- 1 | import fcntl 2 | import os 3 | import shutil 4 | import sys 5 | import time 6 | from collections import deque 7 | from datetime import datetime 8 | from urllib.parse import urlparse 9 | 10 | recent_downloads = deque(maxlen=1000) 11 | 12 | open_objects = {} 13 | max_open_objects = 100 14 | 15 | 16 | class ULockFile: 17 | """A simple locking class. We don't need any of the third 18 | party libraries since we rely on POSIX semantics for linking 19 | below anyway.""" 20 | 21 | def __init__(self, path): 22 | self.lockfile_path = path 23 | self.lockfile = None 24 | 25 | def __enter__(self): 26 | self.lockfile = open(self.lockfile_path, "w") 27 | fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX) 28 | return self 29 | 30 | def __exit__(self, exc_type, exc_val, exc_tb): 31 | fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN) 32 | self.lockfile.close() 33 | self.lockfile = None 34 | try: 35 | os.unlink(self.lockfile_path) 36 | except FileNotFoundError: 37 | pass 38 | 39 | 40 | def pipe_download(remote, local): 41 | """Perform a download for a pipe: url.""" 42 | assert remote.startswith("pipe:") 43 | cmd = remote[5:] 44 | cmd = cmd.format(local=local) 45 | assert os.system(cmd) == 0, "Command failed: %s" % cmd 46 | 47 | 48 | def copy_file(remote, local): 49 | remote = urlparse(remote) 50 | assert remote.scheme in ["file", ""] 51 | # use absolute path 52 | remote = os.path.abspath(remote.path) 53 | local = urlparse(local) 54 | assert local.scheme in ["file", ""] 55 | local = os.path.abspath(local.path) 56 | if remote == local: 57 | return 58 | # check if the local file exists 59 | shutil.copyfile(remote, local) 60 | 61 | 62 | verbose_cmd = int(os.environ.get("WIDS_VERBOSE_CMD", "0")) 63 | 64 | 65 | def vcmd(flag, verbose_flag=""): 66 | return verbose_flag if verbose_cmd else flag 67 | 68 | 69 | default_cmds = { 70 | "posixpath": copy_file, 71 | "file": copy_file, 72 | "pipe": pipe_download, 73 | "http": "curl " + vcmd("-s") + " -L {url} -o {local}", 74 | "https": "curl " + vcmd("-s") + " -L {url} -o {local}", 75 | "ftp": "curl " + vcmd("-s") + " -L {url} -o {local}", 76 | "ftps": "curl " + vcmd("-s") + " -L {url} -o {local}", 77 | "gs": "gsutil " + vcmd("-q") + " cp {url} {local}", 78 | "s3": "aws s3 cp {url} {local}", 79 | } 80 | 81 | #TODO(ligeng): change HTTPS download to python requests library 82 | 83 | def download_file_no_log(remote, local, handlers=default_cmds): 84 | """Download a file from a remote url to a local path. 85 | The remote url can be a pipe: url, in which case the remainder of 86 | the url is treated as a command template that is executed to perform the download. 87 | """ 88 | 89 | if remote.startswith("pipe:"): 90 | schema = "pipe" 91 | else: 92 | schema = urlparse(remote).scheme 93 | if schema is None or schema == "": 94 | schema = "posixpath" 95 | # get the handler 96 | handler = handlers.get(schema) 97 | if handler is None: 98 | raise ValueError("Unknown schema: %s" % schema) 99 | # call the handler 100 | if callable(handler): 101 | handler(remote, local) 102 | else: 103 | assert isinstance(handler, str) 104 | cmd = handler.format(url=remote, local=local) 105 | assert os.system(cmd) == 0, "Command failed: %s" % cmd 106 | return local 107 | 108 | 109 | def download_file(remote, local, handlers=default_cmds, verbose=False): 110 | start = time.time() 111 | try: 112 | return download_file_no_log(remote, local, handlers=handlers) 113 | finally: 114 | recent_downloads.append((remote, local, time.time(), time.time() - start)) 115 | if verbose: 116 | print( 117 | "downloaded", 118 | remote, 119 | "to", 120 | local, 121 | "in", 122 | time.time() - start, 123 | "seconds", 124 | file=sys.stderr, 125 | ) 126 | 127 | 128 | def download_and_open(remote, local, mode="rb", handlers=default_cmds, verbose=False): 129 | with ULockFile(local + ".lock"): 130 | if os.path.exists(remote): 131 | # print("enter1", remote, local, mode) 132 | result = open(remote, mode) 133 | else: 134 | # print("enter2", remote, local, mode) 135 | if not os.path.exists(local): 136 | if verbose: 137 | print("downloading", remote, "to", local, file=sys.stderr) 138 | download_file(remote, local, handlers=handlers) 139 | else: 140 | if verbose: 141 | print("using cached", local, file=sys.stderr) 142 | result = open(local, mode) 143 | 144 | # input() 145 | 146 | if open_objects is not None: 147 | for k, v in list(open_objects.items()): 148 | if v.closed: 149 | del open_objects[k] 150 | if len(open_objects) > max_open_objects: 151 | raise RuntimeError("Too many open objects") 152 | current_time = datetime.now().strftime("%Y%m%d%H%M%S") 153 | key = tuple(str(x) for x in [remote, local, mode, current_time]) 154 | open_objects[key] = result 155 | return result 156 | -------------------------------------------------------------------------------- /vila_u/wids/wids_lru.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | class LRUCache: 5 | def __init__(self, capacity: int, release_handler=None): 6 | """Initialize a new LRU cache with the given capacity.""" 7 | self.capacity = capacity 8 | self.cache = OrderedDict() 9 | self.release_handler = release_handler 10 | 11 | def __getitem__(self, key): 12 | """Return the value associated with the given key, or None.""" 13 | if key not in self.cache: 14 | return None 15 | self.cache.move_to_end(key) 16 | return self.cache[key] 17 | 18 | def __setitem__(self, key, value): 19 | """Associate the given value with the given key.""" 20 | if key in self.cache: 21 | self.cache.move_to_end(key) 22 | self.cache[key] = value 23 | if len(self.cache) > self.capacity: 24 | key, value = self.cache.popitem(last=False) 25 | if self.release_handler is not None: 26 | self.release_handler(key, value) 27 | 28 | def __delitem__(self, key): 29 | """Remove the given key from the cache.""" 30 | if key in self.cache: 31 | if self.release_handler is not None: 32 | value = self.cache[key] 33 | self.release_handler(key, value) 34 | del self.cache[key] 35 | 36 | def __len__(self): 37 | """Return the number of entries in the cache.""" 38 | return len(self.cache) 39 | 40 | def __contains__(self, key): 41 | """Return whether the cache contains the given key.""" 42 | return key in self.cache 43 | 44 | def items(self): 45 | """Return an iterator over the keys of the cache.""" 46 | return self.cache.items() 47 | 48 | def keys(self): 49 | """Return an iterator over the keys of the cache.""" 50 | return self.cache.keys() 51 | 52 | def values(self): 53 | """Return an iterator over the values of the cache.""" 54 | return self.cache.values() 55 | 56 | def clear(self): 57 | for key in list(self.keys()): 58 | value = self.cache[key] 59 | if self.release_handler is not None: 60 | self.release_handler(key, value) 61 | del self[key] 62 | 63 | def __del__(self): 64 | self.clear() 65 | -------------------------------------------------------------------------------- /vila_u/wids/wids_mmtar.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import fcntl 3 | import io 4 | import mmap 5 | import os 6 | import struct 7 | 8 | TarHeader = collections.namedtuple( 9 | "TarHeader", 10 | [ 11 | "name", 12 | "mode", 13 | "uid", 14 | "gid", 15 | "size", 16 | "mtime", 17 | "chksum", 18 | "typeflag", 19 | "linkname", 20 | "magic", 21 | "version", 22 | "uname", 23 | "gname", 24 | "devmajor", 25 | "devminor", 26 | "prefix", 27 | ], 28 | ) 29 | 30 | 31 | def parse_tar_header(header_bytes): 32 | header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes) 33 | return TarHeader(*header) 34 | 35 | 36 | def next_header(offset, header): 37 | block_size = 512 38 | size = header.size.decode("utf-8").strip("\x00") 39 | if size == "": 40 | return -1 41 | size = int(size, 8) 42 | # compute the file size rounded up to the next block size if it is a partial block 43 | padded_file_size = (size + block_size - 1) // block_size * block_size 44 | return offset + block_size + padded_file_size 45 | 46 | 47 | # TODO(ligeng): support gzip stream 48 | class MMIndexedTar: 49 | def __init__(self, fname, index_file=None, verbose=True, cleanup_callback=None): 50 | self.verbose = verbose 51 | self.cleanup_callback = cleanup_callback 52 | if isinstance(fname, str): 53 | self.stream = open(fname, "rb") 54 | self.fname = fname 55 | elif isinstance(fname, io.IOBase): 56 | self.stream = fname 57 | self.fname = None 58 | self.mmapped_file = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ) 59 | if cleanup_callback: 60 | cleanup_callback(fname, self.stream.fileno(), "start") 61 | self._build_index() 62 | 63 | def close(self, dispose=False): 64 | if self.cleanup_callback: 65 | self.cleanup_callback(self.fname, self.stream.fileno(), "end") 66 | self.mmapped_file.close() 67 | self.stream.close() 68 | 69 | def _build_index(self): 70 | self.by_name = {} 71 | self.by_index = [] 72 | offset = 0 73 | while offset >= 0 and offset < len(self.mmapped_file): 74 | header = parse_tar_header(self.mmapped_file[offset : offset + 500]) 75 | name = header.name.decode("utf-8").strip("\x00") 76 | typeflag = header.typeflag.decode("utf-8").strip("\x00") 77 | if name != "" and name != "././@PaxHeader" and typeflag in ["0", ""]: 78 | try: 79 | size = int(header.size.decode("utf-8")[:-1], 8) 80 | except ValueError as exn: 81 | print(header) 82 | raise exn 83 | self.by_name[name] = offset 84 | self.by_index.append((name, offset, size)) 85 | offset = next_header(offset, header) 86 | 87 | def names(self): 88 | return self.by_name.keys() 89 | 90 | def get_at_offset(self, offset): 91 | header = parse_tar_header(self.mmapped_file[offset : offset + 500]) 92 | name = header.name.decode("utf-8").strip("\x00") 93 | start = offset + 512 94 | end = start + int(header.size.decode("utf-8")[:-1], 8) 95 | return name, self.mmapped_file[start:end] 96 | 97 | def get_at_index(self, index): 98 | name, offset, size = self.by_index[index] 99 | return self.get_at_offset(offset) 100 | 101 | def get_by_name(self, name): 102 | offset = self.by_name[name] 103 | return self.get_at_offset(offset) 104 | 105 | def __iter__(self): 106 | for name, offset, size in self.by_index: 107 | yield name, self.mmapped_file[offset + 512 : offset + 512 + size] 108 | 109 | def __getitem__(self, key): 110 | if isinstance(key, int): 111 | return self.get_at_index(key) 112 | else: 113 | return self.get_by_name(key) 114 | 115 | def __len__(self): 116 | return len(self.by_index) 117 | 118 | def get_file(self, i): 119 | fname, data = self.get_at_index(i) 120 | return fname, io.BytesIO(data) 121 | 122 | 123 | def keep_while_reading(fname, fd, phase, delay=0.0): 124 | """This is a possible cleanup callback for cleanup_callback of MIndexedTar. 125 | 126 | It assumes that as long as there are some readers for a file, 127 | more readers may be trying to open it. 128 | 129 | Note that on Linux, unlinking the file doesn't matter after 130 | it has been mmapped. The contents will only be deleted when 131 | all readers close the file. The unlinking merely makes the file 132 | unavailable to new readers, since the downloader checks first 133 | whether the file exists. 134 | """ 135 | assert delay == 0.0, "delay not implemented" 136 | if fd < 0 or fname is None: 137 | return 138 | if phase == "start": 139 | fcntl.flock(fd, fcntl.LOCK_SH) 140 | elif phase == "end": 141 | try: 142 | fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) 143 | os.unlink(fname) 144 | except FileNotFoundError: 145 | # someone else deleted it already 146 | pass 147 | except BlockingIOError: 148 | # we couldn't get an exclusive lock, so someone else is still reading 149 | pass 150 | else: 151 | raise ValueError(f"Unknown phase {phase}") 152 | -------------------------------------------------------------------------------- /vila_u/wids/wids_specs.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import os 4 | import tempfile 5 | from urllib.parse import urlparse, urlunparse 6 | 7 | from .wids_dl import download_and_open 8 | 9 | 10 | def urldir(url): 11 | """Return the directory part of a url.""" 12 | parsed_url = urlparse(url) 13 | path = parsed_url.path 14 | directory = os.path.dirname(path) 15 | return parsed_url._replace(path=directory).geturl() 16 | 17 | 18 | def urlmerge(base, url): 19 | """Merge a base URL and a relative URL. 20 | 21 | The function fills in any missing part of the url from the base, 22 | except for params, query, and fragment, which are taken only from the 'url'. 23 | For the pathname component, it merges the paths like os.path.join: 24 | an absolute path in 'url' overrides the base path, otherwise the paths are merged. 25 | 26 | Parameters: 27 | base (str): The base URL. 28 | url (str): The URL to merge with the base. 29 | 30 | Returns: 31 | str: The merged URL. 32 | """ 33 | # Parse the base and the relative URL 34 | parsed_base = urlparse(base) 35 | parsed_url = urlparse(url) 36 | 37 | # Merge paths using os.path.join 38 | # If the url path is absolute, it overrides the base path 39 | if parsed_url.path.startswith("/"): 40 | merged_path = parsed_url.path 41 | else: 42 | merged_path = os.path.normpath(os.path.join(parsed_base.path, parsed_url.path)) 43 | 44 | # Construct the merged URL 45 | merged_url = urlunparse( 46 | ( 47 | parsed_url.scheme or parsed_base.scheme, 48 | parsed_url.netloc or parsed_base.netloc, 49 | merged_path, 50 | parsed_url.params, # Use params from the url only 51 | parsed_url.query, # Use query from the url only 52 | parsed_url.fragment, # Use fragment from the url only 53 | ) 54 | ) 55 | 56 | return merged_url 57 | 58 | 59 | def check_shards(l): 60 | """Check that a list of shards is well-formed. 61 | 62 | This checks that the list is a list of dictionaries, and that 63 | each dictionary has a "url" and a "nsamples" key. 64 | """ 65 | assert isinstance(l, list) 66 | for shard in l: 67 | assert isinstance(shard, dict) 68 | assert "url" in shard 69 | assert "nsamples" in shard 70 | return l 71 | 72 | 73 | def set_all(l, k, v): 74 | """Set a key to a value in a list of dictionaries.""" 75 | if v is None: 76 | return 77 | for x in l: 78 | if k not in x: 79 | x[k] = v 80 | 81 | 82 | def load_remote_dsdesc_raw(source): 83 | """Load a remote or local dataset description in JSON format.""" 84 | if isinstance(source, str): 85 | with tempfile.TemporaryDirectory() as tmpdir: 86 | dlname = os.path.join(tmpdir, "dataset.json") 87 | with download_and_open(source, dlname) as f: 88 | dsdesc = json.load(f) 89 | elif isinstance(source, io.IOBase): 90 | dsdesc = json.load(source) 91 | else: 92 | # FIXME: use gopen 93 | import requests 94 | 95 | jsondata = requests.get(source).text 96 | dsdesc = json.loads(jsondata) 97 | return dsdesc 98 | 99 | 100 | def rebase_shardlist(shardlist, base): 101 | """Rebase the URLs in a shardlist.""" 102 | if base is None: 103 | return shardlist 104 | for shard in shardlist: 105 | shard["url"] = urlmerge(base, shard["url"]) 106 | return shardlist 107 | 108 | 109 | def resolve_dsdesc(dsdesc, *, options=None, base=None): 110 | """Resolve a dataset description. 111 | 112 | This rebases the shards as necessary and loads any remote references. 113 | 114 | Dataset descriptions are JSON files. They must have the following format; 115 | 116 | { 117 | "wids_version": 1, 118 | # optional immediate shardlist 119 | "shardlist": [ 120 | {"url": "http://example.com/file.tar", "nsamples": 1000}, 121 | ... 122 | ], 123 | # sub-datasets 124 | "datasets": [ 125 | {"source_url": "http://example.com/dataset.json"}, 126 | {"shardlist": [ 127 | {"url": "http://example.com/file.tar", "nsamples": 1000}, 128 | ... 129 | ]} 130 | ... 131 | ] 132 | } 133 | """ 134 | if options is None: 135 | options = {} 136 | assert isinstance(dsdesc, dict) 137 | dsdesc = dict(dsdesc, **options) 138 | shardlist = rebase_shardlist(dsdesc.get("shardlist", []), base) 139 | assert shardlist is not None 140 | set_all(shardlist, "weight", dsdesc.get("weight")) 141 | set_all(shardlist, "name", dsdesc.get("name")) 142 | check_shards(shardlist) 143 | assert "wids_version" in dsdesc, "No wids_version in dataset description" 144 | assert dsdesc["wids_version"] == 1, "Unknown wids_version" 145 | for component in dsdesc.get("datasets", []): 146 | # we use the weight from the reference to the dataset, 147 | # regardless of remote loading 148 | weight = component.get("weight") 149 | # follow any source_url dsdescs through remote loading 150 | source_url = None 151 | if "source_url" in component: 152 | source_url = component["source_url"] 153 | component = load_remote_dsdesc_raw(source_url) 154 | assert ( 155 | "source_url" not in component 156 | ), "double indirection in dataset description" 157 | assert "shardlist" in component, "no shardlist in dataset description" 158 | # if the component has a base, use it to rebase the shardlist 159 | # otherwise use the base from the source_url, if any 160 | subbase = component.get("base", urldir(source_url) if source_url else None) 161 | if subbase is not None: 162 | rebase_shardlist(component["shardlist"], subbase) 163 | l = check_shards(component["shardlist"]) 164 | set_all(l, "weight", weight) 165 | set_all(l, "source_url", source_url) 166 | set_all(l, "dataset", component.get("name")) 167 | shardlist.extend(l) 168 | assert len(shardlist) > 0, "No shards found" 169 | dsdesc["shardlist"] = shardlist 170 | return dsdesc 171 | 172 | 173 | def load_dsdesc_and_resolve(source, *, options=None, base=None): 174 | if options is None: 175 | options = {} 176 | dsdesc = load_remote_dsdesc_raw(source) 177 | return resolve_dsdesc(dsdesc, base=base, options=options) 178 | -------------------------------------------------------------------------------- /vila_u/wids/wids_tar.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import os.path 4 | import pickle 5 | import re 6 | import tarfile 7 | 8 | import numpy as np 9 | 10 | 11 | def find_index_file(file): 12 | prefix, last_ext = os.path.splitext(file) 13 | if re.match("._[0-9]+_$", last_ext): 14 | return prefix + ".index" 15 | else: 16 | return file + ".index" 17 | 18 | 19 | class TarFileReader: 20 | def __init__(self, file, index_file=find_index_file, verbose=True): 21 | self.verbose = verbose 22 | if callable(index_file): 23 | index_file = index_file(file) 24 | self.index_file = index_file 25 | 26 | # Open the tar file and keep it open 27 | if isinstance(file, str): 28 | self.tar_file = tarfile.open(file, "r") 29 | else: 30 | self.tar_file = tarfile.open(fileobj=file, mode="r") 31 | 32 | # Create the index 33 | self._create_tar_index() 34 | 35 | def _create_tar_index(self): 36 | if self.index_file is not None and os.path.exists(self.index_file): 37 | if self.verbose: 38 | print("Loading tar index from", self.index_file) 39 | with open(self.index_file, "rb") as stream: 40 | self.fnames, self.index = pickle.load(stream) 41 | return 42 | # Create an empty list for the index 43 | self.fnames = [] 44 | self.index = [] 45 | 46 | if self.verbose: 47 | print("Creating tar index for", self.tar_file.name, "at", self.index_file) 48 | # Iterate over the members of the tar file 49 | for member in self.tar_file: 50 | # If the member is a file, add it to the index 51 | if member.isfile(): 52 | # Get the file's offset 53 | offset = self.tar_file.fileobj.tell() 54 | self.fnames.append(member.name) 55 | self.index.append([offset, member.size]) 56 | if self.verbose: 57 | print( 58 | "Done creating tar index for", self.tar_file.name, "at", self.index_file 59 | ) 60 | self.index = np.array(self.index) 61 | if self.index_file is not None: 62 | if os.path.exists(self.index_file + ".temp"): 63 | os.unlink(self.index_file + ".temp") 64 | with open(self.index_file + ".temp", "wb") as stream: 65 | pickle.dump((self.fnames, self.index), stream) 66 | os.rename(self.index_file + ".temp", self.index_file) 67 | 68 | def names(self): 69 | return self.fnames 70 | 71 | def __len__(self): 72 | return len(self.index) 73 | 74 | def get_file(self, i): 75 | name = self.fnames[i] 76 | offset, size = self.index[i] 77 | self.tar_file.fileobj.seek(offset) 78 | file_bytes = self.tar_file.fileobj.read(size) 79 | return name, io.BytesIO(file_bytes) 80 | 81 | def close(self): 82 | # Close the tar file 83 | self.tar_file.close() 84 | --------------------------------------------------------------------------------