├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .github └── workflows │ └── release-pypi.yaml ├── .gitignore ├── LICENSE.md ├── Makefile ├── README.md ├── example └── llavaction_video_demo.ipynb ├── llavaction ├── __init__.py ├── action │ ├── benchmark.py │ ├── chatgpt_utils.py │ ├── dataset.py │ ├── ek_eval.py │ ├── generate_description.py │ ├── generate_interval_pred.py │ ├── generate_temporal_detection_data.py │ ├── llava_inference.py │ ├── make_visualizations.py │ ├── prediction_analysis.py │ ├── render_utils.py │ ├── selective_inference.py │ ├── utils.py │ └── vis_utils.py ├── constants.py ├── conversation.py ├── eval │ ├── evaluate_interleave.py │ └── model_vqa.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_gemma.py │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mixtral.py │ │ ├── llava_mpt.py │ │ ├── llava_qwen.py │ │ ├── llava_qwen_moe.py │ │ └── modeling_llama.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── dev_eva_clip │ │ │ ├── eva_clip │ │ │ │ ├── __init__.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── constants.py │ │ │ │ ├── eva_vit_model.py │ │ │ │ ├── factory.py │ │ │ │ ├── hf_configs.py │ │ │ │ ├── hf_model.py │ │ │ │ ├── loss.py │ │ │ │ ├── model.py │ │ │ │ ├── model_configs │ │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ │ │ ├── modified_resnet.py │ │ │ │ ├── openai.py │ │ │ │ ├── pretrained.py │ │ │ │ ├── rope.py │ │ │ │ ├── timm_model.py │ │ │ │ ├── tokenizer.py │ │ │ │ ├── transform.py │ │ │ │ ├── transformer.py │ │ │ │ └── utils.py │ │ │ └── eva_vit.py │ │ ├── eva_clip │ │ │ ├── eva_clip_encoder.py │ │ │ ├── eva_clip_processors.py │ │ │ ├── eva_vit.py │ │ │ ├── factory.py │ │ │ └── model_configs │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ ├── hf_vision.py │ │ ├── imagebind.py │ │ ├── open_clip_encoder.py │ │ └── siglip_encoder.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── masked_drop.py │ │ ├── perceiver.py │ │ ├── qformer.py │ │ └── spatial_pool.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_multi_image.py │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── llava_trainer_eval.py │ ├── train.py │ └── train_mem.py └── utils.py ├── llavaction_figs └── Fig1.ipynb ├── pyproject.toml ├── resinstall.sh └── scripts ├── qwen.py ├── train ├── avion_tim_top5_gpt4o_detection_direct.yaml ├── avion_tim_top5_gpt4o_detection_direct_178K_100percent.yaml └── tim_top20_official_key_gpt4o_direct_detection.yaml ├── zero2.json ├── zero2_fused_adamw.json ├── zero2_offload.json ├── zero3.json ├── zero3_offload.json └── zero3pp.json /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | # Exclude some weights 20 | /openai 21 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | # Unix-style newlines with a newline ending every file 4 | [*] 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.{py,json}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | # 2 space indentation 16 | [*.{md,sh,yaml,yml}] 17 | indent_style = space 18 | indent_size = 2 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | pull_request: 8 | branches: 9 | - main 10 | types: 11 | - labeled 12 | - opened 13 | - edited 14 | - synchronize 15 | - reopened 16 | 17 | jobs: 18 | release: 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - name: Cache dependencies 23 | id: pip-cache 24 | uses: actions/cache@v3 25 | with: 26 | path: ~/.cache/pip 27 | key: ${{ runner.os }}-pip 28 | restore-keys: | 29 | ${{ runner.os }}-pip 30 | 31 | - name: Install dependencies 32 | run: | 33 | pip install --upgrade pip 34 | pip install wheel 35 | # see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669 36 | pip install "packaging>=24.2" 37 | 38 | - name: Checkout code 39 | uses: actions/checkout@v3 40 | 41 | 42 | - name: Build and publish to PyPI 43 | if: ${{ github.event_name == 'push' }} 44 | env: 45 | TWINE_USERNAME: __token__ 46 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 47 | run: | 48 | make dist 49 | ls dist/ 50 | tar tvf dist/llavaction-*.tar.gz 51 | python3 -m twine upload dist/* 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Distribution / packaging 2 | .Python 3 | build/ 4 | develop-eggs/ 5 | dist/ 6 | downloads/ 7 | eggs/ 8 | .eggs/ 9 | lib/ 10 | lib64/ 11 | parts/ 12 | sdist/ 13 | var/ 14 | wheels/ 15 | share/python-wheels/ 16 | *.egg-info/ 17 | .installed.cfg 18 | *.egg 19 | MANIFEST 20 | 21 | # Log 22 | *.log 23 | *.log.* 24 | # *.json 25 | # *.jsonl 26 | 27 | # Data 28 | !**/alpaca-data-conversation.json 29 | # Editor 30 | .idea 31 | *.swp 32 | 33 | # Other 34 | .DS_Store 35 | wandb 36 | output 37 | 38 | checkpoints 39 | project_checkpoints 40 | debug_checkpoints 41 | ckpts* 42 | 43 | # DevContainer 44 | !.devcontainer/* 45 | 46 | # demo/ 47 | 48 | 49 | experiments/ 50 | *.out 51 | pretrained_models/ 52 | 53 | huggingface/ 54 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | LLAVACTION_VERSION := 0.0.1 2 | 3 | dist: 4 | python3 -m pip install virtualenv 5 | python3 -m pip install --upgrade build twine 6 | python3 -m build --wheel --sdist 7 | 8 | build: dist 9 | 10 | archlinux: 11 | mkdir -p dist/arch 12 | cp PKGBUILD dist/arch 13 | cp dist/llavaction-${LLAVACTION_VERSION}.tar.gz dist/arch 14 | (cd dist/arch; makepkg --skipchecksums -f) 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLaVAction: Evaluating and Training Multi-Modal Large Language Models for Action Recognition 2 | 3 | [![Static Badge](https://img.shields.io/badge/LLaVAction-paper-green)](https://arxiv.org/abs/2503.18712) 4 | [![Demo Website](https://img.shields.io/badge/LLaVAction-website-red)](https://mmathislab.github.io/llavaction/) 5 | [![llavaction-checkpoints](https://img.shields.io/badge/LLaVAction-checkpoints_🤗-blue)](https://huggingface.co/MLAdaptiveIntelligence) 6 | 7 | [![Downloads](https://static.pepy.tech/badge/llavaction)](https://pepy.tech/project/llavaction) 8 | [![Downloads](https://static.pepy.tech/badge/llavaction/month)](https://pepy.tech/project/llavaction) 9 | [![PyPI version](https://badge.fury.io/py/llavaction.svg)](https://badge.fury.io/py/llavaction) 10 | ![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-red) 11 | 12 | ## Abstract 13 | 14 | Understanding human behavior requires measuring behavioral actions. Due to its complexity, behavior is best mapped onto a rich, semantic structure such as language. The recent development of multi-modal large language models (MLLMs) is a promising candidate for a wide range of action understanding tasks. In this work, we focus on evaluating and then improving MLLMs to perform action recognition. We reformulate EPIC-KITCHENS-100, one of the largest and most challenging egocentric action datasets, to the form of video multiple question answering (EPIC-KITCHENS-100-MQA). We show that when we sample difficult incorrect answers as distractors, leading MLLMs struggle to recognize the correct actions. We propose a series of methods that greatly improve the MLLMs' ability to perform action recognition, achieving state-of-the-art on both the EPIC-KITCHENS-100 Challenge, as well as outperforming GPT-4o by 21 points in accuracy on EPIC-KITCHENS-100-MQA. Lastly, we show improvements on other action-related video benchmarks such as VideoMME, PerceptionTest and MVBench. 15 | 16 | ## Code 17 | 18 | - This repository contains the implementation for our preprint on evaluating and training multi-modal large language models for action recognition. 19 | - Our code is built on [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT), and files in the directory `llavaction/action` are related to our work. We thank the authors of LLaVA-NeXT for making their code publicly available. 20 | - The files in the `/eval`, `/model`, `/serve` and `/train` are directly from [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT), unless modified and noted below. 21 | - `/model/llava_arch.py` 22 | - `/model/language_model/llava_qwen.py` 23 | - `/train/train.py` 24 | - `/train/llava_trainer.py` 25 | - `/utils.py` 26 | 27 | ## Demo 28 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AdaptiveMotorControlLab/LLaVAction/blob/main/example/llavaction_video_demo.ipynb) 29 | We provide code to run video inference in a Jupyter Notebook (which can be run on Google Colaboratory). 30 | 31 | 32 | ### Installation guide for video inference: 33 | ```bash 34 | conda create -n llavaction python=3.10 -y 35 | conda activate llavaction 36 | pip install --upgrade pip # Enable PEP 660 support. 37 | pip install --pre llavaction 38 | ``` 39 | 40 | - Please see the `/example` directory for a demo notebook. 41 | 42 | ## EPIC-KITCHENS-100-MQA 43 | 44 | In our work, we introduce a new way to evaluate MLMMs for action recognition by casting EPIC-KITCHENS-100 into a multi-question-answer benchmark. This has not yet been released [as of 3/2025], but please check the issues or open an issue if you are interested in accessing this resource before the paper is published. We also plan to integrate this the package [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). 45 | 46 | # Acknowledgments 47 | We thank the Swiss AI Initiative Project ID a03 from the Swiss National Supercomputing Centre (CSCS); Boehringer Ingelheim Fonds PhD stipend (H.Q.); M.W.M. thanks the Vallee Foundation; M.W.M. and A.M. thank the SNSF by grant No. 320030-227871. 48 | 49 | ![group-logo](https://github.com/user-attachments/assets/ad034dc3-5e92-4e8b-915b-85e443b3bdb2) 50 | 51 | -------------------------------------------------------------------------------- /llavaction/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llavaction/action/benchmark.py: -------------------------------------------------------------------------------- 1 | # benchmark gpt-4o on avion_mcq_top5_500 2 | # benchmark gpt-4o on tim_mcq_top5_500 3 | # benchmark gpt-4o on random_mcq_top5_500 4 | from llavaction.action.chatgpt_utils import GPTInferenceAnnotator 5 | import glob 6 | import json 7 | import os 8 | import re 9 | 10 | def process_raw_pred(raw_pred): 11 | matches = re.findall(r"[A-Z]\.\s(.+)", raw_pred) 12 | 13 | if 'None' in raw_pred: 14 | return raw_pred.replace('None. ', '') 15 | 16 | if matches: 17 | # Get the last match 18 | last_match = matches[-1] 19 | # Remove a trailing period and anything after it 20 | last_match = re.sub(r"\.\s*.*$", "", last_match) 21 | return last_match 22 | else: 23 | return raw_pred 24 | 25 | # root = '/data/EK100/EK100_320p_15sec_30fps_libx264' 26 | # annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv' 27 | # avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json' 28 | # tim_prediction_file = '/data/epic_kitchen/TIM_PREDS/tim_pred_ids_val.json' 29 | 30 | root = '/data/anonymous/EK100/' 31 | annotation_file = '/data/anonymous/epic-kitchens-100-annotations/EPIC_100_validation.csv' 32 | avion_prediction_file = '/data/anonymous/AVION_PREDS/avion_pred_ids_val.json' 33 | tim_prediction_file = '/data/anonymous/TIM_PREDS/tim_pred_ids_val.json' 34 | 35 | 36 | n_frames = 8 37 | topk = 5 38 | action_representation = 'GT_random_narration' 39 | perspective = 'first_person' 40 | benchmark_testing = True 41 | 42 | 43 | def benchmark_avion_mcq(n_samples, gpt_model, action_representation, benchmark_testing = True, n_frames = 8): 44 | 45 | inferencer = GPTInferenceAnnotator(gpt_model, 46 | root, 47 | annotation_file, 48 | gen_type = 'avion', 49 | prediction_file = avion_prediction_file, 50 | clip_length = n_frames, 51 | question_type = 'mc_', 52 | action_representation=action_representation, 53 | perspective = perspective, 54 | benchmark_testing = benchmark_testing, 55 | topk = topk) 56 | inferencer.multi_process_run(n_samples = n_samples, 57 | offset = 0) 58 | 59 | def benchmark_tim_mcq(n_samples, gpt_model, action_representation, benchmark_testing = True, n_frames = 8): 60 | 61 | inferencer = GPTInferenceAnnotator(gpt_model, 62 | root, 63 | annotation_file, 64 | gen_type = 'tim', 65 | prediction_file = tim_prediction_file, 66 | clip_length = n_frames, 67 | question_type = 'mc_', 68 | action_representation=action_representation, 69 | perspective = perspective, 70 | benchmark_testing = benchmark_testing, 71 | topk = topk) 72 | inferencer.multi_process_run(n_samples = n_samples, offset = 0) 73 | 74 | def benchmark_random_mcq(n_samples, gpt_model, action_representation, benchmark_testing = True, n_frames = 8): 75 | inferencer = GPTInferenceAnnotator(gpt_model, 76 | root, 77 | annotation_file, 78 | gen_type = 'random', 79 | prediction_file = avion_prediction_file, 80 | clip_length = n_frames, 81 | question_type = 'mc_', 82 | action_representation=action_representation, 83 | perspective = perspective, 84 | benchmark_testing = benchmark_testing, 85 | topk = topk) 86 | 87 | inferencer.multi_process_run(n_samples = n_samples, offset = 0) 88 | 89 | def calcuate_acc_from_jsons(json_folder): 90 | files = glob.glob(os.path.join(json_folder, '*.json')) 91 | for file in files: 92 | print (file) 93 | preds = json.load(open(file)) 94 | correct = 0 95 | something = 0 96 | for k,v in preds.items(): 97 | options = v['options'] 98 | options = [process_raw_pred(e) for e in options] 99 | 100 | #assert v['gt_name'] in options, f"{v['gt_name']} not in {options}" 101 | if v['gt_name'] not in options: 102 | print ('what?', options) 103 | print ('what?', v) 104 | break 105 | 106 | if v['gt_name'] == v['chatgpt_answer']: 107 | correct+=1 108 | else: 109 | pass 110 | #print ('wrong prediction! pred: gt', v['chatgpt_answer'] + "," + v['gt_name']) 111 | print ('acc ', correct/len(preds)) 112 | print ('gt not in options', something) 113 | 114 | 115 | 116 | if __name__ == '__main__': 117 | # benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8) 118 | # benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8) 119 | # benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8) 120 | # benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8) 121 | # benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8) 122 | # benchmark_random_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8) 123 | benchmark_tim_mcq(1, 'gpt-4o-mini-2024-07-18', 'official_key', benchmark_testing = False, n_frames = 16) 124 | #benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = False, n_frames = 16) 125 | #calcuate_acc_from_jsons('gpt_EK100_results') -------------------------------------------------------------------------------- /llavaction/action/generate_temporal_detection_data.py: -------------------------------------------------------------------------------- 1 | 2 | import csv 3 | from llavaction.action.dataset import datetime2sec 4 | import random 5 | from llavaction.action.utils import generate_label_map 6 | from pathlib import Path 7 | import json 8 | 9 | def get_temporal_detection(train_ann, delta = 5): 10 | 11 | labels, mapping_vn2narration, mapping_vn2act, verb_maps, noun_maps = generate_label_map(Path(train_ann).parent, 'GT_random_narration') 12 | 13 | csv_reader = csv.reader(open(train_ann)) 14 | 15 | _ = next(csv_reader) 16 | 17 | ret = [] 18 | 19 | for idx, row in enumerate(csv_reader): 20 | 21 | start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5]) 22 | pid, vid = row[1:3] 23 | vn_str = f'{row[10]}:{row[12]}' 24 | vid_path = '{}-{}'.format(pid, vid) 25 | process = lambda x: str(round(x, 2)) 26 | 27 | action_duration = process(end_timestamp - start_timestamp) 28 | 29 | action_start_timestamp = process(start_timestamp) 30 | action_end_timestamp = process(end_timestamp) 31 | action_gt_narration = row[8] 32 | start_padding = random.uniform(0, delta) 33 | end_padding = delta - start_padding 34 | start_timestamp = process(max(0, start_timestamp - start_padding)) 35 | end_timestamp = process(end_timestamp + end_padding) 36 | 37 | relative_start_time = process(float(action_start_timestamp) - float(start_timestamp)) 38 | relative_end_time = process(float(action_end_timestamp) - float(start_timestamp)) 39 | # print ('action_star_timestamp', action_start_timestamp) 40 | # print ('video start_timestamp', start_timestamp) 41 | # print ('relative_start_time', relative_start_time) 42 | 43 | # print ('action_end_timestamp', action_end_timestamp) 44 | # print ('video end_timestamp', end_timestamp) 45 | # print ('relative_end_time', relative_end_time) 46 | 47 | 48 | 49 | conversation = [ 50 | {"from": "human", "value": f"The provided video contains an action '{action_gt_narration}' that lasts {action_duration} seconds. What is the relative start and end time of the action in seconds? Format it as 'start_timestamp: end_timestamp' and round to 2 decimal places."}, 51 | {"from": "gpt", "value": f"{relative_start_time}, {relative_end_time}"} 52 | ] 53 | 54 | 55 | data = {'video': vid_path, 56 | 'conversations': conversation, 57 | 'id': vid_path, 58 | 'split': 'train', 59 | 'task_instruction': '', 60 | 'num_samples': 1, 61 | 'question_type': f'temporal_detection', 62 | 'dataset_name': 'EK100', 63 | 'start_timestamp': start_timestamp, 64 | 'end_timestamp': end_timestamp, 65 | 'verb_id': int(row[10]), 66 | 'noun_id': int(row[12]), 67 | 'action_id': mapping_vn2act[vn_str]} 68 | 69 | ret.append(data) 70 | 71 | return ret 72 | 73 | 74 | res = get_temporal_detection('/data/anonymous/epic-kitchens-100-annotations/EPIC_100_train.csv') 75 | 76 | # write to jsonl 77 | 78 | with open('/data/anonymous/EK100_inst_train/temporal_detection.jsonl', 'w') as f: 79 | for item in res: 80 | f.write(json.dumps(item) + '\n') 81 | 82 | -------------------------------------------------------------------------------- /llavaction/action/llava_inference.py: -------------------------------------------------------------------------------- 1 | from llavaction.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 2 | from llavaction.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 3 | from llavaction.conversation import conv_templates, SeparatorStyle 4 | 5 | import torch 6 | import numpy as np 7 | import copy 8 | from llavaction.action.utils import format_llava_prompt 9 | from llavaction.utils import rank0_print 10 | 11 | 12 | def llava_inference( 13 | video_frames, 14 | tokenizer, 15 | model, 16 | image_processor, 17 | input, 18 | clip_length = 16, 19 | num_frames = 16, 20 | temperature = 0, 21 | test_type = 'base', 22 | time_meta = None, 23 | learn_neighbor_actions = "", 24 | meta_data = None, 25 | perspective = "first_person", 26 | include_time_instruction = False 27 | ): 28 | 29 | model.eval() 30 | device = "cuda" 31 | # this [0] is only for batch size 1. 32 | video_frames = video_frames[0] 33 | 34 | temporal_stride = clip_length // num_frames 35 | 36 | video_frames = video_frames[::temporal_stride] 37 | 38 | image_tensors = [] 39 | 40 | video_duration = time_meta['duration'] 41 | n_frames = time_meta['n_frames'] 42 | 43 | frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16) 44 | image_tensors.append(frames) 45 | 46 | conv_template = "qwen_1_5" 47 | original_input = input 48 | if isinstance(input, dict): 49 | input = input['options'][0] if input else None 50 | 51 | if test_type == 'base': 52 | question_type = "mc_top5_official_key" 53 | else: 54 | question_type = test_type 55 | 56 | if test_type == 'caption_then_answer': 57 | caption_answer = llava_inference([video_frames], 58 | tokenizer, 59 | model, 60 | image_processor, 61 | original_input, 62 | test_type = 'caption', 63 | clip_length = clip_length, 64 | num_frames = num_frames, 65 | temperature = 0, 66 | time_meta = time_meta) 67 | 68 | question = format_llava_prompt(DEFAULT_IMAGE_TOKEN, 69 | input, 70 | video_duration, 71 | n_frames, 72 | "mc_top5_official_key", 73 | include_frame_time = False, 74 | learn_neighbor_actions = learn_neighbor_actions, 75 | perspective = perspective, 76 | include_time_instruction= include_time_instruction) 77 | 78 | question = f"You observed the video before and wrote down the notes: {caption_answer}. Now you watch the same video again and you can do better. " + question 79 | 80 | else: 81 | question = format_llava_prompt(DEFAULT_IMAGE_TOKEN, 82 | input, 83 | video_duration, 84 | n_frames, 85 | question_type, 86 | include_frame_time = False, 87 | learn_neighbor_actions = learn_neighbor_actions, 88 | include_time_instruction= include_time_instruction, 89 | perspective = perspective, 90 | meta_data=meta_data) 91 | 92 | 93 | #rank0_print ("debugging", question) 94 | 95 | conv = copy.deepcopy(conv_templates[conv_template]) 96 | conv.append_message(conv.roles[0], question) 97 | conv.append_message(conv.roles[1], None) 98 | prompt_question = conv.get_prompt() 99 | 100 | 101 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 102 | image_sizes = [frame.size for frame in video_frames] 103 | 104 | # Generate response 105 | cont = model.generate( 106 | input_ids, 107 | images=image_tensors, 108 | image_sizes=image_sizes, 109 | do_sample=False, 110 | temperature=temperature, 111 | max_new_tokens=4096, 112 | modalities=["video"], 113 | ) 114 | 115 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) 116 | 117 | return text_outputs[0] -------------------------------------------------------------------------------- /llavaction/action/render_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import ast 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | color_rgb = [(255,255,0), (255, 128,0), (128,255,0), (0,128,255), (0,0,255), (127,0,255), (255,0,255), (255,0,127), (255,0,0), (255,204,153), (255,102,102), (153,255,153), (153,153,255), (0,0,153)] 7 | color_rgba = [(255,255,0,70), (255, 128,0,70), (128,255,0,70), (0,128,255,70), (0,0,255,70), (127,0,255,70), (255,0,255,70), (255,0,127,70), (255,0,0,70), (255,204,153,70), (255,102,102,70), (153,255,153,70), (153,153,255,70), (0,0,153,70)] 8 | 9 | 10 | hand_rgb = [(0, 90, 181), (220, 50, 32)] 11 | hand_rgba = [(0, 90, 181, 70), (220, 50, 32, 70)] 12 | 13 | obj_rgb = (255, 194, 10) 14 | obj_rgba = (255, 194, 10, 70) 15 | 16 | side_map = {'l':'Left', 'r':'Right'} 17 | side_map2 = {0:'Left', 1:'Right'} 18 | side_map3 = {0:'L', 1:'R'} 19 | state_map = {0:'No Contact', 1:'Self Contact', 2:'Another Person', 3:'Portable Object', 4:'Stationary Object'} 20 | state_map2 = {0:'N', 1:'S', 2:'O', 3:'P', 4:'F'} 21 | 22 | vis_settings = {'font_size':20, 'line_width':2, 'point_radius':4, 'hand_color':hand_rgb, 'hand_alpha':[None, None], 'obj_color':obj_rgb, 'obj_alpha':None, 'text_alpha':(255, 255, 255, 255)} 23 | 24 | def calculate_center(bb): 25 | return [(bb[0] + bb[2])/2, (bb[1] + bb[3])/2] 26 | 27 | def filter_object(obj_dets, hand_dets): 28 | filtered_object = [] 29 | object_cc_list = [] 30 | for j in range(obj_dets.shape[0]): 31 | object_cc_list.append(calculate_center(obj_dets[j,:4])) 32 | object_cc_list = np.array(object_cc_list) 33 | img_obj_id = [] 34 | for i in range(hand_dets.shape[0]): 35 | if hand_dets[i, 5] <= 0: 36 | img_obj_id.append(-1) 37 | continue 38 | hand_cc = np.array(calculate_center(hand_dets[i,:4])) 39 | point_cc = np.array([(hand_cc[0]+hand_dets[i,6]*10000*hand_dets[i,7]), (hand_cc[1]+hand_dets[i,6]*10000*hand_dets[i,8])]) 40 | dist = np.sum((object_cc_list - point_cc)**2,axis=1) 41 | dist_min = np.argmin(dist) 42 | img_obj_id.append(dist_min) 43 | return img_obj_id 44 | 45 | def draw_obj_mask(image, draw, obj_idx, obj_bbox, obj_score, width, height): 46 | font = ImageFont.truetype('llavaction/action/times_b.ttf', size=vis_settings['font_size']) 47 | mask = Image.new('RGBA', (width, height)) 48 | pmask = ImageDraw.Draw(mask) 49 | pmask.rectangle(obj_bbox, outline=vis_settings['obj_color'], width=vis_settings['line_width'], fill=vis_settings['obj_alpha']) 50 | image.paste(mask, (0,0), mask) 51 | 52 | draw.rectangle([obj_bbox[0], max(0, obj_bbox[1]-vis_settings['font_size']), obj_bbox[0]+vis_settings['font_size']+2, 53 | max(0, obj_bbox[1]-vis_settings['font_size'])+vis_settings['font_size']], 54 | fill=vis_settings['text_alpha'], outline=vis_settings['obj_color'], width=vis_settings['line_width']) 55 | draw.text((obj_bbox[0]+5, max(0, obj_bbox[1]-vis_settings['font_size'])-2), f'O', font=font, fill=(0,0,0)) # 56 | 57 | return image 58 | 59 | def draw_hand_mask(image, draw, hand_idx, hand_bbox, hand_score, side, state, width, height): 60 | font = ImageFont.truetype('llavaction/action/times_b.ttf', size=vis_settings['font_size']) 61 | if side == 0: 62 | side_idx = 0 63 | elif side == 1: 64 | side_idx = 1 65 | mask = Image.new('RGBA', (width, height)) 66 | pmask = ImageDraw.Draw(mask) 67 | pmask.rectangle(hand_bbox, outline=vis_settings['hand_color'][side_idx], width=vis_settings['line_width'], fill=vis_settings['hand_alpha'][side_idx]) 68 | image.paste(mask, (0,0), mask) 69 | # text 70 | 71 | draw = ImageDraw.Draw(image) 72 | draw.rectangle([hand_bbox[0], max(0, hand_bbox[1]-vis_settings['font_size']), hand_bbox[0]+vis_settings['font_size']*2+2, 73 | max(0, hand_bbox[1]-vis_settings['font_size'])+vis_settings['font_size']], 74 | fill=vis_settings['text_alpha'], outline=vis_settings['hand_color'][side_idx], width=vis_settings['line_width']) 75 | draw.text((hand_bbox[0]+6, max(0, hand_bbox[1]-vis_settings['font_size'])-2), f'{side_map3[int(float(side))]}-{state_map2[int(float(state))]}', font=font, fill=(0,0,0)) # 76 | 77 | return image 78 | 79 | def draw_line_point(draw, side_idx, hand_center, object_center): 80 | 81 | draw.line([hand_center, object_center], fill=vis_settings['hand_color'][side_idx], width=vis_settings['line_width']) 82 | x, y = hand_center[0], hand_center[1] 83 | r=vis_settings['point_radius'] 84 | draw.ellipse((x-r, y-r, x+r, y+r), fill=vis_settings['hand_color'][side_idx]) 85 | x, y = object_center[0], object_center[1] 86 | draw.ellipse((x-r, y-r, x+r, y+r), fill=vis_settings['obj_color']) 87 | 88 | def vis_detections_PIL(im, class_name, dets, thresh=0.8): 89 | """Visual debugging of detections.""" 90 | 91 | image = Image.fromarray(im).convert("RGBA") 92 | draw = ImageDraw.Draw(image) 93 | width, height = image.size 94 | 95 | for hand_idx, i in enumerate(range(np.minimum(10, dets.shape[0]))): 96 | bbox = list(int(np.round(x)) for x in dets[i, :4]) 97 | score = dets[i, 4] 98 | lr = dets[i, -1] 99 | state = dets[i, 5] 100 | if score > thresh: 101 | image = draw_hand_mask(image, draw, hand_idx, bbox, score, lr, state, width, height) 102 | 103 | return image 104 | 105 | def vis_detections_filtered_objects_PIL(im, obj_dets, hand_dets, thresh_hand=0.8, thresh_obj=0.01): 106 | 107 | # convert to PIL 108 | im = im[:,:,::-1] 109 | image = Image.fromarray(im).convert("RGBA") 110 | draw = ImageDraw.Draw(image) 111 | width, height = image.size 112 | 113 | if (obj_dets is not None) and (hand_dets is not None): 114 | img_obj_id = filter_object(obj_dets, hand_dets) 115 | for obj_idx, i in enumerate(range(np.minimum(10, obj_dets.shape[0]))): 116 | bbox = list(int(np.round(x)) for x in obj_dets[i, :4]) 117 | score = obj_dets[i, 4] 118 | if score > thresh_obj and i in img_obj_id: 119 | # viz obj by PIL 120 | image = draw_obj_mask(image, draw, obj_idx, bbox, score, width, height) 121 | 122 | for hand_idx, i in enumerate(range(np.minimum(10, hand_dets.shape[0]))): 123 | bbox = list(int(np.round(x)) for x in hand_dets[i, :4]) 124 | score = hand_dets[i, 4] 125 | lr = hand_dets[i, -1] 126 | state = hand_dets[i, 5] 127 | if score > thresh_hand: 128 | # viz hand by PIL 129 | image = draw_hand_mask(image, draw, hand_idx, bbox, score, lr, state, width, height) 130 | 131 | if state > 0: # in contact hand 132 | 133 | obj_cc, hand_cc = calculate_center(obj_dets[img_obj_id[i],:4]), calculate_center(bbox) 134 | # viz line by PIL 135 | if lr == 0: 136 | side_idx = 0 137 | elif lr == 1: 138 | side_idx = 1 139 | draw_line_point(draw, side_idx, (int(hand_cc[0]), int(hand_cc[1])), (int(obj_cc[0]), int(obj_cc[1]))) 140 | 141 | elif hand_dets is not None: 142 | image = vis_detections_PIL(im, 'hand', hand_dets, thresh_hand) 143 | 144 | return image 145 | 146 | def render_frame(im, hand_dets, obj_dets, thresh_hand=0.5, thresh_obj=0.5): 147 | import cv2 148 | im_show = im.copy() 149 | im_show = cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR) 150 | hand_dets = np.array(ast.literal_eval(hand_dets)) if hand_dets != '[]' else None 151 | obj_dets = np.array(ast.literal_eval(obj_dets)) if obj_dets != '[]' else None 152 | im_show = vis_detections_filtered_objects_PIL(im_show, obj_dets, hand_dets, thresh_hand, thresh_obj) 153 | # im_show.save('test.png') 154 | im_show = np.array(im_show) 155 | return im_show -------------------------------------------------------------------------------- /llavaction/action/selective_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instead of running the whole validation set, 3 | """ 4 | from llavaction.action.ek_eval import prepare_llava 5 | from llavaction.action.generate_interval_pred import get_lookup_dict 6 | from llavaction.action.llava_inference import llava_inference 7 | from llavaction.action.utils import avion_video_loader 8 | 9 | from llavaction.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 10 | 11 | val_metadata = '/data/anonymous/epic-kitchens-100-annotations/EPIC_100_validation.csv' 12 | data_root = '/data/anonymous/EK100_512/EK100' 13 | 14 | n_frames = 32 15 | action_representation = 'GT_random_narration' 16 | perspective = 'first_person' 17 | include_time_instruction = False 18 | image_token = DEFAULT_IMAGE_TOKEN 19 | 20 | 21 | 22 | def get_frames_by_uid(uid, root): 23 | vid_path = '_'.join(uid.split('_')[:2]).replace('-', '/') 24 | print ('debug', uid) 25 | start_timestamp, end_timestamp = uid.split('_')[2:] 26 | start_timestamp = float(start_timestamp) 27 | end_timestamp = float(end_timestamp) 28 | print (vid_path, start_timestamp, end_timestamp) 29 | # split uid to video path and start, end second 30 | frames, time_meta = avion_video_loader(root, 31 | vid_path, 32 | 'MP4', 33 | start_timestamp, 34 | end_timestamp, 35 | chunk_len = 15, 36 | clip_length = n_frames, 37 | threads = 1, 38 | fast_rrc=False, 39 | fast_rcc = False, 40 | jitter = False) 41 | return frames, time_meta 42 | # 43 | 44 | 45 | 46 | 47 | 48 | 49 | # for prior actions 50 | def get_meta_data(): 51 | pass 52 | 53 | 54 | def inference_task_by_uid(data_root, question, checkpoint_folder, uid, task): 55 | 56 | tokenizer, model, image_processor, max_length = prepare_llava(checkpoint_folder) 57 | 58 | frames, time_meta = get_frames_by_uid(uid, data_root) 59 | 60 | meta_data = None 61 | learn_neighbor_actions = "" 62 | if 'temporal_cot' in task: 63 | lookup_table = get_lookup_dict(val_metadata, 64 | action_representation, 65 | test_type = task, 66 | pseudo_folder = '') 67 | meta_data = lookup_table.get(uid, None) 68 | learn_neighbor_actions = "prior" 69 | 70 | video_duration = time_meta['duration'] 71 | 72 | 73 | pred = llava_inference( 74 | [frames], 75 | tokenizer, 76 | model, 77 | image_processor, 78 | question, 79 | test_type = task, 80 | clip_length = n_frames, 81 | num_frames= n_frames, 82 | temperature = 0, 83 | time_meta = time_meta, 84 | learn_neighbor_actions = learn_neighbor_actions, 85 | meta_data = meta_data, 86 | perspective = perspective, 87 | include_time_instruction = include_time_instruction 88 | ) 89 | return pred 90 | 91 | class SelectiveInferencer: 92 | def __init__(self, data_root, checkpoint_folder, include_time_instruction = False, n_frames = 32, use_flash_attention = True): 93 | self.data_root = data_root 94 | self.checkpoint_folder = checkpoint_folder 95 | self.tokenizer, self.model, self.image_processor, self.max_length = prepare_llava(checkpoint_folder, use_flash_attention = use_flash_attention) 96 | self.include_time_instruction = include_time_instruction 97 | self.n_frames = n_frames 98 | def inference(self, question, uid, task): 99 | frames, time_meta = get_frames_by_uid(uid, self.data_root) 100 | 101 | meta_data = None 102 | learn_neighbor_actions = "" 103 | if 'temporal_cot' in task: 104 | lookup_table = get_lookup_dict(val_metadata, 105 | action_representation, 106 | test_type = task, 107 | pseudo_folder = '') 108 | meta_data = lookup_table.get(uid, None) 109 | learn_neighbor_actions = "prior" 110 | 111 | 112 | pred = llava_inference( 113 | [frames], 114 | self.tokenizer, 115 | self.model, 116 | self.image_processor, 117 | question, 118 | test_type = task, 119 | clip_length = self.n_frames, 120 | num_frames= self.n_frames, 121 | temperature = 0, 122 | time_meta = time_meta, 123 | learn_neighbor_actions = learn_neighbor_actions, 124 | meta_data = meta_data, 125 | perspective = perspective, 126 | include_time_instruction = self.include_time_instruction 127 | ) 128 | return pred 129 | 130 | 131 | if __name__ == '__main__': 132 | pretrained_model_folder = 'experiments/dev_LLaVA-Video-7B-Qwen2' 133 | uid = 'P28-P28_15_50.66_51.69' 134 | task = 'open-ended' 135 | question = "What is the object that is to the left of the knife?" 136 | 137 | inference_task_by_uid(data_root, 138 | question, 139 | pretrained_model_folder, 140 | uid, 141 | task) -------------------------------------------------------------------------------- /llavaction/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llavaction/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", 7 | "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", 8 | # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", 9 | # Add other models as needed 10 | } 11 | 12 | for model_name, model_classes in AVAILABLE_MODELS.items(): 13 | try: 14 | exec(f"from .language_model.{model_name} import {model_classes}") 15 | except Exception as e: 16 | print(f"Failed to import {model_name} from llavaction.language_model.{model_name}. Error: {e}") 17 | -------------------------------------------------------------------------------- /llavaction/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llavaction import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llavaction/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llavaction.model import * 11 | from llavaction.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llavaction/model/language_model/llava_gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaGemmaConfig(GemmaConfig): 31 | model_type = "llava_gemma" 32 | 33 | 34 | class LlavaGemmaModel(LlavaMetaModel, GemmaModel): 35 | config_class = LlavaGemmaConfig 36 | 37 | def __init__(self, config: GemmaConfig): 38 | super(LlavaGemmaModel, self).__init__(config) 39 | 40 | 41 | class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaGemmaConfig 43 | 44 | def __init__(self, config): 45 | super(GemmaForCausalLM, self).__init__(config) 46 | self.model = LlavaGemmaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | image_sizes: Optional[List[List[int]]] = None, 69 | return_dict: Optional[bool] = None, 70 | cache_position: Optional[torch.LongTensor] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 75 | 76 | return super().forward( 77 | input_ids=input_ids, 78 | attention_mask=attention_mask, 79 | position_ids=position_ids, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | labels=labels, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict, 87 | cache_position=cache_position, 88 | ) 89 | 90 | @torch.no_grad() 91 | def generate( 92 | self, 93 | inputs: Optional[torch.Tensor] = None, 94 | images: Optional[torch.Tensor] = None, 95 | image_sizes: Optional[torch.Tensor] = None, 96 | **kwargs, 97 | ) -> Union[GenerateOutput, torch.LongTensor]: 98 | position_ids = kwargs.pop("position_ids", None) 99 | attention_mask = kwargs.pop("attention_mask", None) 100 | if "inputs_embeds" in kwargs: 101 | raise NotImplementedError("`inputs_embeds` is not supported") 102 | 103 | if images is not None: 104 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 105 | else: 106 | inputs_embeds = self.get_model().embed_tokens(inputs) 107 | 108 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 109 | 110 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 111 | images = kwargs.pop("images", None) 112 | image_sizes = kwargs.pop("image_sizes", None) 113 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 114 | if images is not None: 115 | inputs["images"] = images 116 | if image_sizes is not None: 117 | inputs["image_sizes"] = image_sizes 118 | return inputs 119 | 120 | 121 | AutoConfig.register("llava_gemma", LlavaGemmaConfig) 122 | AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM) 123 | -------------------------------------------------------------------------------- /llavaction/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig 22 | 23 | from torch.nn import CrossEntropyLoss 24 | 25 | 26 | # , LlamaModel, LlamaForCausalLM, GenerationConfig 27 | # from .modeling_llama import LlamaModel, LlamaForCausalLM 28 | from transformers import LlamaModel, LlamaForCausalLM 29 | from transformers.modeling_outputs import CausalLMOutputWithPast 30 | from transformers.generation.utils import GenerateOutput 31 | 32 | from llavaction.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 33 | 34 | 35 | class LlavaConfig(LlamaConfig): 36 | model_type = "llava_llama" 37 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 38 | max_new_tokens: int = 1024 39 | do_sample: bool = False 40 | top_p: Optional[float] = None 41 | # rope_scaling: Optional[dict] = {} 42 | 43 | 44 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 45 | config_class = LlavaConfig 46 | 47 | def __init__(self, config: LlamaConfig): 48 | super(LlavaLlamaModel, self).__init__(config) 49 | 50 | 51 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 52 | config_class = LlavaConfig 53 | 54 | def __init__(self, config): 55 | LlamaForCausalLM.__init__(self, config) 56 | 57 | # configure default generation settings 58 | config.model_type = "llava_llama" 59 | # config.rope_scaling = None 60 | 61 | self.model = LlavaLlamaModel(config) 62 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 63 | # Initialize weights and apply final processing 64 | self.post_init() 65 | 66 | def get_model(self): 67 | return self.model 68 | 69 | def forward( 70 | self, 71 | input_ids: torch.LongTensor = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | position_ids: Optional[torch.LongTensor] = None, 74 | past_key_values: Optional[List[torch.FloatTensor]] = None, 75 | inputs_embeds: Optional[torch.FloatTensor] = None, 76 | labels: Optional[torch.LongTensor] = None, 77 | use_cache: Optional[bool] = None, 78 | output_attentions: Optional[bool] = None, 79 | output_hidden_states: Optional[bool] = None, 80 | images: Optional[torch.FloatTensor] = None, 81 | image_sizes: Optional[List[List[int]]] = None, 82 | return_dict: Optional[bool] = None, 83 | modalities: Optional[List[str]] = ["image"], 84 | dpo_forward: Optional[bool] = None, 85 | cache_position=None, 86 | ) -> Union[Tuple, CausalLMOutputWithPast]: 87 | 88 | if inputs_embeds is None: 89 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 90 | 91 | if dpo_forward: 92 | outputs = self.model( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | position_ids=position_ids, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict, 102 | ) 103 | 104 | hidden_states = outputs[0] 105 | logits = self.lm_head(hidden_states) 106 | return logits, labels 107 | 108 | else: 109 | return super().forward( 110 | input_ids=input_ids, 111 | attention_mask=attention_mask, 112 | position_ids=position_ids, 113 | past_key_values=past_key_values, 114 | inputs_embeds=inputs_embeds, 115 | labels=labels, 116 | use_cache=use_cache, 117 | output_attentions=output_attentions, 118 | output_hidden_states=output_hidden_states, 119 | return_dict=return_dict, 120 | ) 121 | 122 | @torch.no_grad() 123 | def generate( 124 | self, 125 | inputs: Optional[torch.Tensor] = None, 126 | images: Optional[torch.Tensor] = None, 127 | image_sizes: Optional[torch.Tensor] = None, 128 | modalities: Optional[List[str]] = ["image"], 129 | **kwargs, 130 | ) -> Union[GenerateOutput, torch.LongTensor]: 131 | modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities 132 | position_ids = kwargs.pop("position_ids", None) 133 | attention_mask = kwargs.pop("attention_mask", None) 134 | if "inputs_embeds" in kwargs: 135 | raise NotImplementedError("`inputs_embeds` is not supported") 136 | 137 | if images is not None: 138 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 139 | else: 140 | inputs_embeds = self.get_model().embed_tokens(inputs) 141 | 142 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 145 | images = kwargs.pop("images", None) 146 | image_sizes = kwargs.pop("image_sizes", None) 147 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 148 | if images is not None: 149 | inputs["images"] = images 150 | if image_sizes is not None: 151 | inputs["image_sizes"] = image_sizes 152 | return inputs 153 | 154 | 155 | AutoConfig.register("llava_llama", LlavaConfig) 156 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 157 | -------------------------------------------------------------------------------- /llavaction/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMistralConfig(MistralConfig): 31 | model_type = "llava_mistral" 32 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 33 | max_new_tokens: int = 1024 34 | do_sample: bool = False 35 | top_p: Optional[float] = None 36 | 37 | 38 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 39 | config_class = LlavaMistralConfig 40 | 41 | def __init__(self, config: MistralConfig): 42 | super(LlavaMistralModel, self).__init__(config) 43 | 44 | 45 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMistralConfig 47 | 48 | def __init__(self, config): 49 | super(MistralForCausalLM, self).__init__(config) 50 | 51 | config.model_type = "llava_mistral" 52 | config.rope_scaling = None 53 | 54 | self.model = LlavaMistralModel(config) 55 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | image_sizes: Optional[List[List[int]]] = None, 75 | return_dict: Optional[bool] = None, 76 | cache_position=None, 77 | ) -> Union[Tuple, CausalLMOutputWithPast]: 78 | 79 | if inputs_embeds is None: 80 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 81 | 82 | return super().forward( 83 | input_ids=input_ids, 84 | attention_mask=attention_mask, 85 | position_ids=position_ids, 86 | past_key_values=past_key_values, 87 | inputs_embeds=inputs_embeds, 88 | labels=labels, 89 | use_cache=use_cache, 90 | output_attentions=output_attentions, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | 95 | @torch.no_grad() 96 | def generate( 97 | self, 98 | inputs: Optional[torch.Tensor] = None, 99 | images: Optional[torch.Tensor] = None, 100 | image_sizes: Optional[torch.Tensor] = None, 101 | **kwargs, 102 | ) -> Union[GenerateOutput, torch.LongTensor]: 103 | position_ids = kwargs.pop("position_ids", None) 104 | attention_mask = kwargs.pop("attention_mask", None) 105 | if "inputs_embeds" in kwargs: 106 | raise NotImplementedError("`inputs_embeds` is not supported") 107 | 108 | if images is not None: 109 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 110 | else: 111 | inputs_embeds = self.get_model().embed_tokens(inputs) 112 | 113 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 114 | 115 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 116 | images = kwargs.pop("images", None) 117 | image_sizes = kwargs.pop("image_sizes", None) 118 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 119 | if images is not None: 120 | inputs["images"] = images 121 | if image_sizes is not None: 122 | inputs["image_sizes"] = image_sizes 123 | return inputs 124 | 125 | 126 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 127 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 128 | -------------------------------------------------------------------------------- /llavaction/model/language_model/llava_mixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMixtralConfig(MixtralConfig): 31 | model_type = "llava_mixtral" 32 | 33 | 34 | class LlavaMixtralModel(LlavaMetaModel, MixtralModel): 35 | config_class = LlavaMixtralConfig 36 | 37 | def __init__(self, config: MixtralConfig): 38 | super(LlavaMixtralModel, self).__init__(config) 39 | 40 | 41 | class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaMixtralConfig 43 | 44 | def __init__(self, config): 45 | super(MixtralForCausalLM, self).__init__(config) 46 | 47 | config.model_type = "llava_mixtral" 48 | config.rope_scaling = None 49 | self.model = LlavaMixtralModel(config) 50 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | modalities: Optional[List[str]] = ["image"], 72 | dpo_forward: Optional[bool] = None, 73 | cache_position=None, 74 | ) -> Union[Tuple, CausalLMOutputWithPast]: 75 | 76 | if inputs_embeds is None: 77 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 78 | 79 | if dpo_forward: 80 | outputs = self.model( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask, 83 | position_ids=position_ids, 84 | past_key_values=past_key_values, 85 | inputs_embeds=inputs_embeds, 86 | use_cache=use_cache, 87 | output_attentions=output_attentions, 88 | output_hidden_states=output_hidden_states, 89 | return_dict=return_dict, 90 | ) 91 | 92 | hidden_states = outputs[0] 93 | logits = self.lm_head(hidden_states) 94 | return logits, labels 95 | 96 | else: 97 | return super().forward( 98 | input_ids=input_ids, 99 | attention_mask=attention_mask, 100 | position_ids=position_ids, 101 | past_key_values=past_key_values, 102 | inputs_embeds=inputs_embeds, 103 | labels=labels, 104 | use_cache=use_cache, 105 | output_attentions=output_attentions, 106 | output_hidden_states=output_hidden_states, 107 | return_dict=return_dict, 108 | ) 109 | 110 | @torch.no_grad() 111 | def generate( 112 | self, 113 | inputs: Optional[torch.Tensor] = None, 114 | images: Optional[torch.Tensor] = None, 115 | image_sizes: Optional[torch.Tensor] = None, 116 | modalities: Optional[List[str]] = ["image"], 117 | **kwargs, 118 | ) -> Union[GenerateOutput, torch.LongTensor]: 119 | position_ids = kwargs.pop("position_ids", None) 120 | attention_mask = kwargs.pop("attention_mask", None) 121 | if "inputs_embeds" in kwargs: 122 | raise NotImplementedError("`inputs_embeds` is not supported") 123 | 124 | if images is not None: 125 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 126 | else: 127 | inputs_embeds = self.get_model().embed_tokens(inputs) 128 | 129 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 130 | 131 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 132 | images = kwargs.pop("images", None) 133 | image_sizes = kwargs.pop("image_sizes", None) 134 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 135 | if images is not None: 136 | inputs["images"] = images 137 | if image_sizes is not None: 138 | inputs["image_sizes"] = image_sizes 139 | return inputs 140 | 141 | 142 | AutoConfig.register("llava_mixtral", LlavaMixtralConfig) 143 | AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM) 144 | -------------------------------------------------------------------------------- /llavaction/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig 21 | from llavaction.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 22 | 23 | 24 | class LlavaMptConfig(MptConfig): 25 | model_type = "llava_mpt" 26 | 27 | 28 | class LlavaMptModel(LlavaMetaModel, MptModel): 29 | config_class = LlavaMptConfig 30 | 31 | def __init__(self, config: MptConfig): 32 | config.hidden_size = config.d_model 33 | super(LlavaMptModel, self).__init__(config) 34 | 35 | def embed_tokens(self, x): 36 | return self.wte(x) 37 | 38 | 39 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaMptConfig 41 | supports_gradient_checkpointing = True 42 | 43 | def __init__(self, config): 44 | super(MptForCausalLM, self).__init__(config) 45 | 46 | config.model_type = "llava_mpt" 47 | config.rope_scaling = None 48 | self.generation_config = GenerationConfig( 49 | temperature=0.0, 50 | max_new_tokens=1024, 51 | do_sample=False, 52 | top_p=None, 53 | ) 54 | 55 | self.transformer = LlavaMptModel(config) 56 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | 58 | # Initialize weights and apply final processing 59 | self.post_init() 60 | 61 | def get_model(self): 62 | return self.transformer 63 | 64 | def _set_gradient_checkpointing(self, module, value=False): 65 | if isinstance(module, LlavaMptModel): 66 | module.gradient_checkpointing = value 67 | 68 | def forward( 69 | self, 70 | input_ids: Optional[torch.LongTensor] = None, 71 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | inputs_embeds: Optional[torch.Tensor] = None, 74 | labels: Optional[torch.Tensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | cache_position=None, 80 | images=None, 81 | ): 82 | 83 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 84 | 85 | return super().forward( 86 | input_ids, 87 | past_key_values=past_key_values, 88 | attention_mask=attention_mask, 89 | inputs_embeds=inputs_embeds, 90 | labels=labels, 91 | use_cache=use_cache, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | ) 96 | 97 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 98 | images = kwargs.pop("images", None) 99 | _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 100 | _inputs["images"] = images 101 | return _inputs 102 | 103 | 104 | AutoConfig.register("llava_mpt", LlavaMptConfig) 105 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 106 | -------------------------------------------------------------------------------- /llavaction/model/language_model/llava_qwen_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Hao Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union, Dict 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import CrossEntropyLoss 20 | 21 | import transformers 22 | from transformers import AutoConfig, AutoModelForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 28 | from llavaction.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM 30 | 31 | # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel 32 | # from .qwen.configuration_qwen import QWenConfig 33 | 34 | 35 | class LlavaQwenMoeConfig(Qwen2MoeConfig): 36 | model_type = "llava_qwen_moe" 37 | 38 | 39 | class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel): 40 | config_class = LlavaQwenMoeConfig 41 | 42 | def __init__(self, config: Qwen2MoeConfig): 43 | super(LlavaQwenMoeModel, self).__init__(config) 44 | 45 | 46 | class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM): 47 | config_class = LlavaQwenMoeConfig 48 | 49 | def __init__(self, config): 50 | # super(Qwen2MoeForCausalLM, self).__init__(config) 51 | Qwen2MoeForCausalLM.__init__(self, config) 52 | config.model_type = "llava_qwen_moe" 53 | config.rope_scaling = None 54 | 55 | self.model = LlavaQwenMoeModel(config) 56 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | # Initialize weights and apply final processing 58 | self.post_init() 59 | 60 | def get_model(self): 61 | return self.model 62 | 63 | def forward( 64 | self, 65 | input_ids: torch.LongTensor = None, 66 | attention_mask: Optional[torch.Tensor] = None, 67 | position_ids: Optional[torch.LongTensor] = None, 68 | past_key_values: Optional[List[torch.FloatTensor]] = None, 69 | inputs_embeds: Optional[torch.FloatTensor] = None, 70 | labels: Optional[torch.LongTensor] = None, 71 | use_cache: Optional[bool] = None, 72 | output_attentions: Optional[bool] = None, 73 | output_hidden_states: Optional[bool] = None, 74 | images: Optional[torch.FloatTensor] = None, 75 | image_sizes: Optional[List[List[int]]] = None, 76 | return_dict: Optional[bool] = None, 77 | modalities: Optional[List[str]] = ["image"], 78 | dpo_forward: Optional[bool] = False, 79 | cache_position=None, 80 | ) -> Union[Tuple, CausalLMOutputWithPast]: 81 | 82 | if inputs_embeds is None: 83 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 84 | 85 | if dpo_forward: 86 | outputs = self.model( 87 | input_ids=input_ids, 88 | attention_mask=attention_mask, 89 | position_ids=position_ids, 90 | past_key_values=past_key_values, 91 | inputs_embeds=inputs_embeds, 92 | use_cache=use_cache, 93 | output_attentions=output_attentions, 94 | output_hidden_states=output_hidden_states, 95 | return_dict=return_dict, 96 | ) 97 | 98 | hidden_states = outputs[0] 99 | logits = self.lm_head(hidden_states) 100 | return logits, labels 101 | 102 | else: 103 | return super().forward( 104 | input_ids=input_ids, 105 | attention_mask=attention_mask, 106 | position_ids=position_ids, 107 | past_key_values=past_key_values, 108 | inputs_embeds=inputs_embeds, 109 | labels=labels, 110 | use_cache=use_cache, 111 | output_attentions=output_attentions, 112 | output_hidden_states=output_hidden_states, 113 | return_dict=return_dict, 114 | ) 115 | 116 | @torch.no_grad() 117 | def generate( 118 | self, 119 | inputs: Optional[torch.Tensor] = None, 120 | images: Optional[torch.Tensor] = None, 121 | image_sizes: Optional[torch.Tensor] = None, 122 | modalities: Optional[List[str]] = ["image"], 123 | **kwargs, 124 | ) -> Union[GenerateOutput, torch.LongTensor]: 125 | position_ids = kwargs.pop("position_ids", None) 126 | attention_mask = kwargs.pop("attention_mask", None) 127 | if "inputs_embeds" in kwargs: 128 | raise NotImplementedError("`inputs_embeds` is not supported") 129 | 130 | if images is not None: 131 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 132 | else: 133 | inputs_embeds = self.get_model().embed_tokens(inputs) 134 | 135 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 136 | 137 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 138 | images = kwargs.pop("images", None) 139 | image_sizes = kwargs.pop("image_sizes", None) 140 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 141 | if images is not None: 142 | inputs["images"] = images 143 | if image_sizes is not None: 144 | inputs["image_sizes"] = image_sizes 145 | return inputs 146 | 147 | 148 | AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig) 149 | AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM) 150 | -------------------------------------------------------------------------------- /llavaction/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llavaction.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 34 | 35 | raise ValueError(f"Unknown vision tower: {vision_tower}") 36 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 8 | from .tokenizer import SimpleTokenizer, tokenize 9 | from .transform import image_transform 10 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/LLaVAction/bd9f46583c5a94333575eccdf92e1b2bd118f7bd/llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings", 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings", 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens", 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings", 54 | }, 55 | "pooler": "mean_pooler", 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | 10 | has_distributed = True 11 | except ImportError: 12 | has_distributed = False 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy 20 | 21 | 22 | def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): 23 | assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." 24 | if use_horovod: 25 | assert hvd is not None, "Please install horovod" 26 | if gather_with_grad: 27 | all_image_features = hvd.allgather(image_features) 28 | all_text_features = hvd.allgather(text_features) 29 | else: 30 | with torch.no_grad(): 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | if not local_loss: 34 | # ensure grads for local rank when all_* features don't have a gradient 35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 37 | gathered_image_features[rank] = image_features 38 | gathered_text_features[rank] = text_features 39 | all_image_features = torch.cat(gathered_image_features, dim=0) 40 | all_text_features = torch.cat(gathered_text_features, dim=0) 41 | else: 42 | # We gather tensors from all gpus 43 | if gather_with_grad: 44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 46 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 47 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 48 | else: 49 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 50 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 51 | dist.all_gather(gathered_image_features, image_features) 52 | dist.all_gather(gathered_text_features, text_features) 53 | if not local_loss: 54 | # ensure grads for local rank when all_* features don't have a gradient 55 | gathered_image_features[rank] = image_features 56 | gathered_text_features[rank] = text_features 57 | all_image_features = torch.cat(gathered_image_features, dim=0) 58 | all_text_features = torch.cat(gathered_text_features, dim=0) 59 | 60 | return all_image_features, all_text_features 61 | 62 | 63 | class ClipLoss(nn.Module): 64 | 65 | def __init__( 66 | self, 67 | local_loss=False, 68 | gather_with_grad=False, 69 | cache_labels=False, 70 | rank=0, 71 | world_size=1, 72 | use_horovod=False, 73 | smoothing=0.0, 74 | ): 75 | super().__init__() 76 | self.local_loss = local_loss 77 | self.gather_with_grad = gather_with_grad 78 | self.cache_labels = cache_labels 79 | self.rank = rank 80 | self.world_size = world_size 81 | self.use_horovod = use_horovod 82 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale=1.0): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 92 | 93 | if self.local_loss: 94 | logits_per_image = logit_scale * image_features @ all_text_features.T 95 | logits_per_text = logit_scale * text_features @ all_image_features.T 96 | else: 97 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 98 | logits_per_text = logits_per_image.T 99 | else: 100 | logits_per_image = logit_scale * image_features @ text_features.T 101 | logits_per_text = logit_scale * text_features @ image_features.T 102 | # calculated ground-truth and cache if enabled 103 | num_logits = logits_per_image.shape[0] 104 | if self.prev_num_logits != num_logits or device not in self.labels: 105 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 106 | if self.world_size > 1 and self.local_loss: 107 | labels = labels + num_logits * self.rank 108 | if self.cache_labels: 109 | self.labels[device] = labels 110 | self.prev_num_logits = num_logits 111 | else: 112 | labels = self.labels[device] 113 | 114 | if self.label_smoothing_cross_entropy: 115 | total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 116 | else: 117 | total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 118 | 119 | acc = None 120 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 121 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 122 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 123 | return total_loss, acc 124 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))])) 37 | 38 | def forward(self, x: torch.Tensor): 39 | identity = x 40 | 41 | out = self.act1(self.bn1(self.conv1(x))) 42 | out = self.act2(self.bn2(self.conv2(out))) 43 | out = self.avgpool(out) 44 | out = self.bn3(self.conv3(out)) 45 | 46 | if self.downsample is not None: 47 | identity = self.downsample(x) 48 | 49 | out += identity 50 | out = self.act3(out) 51 | return out 52 | 53 | 54 | class AttentionPool2d(nn.Module): 55 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 56 | super().__init__() 57 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) 58 | self.k_proj = nn.Linear(embed_dim, embed_dim) 59 | self.q_proj = nn.Linear(embed_dim, embed_dim) 60 | self.v_proj = nn.Linear(embed_dim, embed_dim) 61 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 62 | self.num_heads = num_heads 63 | 64 | def forward(self, x): 65 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 66 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 67 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 68 | x, _ = F.multi_head_attention_forward( 69 | query=x, 70 | key=x, 71 | value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0.0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False, 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.image_size = image_size 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.act1 = nn.ReLU(inplace=True) 110 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(width // 2) 112 | self.act2 = nn.ReLU(inplace=True) 113 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d(width) 115 | self.act3 = nn.ReLU(inplace=True) 116 | self.avgpool = nn.AvgPool2d(2) 117 | 118 | # residual layers 119 | self._inplanes = width # this is a *mutable* variable used during construction 120 | self.layer1 = self._make_layer(width, layers[0]) 121 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 124 | 125 | embed_dim = width * 32 # the ResNet feature dimension 126 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 127 | 128 | self.init_parameters() 129 | 130 | def _make_layer(self, planes, blocks, stride=1): 131 | layers = [Bottleneck(self._inplanes, planes, stride)] 132 | 133 | self._inplanes = planes * Bottleneck.expansion 134 | for _ in range(1, blocks): 135 | layers.append(Bottleneck(self._inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def init_parameters(self): 140 | if self.attnpool is not None: 141 | std = self.attnpool.c_proj.in_features**-0.5 142 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 143 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 144 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 146 | 147 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 148 | for name, param in resnet_block.named_parameters(): 149 | if name.endswith("bn3.weight"): 150 | nn.init.zeros_(param) 151 | 152 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 153 | assert unlocked_groups == 0, "partial locking not currently supported for this model" 154 | for param in self.parameters(): 155 | param.requires_grad = False 156 | if freeze_bn_stats: 157 | freeze_batch_norm_2d(self) 158 | 159 | @torch.jit.ignore 160 | def set_grad_checkpointing(self, enable=True): 161 | # FIXME support for non-transformer 162 | pass 163 | 164 | def stem(self, x): 165 | x = self.act1(self.bn1(self.conv1(x))) 166 | x = self.act2(self.bn2(self.conv2(x))) 167 | x = self.act3(self.bn3(self.conv3(x))) 168 | x = self.avgpool(x) 169 | return x 170 | 171 | def forward(self, x): 172 | x = self.stem(x) 173 | x = self.layer1(x) 174 | x = self.layer2(x) 175 | x = self.layer3(x) 176 | x = self.layer4(x) 177 | x = self.attnpool(x) 178 | 179 | return x 180 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag("openai") 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = "fp32" if device == "cpu" else "fp16" 56 | 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith("amp") or precision == "fp32": 87 | model.float() 88 | elif precision == "bf16": 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == "fp32": 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | 8 | def broadcat(tensors, dim=-1): 9 | num_tensors = len(tensors) 10 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 11 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" 12 | shape_len = list(shape_lens)[0] 13 | dim = (dim + shape_len) if dim < 0 else dim 14 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 15 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 16 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" 17 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 18 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 19 | expanded_dims.insert(dim, (dim, dims[dim])) 20 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 21 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 22 | return torch.cat(tensors, dim=dim) 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, "... (d r) -> ... d r", r=2) 27 | x1, x2 = x.unbind(dim=-1) 28 | x = torch.stack((-x2, x1), dim=-1) 29 | return rearrange(x, "... d r -> ... (d r)") 30 | 31 | 32 | class VisionRotaryEmbedding(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | pt_seq_len, 37 | ft_seq_len=None, 38 | custom_freqs=None, 39 | freqs_for="lang", 40 | theta=10000, 41 | max_freq=10, 42 | num_freqs=1, 43 | ): 44 | super().__init__() 45 | if custom_freqs: 46 | freqs = custom_freqs 47 | elif freqs_for == "lang": 48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 49 | elif freqs_for == "pixel": 50 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 51 | elif freqs_for == "constant": 52 | freqs = torch.ones(num_freqs).float() 53 | else: 54 | raise ValueError(f"unknown modality {freqs_for}") 55 | 56 | if ft_seq_len is None: 57 | ft_seq_len = pt_seq_len 58 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 59 | 60 | freqs_h = torch.einsum("..., f -> ... f", t, freqs) 61 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) 62 | 63 | freqs_w = torch.einsum("..., f -> ... f", t, freqs) 64 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) 65 | 66 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) 67 | 68 | self.register_buffer("freqs_cos", freqs.cos()) 69 | self.register_buffer("freqs_sin", freqs.sin()) 70 | 71 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 72 | 73 | def forward(self, t, start_index=0): 74 | rot_dim = self.freqs_cos.shape[-1] 75 | end_index = start_index + rot_dim 76 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 78 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 79 | 80 | return torch.cat((t_left, t, t_right), dim=-1) 81 | 82 | 83 | class VisionRotaryEmbeddingFast(nn.Module): 84 | def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): 85 | super().__init__() 86 | if custom_freqs: 87 | freqs = custom_freqs 88 | elif freqs_for == "lang": 89 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 90 | elif freqs_for == "pixel": 91 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 92 | elif freqs_for == "constant": 93 | freqs = torch.ones(num_freqs).float() 94 | else: 95 | raise ValueError(f"unknown modality {freqs_for}") 96 | 97 | if ft_seq_len is None: 98 | ft_seq_len = pt_seq_len 99 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 100 | 101 | freqs = torch.einsum("..., f -> ... f", t, freqs) 102 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 103 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 104 | 105 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 106 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 107 | 108 | self.patch_dropout = patch_dropout 109 | 110 | self.register_buffer("freqs_cos", freqs_cos) 111 | self.register_buffer("freqs_sin", freqs_sin) 112 | 113 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 114 | 115 | def forward(self, t, patch_indices_keep=None): 116 | if patch_indices_keep is not None: 117 | batch = t.size()[0] 118 | batch_indices = torch.arange(batch) 119 | batch_indices = batch_indices[..., None] 120 | 121 | freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 122 | freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 123 | 124 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 125 | freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") 126 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 127 | freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") 128 | 129 | return t * freqs_cos + rotate_half(t) * freqs_sin 130 | 131 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 132 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | import timm 14 | from timm.models.layers import Mlp, to_2tuple 15 | 16 | try: 17 | # old timm imports < 0.8.1 18 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 19 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 20 | except ImportError: 21 | # new timm imports >= 0.8.1 22 | from timm.layers import RotAttentionPool2d 23 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 24 | except ImportError: 25 | timm = None 26 | 27 | from .utils import freeze_batch_norm_2d 28 | 29 | 30 | class TimmModel(nn.Module): 31 | """timm model adapter 32 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 33 | """ 34 | 35 | def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): 36 | super().__init__() 37 | if timm is None: 38 | raise RuntimeError("Please `pip install timm` to use timm models.") 39 | 40 | self.image_size = to_2tuple(image_size) 41 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 42 | feat_size = self.trunk.default_cfg.get("pool_size", None) 43 | feature_ndim = 1 if not feat_size else 2 44 | if pool in ("abs_attn", "rot_attn"): 45 | assert feature_ndim == 2 46 | # if attn pooling used, remove both classifier and default pool 47 | self.trunk.reset_classifier(0, global_pool="") 48 | else: 49 | # reset global pool if pool config set, otherwise leave as network default 50 | reset_kwargs = dict(global_pool=pool) if pool else {} 51 | self.trunk.reset_classifier(0, **reset_kwargs) 52 | prev_chs = self.trunk.num_features 53 | 54 | head_layers = OrderedDict() 55 | if pool == "abs_attn": 56 | head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 57 | prev_chs = embed_dim 58 | elif pool == "rot_attn": 59 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 60 | prev_chs = embed_dim 61 | else: 62 | assert proj, "projection layer needed if non-attention pooling is used." 63 | 64 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 65 | if proj == "linear": 66 | head_layers["drop"] = nn.Dropout(drop) 67 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 68 | elif proj == "mlp": 69 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 70 | 71 | self.head = nn.Sequential(head_layers) 72 | 73 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 74 | """lock modules 75 | Args: 76 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 77 | """ 78 | if not unlocked_groups: 79 | # lock full model 80 | for param in self.trunk.parameters(): 81 | param.requires_grad = False 82 | if freeze_bn_stats: 83 | freeze_batch_norm_2d(self.trunk) 84 | else: 85 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 86 | try: 87 | # FIXME import here until API stable and in an official release 88 | from timm.models.helpers import group_parameters, group_modules 89 | except ImportError: 90 | raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") 91 | matcher = self.trunk.group_matcher() 92 | gparams = group_parameters(self.trunk, matcher) 93 | max_layer_id = max(gparams.keys()) 94 | max_layer_id = max_layer_id - unlocked_groups 95 | for group_idx in range(max_layer_id + 1): 96 | group = gparams[group_idx] 97 | for param in group: 98 | self.trunk.get_parameter(param).requires_grad = False 99 | if freeze_bn_stats: 100 | gmodules = group_modules(self.trunk, matcher, reverse=True) 101 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 102 | freeze_batch_norm_2d(self.trunk, gmodules) 103 | 104 | @torch.jit.ignore 105 | def set_grad_checkpointing(self, enable=True): 106 | try: 107 | self.trunk.set_grad_checkpointing(enable) 108 | except Exception as e: 109 | logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") 110 | 111 | def forward(self, x): 112 | x = self.trunk(x) 113 | x = self.head(x) 114 | return x 115 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == "min" else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert("RGB") 40 | 41 | 42 | # class CatGen(nn.Module): 43 | # def __init__(self, num=4): 44 | # self.num = num 45 | # def mixgen_batch(image, text): 46 | # batch_size = image.shape[0] 47 | # index = np.random.permutation(batch_size) 48 | 49 | # cat_images = [] 50 | # for i in range(batch_size): 51 | # # image mixup 52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 53 | # # text concat 54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 55 | # text = torch.stack(text) 56 | # return image, text 57 | 58 | 59 | def image_transform( 60 | image_size: int, 61 | is_train: bool, 62 | mean: Optional[Tuple[float, ...]] = None, 63 | std: Optional[Tuple[float, ...]] = None, 64 | resize_longest_max: bool = False, 65 | fill_color: int = 0, 66 | ): 67 | mean = mean or OPENAI_DATASET_MEAN 68 | if not isinstance(mean, (list, tuple)): 69 | mean = (mean,) * 3 70 | 71 | std = std or OPENAI_DATASET_STD 72 | if not isinstance(std, (list, tuple)): 73 | std = (std,) * 3 74 | 75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 77 | image_size = image_size[0] 78 | 79 | normalize = Normalize(mean=mean, std=std) 80 | if is_train: 81 | return Compose( 82 | [ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ] 88 | ) 89 | else: 90 | if resize_longest_max: 91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 92 | else: 93 | transforms = [ 94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 95 | CenterCrop(image_size), 96 | ] 97 | transforms.extend( 98 | [ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ] 103 | ) 104 | return Compose(transforms) 105 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/dev_eva_clip/eva_vit.py: -------------------------------------------------------------------------------- 1 | # Based on EVA, BEIT, timm and DeiT code bases 2 | # https://github.com/baaivision/EVA 3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | # not tested yet 9 | import math 10 | from transformers import CLIPImageProcessor 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as checkpoint 16 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 17 | from .eva_clip import create_model_and_transforms, get_model_config 18 | import torch 19 | import torchvision 20 | import time 21 | 22 | from llavaction.utils import rank0_print 23 | 24 | 25 | class EvaViTWrapper(nn.Module): 26 | def __init__(self, vision_tower, args, delay_load=False): 27 | super().__init__() 28 | 29 | self.is_loaded = False 30 | self.vision_tower_name = vision_tower 31 | self.pretrained = args.vision_tower_pretrained 32 | self.args = args 33 | 34 | self.select_layer = args.mm_vision_select_layer 35 | if self.select_layer < -1: 36 | self.select_layer += 1 37 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 38 | 39 | self.model_config = get_model_config(self.vision_tower_name) 40 | 41 | if not delay_load: 42 | rank0_print(f"Loading vision tower: {vision_tower}") 43 | self.load_model() 44 | elif getattr(args, "unfreeze_mm_vision_tower", False): 45 | # TODO: better detector is needed. 46 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 47 | self.load_model() 48 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 49 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 50 | self.load_model() 51 | 52 | def load_model(self): 53 | rank0_print(f"Loading: {self.vision_tower_name}") 54 | rank0_print(f"Pretrained: {self.pretrained}") 55 | time_start = time.time() 56 | model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") 57 | time_end = time.time() 58 | rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") 59 | self.device = next(model.parameters()).device 60 | self.dtype = next(model.parameters()).dtype 61 | if self.device.type != "meta": 62 | model = model.to("cuda") 63 | self.vision_tower = model.visual 64 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 65 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 66 | self.resize_transform_size = resize_transform.size 67 | self.image_processor = CLIPImageProcessor.from_pretrained( 68 | "openai/clip-vit-large-patch14", 69 | crop_size=resize_transform.size, 70 | size={"shortest_edge": resize_transform.size}, 71 | image_mean=list(normalize_transform.mean), 72 | image_std=list(normalize_transform.std), 73 | ) 74 | rank0_print(f"Loaded image processor: {self.image_processor}") 75 | self.vision_tower.requires_grad_(False) 76 | self.is_loaded = True 77 | 78 | def feature_select(self, image_features): 79 | select_feature_type = self.select_feature 80 | 81 | # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 82 | # select_every_k_layer = len(image_features) // 4 83 | # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) 84 | # select_feature_type = select_feature_type.replace("slicefour_", "") 85 | # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 86 | # select_layers = [-1, -4, -7, -10, 6] 87 | # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) 88 | # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 89 | # else: 90 | # image_features = image_features[self.select_layer] 91 | 92 | if select_feature_type == "patch": 93 | image_features = image_features[:, 1:] 94 | elif select_feature_type == "cls_patch": 95 | image_features = image_features 96 | else: 97 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 98 | return image_features 99 | 100 | def train(self, mode=True): 101 | self.training = mode 102 | 103 | if self.is_loaded: 104 | self.vision_tower.eval() 105 | 106 | def forward(self, images): 107 | if type(images) is list: 108 | image_features = [] 109 | for image in images: 110 | image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) 111 | image_features = self.feature_select(image_features).to(self.dtype) 112 | image_features.append(image_features) 113 | else: 114 | image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) 115 | image_features = self.feature_select(image_features).to(self.dtype) 116 | 117 | return image_features 118 | 119 | @property 120 | def dummy_feature(self): 121 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 122 | 123 | @property 124 | def hidden_size(self): 125 | return self.model_config["vision_cfg"]["width"] 126 | 127 | @property 128 | def num_patches(self): 129 | return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 130 | 131 | @property 132 | def num_patches_per_side(self): 133 | return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] 134 | 135 | @property 136 | def config(self): 137 | return self.model_config 138 | 139 | @property 140 | def image_size(self): 141 | return self.model_config["vision_cfg"]["image_size"] 142 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import EVAEncoderWrapper 6 | from .factory import list_models, add_model_config, get_model_config 7 | 8 | from llavaction.utils import rank0_print 9 | 10 | 11 | class EvaClipVisionTower(nn.Module): 12 | def __init__(self, vision_tower, args, delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | self.vision_tower_name = vision_tower 17 | self.vision_tower_pretrained = args.vision_tower_pretrained 18 | self.config = get_model_config(vision_tower) 19 | 20 | if not delay_load: 21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") 22 | self.load_model() 23 | elif getattr(args, "unfreeze_mm_vision_tower", False): 24 | # TODO: better detector is needed. 25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 26 | self.load_model() 27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 29 | self.load_model() 30 | else: 31 | self.cfg_only = self.config 32 | 33 | def load_model(self, device_map=None): 34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}") 35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) 36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) 37 | rank0_print(f"Loaded image processor: {self.image_processor}") 38 | self.vision_tower.requires_grad_(False) 39 | self.is_loaded = True 40 | 41 | def forward(self, images): 42 | if type(images) is list: 43 | image_features = [] 44 | for image in images: 45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 49 | 50 | return image_features 51 | 52 | @property 53 | def dtype(self): 54 | return self.vision_tower.dtype 55 | 56 | @property 57 | def device(self): 58 | return self.vision_tower.device 59 | 60 | @property 61 | def hidden_size(self): 62 | return self.config["vision_cfg"]["width"] 63 | 64 | @property 65 | def num_patches(self): 66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 67 | 68 | @property 69 | def num_patches_per_side(self): 70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] 71 | 72 | @property 73 | def image_size(self): 74 | return self.config["vision_cfg"]["image_size"] 75 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | """ 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {"height": self.image_size, "width": self.image_size} 69 | 70 | @property 71 | def size(self): 72 | return {"shortest_edge": self.image_size} 73 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 13 | 14 | 15 | def _natural_key(string_): 16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 17 | 18 | 19 | def _rescan_model_configs(): 20 | global _MODEL_CONFIGS 21 | 22 | config_ext = (".json",) 23 | config_files = [] 24 | for config_path in _MODEL_CONFIG_PATHS: 25 | if config_path.is_file() and config_path.suffix in config_ext: 26 | config_files.append(config_path) 27 | elif config_path.is_dir(): 28 | for ext in config_ext: 29 | config_files.extend(config_path.glob(f"*{ext}")) 30 | 31 | for cf in config_files: 32 | with open(cf, "r", encoding="utf8") as f: 33 | model_cfg = json.load(f) 34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): 35 | _MODEL_CONFIGS[cf.stem] = model_cfg 36 | 37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 38 | 39 | 40 | _rescan_model_configs() # initial populate of model config registry 41 | 42 | 43 | def list_models(): 44 | """enumerate available model architectures based on config files""" 45 | return list(_MODEL_CONFIGS.keys()) 46 | 47 | 48 | def add_model_config(path): 49 | """add model config path or file and update registry""" 50 | if not isinstance(path, Path): 51 | path = Path(path) 52 | _MODEL_CONFIG_PATHS.append(path) 53 | _rescan_model_configs() 54 | 55 | 56 | def get_model_config(model_name): 57 | if model_name in _MODEL_CONFIGS: 58 | return deepcopy(_MODEL_CONFIGS[model_name]) 59 | else: 60 | return None 61 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/hf_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor 5 | from llavaction.utils import rank0_print 6 | 7 | 8 | class HFVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1) 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | try: 25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) 26 | except Exception as e: 27 | if "448" in self.vision_tower_name: 28 | image_size = 448 29 | # use image processor with conig 30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) 31 | else: 32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | rank0_print(f"Loaded image processor: {self.image_processor}") 34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") 35 | self.device = self.vision_tower.device 36 | self.dtype = self.vision_tower.dtype 37 | self.config = self.vision_tower.config 38 | 39 | if hasattr(self.vision_tower, "vision_model"): 40 | self.vision_tower = self.vision_tower.vision_model 41 | self.vision_tower.requires_grad_(False) 42 | # self.vision_tower.eval() 43 | self.is_loaded = True 44 | 45 | def feature_select(self, image_forward_outs): 46 | select_feature_type = self.select_feature 47 | 48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 51 | select_feature_type = select_feature_type.replace("slicefour_", "") 52 | else: 53 | image_features = image_forward_outs.hidden_states[self.select_layer] 54 | 55 | if select_feature_type == "patch": 56 | image_features = image_features[:, 1:] 57 | elif select_feature_type == "cls_patch": 58 | image_features = image_features 59 | else: 60 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 61 | return image_features 62 | 63 | def forward(self, images): 64 | if type(images) is list: 65 | image_features = [] 66 | for image in images: 67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 68 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 69 | image_features.append(image_feature) 70 | else: 71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 72 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 73 | 74 | return image_features 75 | 76 | @property 77 | def dummy_feature(self): 78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 79 | 80 | # @property 81 | # def dtype(self): 82 | # return self.vision_tower.dtype 83 | 84 | # @property 85 | # def device(self): 86 | # return self.vision_tower.device 87 | 88 | @property 89 | def hidden_size(self): 90 | try: 91 | _hidden_size = self.config.hidden_size 92 | except: 93 | _hidden_size = self.config.vision_config.hidden_size 94 | if "slicefour" in self.select_feature: 95 | _hidden_size *= 4 96 | return _hidden_size 97 | 98 | @property 99 | def num_patches(self): 100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 101 | if "cls_patch" in self.select_feature: 102 | _num_patches += 1 103 | return _num_patches 104 | 105 | @property 106 | def num_patches_per_side(self): 107 | return self.config.image_size // self.config.patch_size 108 | 109 | @property 110 | def image_size(self): 111 | return self.config.image_size 112 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_encoder/open_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor 4 | from llavaction.utils import rank0_print 5 | 6 | try: 7 | import open_clip 8 | import torchvision 9 | from open_clip.transformer import _expand_token 10 | except ImportError: 11 | print("OpenCLIP not installed") 12 | open_clip = None 13 | 14 | HIDDEN_SIZE_DICT = { 15 | "ViT-H-14-378-quickgelu": 1280, 16 | } 17 | 18 | 19 | class OpenCLIPVisionTower(nn.Module): 20 | def __init__(self, vision_tower, args, delay_load=False): 21 | super().__init__() 22 | 23 | self.is_loaded = False 24 | self.model_name = vision_tower.replace("open_clip_hub:", "") 25 | self.pretrained = args.vision_tower_pretrained 26 | self.select_layer = args.mm_vision_select_layer 27 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 28 | 29 | if not delay_load: 30 | rank0_print(f"Loading vision tower: {vision_tower}") 31 | self.load_model() 32 | elif getattr(args, "unfreeze_mm_vision_tower", False): 33 | # TODO: better detector is needed. 34 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 35 | self.load_model() 36 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 37 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 38 | self.load_model() 39 | 40 | def load_model(self, device_map="auto"): 41 | rank0_print(f"Loading OpenCLIP model: {self.model_name}") 42 | rank0_print(f"Pretrained: {self.pretrained}") 43 | vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") 44 | 45 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 46 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 47 | self.resize_transform_size = resize_transform.size # 224 or 384 48 | self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 49 | 50 | self.image_processor = CLIPImageProcessor.from_pretrained( 51 | "openai/clip-vit-large-patch14", 52 | crop_size=resize_transform.size, 53 | size={"shortest_edge": resize_transform.size}, 54 | image_mean=list(normalize_transform.mean), 55 | image_std=list(normalize_transform.std), 56 | ) 57 | rank0_print(f"Loaded image processor: {self.image_processor}") 58 | self.vision_tower = vision_tower.visual 59 | self.vision_tower.requires_grad_(False) 60 | 61 | self.is_loaded = True 62 | 63 | def feature_select(self, image_forward_outs): 64 | image_features = image_forward_outs[self.select_layer] 65 | if self.select_feature == "patch": 66 | image_features = image_features[:, 1:] 67 | elif self.select_feature == "cls_patch": 68 | image_features = image_features 69 | elif self.select_feature == "conv_flatten": 70 | image_features = image_features.flatten(2).transpose(1, 2) 71 | else: 72 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 73 | return image_features 74 | 75 | def forward_visual(self, x, output_hidden_states=False): 76 | if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): 77 | return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) 78 | else: 79 | 80 | def forward_openclip(self, x: torch.Tensor): 81 | features = [] 82 | x = self.conv1(x) # shape = [*, width, grid, grid] 83 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 84 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 85 | 86 | # class embeddings and positional embeddings 87 | x = torch.cat( 88 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], 89 | dim=1, 90 | ) 91 | # shape = [*, grid ** 2 + 1, width] 92 | x = x + self.positional_embedding.to(x.dtype) 93 | 94 | x = self.patch_dropout(x) 95 | x = self.ln_pre(x) 96 | 97 | x = x.permute(1, 0, 2) # NLD -> LND 98 | for r in self.transformer.resblocks: 99 | x = r(x, attn_mask=None) 100 | features.append(x) 101 | return features 102 | 103 | return forward_openclip(self.vision_tower, x) 104 | 105 | def forward(self, images): 106 | if type(images) is list: 107 | image_features = [] 108 | for image in images: 109 | image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) 110 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 111 | image_features.append(image_feature) 112 | else: 113 | image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) 114 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 115 | 116 | return image_features 117 | 118 | @property 119 | def dummy_feature(self): 120 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 121 | 122 | @property 123 | def dtype(self): 124 | if hasattr(self.vision_tower, "conv1"): 125 | return self.vision_tower.conv1.weight.dtype 126 | if hasattr(self.vision_tower, "trunk"): 127 | return self.vision_tower.trunk.patch_embed.proj.weight.dtype 128 | raise NotImplementedError 129 | 130 | @property 131 | def device(self): 132 | if hasattr(self.vision_tower, "conv1"): 133 | return self.vision_tower.conv1.weight.device 134 | if hasattr(self.vision_tower, "trunk"): 135 | return self.vision_tower.trunk.patch_embed.proj.weight.device 136 | raise NotImplementedError 137 | 138 | @property 139 | def config(self): 140 | return None 141 | 142 | @property 143 | def hidden_size(self): 144 | if self.model_name in HIDDEN_SIZE_DICT: 145 | return HIDDEN_SIZE_DICT[self.model_name] 146 | else: 147 | raise NotImplementedError 148 | 149 | @property 150 | def num_patches(self): 151 | image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] 152 | _num_patches = (image_size // self.patch_size) ** 2 153 | if "cls_patch" in self.select_feature: 154 | _num_patches += 1 155 | return _num_patches 156 | 157 | @property 158 | def image_size(self): 159 | return self.resize_transform_size 160 | 161 | @property 162 | def num_patches_per_side(self): 163 | return self.resize_transform_size // self.patch_size 164 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /llavaction/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /llavaction/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llavaction/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/LLaVAction/bd9f46583c5a94333575eccdf92e1b2bd118f7bd/llavaction/serve/__init__.py -------------------------------------------------------------------------------- /llavaction/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llavaction.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llavaction.conversation import conv_templates, SeparatorStyle 6 | from llavaction.model.builder import load_pretrained_model 7 | from llavaction.utils import disable_torch_init 8 | from llavaction.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith("http") or image_file.startswith("https"): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert("RGB") 22 | else: 23 | image = Image.open(image_file).convert("RGB") 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ("user", "assistant") 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda() 56 | 57 | while True: 58 | try: 59 | inp = input(f"{roles[0]}: ") 60 | except EOFError: 61 | inp = "" 62 | if not inp: 63 | print("exit...") 64 | break 65 | 66 | print(f"{roles[1]}: ", end="") 67 | 68 | if image is not None: 69 | # first message 70 | if model.config.mm_use_im_start_end: 71 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp 72 | else: 73 | inp = DEFAULT_IMAGE_TOKEN + "\n" + inp 74 | conv.append_message(conv.roles[0], inp) 75 | image = None 76 | else: 77 | # later messages 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 83 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 84 | keywords = [stop_str] 85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 87 | 88 | with torch.inference_mode(): 89 | output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) 90 | 91 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip() 92 | conv.messages[-1][-1] = outputs 93 | 94 | if args.debug: 95 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 101 | parser.add_argument("--model-base", type=str, default=None) 102 | parser.add_argument("--image-file", type=str, required=True) 103 | parser.add_argument("--num-gpus", type=int, default=1) 104 | parser.add_argument("--conv-mode", type=str, default=None) 105 | parser.add_argument("--temperature", type=float, default=0.2) 106 | parser.add_argument("--max-new-tokens", type=int, default=512) 107 | parser.add_argument("--load-8bit", action="store_true") 108 | parser.add_argument("--load-4bit", action="store_true") 109 | parser.add_argument("--debug", action="store_true") 110 | args = parser.parse_args() 111 | main(args) 112 | -------------------------------------------------------------------------------- /llavaction/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/LLaVAction/bd9f46583c5a94333575eccdf92e1b2bd118f7bd/llavaction/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llavaction/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/LLaVAction/bd9f46583c5a94333575eccdf92e1b2bd118f7bd/llavaction/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llavaction/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llavaction/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llavaction.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) 21 | worker_addr = ret.json()["address"] 22 | print(f"worker_addr: {worker_addr}") 23 | 24 | if worker_addr == "": 25 | return 26 | 27 | conv = default_conversation.copy() 28 | conv.append_message(conv.roles[0], args.message) 29 | prompt = conv.get_prompt() 30 | 31 | headers = {"User-Agent": "LLaVA Client"} 32 | pload = { 33 | "model": args.model_name, 34 | "prompt": prompt, 35 | "max_new_tokens": args.max_new_tokens, 36 | "temperature": 0.7, 37 | "stop": conv.sep, 38 | } 39 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) 40 | 41 | print(prompt.replace(conv.sep, "\n"), end="") 42 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 43 | if chunk: 44 | data = json.loads(chunk.decode("utf-8")) 45 | output = data["text"].split(conv.sep)[-1] 46 | print(output, end="\r") 47 | print("") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 53 | parser.add_argument("--worker-address", type=str) 54 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 55 | parser.add_argument("--max-new-tokens", type=int, default=32) 56 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") 57 | args = parser.parse_args() 58 | 59 | main() 60 | -------------------------------------------------------------------------------- /llavaction/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | padding_mask: Optional[torch.Tensor] = None, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | if output_attentions: 27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 28 | 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) 34 | 35 | kv_seq_len = key_states.shape[-2] 36 | if past_key_value is not None: 37 | kv_seq_len += past_key_value[0].shape[-2] 38 | 39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 41 | 42 | if past_key_value is not None: 43 | # reuse k, v 44 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 45 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 46 | 47 | past_key_value = (key_states, value_states) if use_cache else None 48 | 49 | # repeat k/v heads if n_kv_heads < n_heads 50 | key_states = repeat_kv(key_states, self.num_key_value_groups) 51 | value_states = repeat_kv(value_states, self.num_key_value_groups) 52 | 53 | # Transform the data into the format required by flash attention 54 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 56 | key_padding_mask = attention_mask 57 | 58 | if key_padding_mask is None: 59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 61 | max_s = q_len 62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 63 | output = output.view(bsz, q_len, -1) 64 | else: 65 | qkv = qkv.reshape(bsz, q_len, -1) 66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 70 | output = pad_input(output_unpad, indices, bsz, q_len) 71 | 72 | return self.o_proj(output), None, past_key_value 73 | 74 | 75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 76 | # requires the attention mask to be the same as the key_padding_mask 77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 78 | # [bsz, seq_len] 79 | return attention_mask 80 | 81 | 82 | def replace_llama_attn_with_flash_attn(): 83 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 84 | if cuda_major < 8: 85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") 86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 88 | -------------------------------------------------------------------------------- /llavaction/train/llava_trainer_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from llavaction.train.llava_trainer import LLaVATrainer 5 | 6 | 7 | class LLaVAEvalTrainer(LLaVATrainer): 8 | def evaluate(self, evaluate_args): 9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ 10 | --model {evaluate_args.model} \ 11 | --model_args {evaluate_args.model_args} \ 12 | --tasks {evaluate_args.task_names} \ 13 | --batch_size {evaluate_args.batch_size} \ 14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \ 15 | --output_path {evaluate_args.output_path}" 16 | if evaluate_args.limit: 17 | cmd += f" --limit {evaluate_args.limit}" 18 | if evaluate_args.num_fewshot: 19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}" 20 | if evaluate_args.gen_kwargs != "": 21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" 22 | if evaluate_args.log_samples: 23 | cmd += f" --log_samples" 24 | else: 25 | assert False, "Please log samples so that the result can be parsed" 26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True) 27 | try: 28 | result_file_index_start = results.stdout.index("Saved samples to ") 29 | result_file_index_end = results.stdout.index(f".json") 30 | result_file_index_start += len("Saved samples to ") 31 | file = results.stdout[result_file_index_start:result_file_index_end] 32 | except: 33 | result_file_index_start = results.stderr.index("Saved samples to ") 34 | result_file_index_end = results.stderr.index(f".json") 35 | result_file_index_start += len("Saved samples to ") 36 | file = results.stderr[result_file_index_start:result_file_index_end] 37 | file = file.split("/")[:-1] 38 | file = "/".join(file) + "/results.json" 39 | with open(file, "r") as f: 40 | lmms_eval_results = json.load(f) 41 | result_dict = {} 42 | tasks_list = evaluate_args.task_names.split(",") 43 | for task in tasks_list: 44 | task_results = lmms_eval_results["results"][task] 45 | for k, v in task_results.items(): 46 | if k != "alias" and "stderr" not in k: 47 | metric = k.split(",")[0] 48 | result_dict[f"{task}_{metric}"] = v 49 | return result_dict 50 | 51 | """def evaluate(self, evaluate_args): 52 | initialize_tasks() 53 | tasks_list = evaluate_args.task_names.split(",") 54 | result_dict = {} 55 | results = evaluator.simple_evaluate( 56 | model=evaluate_args.model, 57 | model_args=evaluate_args.model_args, 58 | tasks=tasks_list, 59 | num_fewshot=evaluate_args.num_fewshot, 60 | batch_size=evaluate_args.batch_size, 61 | device=evaluate_args.device, 62 | limit=evaluate_args.limit, 63 | check_integrity=evaluate_args.check_integrity, 64 | show_task_to_terminal=evaluate_args.show_task_to_terminal, 65 | log_samples=evaluate_args.log_samples, 66 | gen_kwargs=evaluate_args.gen_kwargs, 67 | cli_args=evaluate_args, 68 | ) 69 | for task in tasks_list: 70 | task_results = results["results"][task] 71 | for k,v in task_results.items(): 72 | if k != "alias" and "stderr" not in k: 73 | metric = k.split(",")[0] 74 | result_dict[f"{task}_{metric}"] = v 75 | 76 | return result_dict""" 77 | -------------------------------------------------------------------------------- /llavaction/train/train_mem.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0])) 4 | # sys.path.append(os.path.dirname(sys.path[0])) 5 | 6 | from llavaction.train.train import train 7 | 8 | if __name__ == "__main__": 9 | train() 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "llavaction" 10 | version = "0.0.1" 11 | description = "LLaVAction: Evaluating and Training Multi-Modal Large Language Models for Action Recognition" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | 19 | [project.optional-dependencies] 20 | standalone = [ 21 | "shortuuid", 22 | "httpx==0.24.0", 23 | "einops", 24 | "ftfy", 25 | ] 26 | 27 | 28 | train = [ 29 | "llavaction[standalone]", 30 | "open_clip_torch", 31 | "fastapi", 32 | "markdown2[all]", 33 | "numpy", 34 | "requests", 35 | "sentencepiece", 36 | "uvicorn", 37 | "wandb", 38 | "deepspeed==0.14.4", 39 | "peft==0.4.0", 40 | "bitsandbytes==0.41.0", 41 | "einops==0.6.1", 42 | "einops-exts==0.0.4", 43 | "gradio_client==0.2.9", 44 | "urllib3<=2.0.0", 45 | "pydantic==1.10.8", 46 | "hf_transfer", 47 | "opencv-python", 48 | "av", 49 | "decord", 50 | "tyro", 51 | "scipy", 52 | ] 53 | 54 | 55 | [tool.setuptools.packages.find] 56 | include = ["llavaction*"] 57 | exclude = [ 58 | "assets*", 59 | "benchmark*", 60 | "docs", 61 | "dist*", 62 | "playground*", 63 | "scripts*", 64 | "tests*", 65 | "checkpoints*", 66 | "project_checkpoints*", 67 | "debug_checkpoints*", 68 | "mlx_configs*", 69 | "wandb*", 70 | "notebooks*", 71 | ] 72 | 73 | [tool.wheel] 74 | exclude = [ 75 | "assets*", 76 | "benchmark*", 77 | "docs", 78 | "dist*", 79 | "playground*", 80 | "scripts*", 81 | "tests*", 82 | "checkpoints*", 83 | "project_checkpoints*", 84 | "debug_checkpoints*", 85 | "mlx_configs*", 86 | "wandb*", 87 | "notebooks*", 88 | ] 89 | -------------------------------------------------------------------------------- /resinstall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Re-install the package. By running './reinstall.sh' 4 | # 5 | # Note that llavaction uses the build 6 | # system specified in 7 | # PEP517 https://peps.python.org/pep-0517/ and 8 | # PEP518 https://peps.python.org/pep-0518/ 9 | # and hence there is no setup.py file. 10 | 11 | set -e # abort on error 12 | 13 | pip uninstall -y llavaction 14 | 15 | # Get version 16 | VERSION=0.0.1 17 | echo "Upgrading to LLaVAction v${VERSION}" 18 | 19 | # Upgrade the build system (PEP517/518 compatible) 20 | python3 -m pip install virtualenv 21 | python3 -m pip install --upgrade build 22 | python3 -m build --sdist --wheel . 23 | 24 | # Reinstall the package with most recent version 25 | pip install --upgrade --no-cache-dir "dist/llavaction-${VERSION}-py3-none-any.whl" 26 | -------------------------------------------------------------------------------- /scripts/qwen.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | 4 | device = "cuda" # the device to load the model onto 5 | 6 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B-Chat", torch_dtype=torch.bfloat16, device_map="auto") 7 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B-Chat") 8 | 9 | prompt = "Give me a short introduction to large language model." 10 | messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] 11 | text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 12 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 13 | 14 | generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512) 15 | generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] 16 | 17 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 18 | 19 | print(response) 20 | -------------------------------------------------------------------------------- /scripts/train/avion_tim_top5_gpt4o_detection_direct.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/avion_mc_top5_GT_random_narration/train_convs_narration_actionids.jsonl 3 | sampling_strategy: all 4 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/cross_validation/tim_mc_top5_GT_random_narration/train_convs_narration_actionids.jsonl 5 | sampling_strategy: all 6 | - json_path: /anonymous/VFM/llava_data/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl 7 | sampling_strategy: all 8 | - json_path: /anonymous/VFM/llava_data/first_person_annos/train_anno_gpt-gt-instruct-reason_4_first_person_all_action_idx.jsonl 9 | sampling_strategy: all 10 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/temporal_detection.jsonl 11 | sampling_strategy: all 12 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/direct_narration_GT_random_narration/train_convs_narration_actionids.jsonl 13 | sampling_strategy: all 14 | -------------------------------------------------------------------------------- /scripts/train/avion_tim_top5_gpt4o_detection_direct_178K_100percent.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/avion_mc_top5_GT_random_narration/train_convs_narration_actionids.jsonl 3 | sampling_strategy: all 4 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/cross_validation/tim_mc_top5_GT_random_narration/train_convs_narration_actionids.jsonl 5 | sampling_strategy: all 6 | - json_path: /anonymous/VFM/llava_data/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl 7 | sampling_strategy: all 8 | - json_path: /anonymous/VFM/llava_data/first_person_annos/train_anno_gpt-gt-instruct-reason_4_first_person_all_action_idx.jsonl 9 | sampling_strategy: all 10 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/temporal_detection.jsonl 11 | sampling_strategy: all 12 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/direct_narration_GT_random_narration/train_convs_narration_actionids.jsonl 13 | sampling_strategy: all 14 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_academic_v0_1/0_30_s_academic_v0_1_cap_processed.json 15 | sampling_strategy: all 16 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_youtube_v0_1/0_30_s_youtube_v0_1_cap_processed.json 17 | sampling_strategy: all 18 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_academic_v0_1/30_60_s_academic_v0_1_cap_processed.json 19 | sampling_strategy: all 20 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_youtube_v0_1/30_60_s_youtube_v0_1_cap_processed.json 21 | sampling_strategy: all 22 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_academic_v0_1/1_2_m_academic_v0_1_cap_processed.json 23 | sampling_strategy: all 24 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_youtube_v0_1/1_2_m_youtube_v0_1_cap_processed.json 25 | sampling_strategy: all 26 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_academic_v0_1/0_30_s_academic_oe_v0_1_qa_processed.json 27 | sampling_strategy: all 28 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_academic_v0_1/0_30_s_academic_oe_v0_1_qa_processed.json 29 | sampling_strategy: all 30 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_youtube_v0_1/0_30_s_youtube_oe_v0_1_qa_processed.json 31 | sampling_strategy: all 32 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_youtube_v0_1/0_30_s_youtube_mc_v0_1_qa_processed.json 33 | sampling_strategy: all 34 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_activitynetqa/0_30_s_activitynetqa_oe_qa_processed.json 35 | sampling_strategy: all 36 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_nextqa/0_30_s_nextqa_oe_qa_processed.json 37 | sampling_strategy: all 38 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_nextqa/0_30_s_nextqa_mc_qa_processed.json 39 | sampling_strategy: all 40 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/0_30_s_perceptiontest/0_30_s_perceptiontest_mc_qa_processed.json 41 | sampling_strategy: all 42 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_academic_v0_1/30_60_s_academic_oe_v0_1_qa_processed.json 43 | sampling_strategy: all 44 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_academic_v0_1/30_60_s_academic_mc_v0_1_qa_processed.json 45 | sampling_strategy: all 46 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_youtube_v0_1/30_60_s_youtube_oe_v0_1_qa_processed.json 47 | sampling_strategy: all 48 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_youtube_v0_1/30_60_s_youtube_mc_v0_1_qa_processed.json 49 | sampling_strategy: all 50 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_activitynetqa/30_60_s_activitynetqa_oe_qa_processed.json 51 | sampling_strategy: all 52 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_nextqa/30_60_s_nextqa_oe_qa_processed.json 53 | sampling_strategy: all 54 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_nextqa/30_60_s_nextqa_mc_qa_processed.json 55 | sampling_strategy: all 56 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/30_60_s_perceptiontest/30_60_s_perceptiontest_mc_qa_processed.json 57 | sampling_strategy: all 58 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_academic_v0_1/1_2_m_academic_oe_v0_1_qa_processed.json 59 | sampling_strategy: all 60 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_academic_v0_1/1_2_m_academic_mc_v0_1_qa_processed.json 61 | sampling_strategy: all 62 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_youtube_v0_1/1_2_m_youtube_oe_v0_1_qa_processed.json 63 | sampling_strategy: all 64 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_youtube_v0_1/1_2_m_youtube_mc_v0_1_qa_processed.json 65 | sampling_strategy: all 66 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_activitynetqa/1_2_m_activitynetqa_oe_qa_processed.json 67 | sampling_strategy: all 68 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_nextqa/1_2_m_nextqa_oe_qa_processed.json 69 | sampling_strategy: all 70 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/1_2_m_nextqa/1_2_m_nextqa_mc_qa_processed.json 71 | sampling_strategy: all 72 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_academic_v0_1/2_3_m_academic_v0_1_cap_processed.json 73 | sampling_strategy: all 74 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_youtube_v0_1/2_3_m_youtube_v0_1_cap_processed.json 75 | sampling_strategy: all 76 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_academic_v0_1/2_3_m_academic_oe_v0_1_qa_processed.json 77 | sampling_strategy: all 78 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_academic_v0_1/2_3_m_academic_mc_v0_1_qa_processed.json 79 | sampling_strategy: all 80 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_youtube_v0_1/2_3_m_youtube_oe_v0_1_qa_processed.json 81 | sampling_strategy: all 82 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_youtube_v0_1/2_3_m_youtube_mc_v0_1_qa_processed.json 83 | sampling_strategy: all 84 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_nextqa/2_3_m_nextqa_oe_qa_processed.json 85 | sampling_strategy: all 86 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_nextqa/2_3_m_nextqa_mc_qa_processed.json 87 | sampling_strategy: all 88 | - json_path: /anonymous/VFM/onevision/llava_video/LLaVA-Video-178K/2_3_m_activitynetqa/2_3_m_activitynetqa_oe_qa_processed.json 89 | sampling_strategy: all -------------------------------------------------------------------------------- /scripts/train/tim_top20_official_key_gpt4o_direct_detection.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/cross_validation/tim_mc_top20_official_key/train_convs_narration_actionids.jsonl 3 | sampling_strategy: all 4 | - json_path: /anonymous/VFM/llava_data/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl 5 | sampling_strategy: all 6 | - json_path: /anonymous/VFM/llava_data/first_person_annos/train_anno_gpt-gt-instruct-reason_4_first_person_all_action_idx.jsonl 7 | sampling_strategy: all 8 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/temporal_detection.jsonl 9 | sampling_strategy: all 10 | - json_path: /anonymous/VFM/llava_data/EK100_inst_train/direct_narration_official_key/train_convs_narration_actionids.jsonl 11 | sampling_strategy: all -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_fused_adamw.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": true, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu", 20 | "pin_memory": true 21 | }, 22 | "offload_param": { 23 | "device": "cpu", 24 | "pin_memory": true 25 | }, 26 | "overlap_comm": true, 27 | "contiguous_gradients": true, 28 | "sub_group_size": 1e9, 29 | "reduce_bucket_size": "auto" 30 | } 31 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /scripts/zero3pp.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "none", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "none", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "zero_quantized_weights": true, 36 | "zero_hpz_partition_size": 16, 37 | "zero_quantized_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | 47 | "gradient_accumulation_steps": "auto", 48 | "gradient_clipping": "auto", 49 | "steps_per_print": 100, 50 | "train_batch_size": "auto", 51 | "train_micro_batch_size_per_gpu": "auto", 52 | "wall_clock_breakdown": false 53 | } --------------------------------------------------------------------------------