├── .gitignore ├── LICENSE ├── README.md ├── easy_context ├── __init__.py ├── accelerate_configs │ ├── deepspeed_inference.yaml │ ├── single_node.yaml │ ├── two_node.yaml │ ├── zero3_offload.json │ └── zero3_offload_inference.json ├── dist_flash_attn │ ├── README.md │ ├── async_communication.py │ ├── lightseq_async_attn.py │ ├── lightseq_async_attn_varlen.py │ ├── monkey_patch.py │ └── prepare_input.py ├── low_mem_cross_ent.py ├── low_mem_cross_ent_tests │ ├── test_correctness.py │ └── test_mem_and_speed.py ├── modeling_qwen2.py ├── ulysses_attn │ ├── monkey_patch.py │ └── prepare_inputs.py ├── unsloth_offloaded_gradient_checkpoint │ └── monkey_patch.py └── zigzag_ring_attn │ ├── monkey_patch.py │ └── prepare_inputs.py ├── local_demo ├── assets │ ├── assistant_logo.png │ ├── dc_demo.mp4 │ ├── favicon.ico │ ├── jobs.mp4 │ ├── llava_logo.png │ ├── llava_next.jpg │ ├── lmms-eval.png │ ├── otter_books.jpg │ ├── user_example_01.jpg │ ├── user_example_02.jpg │ ├── user_example_03.jpg │ ├── user_example_04.jpg │ ├── user_example_05.jpg │ ├── user_example_06.jpg │ ├── user_example_07.jpg │ ├── user_example_08.jpg │ ├── user_example_09.jpg │ ├── user_example_10.png │ ├── user_example_11.png │ ├── user_logo.png │ ├── water.mp4 │ └── white_cat_smile.jpg ├── cache │ ├── 0266b3f7b9311a93f0f38bc8cf0698b019c829b1 │ │ └── user_example_09.jpg │ ├── 0333f1f38db26d76752fee6cd2d938f06617dd2f │ │ └── dc_demo.mp4 │ ├── 0c299f0f979725494d926c3e76cfeb75c19ef564 │ │ └── assistant_logo.png │ ├── 3c94cf4ccdbb877666285b42dd9263073d1fa19c │ │ └── user_example_07.jpg │ ├── 447d9ea0a12a7be4cf7f279dc4e667b1b938a0c2 │ │ └── otter_books.jpg │ ├── 7abebd5007998873c3d36612f984d241c3b574d8 │ │ └── user_logo.png │ ├── 7dcdac5ff3df641ae7eb47d8b7aaeab740e02de9 │ │ └── user_example_06.jpg │ ├── b7b8f04e74cda37d9a574c1a2098d7e4ee97b212 │ │ └── user_example_05.jpg │ └── dbb5054b8579d1943ff74045f254e0e759efec99 │ │ └── white_cat_smile.jpg ├── constants.py ├── debug_json_api.py ├── longva_backend.py ├── multimodal_chat.py ├── theme_dropdown.py └── themes │ └── theme_schema@0.1.1.json ├── longva ├── data_processing │ └── utils.py ├── longva │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ ├── model │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── language_model │ │ │ ├── llava_llama.py │ │ │ ├── llava_mistral.py │ │ │ ├── llava_mpt.py │ │ │ ├── llava_qwen.py │ │ │ └── modeling_llama.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ ├── multimodal_projector │ │ │ ├── builder.py │ │ │ └── pooler_projector.py │ │ ├── multimodal_resampler │ │ │ ├── builder.py │ │ │ ├── masked_drop.py │ │ │ ├── perceiver.py │ │ │ ├── qformer.py │ │ │ └── spatial_pool.py │ │ └── utils.py │ ├── train │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ ├── train_dpo.py │ │ └── train_mem.py │ └── utils.py ├── pyproject.toml ├── scripts │ ├── dpo.sh │ ├── finetune.sh │ ├── pretrain.sh │ └── zero3.json └── trl │ ├── __init__.py │ ├── core.py │ ├── environment │ ├── __init__.py │ └── base_environment.py │ ├── extras │ ├── __init__.py │ ├── best_of_n_sampler.py │ └── dataset_formatting.py │ ├── import_utils.py │ ├── models │ ├── __init__.py │ ├── modeling_base.py │ ├── modeling_sd_base.py │ ├── modeling_value_head.py │ └── utils.py │ └── trainer │ ├── __init__.py │ ├── base.py │ ├── ddpo_config.py │ ├── ddpo_trainer.py │ ├── dpo_trainer.py │ ├── iterative_sft_trainer.py │ ├── model_config.py │ ├── ppo_config.py │ ├── ppo_trainer.py │ ├── reward_config.py │ ├── reward_trainer.py │ ├── sft_trainer.py │ └── utils.py ├── requirements.txt ├── text_extend ├── PaulGrahamEssays │ ├── addiction.txt │ ├── aord.txt │ ├── apple.txt │ ├── avg.txt │ ├── before.txt │ ├── bias.txt │ ├── boss.txt │ ├── copy.txt │ ├── corpdev.txt │ ├── desres.txt │ ├── diff.txt │ ├── ecw.txt │ ├── founders.txt │ ├── foundervisa.txt │ ├── gap.txt │ ├── gba.txt │ ├── gh.txt │ ├── goodtaste.txt │ ├── hubs.txt │ ├── iflisp.txt │ ├── island.txt │ ├── know.txt │ ├── langdes.txt │ ├── laundry.txt │ ├── love.txt │ ├── mod.txt │ ├── newideas.txt │ ├── nft.txt │ ├── philosophy.txt │ ├── popular.txt │ ├── pow.txt │ ├── rootsoflisp.txt │ ├── rss.txt │ ├── siliconvalley.txt │ ├── startuplessons.txt │ ├── submarine.txt │ ├── sun.txt │ ├── superangels.txt │ ├── todo.txt │ ├── unions.txt │ ├── useful.txt │ ├── vb.txt │ ├── vcsqueeze.txt │ ├── vw.txt │ ├── want.txt │ ├── web20.txt │ ├── weird.txt │ ├── wisdom.txt │ └── worked.txt ├── eval.sh ├── eval_text_niah.py ├── eval_text_ppl.py ├── extend_qwen2.sh ├── niah_output │ ├── distractor_0 │ │ └── Qwen2-7B-Instruct-extend-step_1000 │ │ │ ├── all_accuracies.json │ │ │ ├── avg_accuracy.txt │ │ │ └── heatmap.png │ ├── distractor_3 │ │ ├── Meta-Llama-3-8B-Instruct-extend-step1000 │ │ │ ├── all_accuracies.json │ │ │ ├── avg_accuracy.txt │ │ │ └── heatmap.png │ │ └── Qwen2-7B-Instruct-extend-step_1000 │ │ │ ├── all_accuracies.json │ │ │ ├── avg_accuracy.txt │ │ │ └── heatmap.png │ ├── distractor_5 │ │ └── Qwen2-7B-Instruct-extend-step_1000 │ │ │ ├── all_accuracies.json │ │ │ ├── avg_accuracy.txt │ │ │ └── heatmap.png │ └── git_placeholder ├── plot_ppl.py └── text_extend_train.py └── vision_niah ├── data ├── haystack_embeddings │ └── git_placeholder ├── haystack_videos │ └── git_placeholder └── needle_embeddings │ └── git_placeholder ├── eval.sh ├── eval_vision_niah.py ├── eval_vision_niah_sampling.py ├── model_weights └── git_placeholder ├── needle_datasets ├── dataset.json ├── generate_hf_dataset.py ├── git_placeholder └── images │ ├── astronaut.png │ ├── construction_site.png │ ├── dolphin.png │ ├── panda_scientist.png │ ├── selenium_green.jpg │ ├── sora_balloon.png │ ├── teddy_bear_times_square.png │ └── ucsd.jpeg ├── niah_output ├── LLaVA-NeXT-Video-7B-32K │ ├── rope_theta_1000000 │ │ ├── all_accuracies.json │ │ ├── avg_accuracy.txt │ │ └── heatmap.png │ └── rope_theta_100000000 │ │ ├── all_accuracies.json │ │ ├── avg_accuracy.txt │ │ └── heatmap.png └── LongVA-7B │ ├── all_accuracies.json │ ├── avg_accuracy.txt │ └── heatmap.png ├── produce_haystack_embedding.py └── produce_needle_embedding.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | **__pycache__/ 3 | wandb/ 4 | *.pyc 5 | *.egg-info 6 | *.log 7 | *.log.* 8 | .deepspeed_env 9 | text_extend/training_output 10 | vision_niah/model_weights 11 | **/**.pt 12 | # **/**.mp4 13 | longva/build 14 | vision_niah/data/haystack_videos 15 | local_demo/cache/ 16 | i18n -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. 36 | -------------------------------------------------------------------------------- /easy_context/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs 2 | from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama 3 | from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs 4 | from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama 5 | from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral 6 | from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch 7 | from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs 8 | from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama 9 | from .modeling_qwen2 import Qwen2ForCausalLM_RingAttn 10 | def prepare_seq_parallel_inputs( 11 | seq_algo, input_ids, position_ids, target_ids, rank, world_size, device 12 | ): 13 | if seq_algo == "zigzag_ring_attn": 14 | return prepare_zigzag_ring_attn_inputs( 15 | input_ids, position_ids, target_ids, rank, world_size, device 16 | ) 17 | elif seq_algo == "dist_flash_attn": 18 | return prepare_dist_flash_attn_inputs( 19 | input_ids, position_ids, target_ids, rank, world_size, device 20 | ) 21 | elif seq_algo == "ulysses_attn": 22 | return prepare_ulysses_attn_inputs( 23 | input_ids, position_ids, target_ids, rank, world_size, device 24 | ) 25 | elif seq_algo == "data_parallel": 26 | return { 27 | "local_input_ids": input_ids.to(device), 28 | "local_position_ids": position_ids.to(device), 29 | "local_target_ids": target_ids.to(device), 30 | } 31 | else: 32 | raise ValueError(f"Invalid seq_algo: {seq_algo}") 33 | 34 | def apply_seq_parallel_monkey_patch( 35 | seq_algo, model 36 | ): 37 | assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" 38 | assert model in ["llama", "mistral"], f"Invalid model: {model}" 39 | if seq_algo == "data_parallel": 40 | return 41 | elif seq_algo == "zigzag_ring_attn" and model == "llama": 42 | apply_zigzag_ring_attn_monkey_patch_llama() 43 | elif seq_algo == "zigzag_ring_attn" and model == "mistral": 44 | apply_zigzag_ring_attn_monkey_patch_mistral() 45 | elif seq_algo == "dist_flash_attn" and model == "llama": 46 | apply_dist_flash_attn_monkey_patch_llama() 47 | elif seq_algo == "ulysses_attn" and model == "llama": 48 | apply_ulysses_attn_monkey_patch_llama() 49 | else: 50 | raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") 51 | 52 | def prepare_dataloader(seq_algo, dataloader, acclerator): 53 | if seq_algo == "data_parallel": 54 | return acclerator.prepare(dataloader) 55 | else: 56 | return dataloader -------------------------------------------------------------------------------- /easy_context/accelerate_configs/deepspeed_inference.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: easy_context/accelerate_configs/zero3_offload_inference.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /easy_context/accelerate_configs/single_node.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: easy_context/accelerate_configs/zero3_offload.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 8 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /easy_context/accelerate_configs/two_node.yaml: -------------------------------------------------------------------------------- 1 | debug: false 2 | deepspeed_config: 3 | deepspeed_config_file: easy_context/accelerate_configs/zero3_offload.json 4 | deepspeed_multinode_launcher: standard 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | num_machines: 2 9 | num_processes: 16 10 | main_training_function: main 11 | rdzv_backend: c10d 12 | same_network: false 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /easy_context/accelerate_configs/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "fp16": { 6 | "enabled": "auto" 7 | }, 8 | "scheduler": { 9 | "type": "WarmupLR", 10 | "params": { 11 | "warmup_min_lr": 1e-5, 12 | "warmup_max_lr": 1e-5, 13 | "warmup_num_steps": 0, 14 | "warmup_type": "linear" 15 | } 16 | }, 17 | "optimizer": { 18 | "type": "AdamW", 19 | "params": { 20 | "lr": "auto", 21 | "betas": [0.9, 0.95], 22 | "eps": 1e-8, 23 | "weight_decay": 0.1 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": true, 37 | "contiguous_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 2000, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": 1, 51 | "wall_clock_breakdown": false 52 | } 53 | -------------------------------------------------------------------------------- /easy_context/accelerate_configs/zero3_offload_inference.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "fp16": { 6 | "enabled": "auto" 7 | }, 8 | "zero_optimization": { 9 | "stage": 3, 10 | "stage3_prefetch_bucket_size": 33554432, 11 | "stage3_param_persistence_threshold": 4096, 12 | "stage3_max_live_parameters":33554432, 13 | "offload_param": { 14 | "device": "cpu", 15 | "pin_memory": true 16 | } 17 | }, 18 | "train_batch_size": 8, 19 | "train_micro_batch_size_per_gpu": 1, 20 | "wall_clock_breakdown": false 21 | } -------------------------------------------------------------------------------- /easy_context/dist_flash_attn/README.md: -------------------------------------------------------------------------------- 1 | # LightSeq 2 | Taken from https://github.com/RulinShao/LightSeq. All credits to the authors. 3 | 4 | ``` 5 | @article{li2023lightseq, 6 | title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS}, 7 | author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao}, 8 | journal={arXiv preprint arXiv:2310.03294}, 9 | year={2023} 10 | } 11 | ``` -------------------------------------------------------------------------------- /easy_context/dist_flash_attn/prepare_input.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def extract_local(value, rank, world_size, device, dim=1): 4 | value_local = value.chunk(world_size, dim=dim)[rank] 5 | return value_local.to(device) 6 | 7 | 8 | def prepare_dist_flash_attn_inputs( 9 | input_ids, position_ids, target_ids, rank, world_size, device 10 | ): 11 | local_input_ids = extract_local( 12 | input_ids, 13 | rank, 14 | world_size, 15 | device, 16 | ) 17 | local_position_ids = extract_local( 18 | position_ids, 19 | rank, 20 | world_size, 21 | device, 22 | ) 23 | if target_ids is not None: 24 | local_target_ids = extract_local( 25 | target_ids, 26 | rank, 27 | world_size, 28 | device, 29 | ) 30 | else: 31 | local_target_ids = None 32 | return { 33 | "local_input_ids": local_input_ids, 34 | "local_position_ids": local_position_ids, 35 | "local_target_ids": local_target_ids, 36 | } -------------------------------------------------------------------------------- /easy_context/low_mem_cross_ent.py: -------------------------------------------------------------------------------- 1 | """Low memory cross entropy without materilizing the logits 2 | 3 | This module enables long-context training of large vocab models, e.g., Gemma has 250K vocab and Llama 3 has 150K 4 | 5 | Yao Fu, University of Edinburgh 6 | yao.fu@ed.ac.uk 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | def cross_ent_normal(x, weight, labels): 14 | logits = torch.einsum("bsh, vh -> bsv", x, weight) 15 | vocab = weight.size(0) 16 | loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1)) 17 | return loss 18 | 19 | class LowMemLogitProjCrossEnt(torch.autograd.Function): 20 | """Low memory implementation of logits projection plus cross entropy loss. 21 | Useful for reducing the peak memory when dealing with vocabulary larger than 100000 22 | 23 | TODO: integrate this function into easy context 24 | 25 | Two tricks used here 26 | 1. Shard the data to reduce peak memory 27 | 2. Do not save the logits 28 | """ 29 | 30 | @staticmethod 31 | # @torch.compile() # Currently we do not use torch.compile because it uses additional memory 32 | def forward(ctx, x: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, sp: int=4): 33 | """ 34 | Args: 35 | x: size = [batch, seqlen, hidden] 36 | weight: size = [vocab, hidden] 37 | labels: size = [batch, seqlen] 38 | """ 39 | bsz, seqlen, hidden = x.size() 40 | vocab = weight.size(0) 41 | micro_seqlen = seqlen // sp 42 | 43 | loss = 0 44 | for i in range(sp): # shard data along the sequence dimension 45 | logits_i_slice = torch.einsum("bsh, vh -> bsv", x[:, micro_seqlen * i: micro_seqlen * (i + 1)], weight) 46 | loss_i = F.cross_entropy(logits_i_slice.reshape(-1, vocab), labels[:, micro_seqlen * i: micro_seqlen * (i + 1)].reshape(-1)) 47 | loss = loss + loss_i 48 | 49 | loss = loss / sp 50 | ctx.save_for_backward(x, weight, labels) # because we do no save logits, we save memory 51 | ctx.sp = sp 52 | return loss 53 | 54 | # @torch.compile() 55 | @staticmethod 56 | def backward(ctx, grad_output): 57 | """Manually calculate the gradient in a memory-efficient way 58 | Ref: https://indii.org/blog/gradients-of-softmax-and-logsumexp/ 59 | """ 60 | x, weight, labels = ctx.saved_tensors 61 | sp = ctx.sp 62 | device = x.device 63 | dtype = x.dtype 64 | bsz, seqlen, hidden = x.size() 65 | vocab, hidden = weight.size() 66 | micro_seqlen = seqlen // sp 67 | 68 | d_weight = torch.zeros_like(weight, device=weight.device) 69 | d_x = [] 70 | for i in range(sp): # shard data along sequence dimension, reduce peak memory 71 | x_ = x[:, micro_seqlen * i: micro_seqlen * (i + 1)] 72 | p = F.softmax( 73 | torch.einsum("blh, vh -> blv", x_, weight), 74 | dim=-1 75 | ) 76 | 77 | # memory efficient in-place backprop 78 | # loss -> d_logits 79 | d_logits = -p.view(-1) # [b * l * v] 80 | labels_ = labels[:, micro_seqlen * i: micro_seqlen * (i + 1)].reshape(-1) # [b * l] 81 | index = torch.arange(bsz * micro_seqlen, device=device) * vocab + labels_ 82 | source = torch.tensor([1] * bsz * micro_seqlen, dtype=dtype, device=device) 83 | d_logits.index_add_(0, index, source) 84 | d_logits = -d_logits.view(bsz, micro_seqlen, vocab) / (bsz * seqlen) 85 | 86 | # d_logits -> d_x and d_weight 87 | d_x.append(torch.einsum("blv, vh -> blh", d_logits, weight)) 88 | d_weight += torch.einsum("blv, blh -> vh", d_logits, x_) 89 | 90 | d_weight = grad_output * d_weight 91 | d_x = grad_output * torch.concat(d_x, 1) 92 | return d_x, d_weight, None, None 93 | 94 | low_mem_cross_ent = LowMemLogitProjCrossEnt.apply 95 | -------------------------------------------------------------------------------- /easy_context/low_mem_cross_ent_tests/test_correctness.py: -------------------------------------------------------------------------------- 1 | """Test the correctness (up to certain tolerance of numerical error) of low-memory cross-ent 2 | 3 | Yao Fu, University of Edinburgh 4 | yao.fu@ed.ac.uk 5 | """ 6 | 7 | import sys 8 | sys.path.append("..") 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from low_mem_cross_ent import low_mem_cross_ent, cross_ent_normal 13 | 14 | bsz = 1 15 | seqlen = 50000 16 | hidden = 4096 17 | vocab = 15000 18 | dtype = torch.bfloat16 19 | rtol=1e-05 # relative tolerance when comparing the gradients from two implementations 20 | atol=1e-07 # absolute tolerance when comparing the gradients from two implementations 21 | # in Pytorch its default is 1e-8 but our implementation cannot pass this threshold 22 | # 1e-7 seems to be the smallest rolerance we can pass 23 | 24 | x = torch.normal(mean=0, std=0.01, size=(bsz, seqlen, hidden), 25 | device="cuda", dtype=dtype, requires_grad=True) 26 | weight = torch.normal(mean=0, std=0.01, size=(vocab, hidden), 27 | device="cuda", dtype=dtype, requires_grad=True) 28 | labels = torch.randint(low=0, high=vocab - 1, size=(bsz, seqlen), device="cuda") 29 | 30 | loss_normal = cross_ent_normal(x, weight, labels) 31 | print("loss normal: %.4f" % loss_normal.cpu().item()) 32 | loss_normal.backward() 33 | x_grad = x.grad.clone() 34 | weight_grad = weight.grad.clone() 35 | # print(x.grad) 36 | # print(weight.grad) 37 | 38 | 39 | # TODO: this one almost reduce memory to half. Maybe further increase sp 40 | x.grad = None 41 | weight.grad = None 42 | loss_low_mem = low_mem_cross_ent(x, weight, labels) 43 | print("loss low mem: %.4f" % loss_low_mem.cpu().item()) 44 | loss_low_mem.backward() 45 | # print(x.grad) 46 | # print(weight.grad) 47 | 48 | ## Test implementation by asserting close 49 | assert(torch.allclose(x_grad, x.grad, rtol=rtol, atol=atol)) 50 | assert(torch.allclose(weight_grad, weight.grad, rtol=rtol, atol=atol)) 51 | print("PASS: gradients from normal computation and low memory computation are close.") 52 | 53 | 54 | # #### Test gradient of logits 55 | # x.grad = None 56 | # weight.grad = None 57 | # logits = torch.einsum("bsh, vh -> bsv", x, weight) 58 | # loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1)) 59 | # d_logits = torch.autograd.grad(loss, logits) 60 | # p = F.softmax(torch.einsum("blh, vh -> blv", x, weight), dim=-1) 61 | # p_ = p / (bsz * seqlen) 62 | 63 | # #### test index add 64 | # x = torch.tensor([1, 2, 3, 4, 5, 6, 7]) 65 | # index = torch.tensor([1, 3, 4]) 66 | # source = torch.tensor([1, 1, 1]) 67 | # x.index_add_(dim=0, index=index, source=source) 68 | 69 | # #### test index add 2 70 | # sp = 4 71 | # micro_seqlen = seqlen // sp 72 | # p = torch.normal(mean=0, std=0.01, size=(bsz, micro_seqlen, vocab), 73 | # device="cuda", dtype=torch.bfloat16) 74 | # labels_ = labels[:, :micro_seqlen].view(-1) 75 | # index = torch.arange(bsz * micro_seqlen, device="cuda") * vocab 76 | # index += labels_ 77 | # d_logits = -p.view(-1) 78 | # source = torch.tensor([1] * bsz * micro_seqlen, dtype=torch.bfloat16, device="cuda") 79 | # d_logits.index_add_(0, index, source) 80 | # d_logits = d_logits.view(bsz, micro_seqlen, vocab) 81 | 82 | -------------------------------------------------------------------------------- /easy_context/low_mem_cross_ent_tests/test_mem_and_speed.py: -------------------------------------------------------------------------------- 1 | """Test the memory and speed and MFU of low memory cross entropy 2 | 3 | Yao Fu, University of Edinburgh 4 | yao.fu@ed.ac.uk 5 | 6 | bf16, seqlen=50000, vocab=150000, without torch.compile 7 | | | normal | low_mem | - | 8 | | sp | - | 4 | 16 | 9 | | peak mem | 43.4G | 18.5G | 8.1G | 10 | | forward | 0.307 | 0.310 | 0.315 | 11 | | backward | 0.631 | 0.896 | 0.914 | 12 | | MFU | 0.57 | 0.45 | 0.44 | 13 | 14 | NOTE: tried torch.compile and it takes significantly larger memory, so do not use 15 | TODO: profile and check why backward is slower 16 | """ 17 | import sys 18 | sys.path.append("..") 19 | 20 | import torch 21 | import numpy as np 22 | import torch.nn.functional as F 23 | from low_mem_cross_ent import low_mem_cross_ent, cross_ent_normal 24 | 25 | implementation = "low_mem" # "normal", "low_mem" 26 | device_type = "A100" 27 | bsz = 1 28 | seqlen = 50000 29 | hidden = 4096 30 | vocab = 150000 31 | sp=16 32 | dtype = torch.bfloat16 33 | # dtype = torch.float 34 | G = 1024 ** 3 35 | T = 1024 ** 4 36 | 37 | x = torch.normal(mean=0, std=0.01, size=(bsz, seqlen, hidden), 38 | device="cuda", dtype=dtype, requires_grad=True) 39 | weight = torch.normal(mean=0, std=0.01, size=(vocab, hidden), 40 | device="cuda", dtype=dtype, requires_grad=True) 41 | labels = torch.randint(low=0, high=vocab - 1, size=(bsz, seqlen), device="cuda") 42 | 43 | def timed(fn): 44 | start = torch.cuda.Event(enable_timing=True) 45 | end = torch.cuda.Event(enable_timing=True) 46 | start.record() 47 | result = fn() 48 | end.record() 49 | torch.cuda.synchronize() 50 | return result, start.elapsed_time(end) / 1000 51 | 52 | n_runs = 50 53 | flop = 6 * bsz * seqlen * hidden * vocab 54 | if(implementation == "normal"): 55 | forward_times, backward_times = [], [] 56 | for _ in range(n_runs): 57 | loss_normal, time_elapse = timed(lambda: cross_ent_normal(x, weight, labels)) 58 | forward_times.append(time_elapse) 59 | _, time_elapse = timed(lambda: loss_normal.backward()) 60 | backward_times.append(time_elapse) 61 | mem = torch.cuda.max_memory_allocated() 62 | elif(implementation == "low_mem"): 63 | forward_times, backward_times = [], [] 64 | for _ in range(n_runs): 65 | loss_low_mem, time_elapse = timed(lambda: low_mem_cross_ent(x, weight, labels, sp)) 66 | forward_times.append(time_elapse) 67 | _, time_elapse = timed(lambda: loss_low_mem.backward()) 68 | backward_times.append(time_elapse) 69 | mem = torch.cuda.max_memory_allocated() 70 | else: raise NameError("Implementation %s not recognized" % implementation) 71 | 72 | forward_time = np.median(forward_times) 73 | backward_time = np.median(backward_times) 74 | flops = (flop / T) / (forward_time + backward_time) 75 | if(device_type == "A100"): 76 | device_flop = 312 77 | else: raise NameError("device %s not recognized" % device_type) 78 | 79 | print("%s, peak memory %.1fG, forward time %.4f, backward time %.4f, flops %.2fT, util %.2f" % 80 | (implementation, mem / G, forward_time, backward_time, flops, flops / device_flop)) 81 | -------------------------------------------------------------------------------- /easy_context/ulysses_attn/monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from typing import List, Optional, Tuple, Union 3 | import warnings 4 | import torch 5 | import torch.utils.checkpoint 6 | try: 7 | from yunchang.ulysses import UlyssesAttention 8 | ulysses_attn = UlyssesAttention() 9 | except: 10 | ulysses_attn = None 11 | 12 | 13 | def new_flash_attn_forward( 14 | self, 15 | query_states, 16 | key_states, 17 | value_states, 18 | attention_mask, 19 | query_length, 20 | dropout=0.0, 21 | softmax_scale=None, 22 | use_sliding_windows=False, 23 | ): 24 | if not self._flash_attn_uses_top_left_mask: 25 | causal = self.is_causal 26 | else: 27 | causal = self.is_causal and query_length != 1 28 | 29 | # Contains at least one padding token in the sequence 30 | assert attention_mask is None 31 | assert causal is True 32 | assert use_sliding_windows is False 33 | attn_output = ulysses_attn( 34 | query_states, 35 | key_states, 36 | value_states, 37 | dropout, 38 | softmax_scale, 39 | causal=causal, 40 | ) 41 | 42 | return attn_output 43 | 44 | 45 | def new_decoder_forward( 46 | self, 47 | hidden_states: torch.Tensor, 48 | attention_mask: Optional[torch.Tensor] = None, 49 | position_ids: Optional[torch.LongTensor] = None, 50 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 51 | output_attentions: Optional[bool] = False, 52 | use_cache: Optional[bool] = False, 53 | cache_position: Optional[torch.LongTensor] = None, 54 | **kwargs, 55 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 56 | assert isinstance( 57 | self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 58 | ) or isinstance( 59 | self.self_attn, 60 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2, 61 | ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." 62 | 63 | if "padding_mask" in kwargs: 64 | warnings.warn( 65 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 66 | ) 67 | 68 | residual = hidden_states 69 | 70 | hidden_states = self.input_layernorm(hidden_states) 71 | 72 | # Self Attention 73 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 74 | hidden_states=hidden_states, 75 | attention_mask=attention_mask, 76 | position_ids=position_ids, 77 | past_key_value=past_key_value, 78 | output_attentions=output_attentions, 79 | use_cache=use_cache, 80 | cache_position=cache_position, 81 | **kwargs, 82 | ) 83 | hidden_states = residual + hidden_states 84 | 85 | # Fully Connected 86 | residual = hidden_states 87 | hidden_states = self.post_attention_layernorm(hidden_states) 88 | hidden_states = self.mlp(hidden_states) 89 | hidden_states = residual + hidden_states 90 | 91 | outputs = (hidden_states,) 92 | 93 | if output_attentions: 94 | outputs += (self_attn_weights,) 95 | 96 | if use_cache: 97 | outputs += (present_key_value,) 98 | 99 | return outputs 100 | 101 | 102 | def apply_ulysses_attn_monkey_patch_llama(): 103 | transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( 104 | new_flash_attn_forward 105 | ) 106 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( 107 | new_decoder_forward 108 | ) 109 | 110 | 111 | -------------------------------------------------------------------------------- /easy_context/ulysses_attn/prepare_inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_local(value, rank, world_size, device, dim=1): 5 | dimension_size = value.shape[dim] 6 | sub_seq_length = dimension_size // world_size 7 | 8 | sub_seq_start = rank * sub_seq_length 9 | sub_seq_end = (rank + 1) * sub_seq_length 10 | local_value = value[:, sub_seq_start:sub_seq_end] 11 | 12 | return local_value.to(device) 13 | 14 | 15 | def prepare_ulysses_attn_inputs( 16 | input_ids, position_ids, target_ids, rank, world_size, device 17 | ): 18 | 19 | local_input_ids = extract_local( 20 | input_ids, 21 | rank, 22 | world_size, 23 | device, 24 | ) 25 | local_position_ids = extract_local( 26 | position_ids, 27 | rank, 28 | world_size, 29 | device, 30 | ) 31 | 32 | if target_ids is not None: 33 | local_target_ids = extract_local( 34 | target_ids, 35 | rank, 36 | world_size, 37 | device, 38 | ) 39 | else: 40 | local_target_ids = None 41 | return { 42 | "local_input_ids": local_input_ids, 43 | "local_position_ids": local_position_ids, 44 | "local_target_ids": local_target_ids, 45 | } 46 | -------------------------------------------------------------------------------- /easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import transformers 17 | import inspect 18 | 19 | 20 | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): 21 | """ 22 | Saves VRAM by smartly offloading to RAM. 23 | Tiny hit to performance, since we mask the movement via non blocking calls. 24 | """ 25 | 26 | @staticmethod 27 | @torch.cuda.amp.custom_fwd 28 | def forward(ctx, forward_function, hidden_states, *args): 29 | saved_hidden_states = hidden_states.to("cpu", non_blocking=True) 30 | with torch.no_grad(): 31 | output = forward_function(hidden_states, *args) 32 | ctx.save_for_backward(saved_hidden_states) 33 | ctx.forward_function = forward_function 34 | ctx.args = args 35 | 36 | return output 37 | 38 | pass 39 | 40 | @staticmethod 41 | @torch.cuda.amp.custom_bwd 42 | def backward(ctx, dY): 43 | (hidden_states,) = ctx.saved_tensors 44 | hidden_states = hidden_states.to("cuda", non_blocking=True).detach() 45 | hidden_states.requires_grad = True 46 | with torch.enable_grad(): 47 | (output,) = ctx.forward_function(hidden_states, *ctx.args) 48 | torch.autograd.backward(output, dY) 49 | return ( 50 | None, 51 | hidden_states.grad, 52 | ) + ( 53 | None, 54 | ) * len(ctx.args) 55 | 56 | pass 57 | 58 | 59 | pass 60 | 61 | 62 | def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): 63 | assert gradient_checkpointing_kwargs == None 64 | if not self.supports_gradient_checkpointing: 65 | raise ValueError( 66 | f"{self.__class__.__name__} does not support gradient checkpointing." 67 | ) 68 | 69 | gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply 70 | # For old GC format (transformers < 4.35.0) for models that live on the Hub 71 | # we will fall back to the overwritten `_set_gradient_checkpointing` method 72 | _is_using_old_format = ( 73 | "value" in inspect.signature(self._set_gradient_checkpointing).parameters 74 | ) 75 | 76 | if not _is_using_old_format: 77 | self._set_gradient_checkpointing( 78 | enable=True, gradient_checkpointing_func=gradient_checkpointing_func 79 | ) 80 | else: 81 | raise NotImplementedError() 82 | 83 | if getattr(self, "_hf_peft_config_loaded", False): 84 | # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True 85 | # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 86 | # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate 87 | # the gradients to make sure the gradient flows. 88 | self.enable_input_require_grads() 89 | 90 | 91 | def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): 92 | transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( 93 | new_gradient_checkpointing_enable 94 | ) 95 | -------------------------------------------------------------------------------- /easy_context/zigzag_ring_attn/monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from typing import List, Optional, Tuple, Union 3 | import warnings 4 | import torch 5 | import torch.utils.checkpoint 6 | from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func 7 | 8 | 9 | def new_flash_attn_forward( 10 | self, 11 | query_states, 12 | key_states, 13 | value_states, 14 | attention_mask, 15 | query_length, 16 | dropout=0.0, 17 | softmax_scale=None, 18 | use_sliding_windows=False, 19 | ): 20 | if not self._flash_attn_uses_top_left_mask: 21 | causal = self.is_causal 22 | else: 23 | causal = self.is_causal and query_length != 1 24 | 25 | # Contains at least one padding token in the sequence 26 | assert attention_mask is None 27 | assert causal is True 28 | assert use_sliding_windows is False 29 | attn_output = zigzag_ring_flash_attn_func( 30 | query_states, 31 | key_states, 32 | value_states, 33 | dropout, 34 | softmax_scale, 35 | causal=causal, 36 | ) 37 | 38 | return attn_output 39 | 40 | 41 | def new_decoder_forward( 42 | self, 43 | hidden_states: torch.Tensor, 44 | attention_mask: Optional[torch.Tensor] = None, 45 | position_ids: Optional[torch.LongTensor] = None, 46 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 47 | output_attentions: Optional[bool] = False, 48 | use_cache: Optional[bool] = False, 49 | cache_position: Optional[torch.LongTensor] = None, 50 | **kwargs, 51 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 52 | assert isinstance( 53 | self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 54 | ) or isinstance( 55 | self.self_attn, 56 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2, 57 | ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." 58 | 59 | if "padding_mask" in kwargs: 60 | warnings.warn( 61 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 62 | ) 63 | 64 | residual = hidden_states 65 | 66 | hidden_states = self.input_layernorm(hidden_states) 67 | 68 | # Self Attention 69 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 70 | hidden_states=hidden_states, 71 | attention_mask=attention_mask, 72 | position_ids=position_ids, 73 | past_key_value=past_key_value, 74 | output_attentions=output_attentions, 75 | use_cache=use_cache, 76 | cache_position=cache_position, 77 | **kwargs, 78 | ) 79 | hidden_states = residual + hidden_states 80 | 81 | # Fully Connected 82 | residual = hidden_states 83 | hidden_states = self.post_attention_layernorm(hidden_states) 84 | hidden_states = self.mlp(hidden_states) 85 | hidden_states = residual + hidden_states 86 | 87 | outputs = (hidden_states,) 88 | 89 | if output_attentions: 90 | outputs += (self_attn_weights,) 91 | 92 | if use_cache: 93 | outputs += (present_key_value,) 94 | 95 | return outputs 96 | 97 | 98 | def apply_zigzag_ring_attn_monkey_patch_llama(): 99 | transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( 100 | new_flash_attn_forward 101 | ) 102 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( 103 | new_decoder_forward 104 | ) 105 | 106 | 107 | def apply_zigzag_ring_attn_monkey_patch_mistral(): 108 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( 109 | new_flash_attn_forward 110 | ) 111 | transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( 112 | new_decoder_forward 113 | ) 114 | -------------------------------------------------------------------------------- /easy_context/zigzag_ring_attn/prepare_inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_local(value, rank, world_size, device, dim=1): 5 | value_chunks = value.chunk(2 * world_size, dim=dim) 6 | local_value = torch.cat( 7 | [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim 8 | ) 9 | return local_value.to(device) 10 | 11 | 12 | def prepare_zigzag_ring_attn_inputs( 13 | input_ids, position_ids, target_ids, rank, world_size, device 14 | ): 15 | local_input_ids = extract_local( 16 | input_ids, 17 | rank, 18 | world_size, 19 | device, 20 | ) 21 | local_position_ids = extract_local( 22 | position_ids, 23 | rank, 24 | world_size, 25 | device, 26 | ) 27 | if target_ids is not None: 28 | local_target_ids = extract_local( 29 | target_ids, 30 | rank, 31 | world_size, 32 | device, 33 | ) 34 | else: 35 | local_target_ids = None 36 | return { 37 | "local_input_ids": local_input_ids, 38 | "local_position_ids": local_position_ids, 39 | "local_target_ids": local_target_ids, 40 | } 41 | -------------------------------------------------------------------------------- /local_demo/assets/assistant_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/assistant_logo.png -------------------------------------------------------------------------------- /local_demo/assets/dc_demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/dc_demo.mp4 -------------------------------------------------------------------------------- /local_demo/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/favicon.ico -------------------------------------------------------------------------------- /local_demo/assets/jobs.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/jobs.mp4 -------------------------------------------------------------------------------- /local_demo/assets/llava_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/llava_logo.png -------------------------------------------------------------------------------- /local_demo/assets/llava_next.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/llava_next.jpg -------------------------------------------------------------------------------- /local_demo/assets/lmms-eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/lmms-eval.png -------------------------------------------------------------------------------- /local_demo/assets/otter_books.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/otter_books.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_01.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_02.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_03.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_04.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_05.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_06.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_07.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_08.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_09.jpg -------------------------------------------------------------------------------- /local_demo/assets/user_example_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_10.png -------------------------------------------------------------------------------- /local_demo/assets/user_example_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_11.png -------------------------------------------------------------------------------- /local_demo/assets/user_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_logo.png -------------------------------------------------------------------------------- /local_demo/assets/water.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/water.mp4 -------------------------------------------------------------------------------- /local_demo/assets/white_cat_smile.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/white_cat_smile.jpg -------------------------------------------------------------------------------- /local_demo/cache/0266b3f7b9311a93f0f38bc8cf0698b019c829b1/user_example_09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/0266b3f7b9311a93f0f38bc8cf0698b019c829b1/user_example_09.jpg -------------------------------------------------------------------------------- /local_demo/cache/0333f1f38db26d76752fee6cd2d938f06617dd2f/dc_demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/0333f1f38db26d76752fee6cd2d938f06617dd2f/dc_demo.mp4 -------------------------------------------------------------------------------- /local_demo/cache/0c299f0f979725494d926c3e76cfeb75c19ef564/assistant_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/0c299f0f979725494d926c3e76cfeb75c19ef564/assistant_logo.png -------------------------------------------------------------------------------- /local_demo/cache/3c94cf4ccdbb877666285b42dd9263073d1fa19c/user_example_07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/3c94cf4ccdbb877666285b42dd9263073d1fa19c/user_example_07.jpg -------------------------------------------------------------------------------- /local_demo/cache/447d9ea0a12a7be4cf7f279dc4e667b1b938a0c2/otter_books.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/447d9ea0a12a7be4cf7f279dc4e667b1b938a0c2/otter_books.jpg -------------------------------------------------------------------------------- /local_demo/cache/7abebd5007998873c3d36612f984d241c3b574d8/user_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/7abebd5007998873c3d36612f984d241c3b574d8/user_logo.png -------------------------------------------------------------------------------- /local_demo/cache/7dcdac5ff3df641ae7eb47d8b7aaeab740e02de9/user_example_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/7dcdac5ff3df641ae7eb47d8b7aaeab740e02de9/user_example_06.jpg -------------------------------------------------------------------------------- /local_demo/cache/b7b8f04e74cda37d9a574c1a2098d7e4ee97b212/user_example_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/b7b8f04e74cda37d9a574c1a2098d7e4ee97b212/user_example_05.jpg -------------------------------------------------------------------------------- /local_demo/cache/dbb5054b8579d1943ff74045f254e0e759efec99/white_cat_smile.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/dbb5054b8579d1943ff74045f254e0e759efec99/white_cat_smile.jpg -------------------------------------------------------------------------------- /local_demo/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | 4 | ############# 5 | # LongVA Demo Utils 6 | ############# 7 | 8 | title_markdown = """ 9 | # [LongVA Multimodal Chat](https://lmm-lab.github.io/LongVA/) 10 | """ 11 | 12 | subtitle_markdown = """ 13 | ### This is our research preview of LongVA, a multimodal model that is capable of accurately retrieving visual information from 2000 frames or more than 200K visual tokens. 14 | """ 15 | 16 | html_header = """ 17 | 78 | 79 |
80 | 81 | LLaVA-NeXT 82 | 83 |
84 |

Long Context Transfer from Language to Vision

85 |
Github | Huggingface | Blog | ArXiv Paper
86 |
This is our research preview of LongVA, a multimodal model that is capable of accurately retrieving visual information from 2000 frames or more than 200K visual tokens.
87 |
88 |
89 | """ 90 | 91 | block_css = """ 92 | #buttons button { 93 | min-width: min(120px,100%); 94 | } 95 | """ 96 | 97 | tos_markdown = """ 98 | ## Terms of use 99 | By using this service, users are required to agree to the following terms: 100 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. 101 | Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. 102 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. 103 | """ 104 | 105 | 106 | learn_more_markdown = """ 107 | ## License 108 | The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. 109 | """ 110 | 111 | bibtext = """ 112 | ## Citation 113 | ``` 114 | @misc{zhang2024longva, 115 | title={LongVA: Long Context Transfer from Language to Vision}, 116 | url={https://lmms-lab.github.io/posts/longva/}, 117 | author={Zhang, Peiyuan and Zhang, Kaichen and Li, Bo and Zeng, Guangtao and Yang, Jingkang and Zhang, Yuanhan and Li, Chunyuan and Liu, Ziwei}, 118 | month={June}, 119 | year={2024} 120 | } 121 | ``` 122 | """ 123 | 124 | PARENT_FOLDER = os.path.dirname(os.path.abspath(__file__)) 125 | ################## BACKEND ################## 126 | os.environ["GRADIO_EXAMPLES_CACHE"] = ( 127 | f"{PARENT_FOLDER}/cache" 128 | ) 129 | os.environ["GRADIO_TEMP_DIR"] = ( 130 | f"{PARENT_FOLDER}/cache" 131 | ) 132 | 133 | def generate_file_hash(file_path): 134 | sha256_hash = hashlib.sha256() 135 | with open(file_path, "rb") as f: 136 | for byte_block in iter(lambda: f.read(4096), b""): 137 | sha256_hash.update(byte_block) 138 | return sha256_hash.hexdigest()[:6] -------------------------------------------------------------------------------- /local_demo/debug_json_api.py: -------------------------------------------------------------------------------- 1 | from gradio_client import Client, handle_file 2 | 3 | client = Client("http://127.0.0.1:8000/") 4 | 5 | # state = [ 6 | # (('/mnt/sfs-common/peiyuan/peiyuan/LongVa/local_demo/assets/otter_books.jpg',), None), 7 | # ("What does this image show?", None) 8 | # ] 9 | state = [ ("What's the video about?", None)] 10 | 11 | 12 | # result = client.predict( 13 | # history=[], 14 | # message={'text': "What's the video about?", 'files': []}, 15 | # video_input="/mnt/sfs-common/peiyuan/peiyuan/LongVa/local_demo/cache/0333f1f38db26d76752fee6cd2d938f06617dd2f/dc_demo.mp4", 16 | # api_name="/add_message" 17 | # ) 18 | 19 | # result = client.predict( 20 | # video_input=None, 21 | # state=state, 22 | # sample_frames=16, 23 | # temperature=0.7, 24 | # max_new_tokens=1024, 25 | # top_p=1, 26 | # api_name="/bot_response_1" 27 | # ) 28 | # print(result) 29 | 30 | image_path = "local_demo/assets/assistant_logo.png" 31 | # image to base64 32 | import base64 33 | with open(image_path, "rb") as image_file: 34 | base64_image = base64.b64encode(image_file.read()).decode() 35 | 36 | input_json_url = [ 37 | { 38 | "role": "user", 39 | "content": [ 40 | {"type": "text", "text": "What’s in this image?"}, 41 | { 42 | "type": "image_url", 43 | "image_url": { 44 | "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", 45 | }, 46 | }, 47 | ], 48 | }, 49 | { 50 | "role": "assistant", 51 | "content": [ 52 | {"type": "text", "text": "The image shows a serene, open field with lush green grass and a few scattered trees or shrubs in the distance. The sky above is mostly clear with some wispy clouds, suggesting it's a bright and likely pleasant day. There's a light-colored wooden boardwalk or path that meanders through the field, inviting a peaceful walk along its length. The overall scene conveys a sense of tranquility and natural beauty."}, 53 | ] 54 | }, 55 | { 56 | "role": "user", 57 | "content": [ 58 | {"type": "text", "text": "Where is this place?"}, 59 | ] 60 | } 61 | ] 62 | 63 | input_json_base64 = [ 64 | { 65 | "role": "user", 66 | "content": [ 67 | {"type": "text", "text": "What’s in this image?"}, 68 | { 69 | "type": "image_url", 70 | "image_url": { 71 | "url": f"data:image/jpeg;base64,{base64_image}", 72 | }, 73 | }, 74 | ], 75 | } 76 | ] 77 | result = client.predict( 78 | input_json=input_json_base64, 79 | api_name="/base64_api" 80 | ) 81 | print(result) 82 | -------------------------------------------------------------------------------- /local_demo/theme_dropdown.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | from gradio.themes.utils import ThemeAsset 5 | 6 | 7 | def create_theme_dropdown(): 8 | import gradio as gr 9 | 10 | asset_path = pathlib.Path(__file__).parent / "themes" 11 | themes = [] 12 | for theme_asset in os.listdir(str(asset_path)): 13 | themes.append( 14 | (ThemeAsset(theme_asset), gr.Theme.load(str(asset_path / theme_asset))) 15 | ) 16 | 17 | def make_else_if(theme_asset): 18 | return f""" 19 | else if (theme == '{str(theme_asset[0].version)}') {{ 20 | var theme_css = `{theme_asset[1]._get_theme_css()}` 21 | }}""" 22 | 23 | head, tail = themes[0], themes[1:] 24 | if_statement = f""" 25 | if (theme == "{str(head[0].version)}") {{ 26 | var theme_css = `{head[1]._get_theme_css()}` 27 | }} {" ".join(make_else_if(t) for t in tail)} 28 | """ 29 | 30 | latest_to_oldest = sorted([t[0] for t in themes], key=lambda asset: asset.version)[ 31 | ::-1 32 | ] 33 | latest_to_oldest = [str(t.version) for t in latest_to_oldest] 34 | 35 | component = gr.Dropdown( 36 | choices=latest_to_oldest, 37 | value=latest_to_oldest[0], 38 | render=False, 39 | label="Select Version", 40 | ) 41 | 42 | return ( 43 | component, 44 | f""" 45 | (theme) => {{ 46 | if (!document.querySelector('.theme-css')) {{ 47 | var theme_elem = document.createElement('style'); 48 | theme_elem.classList.add('theme-css'); 49 | document.head.appendChild(theme_elem); 50 | }} else {{ 51 | var theme_elem = document.querySelector('.theme-css'); 52 | }} 53 | {if_statement} 54 | theme_elem.innerHTML = theme_css; 55 | }} 56 | """, 57 | ) -------------------------------------------------------------------------------- /longva/longva/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /longva/longva/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /longva/longva/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", 7 | # Add other models as needed 8 | } 9 | 10 | for model_name, model_classes in AVAILABLE_MODELS.items(): 11 | try: 12 | exec(f"from .language_model.{model_name} import {model_classes}") 13 | except Exception as e: 14 | raise e 15 | print(f"Failed to import {model_name} from longva.language_model.{model_name}") 16 | -------------------------------------------------------------------------------- /longva/longva/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from longva import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /longva/longva/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from longva.model import * 11 | from longva.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /longva/longva/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig 21 | from longva.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 22 | 23 | 24 | class LlavaMptConfig(MptConfig): 25 | model_type = "llava_mpt" 26 | 27 | 28 | class LlavaMptModel(LlavaMetaModel, MptModel): 29 | config_class = LlavaMptConfig 30 | 31 | def __init__(self, config: MptConfig): 32 | config.hidden_size = config.d_model 33 | super(LlavaMptModel, self).__init__(config) 34 | 35 | def embed_tokens(self, x): 36 | return self.wte(x) 37 | 38 | 39 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaMptConfig 41 | supports_gradient_checkpointing = True 42 | 43 | def __init__(self, config): 44 | super(MptForCausalLM, self).__init__(config) 45 | 46 | config.model_type = "llava_mpt" 47 | config.rope_scaling = None 48 | self.generation_config = GenerationConfig( 49 | temperature=0.0, 50 | max_new_tokens=1024, 51 | do_sample=False, 52 | top_p=None, 53 | ) 54 | 55 | self.transformer = LlavaMptModel(config) 56 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | 58 | # Initialize weights and apply final processing 59 | self.post_init() 60 | 61 | def get_model(self): 62 | return self.transformer 63 | 64 | def _set_gradient_checkpointing(self, module, value=False): 65 | if isinstance(module, LlavaMptModel): 66 | module.gradient_checkpointing = value 67 | 68 | def forward( 69 | self, 70 | input_ids: Optional[torch.LongTensor] = None, 71 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | inputs_embeds: Optional[torch.Tensor] = None, 74 | labels: Optional[torch.Tensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | cache_position=None, 80 | images=None, 81 | ): 82 | 83 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 84 | 85 | return super().forward( 86 | input_ids, 87 | past_key_values=past_key_values, 88 | attention_mask=attention_mask, 89 | inputs_embeds=inputs_embeds, 90 | labels=labels, 91 | use_cache=use_cache, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | ) 96 | 97 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 98 | images = kwargs.pop("images", None) 99 | _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 100 | _inputs["images"] = images 101 | return _inputs 102 | 103 | 104 | AutoConfig.register("llava_mpt", LlavaMptConfig) 105 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 106 | -------------------------------------------------------------------------------- /longva/longva/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from longva.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 4 | 5 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 6 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 7 | 8 | 9 | def build_vision_tower(vision_tower_cfg, **kwargs): 10 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 11 | is_absolute_path_exists = os.path.exists(vision_tower) 12 | use_s2 = getattr(vision_tower_cfg, "s2", False) 13 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 14 | if use_s2: 15 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 16 | else: 17 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 18 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 19 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 20 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 21 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 22 | 23 | raise ValueError(f"Unknown vision tower: {vision_tower}") 24 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /longva/longva/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /longva/longva/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /longva/longva/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | padding_mask: Optional[torch.Tensor] = None, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | if output_attentions: 27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 28 | 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) 34 | 35 | kv_seq_len = key_states.shape[-2] 36 | if past_key_value is not None: 37 | kv_seq_len += past_key_value[0].shape[-2] 38 | 39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 41 | 42 | if past_key_value is not None: 43 | # reuse k, v 44 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 45 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 46 | 47 | past_key_value = (key_states, value_states) if use_cache else None 48 | 49 | # repeat k/v heads if n_kv_heads < n_heads 50 | key_states = repeat_kv(key_states, self.num_key_value_groups) 51 | value_states = repeat_kv(value_states, self.num_key_value_groups) 52 | 53 | # Transform the data into the format required by flash attention 54 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 56 | key_padding_mask = attention_mask 57 | 58 | if key_padding_mask is None: 59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 61 | max_s = q_len 62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 63 | output = output.view(bsz, q_len, -1) 64 | else: 65 | qkv = qkv.reshape(bsz, q_len, -1) 66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 70 | output = pad_input(output_unpad, indices, bsz, q_len) 71 | 72 | return self.o_proj(output), None, past_key_value 73 | 74 | 75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 76 | # requires the attention mask to be the same as the key_padding_mask 77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 78 | # [bsz, seq_len] 79 | return attention_mask 80 | 81 | 82 | def replace_llama_attn_with_flash_attn(): 83 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 84 | if cuda_major < 8: 85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") 86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 88 | -------------------------------------------------------------------------------- /longva/longva/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from longva.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train() 5 | -------------------------------------------------------------------------------- /longva/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "longva" 10 | version = "1.7.0.dev0" 11 | description = "Towards GPT-4 like large language and visual assistant." 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | 19 | [project.optional-dependencies] 20 | standalone = [ 21 | "shortuuid", 22 | "httpx==0.24.0", 23 | "einops", 24 | "ftfy", 25 | ] 26 | 27 | train = [ 28 | "longva[standalone]", 29 | "open_clip_torch", 30 | "fastapi", 31 | "gradio", 32 | "markdown2[all]", 33 | "numpy", 34 | "requests", 35 | "sentencepiece", 36 | "uvicorn", 37 | "wandb==0.16.5", 38 | "deepspeed==0.14.2", 39 | "torch==2.1.2", 40 | "torchvision==0.16.2", 41 | "peft==0.4.0", 42 | "accelerate>=0.29.1", 43 | "tokenizers~=0.15.2", 44 | "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4", 45 | "bitsandbytes==0.41.0", 46 | "scikit-learn==1.2.2", 47 | "sentencepiece~=0.1.99", 48 | "einops==0.6.1", 49 | "einops-exts==0.0.4", 50 | "gradio_client", 51 | "urllib3<=2.0.0", 52 | "datasets==2.16.1", 53 | "pydantic==1.10.8", 54 | "timm", 55 | "hf_transfer", 56 | "opencv-python", 57 | "av", 58 | "decord", 59 | "tyro", 60 | "scipy", 61 | ] 62 | 63 | 64 | [tool.setuptools.packages.find] 65 | include = ["longva*", "trl*"] 66 | exclude = [ 67 | "assets*", 68 | "benchmark*", 69 | "docs", 70 | "dist*", 71 | "playground*", 72 | "scripts*", 73 | "tests*", 74 | "checkpoints*", 75 | "project_checkpoints*", 76 | "debug_checkpoints*", 77 | "mlx_configs*", 78 | "wandb*", 79 | "notebooks*", 80 | ] 81 | 82 | [tool.wheel] 83 | exclude = [ 84 | "assets*", 85 | "benchmark*", 86 | "docs", 87 | "dist*", 88 | "playground*", 89 | "scripts*", 90 | "tests*", 91 | "checkpoints*", 92 | "project_checkpoints*", 93 | "debug_checkpoints*", 94 | "mlx_configs*", 95 | "wandb*", 96 | "notebooks*", 97 | ] -------------------------------------------------------------------------------- /longva/scripts/dpo.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | # export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} 5 | export NCCL_SOCKET_IFNAME=eth0 6 | export NCCL_DEBUG=INFO 7 | 8 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 9 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 10 | 11 | ############### Pretrain ################ 12 | 13 | # Stage 2 14 | PROMPT_VERSION="qwen_1_5" 15 | 16 | #torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ 17 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ 18 | longva/train/train_dpo.py \ 19 | --deepspeed scripts/zero3.json \ 20 | --model_name_or_path lmms-lab/LongVA-7B \ 21 | --version $PROMPT_VERSION \ 22 | --dpo_alpha 1.0 --beta 0.1 --gamma 0 \ 23 | --data_path="/data/llava_video/shareVideoGPTV/dpo/sft_dpo_17k.jsonl" \ 24 | --image_folder /data/llava_data \ 25 | --video_folder /llava_video/shareVideoGPTV/frames/all_frames/ \ 26 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 27 | --vision_tower ${VISION_MODEL_VERSION} \ 28 | --mm_projector_type mlp2x_gelu \ 29 | --mm_vision_select_layer -2 \ 30 | --mm_use_im_start_end False \ 31 | --mm_use_im_patch_token False \ 32 | --mm_spatial_pool_stride 2 \ 33 | --mm_resampler_type "spatial_pool" \ 34 | --mm_spatial_pool_out_channels 1024 \ 35 | --group_by_modality_length True \ 36 | --image_aspect_ratio anyres \ 37 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ 38 | --mm_patch_merge_type unires \ 39 | --bf16 True \ 40 | --run_name $MID_RUN_NAME \ 41 | --output_dir "/checkpoints/${MID_RUN_NAME}" \ 42 | --num_train_epochs 3 \ 43 | --per_device_train_batch_size 1 \ 44 | --per_device_eval_batch_size 4 \ 45 | --gradient_accumulation_steps 16 \ 46 | --evaluation_strategy "no" \ 47 | --save_strategy "steps" \ 48 | --save_steps 3000 \ 49 | --save_total_limit 1 \ 50 | --learning_rate 5e-7 \ 51 | --weight_decay 0. \ 52 | --warmup_ratio 0.1 \ 53 | --lr_scheduler_type "linear" \ 54 | --logging_steps 1 \ 55 | --tf32 True \ 56 | --model_max_length 32768 \ 57 | --gradient_checkpointing True \ 58 | --dataloader_num_workers 16 \ 59 | --lazy_preprocess True \ 60 | --report_to wandb \ 61 | --torch_compile True \ 62 | --torch_compile_backend "inductor" \ 63 | --dataloader_drop_last True \ 64 | --attn_implementation sdpa -------------------------------------------------------------------------------- /longva/scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_SOCKET_IFNAME=eth0 5 | export NCCL_DEBUG=INFO 6 | 7 | LLM_VERSION="lmms-lab/Qwen2-7B-Instruct-224K" 8 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 9 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 10 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 11 | 12 | ############### Pretrain ################ 13 | 14 | PROMPT_VERSION=qwen_1_5 15 | 16 | BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" 17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 18 | 19 | CKPT_PATH=$LLM_VERSION # this could also be the previous stage checkpoint 20 | 21 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 22 | longva/train/train_mem.py \ 23 | --deepspeed scripts/zero3.json \ 24 | --model_name_or_path ${CKPT_PATH} \ 25 | --version ${PROMPT_VERSION} \ 26 | --data_path=llava_1_6.json \ 27 | --image_folder your_image_folder \ 28 | --pretrain_mm_mlp_adapter="/checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \ 29 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 30 | --mm_vision_tower_lr=2e-6 \ 31 | --vision_tower ${VISION_MODEL_VERSION} \ 32 | --mm_projector_type mlp2x_gelu \ 33 | --mm_vision_select_layer -2 \ 34 | --mm_use_im_start_end False \ 35 | --mm_use_im_patch_token False \ 36 | --group_by_modality_length True \ 37 | --image_aspect_ratio anyres \ 38 | --image_grid_pinpoints "[(336, 672), (336, 1008), (336, 1344), (336, 1680), (336, 2016), (336, 2352), (336, 2688), (336, 3024), (336, 3360), (336, 3696), (336, 4032), (336, 4368), (336, 4704), (336, 5040), (336, 5376), (336, 5712), (336, 6048), (336, 6384), (336, 6720), (336, 7056), (336, 7392), (336, 7728), (336, 8064), (336, 8400), (336, 8736), (336, 9072), (336, 9408), (336, 9744), (336, 10080), (336, 10416), (336, 10752), (336, 11088), (336, 11424), (336, 11760), (336, 12096), (336, 12432), (336, 12768), (336, 13104), (336, 13440), (336, 13776), (336, 14112), (336, 14448), (336, 14784), (336, 15120), (336, 15456), (336, 15792), (336, 16128), (336, 16464), (672, 336), (672, 672), (672, 1008), (672, 1344), (672, 1680), (672, 2016), (672, 2352), (672, 2688), (672, 3024), (672, 3360), (672, 3696), (672, 4032), (672, 4368), (672, 4704), (672, 5040), (672, 5376), (672, 5712), (672, 6048), (672, 6384), (672, 6720), (672, 7056), (672, 7392), (672, 7728), (672, 8064), (1008, 336), (1008, 672), (1008, 1008), (1008, 1344), (1008, 1680), (1008, 2016), (1008, 2352), (1008, 2688), (1008, 3024), (1008, 3360), (1008, 3696), (1008, 4032), (1008, 4368), (1008, 4704), (1008, 5040), (1008, 5376), (1344, 336), (1344, 672), (1344, 1008), (1344, 1344), (1344, 1680), (1344, 2016), (1344, 2352), (1344, 2688), (1344, 3024), (1344, 3360), (1344, 3696), (1344, 4032), (1680, 336), (1680, 672), (1680, 1008), (1680, 1344), (1680, 1680), (1680, 2016), (1680, 2352), (1680, 2688), (1680, 3024), (2016, 336), (2016, 672), (2016, 1008), (2016, 1344), (2016, 1680), (2016, 2016), (2016, 2352), (2016, 2688), (2352, 336), (2352, 672), (2352, 1008), (2352, 1344), (2352, 1680), (2352, 2016), (2352, 2352), (2688, 336), (2688, 672), (2688, 1008), (2688, 1344), (2688, 1680), (2688, 2016), (3024, 336), (3024, 672), (3024, 1008), (3024, 1344), (3024, 1680), (3360, 336), (3360, 672), (3360, 1008), (3360, 1344), (3696, 336), (3696, 672), (3696, 1008), (3696, 1344), (4032, 336), (4032, 672), (4032, 1008), (4032, 1344), (4368, 336), (4368, 672), (4368, 1008), (4704, 336), (4704, 672), (4704, 1008), (5040, 336), (5040, 672), (5040, 1008), (5376, 336), (5376, 672), (5376, 1008), (5712, 336), (5712, 672), (6048, 336), (6048, 672), (6384, 336), (6384, 672), (6720, 336), (6720, 672), (7056, 336), (7056, 672), (7392, 336), (7392, 672), (7728, 336), (7728, 672), (8064, 336), (8064, 672), (8400, 336), (8736, 336), (9072, 336), (9408, 336), (9744, 336), (10080, 336), (10416, 336), (10752, 336), (11088, 336), (11424, 336), (11760, 336), (12096, 336), (12432, 336), (12768, 336), (13104, 336), (13440, 336), (13776, 336), (14112, 336), (14448, 336), (14784, 336), (15120, 336), (15456, 336), (15792, 336), (16128, 336), (16464, 336)]" \ 39 | --mm_patch_merge_type unires \ 40 | --bf16 True \ 41 | --run_name $MID_RUN_NAME \ 42 | --output_dir "/checkpoints/${MID_RUN_NAME}" \ 43 | --num_train_epochs 1 \ 44 | --per_device_train_batch_size 4 \ 45 | --per_device_eval_batch_size 4 \ 46 | --gradient_accumulation_steps 1 \ 47 | --evaluation_strategy "no" \ 48 | --save_strategy "steps" \ 49 | --save_steps 3000 \ 50 | --save_total_limit 1 \ 51 | --learning_rate 1e-5 \ 52 | --weight_decay 0. \ 53 | --warmup_ratio 0.03 \ 54 | --lr_scheduler_type "cosine" \ 55 | --logging_steps 1 \ 56 | --tf32 True \ 57 | --model_max_length 32768 \ 58 | --gradient_checkpointing True \ 59 | --dataloader_num_workers 16 \ 60 | --lazy_preprocess True \ 61 | --report_to wandb \ 62 | --torch_compile True \ 63 | --torch_compile_backend "inductor" \ 64 | --dataloader_drop_last True \ 65 | --attn_implementation sdpa 66 | 67 | # You can delete the sdpa attn_implementation if you want to use flash attn 68 | -------------------------------------------------------------------------------- /longva/scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_SOCKET_IFNAME=eth0 5 | export NCCL_DEBUG=INFO 6 | 7 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 8 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 9 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 10 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 11 | 12 | ############### Pretrain ################ 13 | 14 | PROMPT_VERSION=plain 15 | 16 | BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" 17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 18 | 19 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 20 | longva/train/train_mem.py \ 21 | --deepspeed scripts/zero3.json \ 22 | --model_name_or_path ${LLM_VERSION} \ 23 | --version ${PROMPT_VERSION} \ 24 | --data_path /blip_558k/blip_558k_plain.json \ 25 | --image_folder /blip_558k/images \ 26 | --vision_tower ${VISION_MODEL_VERSION} \ 27 | --mm_tunable_parts="mm_mlp_adapter" \ 28 | --mm_vision_select_layer -2 \ 29 | --mm_projector_type mlp2x_gelu \ 30 | --mm_use_im_start_end False \ 31 | --mm_use_im_patch_token False \ 32 | --bf16 True \ 33 | --output_dir /checkpoints/projectors/${BASE_RUN_NAME} \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 16 \ 36 | --per_device_eval_batch_size 4 \ 37 | --gradient_accumulation_steps 1 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "no" \ 40 | --save_steps 50000 \ 41 | --learning_rate 1e-3 \ 42 | --weight_decay 0. \ 43 | --warmup_ratio 0.03 \ 44 | --lr_scheduler_type "cosine" \ 45 | --logging_steps 1 \ 46 | --tf32 True \ 47 | --model_max_length 8192 \ 48 | --gradient_checkpointing True \ 49 | --dataloader_num_workers 16 \ 50 | --lazy_preprocess True \ 51 | --report_to wandb \ 52 | --run_name $BASE_RUN_NAME \ 53 | --attn_implementation sdpa 54 | 55 | # You can delete the sdpa attn_implementation if you want to use flash attn -------------------------------------------------------------------------------- /longva/scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /longva/trl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | __version__ = "0.7.11.dev0" 4 | 5 | from .core import set_seed 6 | from .environment import TextEnvironment, TextHistory 7 | from .extras import BestOfNSampler 8 | from .import_utils import ( 9 | is_bitsandbytes_available, 10 | is_diffusers_available, 11 | is_npu_available, 12 | is_peft_available, 13 | is_wandb_available, 14 | is_xpu_available, 15 | ) 16 | from .models import ( 17 | AutoModelForCausalLMWithValueHead, 18 | AutoModelForSeq2SeqLMWithValueHead, 19 | PreTrainedModelWrapper, 20 | create_reference_model, 21 | setup_chat_format, 22 | ) 23 | from .trainer import ( 24 | DataCollatorForCompletionOnlyLM, 25 | DPOTrainer, 26 | IterativeSFTTrainer, 27 | ModelConfig, 28 | PPOConfig, 29 | PPOTrainer, 30 | RewardConfig, 31 | RewardTrainer, 32 | SFTTrainer, 33 | ) 34 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config 35 | 36 | 37 | if is_diffusers_available(): 38 | from .models import ( 39 | DDPOPipelineOutput, 40 | DDPOSchedulerOutput, 41 | DDPOStableDiffusionPipeline, 42 | DefaultDDPOStableDiffusionPipeline, 43 | ) 44 | from .trainer import DDPOConfig, DDPOTrainer 45 | -------------------------------------------------------------------------------- /longva/trl/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .base_environment import TextEnvironment, TextHistory 4 | -------------------------------------------------------------------------------- /longva/trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .best_of_n_sampler import BestOfNSampler 17 | -------------------------------------------------------------------------------- /longva/trl/extras/best_of_n_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Union 2 | 3 | import torch 4 | from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast 5 | 6 | from ..core import set_seed 7 | from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper 8 | 9 | 10 | class BestOfNSampler(object): 11 | def __init__( 12 | self, 13 | model: PreTrainedModelWrapper, 14 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 15 | queries_to_scores: Callable[[List[str]], List[float]], 16 | length_sampler: Any, 17 | sample_size: int = 4, 18 | seed: Optional[int] = None, 19 | n_candidates: int = 1, 20 | generation_config: Optional[GenerationConfig] = None, 21 | ) -> None: 22 | r""" 23 | Initialize the sampler for best-of-n generation 24 | 25 | Args: 26 | model (`PreTrainedModelWrapper`): 27 | The pretrained model to use for generation 28 | tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): 29 | Tokenizer associated with the pretrained model 30 | queries_to_scores (`Callable[[List[str]], List[float]]`): 31 | Callable that takes a list of generated texts and returns the associated reward scores 32 | length_sampler (`Any`): 33 | Sampler used to sample the length of the generated text 34 | sample_size (`int`): 35 | Number of samples to generate for each query 36 | seed (`int`, *optional*): 37 | Random seed used to control generation 38 | n_candidates (`int`): 39 | Number of candidates to return for each query 40 | generation_config (`GenerationConfig`, *optional*): 41 | Generation config passed to the underlying model's `generate` method. 42 | See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details 43 | """ 44 | if seed is not None: 45 | set_seed(seed) 46 | 47 | if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): 48 | raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}") 49 | if not isinstance(model, (SUPPORTED_ARCHITECTURES)): 50 | raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") 51 | 52 | self.model = model 53 | self.tokenizer = tokenizer 54 | 55 | self.queries_to_scores = queries_to_scores 56 | self.length_sampler = length_sampler 57 | self.gen_config = generation_config 58 | self.sample_size = sample_size 59 | self.n_candidates = n_candidates 60 | 61 | def generate( 62 | self, 63 | tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], 64 | skip_special_tokens: bool = True, 65 | device: Optional[Union[str, torch.device]] = None, 66 | **generation_kwargs, 67 | ) -> List[List[str]]: 68 | r""" 69 | Generate the best of n samples for input queries 70 | 71 | Args: 72 | tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): 73 | represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) 74 | skip_special_tokens (`bool`): 75 | Whether to remove the special tokens from the output 76 | device (`str` or `torch.device`, *optional*): 77 | The device on which the model will be loaded 78 | **generation_kwargs (`dict`, *optional*): 79 | Additional keyword arguments passed along to the underlying model's `generate` method. 80 | This is used to override generation config 81 | 82 | Returns: 83 | List[List[str]]: A list of lists of generated texts 84 | """ 85 | queries = None 86 | 87 | if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: 88 | queries = tokenized_query.unsqueeze(0) 89 | elif isinstance(tokenized_query, List): 90 | element_type = type(tokenized_query[0]) 91 | if element_type == int: 92 | queries = torch.tensor(tokenized_query).unsqueeze(0) 93 | elif element_type == torch.Tensor: 94 | queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] 95 | else: 96 | queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] 97 | 98 | result = [] 99 | 100 | for query in queries: 101 | queries = query.repeat((self.sample_size, 1)) 102 | output = self.model.generate( 103 | queries.to(device), 104 | max_new_tokens=self.length_sampler(), 105 | generation_config=self.gen_config, 106 | **generation_kwargs, 107 | ).squeeze() 108 | output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) 109 | scores = torch.tensor(self.queries_to_scores(output)) 110 | output = [output[i] for i in scores.topk(self.n_candidates).indices] 111 | result.append(output) 112 | 113 | return result 114 | -------------------------------------------------------------------------------- /longva/trl/extras/dataset_formatting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Literal, Optional, Union 3 | 4 | from datasets import Dataset, Value 5 | from transformers import AutoTokenizer 6 | 7 | from ..trainer.utils import ConstantLengthDataset 8 | 9 | 10 | FORMAT_MAPPING = { 11 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], 12 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 13 | } 14 | 15 | 16 | def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): 17 | r""" 18 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer 19 | apply chat template to the dataset 20 | """ 21 | 22 | def format_dataset(examples): 23 | if isinstance(examples[messages_field][0], list): 24 | output_texts = [] 25 | for i in range(len(examples[messages_field])): 26 | output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) 27 | return output_texts 28 | else: 29 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) 30 | 31 | return format_dataset 32 | 33 | 34 | def instructions_formatting_function(tokenizer: AutoTokenizer): 35 | r""" 36 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer 37 | apply chat template to the dataset 38 | """ 39 | 40 | def format_dataset(examples): 41 | if isinstance(examples["prompt"], list): 42 | output_texts = [] 43 | for i in range(len(examples["prompt"])): 44 | converted_sample = [ 45 | {"role": "user", "content": examples["prompt"][i]}, 46 | {"role": "assistant", "content": examples["completion"][i]}, 47 | ] 48 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) 49 | return output_texts 50 | else: 51 | converted_sample = [ 52 | {"role": "user", "content": examples["prompt"]}, 53 | {"role": "assistant", "content": examples["completion"]}, 54 | ] 55 | return tokenizer.apply_chat_template(converted_sample, tokenize=False) 56 | 57 | return format_dataset 58 | 59 | 60 | def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]: 61 | r""" 62 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are: 63 | - `ChatML` with [{"role": str, "content": str}] 64 | - `instruction` with [{"prompt": str, "completion": str}] 65 | 66 | Args: 67 | dataset (Dataset): User dataset 68 | tokenizer (AutoTokenizer): Tokenizer used for formatting 69 | 70 | Returns: 71 | Callable: Formatting function if the dataset format is supported else None 72 | """ 73 | if isinstance(dataset, Dataset): 74 | if "messages" in dataset.features: 75 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: 76 | logging.info("Formatting dataset with chatml format") 77 | return conversations_formatting_function(tokenizer, "messages") 78 | if "conversations" in dataset.features: 79 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: 80 | logging.info("Formatting dataset with chatml format") 81 | return conversations_formatting_function(tokenizer, "conversations") 82 | elif dataset.features == FORMAT_MAPPING["instruction"]: 83 | logging.info("Formatting dataset with instruction format") 84 | return instructions_formatting_function(tokenizer) 85 | 86 | return None 87 | -------------------------------------------------------------------------------- /longva/trl/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import sys 16 | 17 | 18 | if sys.version_info < (3, 8): 19 | _is_python_greater_3_8 = False 20 | else: 21 | _is_python_greater_3_8 = True 22 | 23 | 24 | def is_peft_available() -> bool: 25 | return importlib.util.find_spec("peft") is not None 26 | 27 | 28 | def is_unsloth_available() -> bool: 29 | return importlib.util.find_spec("unsloth") is not None 30 | 31 | 32 | def is_accelerate_greater_20_0() -> bool: 33 | if _is_python_greater_3_8: 34 | from importlib.metadata import version 35 | 36 | accelerate_version = version("accelerate") 37 | else: 38 | import pkg_resources 39 | 40 | accelerate_version = pkg_resources.get_distribution("accelerate").version 41 | return accelerate_version >= "0.20.0" 42 | 43 | 44 | def is_transformers_greater_than(version: str) -> bool: 45 | _transformers_version = importlib.metadata.version("transformers") 46 | return _transformers_version > version 47 | 48 | 49 | def is_torch_greater_2_0() -> bool: 50 | if _is_python_greater_3_8: 51 | from importlib.metadata import version 52 | 53 | torch_version = version("torch") 54 | else: 55 | import pkg_resources 56 | 57 | torch_version = pkg_resources.get_distribution("torch").version 58 | return torch_version >= "2.0" 59 | 60 | 61 | def is_diffusers_available() -> bool: 62 | return importlib.util.find_spec("diffusers") is not None 63 | 64 | 65 | def is_bitsandbytes_available() -> bool: 66 | import torch 67 | 68 | # bnb can be imported without GPU but is not usable. 69 | return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() 70 | 71 | 72 | def is_torchvision_available() -> bool: 73 | return importlib.util.find_spec("torchvision") is not None 74 | 75 | 76 | def is_rich_available() -> bool: 77 | return importlib.util.find_spec("rich") is not None 78 | 79 | 80 | def is_wandb_available() -> bool: 81 | return importlib.util.find_spec("wandb") is not None 82 | 83 | 84 | def is_xpu_available() -> bool: 85 | if is_accelerate_greater_20_0(): 86 | import accelerate 87 | 88 | return accelerate.utils.is_xpu_available() 89 | else: 90 | if importlib.util.find_spec("intel_extension_for_pytorch") is None: 91 | return False 92 | try: 93 | import torch 94 | 95 | return hasattr(torch, "xpu") and torch.xpu.is_available() 96 | except RuntimeError: 97 | return False 98 | 99 | 100 | def is_npu_available() -> bool: 101 | """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" 102 | if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: 103 | return False 104 | 105 | import torch 106 | import torch_npu # noqa: F401 107 | 108 | return hasattr(torch, "npu") and torch.npu.is_available() 109 | -------------------------------------------------------------------------------- /longva/trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .modeling_base import PreTrainedModelWrapper, create_reference_model 17 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 18 | from .utils import setup_chat_format 19 | 20 | 21 | SUPPORTED_ARCHITECTURES = ( 22 | AutoModelForCausalLMWithValueHead, 23 | AutoModelForSeq2SeqLMWithValueHead, 24 | ) 25 | 26 | from ..import_utils import is_diffusers_available 27 | 28 | 29 | if is_diffusers_available(): 30 | from .modeling_sd_base import ( 31 | DDPOPipelineOutput, 32 | DDPOSchedulerOutput, 33 | DDPOStableDiffusionPipeline, 34 | DefaultDDPOStableDiffusionPipeline, 35 | ) 36 | -------------------------------------------------------------------------------- /longva/trl/models/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Tuple 3 | 4 | from transformers import PreTrainedModel, PreTrainedTokenizer 5 | 6 | 7 | # TODO: Add Abstract Base Class if more formats are added 8 | @dataclass 9 | class ChatMlSpecialTokens: 10 | """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" 11 | 12 | bos_token: str = "<|im_start|>" 13 | eos_token: str = "<|im_end|>" 14 | pad_token: str = "<|im_end|>" 15 | 16 | @property 17 | def system(self): 18 | return f"{self.bos_token}system" 19 | 20 | @property 21 | def user(self): 22 | return f"{self.bos_token}user" 23 | 24 | @property 25 | def assistant(self): 26 | return f"{self.bos_token}assistant" 27 | 28 | @property 29 | def chat_template(self): 30 | return ( 31 | "{% for message in messages %}" 32 | f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" 33 | "{% endfor %}" 34 | "{% if add_generation_prompt %}" 35 | f"{{{{ '{self.assistant}\n' }}}}" 36 | "{% endif %}" 37 | ) 38 | 39 | 40 | FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} 41 | 42 | 43 | def setup_chat_format( 44 | model: PreTrainedModel, 45 | tokenizer: PreTrainedTokenizer, 46 | format: Optional[Literal["chatml"]] = "chatml", 47 | resize_to_multiple_of: Optional[int] = None, 48 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 49 | """ 50 | Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. 51 | 52 | Args: 53 | model (`~transformers.PreTrainedModel`): The model to be modified. 54 | tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. 55 | format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". 56 | resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. 57 | Returns: 58 | model (`~transformers.PreTrainedModel`): The modified model. 59 | tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. 60 | """ 61 | # check if format available and retrieve 62 | if format not in FORMAT_MAPPING: 63 | raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") 64 | 65 | chat_format = FORMAT_MAPPING[format]() 66 | 67 | # set special tokens and them 68 | tokenizer.eos_token = chat_format.eos_token 69 | tokenizer.pad_token = chat_format.pad_token 70 | tokenizer.bos_token = chat_format.bos_token 71 | tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) 72 | # set chat format for tokenizer 73 | tokenizer.chat_template = chat_format.chat_template 74 | 75 | # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 76 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) 77 | # Make sure to update the generation config to use the new eos & bos token 78 | if getattr(model, "generation_config", None) is not None: 79 | model.generation_config.bos_token_id = tokenizer.bos_token_id 80 | model.generation_config.eos_token_id = tokenizer.eos_token_id 81 | model.generation_config.pad_token_id = tokenizer.pad_token_id 82 | 83 | return model, tokenizer 84 | -------------------------------------------------------------------------------- /longva/trl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # There is a circular import in the PPOTrainer if we let isort sort these 18 | # isort: off 19 | from .utils import ( 20 | AdaptiveKLController, 21 | FixedKLController, 22 | ConstantLengthDataset, 23 | DataCollatorForCompletionOnlyLM, 24 | RunningMoments, 25 | disable_dropout_in_model, 26 | peft_module_casting_to_bf16, 27 | ) 28 | 29 | # isort: on 30 | 31 | from ..import_utils import is_diffusers_available 32 | from .base import BaseTrainer 33 | from .ddpo_config import DDPOConfig 34 | 35 | 36 | if is_diffusers_available(): 37 | from .ddpo_trainer import DDPOTrainer 38 | 39 | from .dpo_trainer import DPOTrainer 40 | from .iterative_sft_trainer import IterativeSFTTrainer 41 | from .model_config import ModelConfig 42 | from .ppo_config import PPOConfig 43 | from .ppo_trainer import PPOTrainer 44 | from .reward_config import RewardConfig 45 | from .reward_trainer import RewardTrainer, compute_accuracy 46 | from .sft_trainer import SFTTrainer 47 | -------------------------------------------------------------------------------- /longva/trl/trainer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from huggingface_hub import PyTorchModelHubMixin 16 | 17 | 18 | class BaseTrainer(PyTorchModelHubMixin): 19 | r""" 20 | Base class for all trainers - this base class implements the basic functions that we 21 | need for a trainer. 22 | 23 | The trainer needs to have the following functions: 24 | - step: takes in a batch of data and performs a step of training 25 | - loss: takes in a batch of data and returns the loss 26 | - compute_rewards: takes in a batch of data and returns the rewards 27 | - _build_models_and_tokenizer: builds the models and tokenizer 28 | - _build_dataset: builds the dataset 29 | Each user is expected to implement their own trainer class that inherits from this base 30 | if they want to use a new training algorithm. 31 | """ 32 | 33 | def __init__(self, config): 34 | self.config = config 35 | 36 | def step(self, *args): 37 | raise NotImplementedError("Not implemented") 38 | 39 | def loss(self, *args): 40 | raise NotImplementedError("Not implemented") 41 | 42 | def compute_rewards(self, *args): 43 | raise NotImplementedError("Not implemented") 44 | 45 | def _save_pretrained(self, save_directory): 46 | raise NotImplementedError("Not implemented") 47 | -------------------------------------------------------------------------------- /longva/trl/trainer/ddpo_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from dataclasses import dataclass, field 5 | from typing import Literal, Optional 6 | 7 | from ..core import flatten_dict 8 | from ..import_utils import is_bitsandbytes_available, is_torchvision_available 9 | 10 | 11 | @dataclass 12 | class DDPOConfig: 13 | """ 14 | Configuration class for DDPOTrainer 15 | """ 16 | 17 | # common parameters 18 | exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] 19 | """the name of this experiment (by default is the file name without the extension name)""" 20 | run_name: Optional[str] = "" 21 | """Run name for wandb logging and checkpoint saving.""" 22 | seed: int = 0 23 | """Seed value for random generations""" 24 | log_with: Optional[Literal["wandb", "tensorboard"]] = None 25 | """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" 26 | tracker_kwargs: dict = field(default_factory=dict) 27 | """Keyword arguments for the tracker (e.g. wandb_project)""" 28 | accelerator_kwargs: dict = field(default_factory=dict) 29 | """Keyword arguments for the accelerator""" 30 | project_kwargs: dict = field(default_factory=dict) 31 | """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" 32 | tracker_project_name: str = "trl" 33 | """Name of project to use for tracking""" 34 | logdir: str = "logs" 35 | """Top-level logging directory for checkpoint saving.""" 36 | 37 | # hyperparameters 38 | num_epochs: int = 100 39 | """Number of epochs to train.""" 40 | save_freq: int = 1 41 | """Number of epochs between saving model checkpoints.""" 42 | num_checkpoint_limit: int = 5 43 | """Number of checkpoints to keep before overwriting old ones.""" 44 | mixed_precision: str = "fp16" 45 | """Mixed precision training.""" 46 | allow_tf32: bool = True 47 | """Allow tf32 on Ampere GPUs.""" 48 | resume_from: Optional[str] = "" 49 | """Resume training from a checkpoint.""" 50 | sample_num_steps: int = 50 51 | """Number of sampler inference steps.""" 52 | sample_eta: float = 1.0 53 | """Eta parameter for the DDIM sampler.""" 54 | sample_guidance_scale: float = 5.0 55 | """Classifier-free guidance weight.""" 56 | sample_batch_size: int = 1 57 | """Batch size (per GPU!) to use for sampling.""" 58 | sample_num_batches_per_epoch: int = 2 59 | """Number of batches to sample per epoch.""" 60 | train_batch_size: int = 1 61 | """Batch size (per GPU!) to use for training.""" 62 | train_use_8bit_adam: bool = False 63 | """Whether to use the 8bit Adam optimizer from bitsandbytes.""" 64 | train_learning_rate: float = 3e-4 65 | """Learning rate.""" 66 | train_adam_beta1: float = 0.9 67 | """Adam beta1.""" 68 | train_adam_beta2: float = 0.999 69 | """Adam beta2.""" 70 | train_adam_weight_decay: float = 1e-4 71 | """Adam weight decay.""" 72 | train_adam_epsilon: float = 1e-8 73 | """Adam epsilon.""" 74 | train_gradient_accumulation_steps: int = 1 75 | """Number of gradient accumulation steps.""" 76 | train_max_grad_norm: float = 1.0 77 | """Maximum gradient norm for gradient clipping.""" 78 | train_num_inner_epochs: int = 1 79 | """Number of inner epochs per outer epoch.""" 80 | train_cfg: bool = True 81 | """Whether or not to use classifier-free guidance during training.""" 82 | train_adv_clip_max: float = 5 83 | """Clip advantages to the range.""" 84 | train_clip_range: float = 1e-4 85 | """The PPO clip range.""" 86 | train_timestep_fraction: float = 1.0 87 | """The fraction of timesteps to train on.""" 88 | per_prompt_stat_tracking: bool = False 89 | """Whether to track statistics for each prompt separately.""" 90 | per_prompt_stat_tracking_buffer_size: int = 16 91 | """Number of reward values to store in the buffer for each prompt.""" 92 | per_prompt_stat_tracking_min_count: int = 16 93 | """The minimum number of reward values to store in the buffer.""" 94 | async_reward_computation: bool = False 95 | """Whether to compute rewards asynchronously.""" 96 | max_workers: int = 2 97 | """The maximum number of workers to use for async reward computation.""" 98 | negative_prompts: Optional[str] = "" 99 | """Comma-separated list of prompts to use as negative examples.""" 100 | 101 | def to_dict(self): 102 | output_dict = {} 103 | for key, value in self.__dict__.items(): 104 | output_dict[key] = value 105 | return flatten_dict(output_dict) 106 | 107 | def __post_init__(self): 108 | if self.log_with not in ["wandb", "tensorboard"]: 109 | warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.")) 110 | 111 | if self.log_with == "wandb" and not is_torchvision_available(): 112 | warnings.warn("Wandb image logging requires torchvision to be installed") 113 | 114 | if self.train_use_8bit_adam and not is_bitsandbytes_available(): 115 | raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.") 116 | -------------------------------------------------------------------------------- /longva/trl/trainer/model_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | from ..core import flatten_dict 5 | 6 | 7 | @dataclass 8 | class ModelConfig: 9 | """ 10 | Arguments which define the model and tokenizer to load. 11 | """ 12 | 13 | model_name_or_path: Optional[str] = field( 14 | default=None, 15 | metadata={"help": ("The model checkpoint for weights initialization.")}, 16 | ) 17 | model_revision: str = field( 18 | default="main", 19 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 20 | ) 21 | torch_dtype: Optional[str] = field( 22 | default=None, 23 | metadata={ 24 | "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), 25 | "choices": ["auto", "bfloat16", "float16", "float32"], 26 | }, 27 | ) 28 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) 29 | attn_implementation: Optional[str] = field( 30 | default=None, 31 | metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, 32 | ) 33 | use_peft: bool = field( 34 | default=False, 35 | metadata={"help": ("Whether to use PEFT or not for training.")}, 36 | ) 37 | lora_r: Optional[int] = field( 38 | default=16, 39 | metadata={"help": ("LoRA R value.")}, 40 | ) 41 | lora_alpha: Optional[int] = field( 42 | default=32, 43 | metadata={"help": ("LoRA alpha.")}, 44 | ) 45 | lora_dropout: Optional[float] = field( 46 | default=0.05, 47 | metadata={"help": ("LoRA dropout.")}, 48 | ) 49 | lora_target_modules: Optional[List[str]] = field( 50 | default=None, 51 | metadata={"help": ("LoRA target modules.")}, 52 | ) 53 | lora_modules_to_save: Optional[List[str]] = field( 54 | default=None, 55 | metadata={"help": ("Model layers to unfreeze & train")}, 56 | ) 57 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) 58 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) 59 | 60 | bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) 61 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) 62 | 63 | def to_dict(self): 64 | output_dict = {} 65 | for key, value in self.__dict__.items(): 66 | output_dict[key] = value 67 | return flatten_dict(output_dict) 68 | 69 | def __post_init__(self): 70 | if self.load_in_8bit and self.load_in_4bit: 71 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 72 | -------------------------------------------------------------------------------- /longva/trl/trainer/reward_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class RewardConfig(TrainingArguments): 23 | """ 24 | RewardConfig collects all training arguments related to the [`RewardTrainer`] class. 25 | 26 | Using [`HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | max_length (`int`, *optional*, defaults to `None`): 32 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 33 | gradient_checkpointing (`bool`, *optional*, defaults to `True`): 34 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 35 | """ 36 | 37 | max_length: Optional[int] = None 38 | """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | decord>=0.6.0 3 | datasets 4 | evaluate>=0.4.1 5 | tqdm 6 | einops 7 | sentencepiece 8 | scikit-learn 9 | matplotlib 10 | numpy==1.26.4 11 | ring_flash_attn@git+https://github.com/zhuzilin/ring-flash-attention.git@b39e427e737ed9a515c88122f88f9183b7eb32de 12 | deepspeed==0.14.0 13 | wandb 14 | seaborn 15 | pandas 16 | pytest 17 | loguru 18 | gradio==4.29.0 -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/bias.txt: -------------------------------------------------------------------------------- 1 | October 2015This will come as a surprise to a lot of people, but in some cases 2 | it's possible to detect bias in a selection process without knowing 3 | anything about the applicant pool. Which is exciting because among 4 | other things it means third parties can use this technique to detect 5 | bias whether those doing the selecting want them to or not.You can use this technique whenever (a) you have at least 6 | a random sample of the applicants that were selected, (b) their 7 | subsequent performance is measured, and (c) the groups of 8 | applicants you're comparing have roughly equal distribution of ability.How does it work? Think about what it means to be biased. What 9 | it means for a selection process to be biased against applicants 10 | of type x is that it's harder for them to make it through. Which 11 | means applicants of type x have to be better to get selected than 12 | applicants not of type x. 13 | [1] 14 | Which means applicants of type x 15 | who do make it through the selection process will outperform other 16 | successful applicants. And if the performance of all the successful 17 | applicants is measured, you'll know if they do.Of course, the test you use to measure performance must be a valid 18 | one. And in particular it must not be invalidated by the bias you're 19 | trying to measure. 20 | But there are some domains where performance can be measured, and 21 | in those detecting bias is straightforward. Want to know if the 22 | selection process was biased against some type of applicant? Check 23 | whether they outperform the others. This is not just a heuristic 24 | for detecting bias. It's what bias means.For example, many suspect that venture capital firms are biased 25 | against female founders. This would be easy to detect: among their 26 | portfolio companies, do startups with female founders outperform 27 | those without? A couple months ago, one VC firm (almost certainly 28 | unintentionally) published a study showing bias of this type. First 29 | Round Capital found that among its portfolio companies, startups 30 | with female founders outperformed 31 | those without by 63%. 32 | [2]The reason I began by saying that this technique would come as a 33 | surprise to many people is that we so rarely see analyses of this 34 | type. I'm sure it will come as a surprise to First Round that they 35 | performed one. I doubt anyone there realized that by limiting their 36 | sample to their own portfolio, they were producing a study not of 37 | startup trends but of their own biases when selecting companies.I predict we'll see this technique used more in the future. The 38 | information needed to conduct such studies is increasingly available. 39 | Data about who applies for things is usually closely guarded by the 40 | organizations selecting them, but nowadays data about who gets 41 | selected is often publicly available to anyone who takes the trouble 42 | to aggregate it. 43 | Notes[1] 44 | This technique wouldn't work if the selection process looked 45 | for different things from different types of applicants—for 46 | example, if an employer hired men based on their ability but women 47 | based on their appearance.[2] 48 | As Paul Buchheit points out, First Round excluded their most 49 | successful investment, Uber, from the study. And while it 50 | makes sense to exclude outliers from some types of studies, 51 | studies of returns from startup investing, which is all about 52 | hitting outliers, are not one of them. 53 | Thanks to Sam Altman, Jessica Livingston, and Geoff Ralston for reading 54 | drafts of this. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/copy.txt: -------------------------------------------------------------------------------- 1 | July 2006 2 | When I was in high school I spent a lot of time imitating bad 3 | writers. What we studied in English classes was mostly fiction, 4 | so I assumed that was the highest form of writing. Mistake number 5 | one. The stories that seemed to be most admired were ones in which 6 | people suffered in complicated ways. Anything funny or 7 | gripping was ipso facto suspect, unless it was old enough to be hard to 8 | understand, like Shakespeare or Chaucer. Mistake number two. The 9 | ideal medium seemed the short story, which I've since learned had 10 | quite a brief life, roughly coincident with the peak of magazine 11 | publishing. But since their size made them perfect for use in 12 | high school classes, we read a lot of them, which gave us the 13 | impression the short story was flourishing. Mistake number three. 14 | And because they were so short, nothing really had to happen; you 15 | could just show a randomly truncated slice of life, and that was 16 | considered advanced. Mistake number four. The result was that I 17 | wrote a lot of stories in which nothing happened except that someone 18 | was unhappy in a way that seemed deep.For most of college I was a philosophy major. I was very impressed 19 | by the papers published in philosophy journals. They were so 20 | beautifully typeset, and their tone was just captivating—alternately 21 | casual and buffer-overflowingly technical. A fellow would be walking 22 | along a street and suddenly modality qua modality would spring upon 23 | him. I didn't ever quite understand these papers, but I figured 24 | I'd get around to that later, when I had time to reread them more 25 | closely. In the meantime I tried my best to imitate them. This 26 | was, I can now see, a doomed undertaking, because they weren't 27 | really saying anything. No philosopher ever refuted another, for 28 | example, because no one said anything definite enough to refute. 29 | Needless to say, my imitations didn't say anything either.In grad school I was still wasting time imitating the wrong things. 30 | There was then a fashionable type of program called an expert system, 31 | at the core of which was something called an inference engine. I 32 | looked at what these things did and thought "I could write that in 33 | a thousand lines of code." And yet eminent professors were writing 34 | books about them, and startups were selling them for a year's salary 35 | a copy. What an opportunity, I thought; these impressive things 36 | seem easy to me; I must be pretty sharp. Wrong. It was simply a 37 | fad. The books the professors wrote about expert systems are now 38 | ignored. They were not even on a path to anything interesting. 39 | And the customers paying so much for them were largely the same 40 | government agencies that paid thousands for screwdrivers and toilet 41 | seats.How do you avoid copying the wrong things? Copy only what you 42 | genuinely like. That would have saved me in all three cases. I 43 | didn't enjoy the short stories we had to read in English classes; 44 | I didn't learn anything from philosophy papers; I didn't use expert 45 | systems myself. I believed these things were good because they 46 | were admired.It can be hard to separate the things you like from the things 47 | you're impressed with. One trick is to ignore presentation. Whenever 48 | I see a painting impressively hung in a museum, I ask myself: how 49 | much would I pay for this if I found it at a garage sale, dirty and 50 | frameless, and with no idea who painted it? If you walk around a 51 | museum trying this experiment, you'll find you get some truly 52 | startling results. Don't ignore this data point just because it's 53 | an outlier.Another way to figure out what you like is to look at what you enjoy 54 | as guilty pleasures. Many things people like, especially if they're 55 | young and ambitious, they like largely for the feeling of virtue 56 | in liking them. 99% of people reading Ulysses are thinking 57 | "I'm reading Ulysses" as they do it. A guilty pleasure is 58 | at least a pure one. What do you read when you don't feel up to being 59 | virtuous? What kind of book do you read and feel sad that there's 60 | only half of it left, instead of being impressed that you're half 61 | way through? That's what you really like.Even when you find genuinely good things to copy, there's another 62 | pitfall to be avoided. Be careful to copy what makes them good, 63 | rather than their flaws. It's easy to be drawn into imitating 64 | flaws, because they're easier to see, and of course easier to copy 65 | too. For example, most painters in the eighteenth and nineteenth 66 | centuries used brownish colors. They were imitating the great 67 | painters of the Renaissance, whose paintings by that time were brown 68 | with dirt. Those paintings have since been cleaned, revealing 69 | brilliant colors; their imitators are of course still brown.It was painting, incidentally, that cured me of copying the wrong 70 | things. Halfway through grad school I decided I wanted to try being 71 | a painter, and the art world was so manifestly corrupt that it 72 | snapped the leash of credulity. These people made philosophy 73 | professors seem as scrupulous as mathematicians. It was so clearly 74 | a choice of doing good work xor being an insider that I was forced 75 | to see the distinction. It's there to some degree in almost every 76 | field, but I had till then managed to avoid facing it.That was one of the most valuable things I learned from painting: 77 | you have to figure out for yourself what's 78 | good. You can't trust 79 | authorities. They'll lie to you on this one. 80 | 81 | Comment on this essay. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/diff.txt: -------------------------------------------------------------------------------- 1 | December 2001 (rev. May 2002) 2 | 3 | (This article came about in response to some questions on 4 | the LL1 mailing list. It is now 5 | incorporated in Revenge of the Nerds.)When McCarthy designed Lisp in the late 1950s, it was 6 | a radical departure from existing languages, 7 | the most important of which was Fortran.Lisp embodied nine new ideas: 8 | 1. Conditionals. A conditional is an if-then-else 9 | construct. We take these for granted now. They were 10 | invented 11 | by McCarthy in the course of developing Lisp. 12 | (Fortran at that time only had a conditional 13 | goto, closely based on the branch instruction in the 14 | underlying hardware.) McCarthy, who was on the Algol committee, got 15 | conditionals into Algol, whence they spread to most other 16 | languages.2. A function type. In Lisp, functions are first class 17 | objects-- they're a data type just like integers, strings, 18 | etc, and have a literal representation, can be stored in variables, 19 | can be passed as arguments, and so on.3. Recursion. Recursion existed as a mathematical concept 20 | before Lisp of course, but Lisp was the first programming language to support 21 | it. (It's arguably implicit in making functions first class 22 | objects.)4. A new concept of variables. In Lisp, all variables 23 | are effectively pointers. Values are what 24 | have types, not variables, and assigning or binding 25 | variables means copying pointers, not what they point to.5. Garbage-collection.6. Programs composed of expressions. Lisp programs are 26 | trees of expressions, each of which returns a value. 27 | (In some Lisps expressions 28 | can return multiple values.) This is in contrast to Fortran 29 | and most succeeding languages, which distinguish between 30 | expressions and statements.It was natural to have this 31 | distinction in Fortran because (not surprisingly in a language 32 | where the input format was punched cards) the language was 33 | line-oriented. You could not nest statements. And 34 | so while you needed expressions for math to work, there was 35 | no point in making anything else return a value, because 36 | there could not be anything waiting for it.This limitation 37 | went away with the arrival of block-structured languages, 38 | but by then it was too late. The distinction between 39 | expressions and statements was entrenched. It spread from 40 | Fortran into Algol and thence to both their descendants.When a language is made entirely of expressions, you can 41 | compose expressions however you want. You can say either 42 | (using Arc syntax)(if foo (= x 1) (= x 2))or(= x (if foo 1 2))7. A symbol type. Symbols differ from strings in that 43 | you can test equality by comparing a pointer.8. A notation for code using trees of symbols.9. The whole language always available. 44 | There is 45 | no real distinction between read-time, compile-time, and runtime. 46 | You can compile or run code while reading, read or run code 47 | while compiling, and read or compile code at runtime.Running code at read-time lets users reprogram Lisp's syntax; 48 | running code at compile-time is the basis of macros; compiling 49 | at runtime is the basis of Lisp's use as an extension 50 | language in programs like Emacs; and reading at runtime 51 | enables programs to communicate using s-expressions, an 52 | idea recently reinvented as XML. 53 | When Lisp was first invented, all these ideas were far 54 | removed from ordinary programming practice, which was 55 | dictated largely by the hardware available in the late 1950s.Over time, the default language, embodied 56 | in a succession of popular languages, has 57 | gradually evolved toward Lisp. 1-5 are now widespread. 58 | 6 is starting to appear in the mainstream. 59 | Python has a form of 7, though there doesn't seem to be 60 | any syntax for it. 61 | 8, which (with 9) is what makes Lisp macros 62 | possible, is so far still unique to Lisp, 63 | perhaps because (a) it requires those parens, or something 64 | just as bad, and (b) if you add that final increment of power, 65 | you can no 66 | longer claim to have invented a new language, but only 67 | to have designed a new dialect of Lisp ; -)Though useful to present-day programmers, it's 68 | strange to describe Lisp in terms of its 69 | variation from the random expedients other languages 70 | adopted. That was not, probably, how McCarthy 71 | thought of it. Lisp wasn't designed to fix the mistakes 72 | in Fortran; it came about more as the byproduct of an 73 | attempt to axiomatize computation. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/founders.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | Want to start a startup? Get funded by 4 | Y Combinator. 5 | 6 | 7 | 8 | 9 | October 2010 10 | 11 | (I wrote this for Forbes, who asked me to write something 12 | about the qualities we look for in founders. In print they had to cut 13 | the last item because they didn't have room.)1. DeterminationThis has turned out to be the most important quality in startup 14 | founders. We thought when we started Y Combinator that the most 15 | important quality would be intelligence. That's the myth in the 16 | Valley. And certainly you don't want founders to be stupid. But 17 | as long as you're over a certain threshold of intelligence, what 18 | matters most is determination. You're going to hit a lot of 19 | obstacles. You can't be the sort of person who gets demoralized 20 | easily.Bill Clerico and Rich Aberman of WePay 21 | are a good example. They're 22 | doing a finance startup, which means endless negotiations with big, 23 | bureaucratic companies. When you're starting a startup that depends 24 | on deals with big companies to exist, it often feels like they're 25 | trying to ignore you out of existence. But when Bill Clerico starts 26 | calling you, you may as well do what he asks, because he is not 27 | going away. 28 | 2. FlexibilityYou do not however want the sort of determination implied by phrases 29 | like "don't give up on your dreams." The world of startups is so 30 | unpredictable that you need to be able to modify your dreams on the 31 | fly. The best metaphor I've found for the combination of determination 32 | and flexibility you need is a running back. 33 | He's determined to get 34 | downfield, but at any given moment he may need to go sideways or 35 | even backwards to get there.The current record holder for flexibility may be Daniel Gross of 36 | Greplin. He applied to YC with 37 | some bad ecommerce idea. We told 38 | him we'd fund him if he did something else. He thought for a second, 39 | and said ok. He then went through two more ideas before settling 40 | on Greplin. He'd only been working on it for a couple days when 41 | he presented to investors at Demo Day, but he got a lot of interest. 42 | He always seems to land on his feet. 43 | 3. ImaginationIntelligence does matter a lot of course. It seems like the type 44 | that matters most is imagination. It's not so important to be able 45 | to solve predefined problems quickly as to be able to come up with 46 | surprising new ideas. In the startup world, most good ideas 47 | seem 48 | bad initially. If they were obviously good, someone would already 49 | be doing them. So you need the kind of intelligence that produces 50 | ideas with just the right level of craziness.Airbnb is that kind of idea. 51 | In fact, when we funded Airbnb, we 52 | thought it was too crazy. We couldn't believe large numbers of 53 | people would want to stay in other people's places. We funded them 54 | because we liked the founders so much. As soon as we heard they'd 55 | been supporting themselves by selling Obama and McCain branded 56 | breakfast cereal, they were in. And it turned out the idea was on 57 | the right side of crazy after all. 58 | 4. NaughtinessThough the most successful founders are usually good people, they 59 | tend to have a piratical gleam in their eye. They're not Goody 60 | Two-Shoes type good. Morally, they care about getting the big 61 | questions right, but not about observing proprieties. That's why 62 | I'd use the word naughty rather than evil. They delight in 63 | breaking 64 | rules, but not rules that matter. This quality may be redundant 65 | though; it may be implied by imagination.Sam Altman of Loopt 66 | is one of the most successful alumni, so we 67 | asked him what question we could put on the Y Combinator application 68 | that would help us discover more people like him. He said to ask 69 | about a time when they'd hacked something to their advantage—hacked in the sense of beating the system, not breaking into 70 | computers. It has become one of the questions we pay most attention 71 | to when judging applications. 72 | 5. FriendshipEmpirically it seems to be hard to start a startup with just 73 | one 74 | founder. Most of the big successes have two or three. And the 75 | relationship between the founders has to be strong. They must 76 | genuinely like one another, and work well together. Startups do 77 | to the relationship between the founders what a dog does to a sock: 78 | if it can be pulled apart, it will be.Emmett Shear and Justin Kan of Justin.tv 79 | are a good example of close 80 | friends who work well together. They've known each other since 81 | second grade. They can practically read one another's minds. I'm 82 | sure they argue, like all founders, but I have never once sensed 83 | any unresolved tension between them.Thanks to Jessica Livingston and Chris Steiner for reading drafts of this. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/foundervisa.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | April 2009I usually avoid politics, but since we now seem to have an administration that's open to suggestions, I'm going to risk making one. The single biggest thing the government could do to increase the number of startups in this country is a policy that would cost nothing: establish a new class of visa for startup founders.The biggest constraint on the number of new startups that get created in the US is not tax policy or employment law or even Sarbanes-Oxley. It's that we won't let the people who want to start them into the country.Letting just 10,000 startup founders into the country each year could have a visible effect on the economy. If we assume 4 people per startup, which is probably an overestimate, that's 2500 new companies. Each year. They wouldn't all grow as big as Google, but out of 2500 some would come close.By definition these 10,000 founders wouldn't be taking jobs from Americans: it could be part of the terms of the visa that they couldn't work for existing companies, only new ones they'd founded. In fact they'd cause there to be 4 | more jobs for Americans, because the companies they started would hire more employees as they grew.The tricky part might seem to be how one defined a startup. But that could be solved quite easily: let the market decide. Startup investors work hard to find the best startups. The government could not do better than to piggyback on their expertise, and use investment by recognized startup investors as the test of whether a company was a real startup.How would the government decide who's a startup investor? The same way they decide what counts as a university for student visas. We'll establish our own accreditation procedure. We know who one another are.10,000 people is a drop in the bucket by immigration standards, but would represent a huge increase in the pool of startup founders. I think this would have such a visible effect on the economy that it would make the legislator who introduced the bill famous. The only way to know for sure would be to try it, and that would cost practically nothing. 5 | Thanks to Trevor Blackwell, Paul Buchheit, Jeff Clavier, David Hornik, Jessica Livingston, Greg Mcadoo, Aydin Senkut, and Fred Wilson for reading drafts of this.Related: -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/iflisp.txt: -------------------------------------------------------------------------------- 1 | May 2003If Lisp is so great, why don't more people use it? I was 2 | asked this question by a student in the audience at a 3 | talk I gave recently. Not for the first time, either.In languages, as in so many things, there's not much 4 | correlation between popularity and quality. Why does 5 | John Grisham (King of Torts sales rank, 44) outsell 6 | Jane Austen (Pride and Prejudice sales rank, 6191)? 7 | Would even Grisham claim that it's because he's a better 8 | writer?Here's the first sentence of Pride and Prejudice: 9 | 10 | It is a truth universally acknowledged, that a single man 11 | in possession of a good fortune must be in want of a 12 | wife. 13 | 14 | "It is a truth universally acknowledged?" Long words for 15 | the first sentence of a love story.Like Jane Austen, Lisp looks hard. Its syntax, or lack 16 | of syntax, makes it look completely unlike 17 | the languages 18 | most people are used to. Before I learned Lisp, I was afraid 19 | of it too. I recently came across a notebook from 1983 20 | in which I'd written: 21 | 22 | I suppose I should learn Lisp, but it seems so foreign. 23 | 24 | Fortunately, I was 19 at the time and not too resistant to learning 25 | new things. I was so ignorant that learning 26 | almost anything meant learning new things.People frightened by Lisp make up other reasons for not 27 | using it. The standard 28 | excuse, back when C was the default language, was that Lisp 29 | was too slow. Now that Lisp dialects are among 30 | the faster 31 | languages available, that excuse has gone away. 32 | Now the standard excuse is openly circular: that other languages 33 | are more popular.(Beware of such reasoning. It gets you Windows.)Popularity is always self-perpetuating, but it's especially 34 | so in programming languages. More libraries 35 | get written for popular languages, which makes them still 36 | more popular. Programs often have to work with existing programs, 37 | and this is easier if they're written in the same language, 38 | so languages spread from program to program like a virus. 39 | And managers prefer popular languages, because they give them 40 | more leverage over developers, who can more easily be replaced.Indeed, if programming languages were all more or less equivalent, 41 | there would be little justification for using any but the most 42 | popular. But they aren't all equivalent, not by a long 43 | shot. And that's why less popular languages, like Jane Austen's 44 | novels, continue to survive at all. When everyone else is reading 45 | the latest John Grisham novel, there will always be a few people 46 | reading Jane Austen instead. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/island.txt: -------------------------------------------------------------------------------- 1 | July 2006I've discovered a handy test for figuring out what you're addicted 2 | to. Imagine you were going to spend the weekend at a friend's house 3 | on a little island off the coast of Maine. There are no shops on 4 | the island and you won't be able to leave while you're there. Also, 5 | you've never been to this house before, so you can't assume it will 6 | have more than any house might.What, besides clothes and toiletries, do you make a point of packing? 7 | That's what you're addicted to. For example, if you find yourself 8 | packing a bottle of vodka (just in case), you may want to stop and 9 | think about that.For me the list is four things: books, earplugs, a notebook, and a 10 | pen.There are other things I might bring if I thought of it, like music, 11 | or tea, but I can live without them. I'm not so addicted to caffeine 12 | that I wouldn't risk the house not having any tea, just for a 13 | weekend.Quiet is another matter. I realize it seems a bit eccentric to 14 | take earplugs on a trip to an island off the coast of Maine. If 15 | anywhere should be quiet, that should. But what if the person in 16 | the next room snored? What if there was a kid playing basketball? 17 | (Thump, thump, thump... thump.) Why risk it? Earplugs are small.Sometimes I can think with noise. If I already have momentum on 18 | some project, I can work in noisy places. I can edit an essay or 19 | debug code in an airport. But airports are not so bad: most of the 20 | noise is whitish. I couldn't work with the sound of a sitcom coming 21 | through the wall, or a car in the street playing thump-thump music.And of course there's another kind of thinking, when you're starting 22 | something new, that requires complete quiet. You never 23 | know when this will strike. It's just as well to carry plugs.The notebook and pen are professional equipment, as it were. Though 24 | actually there is something druglike about them, in the sense that 25 | their main purpose is to make me feel better. I hardly ever go 26 | back and read stuff I write down in notebooks. It's just that if 27 | I can't write things down, worrying about remembering one idea gets 28 | in the way of having the next. Pen and paper wick ideas.The best notebooks I've found are made by a company called Miquelrius. 29 | I use their smallest size, which is about 2.5 x 4 in. 30 | The secret to writing on such 31 | narrow pages is to break words only when you run out of space, like 32 | a Latin inscription. I use the cheapest plastic Bic ballpoints, 33 | partly because their gluey ink doesn't seep through pages, and 34 | partly so I don't worry about losing them.I only started carrying a notebook about three years ago. Before 35 | that I used whatever scraps of paper I could find. But the problem 36 | with scraps of paper is that they're not ordered. In a notebook 37 | you can guess what a scribble means by looking at the pages 38 | around it. In the scrap era I was constantly finding notes I'd 39 | written years before that might say something I needed to remember, 40 | if I could only figure out what.As for books, I know the house would probably have something to 41 | read. On the average trip I bring four books and only read one of 42 | them, because I find new books to read en route. Really bringing 43 | books is insurance.I realize this dependence on books is not entirely good—that what 44 | I need them for is distraction. The books I bring on trips are 45 | often quite virtuous, the sort of stuff that might be assigned 46 | reading in a college class. But I know my motives aren't virtuous. 47 | I bring books because if the world gets boring I need to be able 48 | to slip into another distilled by some writer. It's like eating 49 | jam when you know you should be eating fruit.There is a point where I'll do without books. I was walking in 50 | some steep mountains once, and decided I'd rather just think, if I 51 | was bored, rather than carry a single unnecessary ounce. It wasn't 52 | so bad. I found I could entertain myself by having ideas instead 53 | of reading other people's. If you stop eating jam, fruit starts 54 | to taste better.So maybe I'll try not bringing books on some future trip. They're 55 | going to have to pry the plugs out of my cold, dead ears, however. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/know.txt: -------------------------------------------------------------------------------- 1 | December 2014I've read Villehardouin's chronicle of the Fourth Crusade at least 2 | two times, maybe three. And yet if I had to write down everything 3 | I remember from it, I doubt it would amount to much more than a 4 | page. Multiply this times several hundred, and I get an uneasy 5 | feeling when I look at my bookshelves. What use is it to read all 6 | these books if I remember so little from them?A few months ago, as I was reading Constance Reid's excellent 7 | biography of Hilbert, I figured out if not the answer to this 8 | question, at least something that made me feel better about it. 9 | She writes: 10 | 11 | Hilbert had no patience with mathematical lectures which filled 12 | the students with facts but did not teach them how to frame a 13 | problem and solve it. He often used to tell them that "a perfect 14 | formulation of a problem is already half its solution." 15 | 16 | That has always seemed to me an important point, and I was even 17 | more convinced of it after hearing it confirmed by Hilbert.But how had I come to believe in this idea in the first place? A 18 | combination of my own experience and other things I'd read. None 19 | of which I could at that moment remember! And eventually I'd forget 20 | that Hilbert had confirmed it too. But my increased belief in the 21 | importance of this idea would remain something I'd learned from 22 | this book, even after I'd forgotten I'd learned it.Reading and experience train your model of the world. And even if 23 | you forget the experience or what you read, its effect on your model 24 | of the world persists. Your mind is like a compiled program you've 25 | lost the source of. It works, but you don't know why.The place to look for what I learned from Villehardouin's chronicle 26 | is not what I remember from it, but my mental models of the crusades, 27 | Venice, medieval culture, siege warfare, and so on. Which doesn't 28 | mean I couldn't have read more attentively, but at least the harvest 29 | of reading is not so miserably small as it might seem.This is one of those things that seem obvious in retrospect. But 30 | it was a surprise to me and presumably would be to anyone else who 31 | felt uneasy about (apparently) forgetting so much they'd read.Realizing it does more than make you feel a little better about 32 | forgetting, though. There are specific implications.For example, reading and experience are usually "compiled" at the 33 | time they happen, using the state of your brain at that time. The 34 | same book would get compiled differently at different points in 35 | your life. Which means it is very much worth reading important 36 | books multiple times. I always used to feel some misgivings about 37 | rereading books. I unconsciously lumped reading together with work 38 | like carpentry, where having to do something again is a sign you 39 | did it wrong the first time. Whereas now the phrase "already read" 40 | seems almost ill-formed.Intriguingly, this implication isn't limited to books. Technology 41 | will increasingly make it possible to relive our experiences. When 42 | people do that today it's usually to enjoy them again (e.g. when 43 | looking at pictures of a trip) or to find the origin of some bug in 44 | their compiled code (e.g. when Stephen Fry succeeded in remembering 45 | the childhood trauma that prevented him from singing). But as 46 | technologies for recording and playing back your life improve, it 47 | may become common for people to relive experiences without any goal 48 | in mind, simply to learn from them again as one might when rereading 49 | a book.Eventually we may be able not just to play back experiences but 50 | also to index and even edit them. So although not knowing how you 51 | know things may seem part of being human, it may not be. 52 | Thanks to Sam Altman, Jessica Livingston, and Robert Morris for reading 53 | drafts of this. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/mod.txt: -------------------------------------------------------------------------------- 1 | December 2019There are two distinct ways to be politically moderate: on purpose 2 | and by accident. Intentional moderates are trimmers, deliberately 3 | choosing a position mid-way between the extremes of right and left. 4 | Accidental moderates end up in the middle, on average, because they 5 | make up their own minds about each question, and the far right and 6 | far left are roughly equally wrong.You can distinguish intentional from accidental moderates by the 7 | distribution of their opinions. If the far left opinion on some 8 | matter is 0 and the far right opinion 100, an intentional moderate's 9 | opinion on every question will be near 50. Whereas an accidental 10 | moderate's opinions will be scattered over a broad range, but will, 11 | like those of the intentional moderate, average to about 50.Intentional moderates are similar to those on the far left and the 12 | far right in that their opinions are, in a sense, not their own. 13 | The defining quality of an ideologue, whether on the left or the 14 | right, is to acquire one's opinions in bulk. You don't get to pick 15 | and choose. Your opinions about taxation can be predicted from your 16 | opinions about sex. And although intentional moderates 17 | might seem to be the opposite of ideologues, their beliefs (though 18 | in their case the word "positions" might be more accurate) are also 19 | acquired in bulk. If the median opinion shifts to the right or left, 20 | the intentional moderate must shift with it. Otherwise they stop 21 | being moderate.Accidental moderates, on the other hand, not only choose their own 22 | answers, but choose their own questions. They may not care at all 23 | about questions that the left and right both think are terribly 24 | important. So you can only even measure the politics of an accidental 25 | moderate from the intersection of the questions they care about and 26 | those the left and right care about, and this can 27 | sometimes be vanishingly small.It is not merely a manipulative rhetorical trick to say "if you're 28 | not with us, you're against us," but often simply false.Moderates are sometimes derided as cowards, particularly by 29 | the extreme left. But while it may be accurate to call intentional 30 | moderates cowards, openly being an accidental moderate requires the 31 | most courage of all, because you get attacked from both right and 32 | left, and you don't have the comfort of being an orthodox member 33 | of a large group to sustain you.Nearly all the most impressive people I know are accidental moderates. 34 | If I knew a lot of professional athletes, or people in the entertainment 35 | business, that might be different. Being on the far left or far 36 | right doesn't affect how fast you run or how well you sing. But 37 | someone who works with ideas has to be independent-minded to do it 38 | well.Or more precisely, you have to be independent-minded about the ideas 39 | you work with. You could be mindlessly doctrinaire in your politics 40 | and still be a good mathematician. In the 20th century, a lot of 41 | very smart people were Marxists — just no one who was smart about 42 | the subjects Marxism involves. But if the ideas you use in your 43 | work intersect with the politics of your time, you have two choices: 44 | be an accidental moderate, or be mediocre.Notes[1] It's possible in theory for one side to be entirely right and 45 | the other to be entirely wrong. Indeed, ideologues must always 46 | believe this is the case. But historically it rarely has been.[2] For some reason the far right tend to ignore moderates rather 47 | than despise them as backsliders. I'm not sure why. Perhaps it 48 | means that the far right is less ideological than the far left. Or 49 | perhaps that they are more confident, or more resigned, or simply 50 | more disorganized. I just don't know.[3] Having heretical opinions doesn't mean you have to express 51 | them openly. It may be 52 | easier to have them if you don't. 53 | Thanks to Austen Allred, Trevor Blackwell, Patrick Collison, Jessica Livingston, 54 | Amjad Masad, Ryan Petersen, and Harj Taggar for reading drafts of this. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/nft.txt: -------------------------------------------------------------------------------- 1 | May 2021Noora Health, a nonprofit I've 2 | supported for years, just launched 3 | a new NFT. It has a dramatic name, Save Thousands of Lives, 4 | because that's what the proceeds will do.Noora has been saving lives for 7 years. They run programs in 5 | hospitals in South Asia to teach new mothers how to take care of 6 | their babies once they get home. They're in 165 hospitals now. And 7 | because they know the numbers before and after they start at a new 8 | hospital, they can measure the impact they have. It is massive. 9 | For every 1000 live births, they save 9 babies.This number comes from a study 10 | of 133,733 families at 28 different 11 | hospitals that Noora conducted in collaboration with the Better 12 | Birth team at Ariadne Labs, a joint center for health systems 13 | innovation at Brigham and Women’s Hospital and Harvard T.H. Chan 14 | School of Public Health.Noora is so effective that even if you measure their costs in the 15 | most conservative way, by dividing their entire budget by the number 16 | of lives saved, the cost of saving a life is the lowest I've seen. 17 | $1,235.For this NFT, they're going to issue a public report tracking how 18 | this specific tranche of money is spent, and estimating the number 19 | of lives saved as a result.NFTs are a new territory, and this way of using them is especially 20 | new, but I'm excited about its potential. And I'm excited to see 21 | what happens with this particular auction, because unlike an NFT 22 | representing something that has already happened, 23 | this NFT gets better as the price gets higher.The reserve price was about $2.5 million, because that's what it 24 | takes for the name to be accurate: that's what it costs to save 25 | 2000 lives. But the higher the price of this NFT goes, the more 26 | lives will be saved. What a sentence to be able to write. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/pow.txt: -------------------------------------------------------------------------------- 1 | January 2017People who are powerful but uncharismatic will tend to be disliked. 2 | Their power makes them a target for criticism that they don't have 3 | the charisma to disarm. That was Hillary Clinton's problem. It also 4 | tends to be a problem for any CEO who is more of a builder than a 5 | schmoozer. And yet the builder-type CEO is (like Hillary) probably 6 | the best person for the job.I don't think there is any solution to this problem. It's human 7 | nature. The best we can do is to recognize that it's happening, and 8 | to understand that being a magnet for criticism is sometimes a sign 9 | not that someone is the wrong person for a job, but that they're 10 | the right one. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/rootsoflisp.txt: -------------------------------------------------------------------------------- 1 | May 2001 2 | 3 | (I wrote this article to help myself understand exactly 4 | what McCarthy discovered. You don't need to know this stuff 5 | to program in Lisp, but it should be helpful to 6 | anyone who wants to 7 | understand the essence of Lisp — both in the sense of its 8 | origins and its semantic core. The fact that it has such a core 9 | is one of Lisp's distinguishing features, and the reason why, 10 | unlike other languages, Lisp has dialects.)In 1960, John 11 | McCarthy published a remarkable paper in 12 | which he did for programming something like what Euclid did for 13 | geometry. He showed how, given a handful of simple 14 | operators and a notation for functions, you can 15 | build a whole programming language. 16 | He called this language Lisp, for "List Processing," 17 | because one of his key ideas was to use a simple 18 | data structure called a list for both 19 | code and data.It's worth understanding what McCarthy discovered, not 20 | just as a landmark in the history of computers, but as 21 | a model for what programming is tending to become in 22 | our own time. It seems to me that there have been 23 | two really clean, consistent models of programming so 24 | far: the C model and the Lisp model. 25 | These two seem points of high ground, with swampy lowlands 26 | between them. As computers have grown more powerful, 27 | the new languages being developed have been moving 28 | steadily toward the Lisp model. A popular recipe 29 | for new programming languages in the past 20 years 30 | has been to take the C model of computing and add to 31 | it, piecemeal, parts taken from the Lisp model, 32 | like runtime typing and garbage collection.In this article I'm going to try to explain in the 33 | simplest possible terms what McCarthy discovered. 34 | The point is not just to learn about an interesting 35 | theoretical result someone figured out forty years ago, 36 | but to show where languages are heading. 37 | The unusual thing about Lisp — in fact, the defining 38 | quality of Lisp — is that it can be written in 39 | itself. To understand what McCarthy meant by this, 40 | we're going to retrace his steps, with his mathematical 41 | notation translated into running Common Lisp code. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/rss.txt: -------------------------------------------------------------------------------- 1 | Aaron Swartz created a scraped 2 | feed 3 | of the essays page. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/sun.txt: -------------------------------------------------------------------------------- 1 | September 2017The most valuable insights are both general and surprising. 2 | F = ma for example. But general and surprising is a hard 3 | combination to achieve. That territory tends to be picked 4 | clean, precisely because those insights are so valuable.Ordinarily, the best that people can do is one without the 5 | other: either surprising without being general (e.g. 6 | gossip), or general without being surprising (e.g. 7 | platitudes).Where things get interesting is the moderately valuable 8 | insights. You get those from small additions of whichever 9 | quality was missing. The more common case is a small 10 | addition of generality: a piece of gossip that's more than 11 | just gossip, because it teaches something interesting about 12 | the world. But another less common approach is to focus on 13 | the most general ideas and see if you can find something new 14 | to say about them. Because these start out so general, you 15 | only need a small delta of novelty to produce a useful 16 | insight.A small delta of novelty is all you'll be able to get most 17 | of the time. Which means if you take this route, your ideas 18 | will seem a lot like ones that already exist. Sometimes 19 | you'll find you've merely rediscovered an idea that did 20 | already exist. But don't be discouraged. Remember the huge 21 | multiplier that kicks in when you do manage to think of 22 | something even a little new.Corollary: the more general the ideas you're talking about, 23 | the less you should worry about repeating yourself. If you 24 | write enough, it's inevitable you will. Your brain is much 25 | the same from year to year and so are the stimuli that hit 26 | it. I feel slightly bad when I find I've said something 27 | close to what I've said before, as if I were plagiarizing 28 | myself. But rationally one shouldn't. You won't say 29 | something exactly the same way the second time, and that 30 | variation increases the chance you'll get that tiny but 31 | critical delta of novelty.And of course, ideas beget ideas. (That sounds 32 | familiar.) 33 | An idea with a small amount of novelty could lead to one 34 | with more. But only if you keep going. So it's doubly 35 | important not to let yourself be discouraged by people who 36 | say there's not much new about something you've discovered. 37 | "Not much new" is a real achievement when you're talking 38 | about the most general ideas. It's not true that there's nothing new under the sun. There 39 | are some domains where there's almost nothing new. But 40 | there's a big difference between nothing and almost nothing, 41 | when it's multiplied by the area under the sun. 42 | Thanks to Sam Altman, Patrick Collison, and Jessica 43 | Livingston for reading drafts of this. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/todo.txt: -------------------------------------------------------------------------------- 1 | April 2012A palliative care nurse called Bronnie Ware made a list of the 2 | biggest regrets 3 | of the dying. Her list seems plausible. I could see 4 | myself — can see myself — making at least 4 of these 5 | 5 mistakes.If you had to compress them into a single piece of advice, it might 6 | be: don't be a cog. The 5 regrets paint a portrait of post-industrial 7 | man, who shrinks himself into a shape that fits his circumstances, 8 | then turns dutifully till he stops.The alarming thing is, the mistakes that produce these regrets are 9 | all errors of omission. You forget your dreams, ignore your family, 10 | suppress your feelings, neglect your friends, and forget to be 11 | happy. Errors of omission are a particularly dangerous type of 12 | mistake, because you make them by default.I would like to avoid making these mistakes. But how do you avoid 13 | mistakes you make by default? Ideally you transform your life so 14 | it has other defaults. But it may not be possible to do that 15 | completely. As long as these mistakes happen by default, you probably 16 | have to be reminded not to make them. So I inverted the 5 regrets, 17 | yielding a list of 5 commands 18 | 19 | Don't ignore your dreams; don't work too much; say what you 20 | think; cultivate friendships; be happy. 21 | 22 | which I then put at the top of the file I use as a todo list. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/unions.txt: -------------------------------------------------------------------------------- 1 | May 2007People who worry about the increasing gap between rich and poor 2 | generally look back on the mid twentieth century as a golden age. 3 | In those days we had a large number of high-paying union manufacturing 4 | jobs that boosted the median income. I wouldn't quite call the 5 | high-paying union job a myth, but I think people who dwell on it 6 | are reading too much into it.Oddly enough, it was working with startups that made me realize 7 | where the high-paying union job came from. In a rapidly growing 8 | market, you don't worry too much about efficiency. It's more 9 | important to grow fast. If there's some mundane problem getting 10 | in your way, and there's a simple solution that's somewhat expensive, 11 | just take it and get on with more important things. EBay didn't 12 | win by paying less for servers than their competitors.Difficult though it may be to imagine now, manufacturing was a 13 | growth industry in the mid twentieth century. This was an era when 14 | small firms making everything from cars to candy were getting 15 | consolidated into a new kind of corporation with national reach and 16 | huge economies of scale. You had to grow fast or die. Workers 17 | were for these companies what servers are for an Internet startup. 18 | A reliable supply was more important than low cost.If you looked in the head of a 1950s auto executive, the attitude 19 | must have been: sure, give 'em whatever they ask for, so long as 20 | the new model isn't delayed.In other words, those workers were not paid what their work was 21 | worth. Circumstances being what they were, companies would have 22 | been stupid to insist on paying them so little.If you want a less controversial example of this phenomenon, ask 23 | anyone who worked as a consultant building web sites during the 24 | Internet Bubble. In the late nineties you could get paid huge sums 25 | of money for building the most trivial things. And yet does anyone 26 | who was there have any expectation those days will ever return? I 27 | doubt it. Surely everyone realizes that was just a temporary 28 | aberration.The era of labor unions seems to have been the same kind of aberration, 29 | just spread 30 | over a longer period, and mixed together with a lot of ideology 31 | that prevents people from viewing it with as cold an eye as they 32 | would something like consulting during the Bubble.Basically, unions were just Razorfish.People who think the labor movement was the creation of heroic union 33 | organizers have a problem to explain: why are unions shrinking now? 34 | The best they can do is fall back on the default explanation of 35 | people living in fallen civilizations. Our ancestors were giants. 36 | The workers of the early twentieth century must have had a moral 37 | courage that's lacking today.In fact there's a simpler explanation. The early twentieth century 38 | was just a fast-growing startup overpaying for infrastructure. And 39 | we in the present are not a fallen people, who have abandoned 40 | whatever mysterious high-minded principles produced the high-paying 41 | union job. We simply live in a time when the fast-growing companies 42 | overspend on different things. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/vw.txt: -------------------------------------------------------------------------------- 1 | January 2012A few hours before the Yahoo acquisition was announced in June 1998 2 | I took a snapshot of Viaweb's 3 | site. I thought it might be interesting to look at one day.The first thing one notices is is how tiny the pages are. Screens 4 | were a lot smaller in 1998. If I remember correctly, our frontpage 5 | used to just fit in the size window people typically used then.Browsers then (IE 6 was still 3 years in the future) had few fonts 6 | and they weren't antialiased. If you wanted to make pages that 7 | looked good, you had to render display text as images.You may notice a certain similarity between the Viaweb and Y Combinator logos. We did that 8 | as an inside joke when we started YC. Considering how basic a red 9 | circle is, it seemed surprising to me when we started Viaweb how 10 | few other companies used one as their logo. A bit later I realized 11 | why.On the Company 12 | page you'll notice a mysterious individual called John McArtyem. 13 | Robert Morris (aka Rtm) was so publicity averse after the 14 | Worm that he 15 | didn't want his name on the site. I managed to get him to agree 16 | to a compromise: we could use his bio but not his name. He has 17 | since relaxed a bit 18 | on that point.Trevor graduated at about the same time the acquisition closed, so in the 19 | course of 4 days he went from impecunious grad student to millionaire 20 | PhD. The culmination of my career as a writer of press releases 21 | was one celebrating 22 | his graduation, illustrated with a drawing I did of him during 23 | a meeting.(Trevor also appears as Trevino 24 | Bagwell in our directory of web designers merchants could hire 25 | to build stores for them. We inserted him as a ringer in case some 26 | competitor tried to spam our web designers. We assumed his logo 27 | would deter any actual customers, but it did not.)Back in the 90s, to get users you had to get mentioned in magazines 28 | and newspapers. There were not the same ways to get found online 29 | that there are today. So we used to pay a PR 30 | firm $16,000 a month to get us mentioned in the press. Fortunately 31 | reporters liked 32 | us.In our advice about 33 | getting traffic from search engines (I don't think the term SEO 34 | had been coined yet), we say there are only 7 that matter: Yahoo, 35 | AltaVista, Excite, WebCrawler, InfoSeek, Lycos, and HotBot. Notice 36 | anything missing? Google was incorporated that September.We supported online transactions via a company called 37 | Cybercash, 38 | since if we lacked that feature we'd have gotten beaten up in product 39 | comparisons. But Cybercash was so bad and most stores' order volumes 40 | were so low that it was better if merchants processed orders like phone orders. We had a page in our site trying to talk merchants 41 | out of doing real time authorizations.The whole site was organized like a funnel, directing people to the 42 | test drive. 43 | It was a novel thing to be able to try out software online. We put 44 | cgi-bin in our dynamic urls to fool competitors about how our 45 | software worked.We had some well 46 | known users. Needless to say, Frederick's of Hollywood got the 47 | most traffic. We charged a flat fee of $300/month for big stores, 48 | so it was a little alarming to have users who got lots of traffic. 49 | I once calculated how much Frederick's was costing us in bandwidth, 50 | and it was about $300/month.Since we hosted all the stores, which together were getting just 51 | over 10 million page views per month in June 1998, we consumed what 52 | at the time seemed a lot of bandwidth. We had 2 T1s (3 Mb/sec) 53 | coming into our offices. In those days there was no AWS. Even 54 | colocating servers seemed too risky, considering how often things 55 | went wrong with them. So we had our servers in our offices. Or 56 | more precisely, in Trevor's office. In return for the unique 57 | privilege of sharing his office with no other humans, he had to 58 | share it with 6 shrieking tower servers. His office was nicknamed 59 | the Hot Tub on account of the heat they generated. Most days his 60 | stack of window air conditioners could keep up.For describing pages, we had a template language called RTML, which 61 | supposedly stood for something, but which in fact I named after 62 | Rtm. RTML was Common Lisp augmented by some macros and libraries, 63 | and concealed under a structure editor that made it look like it 64 | had syntax.Since we did continuous releases, our software didn't actually have 65 | versions. But in those days the trade press expected versions, so 66 | we made them up. If we wanted to get lots of attention, we made 67 | the version number an 68 | integer. That "version 4.0" icon was generated by our own 69 | button generator, incidentally. The whole Viaweb site was made 70 | with our software, even though it wasn't an online store, because 71 | we wanted to experience what our users did.At the end of 1997, we released a general purpose shopping search 72 | engine called Shopfind. It 73 | was pretty advanced for the time. It had a programmable crawler 74 | that could crawl most of the different stores online and pick out 75 | the products. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/want.txt: -------------------------------------------------------------------------------- 1 | November 2022Since I was about 9 I've been puzzled by the apparent contradiction 2 | between being made of matter that behaves in a predictable way, and 3 | the feeling that I could choose to do whatever I wanted. At the 4 | time I had a self-interested motive for exploring the question. At 5 | that age (like most succeeding ages) I was always in trouble with 6 | the authorities, and it seemed to me that there might possibly be 7 | some way to get out of trouble by arguing that I wasn't responsible 8 | for my actions. I gradually lost hope of that, but the puzzle 9 | remained: How do you reconcile being a machine made of matter with 10 | the feeling that you're free to choose what you do? 11 | [1]The best way to explain the answer may be to start with a slightly 12 | wrong version, and then fix it. The wrong version is: You can do 13 | what you want, but you can't want what you want. Yes, you can control 14 | what you do, but you'll do what you want, and you can't control 15 | that.The reason this is mistaken is that people do sometimes change what 16 | they want. People who don't want to want something — drug addicts, 17 | for example — can sometimes make themselves stop wanting it. And 18 | people who want to want something — who want to like classical 19 | music, or broccoli — sometimes succeed.So we modify our initial statement: You can do what you want, but 20 | you can't want to want what you want.That's still not quite true. It's possible to change what you want 21 | to want. I can imagine someone saying "I decided to stop wanting 22 | to like classical music." But we're getting closer to the truth. 23 | It's rare for people to change what they want to want, and the more 24 | "want to"s we add, the rarer it gets.We can get arbitrarily close to a true statement by adding more "want 25 | to"s in much the same way we can get arbitrarily close to 1 by adding 26 | more 9s to a string of 9s following a decimal point. In practice 27 | three or four "want to"s must surely be enough. It's hard even to 28 | envision what it would mean to change what you want to want to want 29 | to want, let alone actually do it.So one way to express the correct answer is to use a regular 30 | expression. You can do what you want, but there's some statement 31 | of the form "you can't (want to)* want what you want" that's true. 32 | Ultimately you get back to a want that you don't control. 33 | [2] 34 | Notes[1] 35 | I didn't know when I was 9 that matter might behave randomly, 36 | but I don't think it affects the problem much. Randomness destroys 37 | the ghost in the machine as effectively as determinism.[2] 38 | If you don't like using an expression, you can make the same 39 | point using higher-order desires: There is some n such that you 40 | don't control your nth-order desires. 41 | Thanks to Trevor Blackwell, 42 | Jessica Livingston, Robert Morris, and 43 | Michael Nielsen for reading drafts of this. -------------------------------------------------------------------------------- /text_extend/PaulGrahamEssays/weird.txt: -------------------------------------------------------------------------------- 1 | August 2021When people say that in their experience all programming languages 2 | are basically equivalent, they're making a statement not about 3 | languages but about the kind of programming they've done.99.5% of programming consists of gluing together calls to library 4 | functions. All popular languages are equally good at this. So one 5 | can easily spend one's whole career operating in the intersection 6 | of popular programming languages.But the other .5% of programming is disproportionately interesting. 7 | If you want to learn what it consists of, the weirdness of weird 8 | languages is a good clue to follow.Weird languages aren't weird by accident. Not the good ones, at 9 | least. The weirdness of the good ones usually implies the existence 10 | of some form of programming that's not just the usual gluing together 11 | of library calls.A concrete example: Lisp macros. Lisp macros seem weird even to 12 | many Lisp programmers. They're not only not in the intersection of 13 | popular languages, but by their nature would be hard to implement 14 | properly in a language without turning it into a dialect of 15 | Lisp. And macros are definitely evidence of techniques that go 16 | beyond glue programming. For example, solving problems by first 17 | writing a language for problems of that type, and then writing 18 | your specific application in it. Nor is this all you can do with 19 | macros; it's just one region in a space of program-manipulating 20 | techniques that even now is far from fully explored.So if you want to expand your concept of what programming can be, 21 | one way to do it is by learning weird languages. Pick a language 22 | that most programmers consider weird but whose median user is smart, 23 | and then focus on the differences between this language and the 24 | intersection of popular languages. What can you say in this language 25 | that would be impossibly inconvenient to say in others? In the 26 | process of learning how to say things you couldn't previously say, 27 | you'll probably be learning how to think things you couldn't 28 | previously think. 29 | Thanks to Trevor Blackwell, Patrick Collison, Daniel Gackle, Amjad 30 | Masad, and Robert Morris for reading drafts of this. 31 | -------------------------------------------------------------------------------- /text_extend/eval.sh: -------------------------------------------------------------------------------- 1 | for model in Qwen2-7B-Instrcuct-224K 2 | do 3 | for num_distractor in 0 3 5 4 | do 5 | accelerate launch --num_processes 8 --config_file easy_context/accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 text_extend/eval_text_niah.py \ 6 | --model text_extend/training_output/$model \ 7 | --max_context_length 512000 \ 8 | --min_context_length 32000 \ 9 | --context_interval 32000 \ 10 | --depth_interval 0.1 \ 11 | --num_samples 5 \ 12 | --rnd_number_digits 7 \ 13 | --haystack_dir text_extend/PaulGrahamEssays \ 14 | --num_distractor $num_distractor 15 | done 16 | done 17 | 18 | -------------------------------------------------------------------------------- /text_extend/extend_qwen2.sh: -------------------------------------------------------------------------------- 1 | accelerate launch \ 2 | --config_file easy_context/accelerate_configs/single_node.yaml \ 3 | text_extend/text_extend_train.py \ 4 | --batch-size 1 \ 5 | --gradient-accumulate-every 4 \ 6 | --seed 2024 \ 7 | --output-dir text_extend/training_output/Qwen2-7B-Instrcuct-224K \ 8 | --wandb LMExtend \ 9 | --max-train-steps 1000 \ 10 | --learning-rate 1e-5 \ 11 | --dataset PY007/slimpajama_Qwen2_tokenized_upsample_4096_chunk_256K \ 12 | --model Qwen/Qwen2-7B-Instruct \ 13 | --seq-length 224000 \ 14 | --rope-theta 1000000000 \ 15 | --parallel_mode zigzag_ring_attn \ 16 | --checkpointing-steps 200 17 | 18 | rm text_extend/training_output/Qwen2-7B-Instruct-extend/model.safetensors 19 | -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_0/Qwen2-7B-Instruct-extend-step_1000/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.9977272727272727 2 | -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_0/Qwen2-7B-Instruct-extend-step_1000/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_0/Qwen2-7B-Instruct-extend-step_1000/heatmap.png -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_3/Meta-Llama-3-8B-Instruct-extend-step1000/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.905681818181818 2 | -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_3/Meta-Llama-3-8B-Instruct-extend-step1000/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_3/Meta-Llama-3-8B-Instruct-extend-step1000/heatmap.png -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_3/Qwen2-7B-Instruct-extend-step_1000/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.9352272727272729 2 | -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_3/Qwen2-7B-Instruct-extend-step_1000/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_3/Qwen2-7B-Instruct-extend-step_1000/heatmap.png -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_5/Qwen2-7B-Instruct-extend-step_1000/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.9204545454545454 2 | -------------------------------------------------------------------------------- /text_extend/niah_output/distractor_5/Qwen2-7B-Instruct-extend-step_1000/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_5/Qwen2-7B-Instruct-extend-step_1000/heatmap.png -------------------------------------------------------------------------------- /text_extend/niah_output/git_placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/git_placeholder -------------------------------------------------------------------------------- /text_extend/plot_ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | 6 | def main(args): 7 | data = pd.read_csv(args.csv) 8 | fig, ax = plt.subplots(figsize=(10, 5)) 9 | 10 | x_data = [float(x) for x in data.columns[1:]] 11 | for row in data.values: 12 | label = row[0].replace("NousResearch/", "") 13 | ax.plot(x_data, [float(x) for x in row[1:]], label=label) 14 | 15 | ax.set_xlabel("Context Window") 16 | ax.set_ylabel("Perplexity (lower is better)") 17 | 18 | ax.set_xlim(args.xmin, args.xmax) 19 | ax.set_ylim(args.ymin, args.ymax) 20 | 21 | ax.legend(loc="upper right") 22 | 23 | fig.savefig(args.csv + ".png") 24 | fig.savefig(args.csv + ".pdf", transparent=True) 25 | 26 | 27 | if __name__ == "__main__": 28 | args = argparse.ArgumentParser() 29 | args.add_argument("csv", type=str) 30 | args.add_argument("--xmin", type=int, default=0) 31 | args.add_argument("--xmax", type=int, default=32768) 32 | args.add_argument("--ymin", type=float, default=3) 33 | args.add_argument("--ymax", type=float, default=10) 34 | main(args.parse_args()) 35 | -------------------------------------------------------------------------------- /vision_niah/data/haystack_embeddings/git_placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/data/haystack_embeddings/git_placeholder -------------------------------------------------------------------------------- /vision_niah/data/haystack_videos/git_placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/data/haystack_videos/git_placeholder -------------------------------------------------------------------------------- /vision_niah/data/needle_embeddings/git_placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/data/needle_embeddings/git_placeholder -------------------------------------------------------------------------------- /vision_niah/eval.sh: -------------------------------------------------------------------------------- 1 | for MODEL_NAME in LongVA-7B 2 | do 3 | mkdir vision_niah/data/haystack_embeddings/$MODEL_NAME 4 | mkdir vision_niah/data/needle_embeddings/$MODEL_NAME 5 | python vision_niah/produce_haystack_embedding.py --model vision_niah/model_weights/$MODEL_NAME --output_dir vision_niah/data/haystack_embeddings/$MODEL_NAME --sampled_frames_num 3000 --pooling_size 2 6 | python vision_niah/produce_needle_embedding.py --model vision_niah/model_weights/$MODEL_NAME --output_dir vision_niah/data/needle_embeddings/$MODEL_NAME --pooling_size 2 --needle_dataset LongVa/v_niah_needles 7 | 8 | accelerate launch --num_processes 8 --config_file easy_context/accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 vision_niah/eval_vision_niah.py \ 9 | --model vision_niah/model_weights/$MODEL_NAME \ 10 | --needle_embedding_dir vision_niah/data/needle_embeddings/$MODEL_NAME \ 11 | --haystack_dir vision_niah/data/haystack_embeddings/$MODEL_NAME \ 12 | --needle_dataset lmms-lab/v_niah_needles \ 13 | --prompt_template qwen2 \ 14 | --max_frame_num 3000 \ 15 | --min_frame_num 200\ 16 | --frame_interval 200 \ 17 | --depth_interval 0.2 18 | done 19 | 20 | 21 | -------------------------------------------------------------------------------- /vision_niah/eval_vision_niah_sampling.py: -------------------------------------------------------------------------------- 1 | from longva.model.builder import load_pretrained_model 2 | from longva.mm_utils import tokenizer_image_token, process_images 3 | from longva.constants import IMAGE_TOKEN_INDEX 4 | from PIL import Image 5 | from datasets import load_dataset 6 | from decord import VideoReader, cpu 7 | import torch 8 | import numpy as np 9 | # fix seed 10 | torch.manual_seed(0) 11 | 12 | model_path = "lmms-lab/LongVA-7B" 13 | video_path = "vision_niah/data/haystack_videos/movie.mp4" 14 | haystack_frames = 256 # you can change this to several thousands so long you GPU memory can handle it :) 15 | gen_kwargs = {"do_sample": False, "use_cache": True, "max_new_tokens": 1024} 16 | tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, "llava_qwen", device_map="auto") 17 | 18 | preprompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n" 19 | postprompt = "<|im_end|>\n<|im_start|>assistant\n" 20 | #video input 21 | 22 | 23 | vniah_needle_dataset = load_dataset("lmms-lab/v_niah_needles")["test"] 24 | question = vniah_needle_dataset[1]["question"] 25 | image = vniah_needle_dataset[1]["image"].convert("RGB") 26 | answer = vniah_needle_dataset[1]["answer"] 27 | images_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.float16) 28 | 29 | 30 | 31 | prompt = preprompt + "" + question + postprompt 32 | print(prompt) 33 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device) 34 | vr = VideoReader(video_path, ctx=cpu(0)) 35 | total_frame_num = len(vr) 36 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, haystack_frames, dtype=int) 37 | frame_idx = uniform_sampled_frames.tolist() 38 | frames = vr.get_batch(frame_idx).asnumpy() 39 | video_tensor = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.float16) 40 | # insert image_tensor in the middle 41 | video_tensor = torch.cat([video_tensor[:len(video_tensor)//2], images_tensor, video_tensor[len(video_tensor)//2:]], dim=0) 42 | # insert at the very end 43 | # video_tensor = torch.cat([video_tensor, images_tensor], dim=0) 44 | with torch.inference_mode(): 45 | output_ids = model.generate(input_ids, images=[video_tensor], modalities=["video"], **gen_kwargs) 46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 47 | print("Output:", outputs) 48 | print("Answer:", answer) -------------------------------------------------------------------------------- /vision_niah/model_weights/git_placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/model_weights/git_placeholder -------------------------------------------------------------------------------- /vision_niah/needle_datasets/dataset.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "path": "ucsd.jpeg", 4 | "prompt": "\nFind the frame of the 'While You Were Out' note. What is the name of the university on that note?\nA. University of california, los angeles\nB. University of california, san diego\nC. University of california, berkeley\nD. University of california, santa barbara\nAnswer with the option's letter from the given choices directly.", 5 | "answer": "B" 6 | }, 7 | { 8 | "path": "sora_balloon.png", 9 | "prompt": "\nFind the frame of a couple in a wedding. In side the frame, there is a balloon on the bridegroom's head. What is the color of that ballon?\nAnswer the question using a single word or phrase.", 10 | "answer": "Yellow" 11 | }, 12 | { 13 | "path": "selenium_green.jpg", 14 | "prompt": "\nFind the frame with the image of Selenium tablets. How many mg does each tablet contain?\nAnswer the question using a single word or phrase.", 15 | "answer": "200" 16 | }, 17 | { 18 | "path": "panda_scientist.png", 19 | "prompt": "\nFind the frame of a scientist. The scientist is a...\nA. Bird\nB. Elephant\nC. Panda\nD. Dog\nAnswer with the option's letter from the given choices directly.", 20 | "answer": "C" 21 | }, 22 | { 23 | "path": "teddy_bear_times_square.png", 24 | "prompt": "\nFind the frame of a teddy bear. Where is this teddy bear?\nA. Times Square\nB. Eiffel Tower\nC. Taj Mahal\nD. Sydney Opera House\nAnswer with the option's letter from the given choices directly.", 25 | "answer": "A" 26 | } 27 | ] -------------------------------------------------------------------------------- /vision_niah/needle_datasets/generate_hf_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import os 18 | from PIL import Image 19 | import datasets 20 | 21 | 22 | # Find for instance the citation on arxiv or on the dataset repo/website 23 | _CITATION = """\ 24 | @inproceedings{masry-etal-2022-chartqa, 25 | title = "{C}hart{QA}: A Benchmark for Question Answering about Charts with Visual and Logical Reasoning", 26 | author = "Masry, Ahmed and 27 | Long, Do and 28 | Tan, Jia Qing and 29 | Joty, Shafiq and 30 | Hoque, Enamul", 31 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", 32 | month = may, 33 | year = "2022", 34 | address = "Dublin, Ireland", 35 | publisher = "Association for Computational Linguistics", 36 | url = "https://aclanthology.org/2022.findings-acl.177", 37 | doi = "10.18653/v1/2022.findings-acl.177", 38 | pages = "2263--2279", 39 | } 40 | """ 41 | _DESCRIPTION = "A largescale benchmark covering 9.6K human-written questions as well as 23.1K questions generated from human-written chart summaries." 42 | 43 | 44 | def get_builder_config(VERSION): 45 | builder_config = [ 46 | datasets.BuilderConfig( 47 | name=f"V-NIAH", 48 | version=VERSION, 49 | description=f"V-NIAH", 50 | ) 51 | ] 52 | return builder_config 53 | 54 | 55 | dataset_features = { 56 | "image": datasets.Image(), 57 | "question": datasets.Value("string"), 58 | "answer": datasets.Value("string"), 59 | } 60 | 61 | 62 | class ChartQA(datasets.GeneratorBasedBuilder): 63 | VERSION = datasets.Version("1.0.0") 64 | 65 | BUILDER_CONFIGS = get_builder_config(VERSION) 66 | 67 | def _info(self): 68 | features = datasets.Features(dataset_features) 69 | return datasets.DatasetInfo( 70 | # This is the description that will appear on the datasets page. 71 | description=_DESCRIPTION, 72 | # This defines the different columns of the dataset and their types 73 | features=features, # Here we define them above because they are different between the two configurations 74 | # If there's a common (input, target) tuple from the features, uncomment supervised_keys line below and 75 | # specify them. They'll be used if as_supervised=True in builder.as_dataset. 76 | # supervised_keys=("sentence", "label"), 77 | # Homepage of the dataset for documentation 78 | # Citation for the dataset 79 | citation=_CITATION, 80 | ) 81 | 82 | def _split_generators(self, dl_manager): 83 | # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name 84 | 85 | # dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLS 86 | # It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files. 87 | # By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive 88 | image_path = "needle_datasets/images/" 89 | augmented_annotation_path = "needle_datasets/dataset.json" 90 | return [ 91 | datasets.SplitGenerator( 92 | name=datasets.Split.TEST, 93 | gen_kwargs={ 94 | "annotation": augmented_annotation_path, 95 | "image_path": image_path, 96 | }, 97 | ), 98 | ] 99 | 100 | # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` 101 | def _generate_examples(self, annotation, image_path): 102 | # The `key` is for legacy reasons (tfds) and is not important in itself, but must be unique for each example. 103 | with open(annotation, encoding="utf-8") as f: 104 | annotation = json.load(f) 105 | for index, data in enumerate(annotation): 106 | question = data["prompt"] 107 | answer = data["answer"] 108 | print(data["path"]) 109 | now_data = {} 110 | now_data["image"] = Image.open(image_path + data["path"]) 111 | now_data["question"] = question 112 | now_data["answer"] = answer 113 | yield index, now_data 114 | 115 | 116 | if __name__ == "__main__": 117 | from datasets import load_dataset 118 | 119 | data = load_dataset( 120 | "needle_datasets/generate_hf_dataset.py", 121 | ) 122 | data.push_to_hub("LongVa/longva_needles", private=False) -------------------------------------------------------------------------------- /vision_niah/needle_datasets/git_placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/git_placeholder -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/astronaut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/astronaut.png -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/construction_site.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/construction_site.png -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/dolphin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/dolphin.png -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/panda_scientist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/panda_scientist.png -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/selenium_green.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/selenium_green.jpg -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/sora_balloon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/sora_balloon.png -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/teddy_bear_times_square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/teddy_bear_times_square.png -------------------------------------------------------------------------------- /vision_niah/needle_datasets/images/ucsd.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/ucsd.jpeg -------------------------------------------------------------------------------- /vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_1000000/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.14666666666666667 2 | -------------------------------------------------------------------------------- /vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_1000000/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_1000000/heatmap.png -------------------------------------------------------------------------------- /vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_100000000/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.28444444444444444 2 | -------------------------------------------------------------------------------- /vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_100000000/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_100000000/heatmap.png -------------------------------------------------------------------------------- /vision_niah/niah_output/LongVA-7B/avg_accuracy.txt: -------------------------------------------------------------------------------- 1 | Average Accuracy: 0.908888888888889 2 | -------------------------------------------------------------------------------- /vision_niah/niah_output/LongVA-7B/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/niah_output/LongVA-7B/heatmap.png -------------------------------------------------------------------------------- /vision_niah/produce_haystack_embedding.py: -------------------------------------------------------------------------------- 1 | 2 | from longva.model.builder import load_pretrained_model 3 | from longva.mm_utils import tokenizer_image_token, get_model_name_from_path 4 | from decord import VideoReader, cpu 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch 9 | from longva.mm_utils import process_images 10 | import math 11 | from PIL import Image 12 | def load_video_batches(video_path, batch_size): 13 | vr = VideoReader(video_path, ctx=cpu(0)) 14 | total_frame_num = len(vr) 15 | fps = round(vr.get_avg_fps()) 16 | frame_idx = [i for i in range(0, len(vr), fps)] 17 | for start_idx in range(0, len(frame_idx), batch_size): 18 | end_idx = min(start_idx + batch_size, total_frame_num) 19 | frame_indices = frame_idx[start_idx:end_idx] 20 | batch_frames = vr.get_batch(frame_indices).asnumpy() 21 | yield batch_frames 22 | 23 | def main(args): 24 | video_path = args.video_path 25 | model_path = args.model 26 | model_name = "llava_qwen" 27 | 28 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False,device_map="cuda:0") 29 | del model.model.layers 30 | model.config.image_aspect_ratio = "pad" 31 | model.config.mm_patch_merge_type="flat" 32 | # Process video in batches 33 | batch_size = 32 34 | total_batches = (args.sampled_frames_num + batch_size - 1) // batch_size 35 | image_feature_list = [] 36 | if args.add_newline_token: 37 | newline_token_embeddong = model.model.image_newline 38 | with torch.inference_mode(): 39 | for i, video_batch in tqdm(enumerate(load_video_batches(video_path, batch_size)), total=total_batches, desc="Processing Video Batches"): 40 | images = [Image.fromarray(frame).convert("RGB") for frame in video_batch] 41 | processed_images = process_images(images, image_processor,model.config).half() 42 | image_features = model.encode_images(processed_images) 43 | print(image_features.shape) 44 | if args.pooling_size != 0: 45 | B, _, F = image_features.shape 46 | 47 | image_features_spatial = image_features.view(B, int(math.sqrt(_)), int(math.sqrt(_)), F).permute(0, 3, 1, 2) # B, F, 24, 24 48 | image_features_spatial_pool = torch.nn.functional.avg_pool2d(image_features_spatial, args.pooling_size, args.pooling_size) # B, F, 12, 12 49 | image_features = image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() # B, 144, F 50 | if args.add_newline_token: 51 | image_features = torch.cat([image_features, newline_token_embeddong.unsqueeze(0).expand(image_features.shape[0], 1, -1)], dim=1) 52 | image_feature_list.append(image_features.to(torch.bfloat16).to("cpu")) 53 | if i > total_batches: 54 | break 55 | image_feature_list = torch.cat(image_feature_list, dim=0) 56 | torch.save(image_feature_list, f"{args.output_dir}/video_embeddings.pt") 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-Vicuna") 61 | parser.add_argument("--video_path", type=str, default="vision_niah/data/haystack_videos/movie.mp4") 62 | parser.add_argument("--sampled_frames_num", type=int, default=7200) 63 | parser.add_argument("--output_dir", type=str, default="video_needle_haystack/data/haystack_vicuna_embeddings") 64 | parser.add_argument("--pooling_size", type=int, default=0) 65 | parser.add_argument("--add_newline_token", action="store_true") 66 | args = parser.parse_args() 67 | main(args) 68 | -------------------------------------------------------------------------------- /vision_niah/produce_needle_embedding.py: -------------------------------------------------------------------------------- 1 | 2 | from longva.model.builder import load_pretrained_model 3 | from longva.mm_utils import tokenizer_image_token, get_model_name_from_path 4 | import json 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch 9 | from pathlib import Path 10 | from PIL import Image 11 | from datasets import load_dataset 12 | from longva.mm_utils import process_images 13 | import math 14 | def main(args): 15 | model_name = "llava_qwen" 16 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model, None, model_name, load_8bit=False,device_map="cuda:0") 17 | model.config.image_aspect_ratio = "pad" 18 | model.config.mm_patch_merge_type="flat" 19 | dataset = load_dataset(args.needle_dataset)["test"] 20 | for index, instance in enumerate(dataset): 21 | image = instance["image"].convert("RGB") 22 | image = process_images([image], image_processor, model.config).half() 23 | image_features = model.encode_images(image) 24 | if args.pooling_size != 0: 25 | B, _, F = image_features.shape 26 | image_features_spatial = image_features.view(B, int(math.sqrt(_)), int(math.sqrt(_)), F).permute(0, 3, 1, 2) # B, F, 24, 24 27 | image_features_spatial_pool = torch.nn.functional.avg_pool2d(image_features_spatial, args.pooling_size, args.pooling_size) # B, F, 12, 12 28 | image_features = image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() # B, 144, F 29 | image_features = image_features.squeeze(0) 30 | torch.save(image_features, f"{args.output_dir}/{index}.pt") 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-Vicuna") 36 | parser.add_argument("--needle_dataset", type=str, default="lmms-lab/v_niah_needles") 37 | parser.add_argument("--output_dir", type=str, default="video_needle_haystack/data/needle_vicuna_embeddings") 38 | parser.add_argument("--pooling_size", type=int, default=0) 39 | args = parser.parse_args() 40 | main(args) 41 | --------------------------------------------------------------------------------