├── .gitignore
├── LICENSE
├── README.md
├── easy_context
├── __init__.py
├── accelerate_configs
│ ├── deepspeed_inference.yaml
│ ├── single_node.yaml
│ ├── two_node.yaml
│ ├── zero3_offload.json
│ └── zero3_offload_inference.json
├── dist_flash_attn
│ ├── README.md
│ ├── async_communication.py
│ ├── lightseq_async_attn.py
│ ├── lightseq_async_attn_varlen.py
│ ├── monkey_patch.py
│ └── prepare_input.py
├── low_mem_cross_ent.py
├── low_mem_cross_ent_tests
│ ├── test_correctness.py
│ └── test_mem_and_speed.py
├── modeling_qwen2.py
├── ulysses_attn
│ ├── monkey_patch.py
│ └── prepare_inputs.py
├── unsloth_offloaded_gradient_checkpoint
│ └── monkey_patch.py
└── zigzag_ring_attn
│ ├── monkey_patch.py
│ └── prepare_inputs.py
├── local_demo
├── assets
│ ├── assistant_logo.png
│ ├── dc_demo.mp4
│ ├── favicon.ico
│ ├── jobs.mp4
│ ├── llava_logo.png
│ ├── llava_next.jpg
│ ├── lmms-eval.png
│ ├── otter_books.jpg
│ ├── user_example_01.jpg
│ ├── user_example_02.jpg
│ ├── user_example_03.jpg
│ ├── user_example_04.jpg
│ ├── user_example_05.jpg
│ ├── user_example_06.jpg
│ ├── user_example_07.jpg
│ ├── user_example_08.jpg
│ ├── user_example_09.jpg
│ ├── user_example_10.png
│ ├── user_example_11.png
│ ├── user_logo.png
│ ├── water.mp4
│ └── white_cat_smile.jpg
├── cache
│ ├── 0266b3f7b9311a93f0f38bc8cf0698b019c829b1
│ │ └── user_example_09.jpg
│ ├── 0333f1f38db26d76752fee6cd2d938f06617dd2f
│ │ └── dc_demo.mp4
│ ├── 0c299f0f979725494d926c3e76cfeb75c19ef564
│ │ └── assistant_logo.png
│ ├── 3c94cf4ccdbb877666285b42dd9263073d1fa19c
│ │ └── user_example_07.jpg
│ ├── 447d9ea0a12a7be4cf7f279dc4e667b1b938a0c2
│ │ └── otter_books.jpg
│ ├── 7abebd5007998873c3d36612f984d241c3b574d8
│ │ └── user_logo.png
│ ├── 7dcdac5ff3df641ae7eb47d8b7aaeab740e02de9
│ │ └── user_example_06.jpg
│ ├── b7b8f04e74cda37d9a574c1a2098d7e4ee97b212
│ │ └── user_example_05.jpg
│ └── dbb5054b8579d1943ff74045f254e0e759efec99
│ │ └── white_cat_smile.jpg
├── constants.py
├── debug_json_api.py
├── longva_backend.py
├── multimodal_chat.py
├── theme_dropdown.py
└── themes
│ └── theme_schema@0.1.1.json
├── longva
├── data_processing
│ └── utils.py
├── longva
│ ├── __init__.py
│ ├── constants.py
│ ├── conversation.py
│ ├── mm_utils.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── apply_delta.py
│ │ ├── builder.py
│ │ ├── consolidate.py
│ │ ├── language_model
│ │ │ ├── llava_llama.py
│ │ │ ├── llava_mistral.py
│ │ │ ├── llava_mpt.py
│ │ │ ├── llava_qwen.py
│ │ │ └── modeling_llama.py
│ │ ├── llava_arch.py
│ │ ├── make_delta.py
│ │ ├── multimodal_encoder
│ │ │ ├── builder.py
│ │ │ └── clip_encoder.py
│ │ ├── multimodal_projector
│ │ │ ├── builder.py
│ │ │ └── pooler_projector.py
│ │ ├── multimodal_resampler
│ │ │ ├── builder.py
│ │ │ ├── masked_drop.py
│ │ │ ├── perceiver.py
│ │ │ ├── qformer.py
│ │ │ └── spatial_pool.py
│ │ └── utils.py
│ ├── train
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ ├── llava_trainer.py
│ │ ├── train.py
│ │ ├── train_dpo.py
│ │ └── train_mem.py
│ └── utils.py
├── pyproject.toml
├── scripts
│ ├── dpo.sh
│ ├── finetune.sh
│ ├── pretrain.sh
│ └── zero3.json
└── trl
│ ├── __init__.py
│ ├── core.py
│ ├── environment
│ ├── __init__.py
│ └── base_environment.py
│ ├── extras
│ ├── __init__.py
│ ├── best_of_n_sampler.py
│ └── dataset_formatting.py
│ ├── import_utils.py
│ ├── models
│ ├── __init__.py
│ ├── modeling_base.py
│ ├── modeling_sd_base.py
│ ├── modeling_value_head.py
│ └── utils.py
│ └── trainer
│ ├── __init__.py
│ ├── base.py
│ ├── ddpo_config.py
│ ├── ddpo_trainer.py
│ ├── dpo_trainer.py
│ ├── iterative_sft_trainer.py
│ ├── model_config.py
│ ├── ppo_config.py
│ ├── ppo_trainer.py
│ ├── reward_config.py
│ ├── reward_trainer.py
│ ├── sft_trainer.py
│ └── utils.py
├── requirements.txt
├── text_extend
├── PaulGrahamEssays
│ ├── addiction.txt
│ ├── aord.txt
│ ├── apple.txt
│ ├── avg.txt
│ ├── before.txt
│ ├── bias.txt
│ ├── boss.txt
│ ├── copy.txt
│ ├── corpdev.txt
│ ├── desres.txt
│ ├── diff.txt
│ ├── ecw.txt
│ ├── founders.txt
│ ├── foundervisa.txt
│ ├── gap.txt
│ ├── gba.txt
│ ├── gh.txt
│ ├── goodtaste.txt
│ ├── hubs.txt
│ ├── iflisp.txt
│ ├── island.txt
│ ├── know.txt
│ ├── langdes.txt
│ ├── laundry.txt
│ ├── love.txt
│ ├── mod.txt
│ ├── newideas.txt
│ ├── nft.txt
│ ├── philosophy.txt
│ ├── popular.txt
│ ├── pow.txt
│ ├── rootsoflisp.txt
│ ├── rss.txt
│ ├── siliconvalley.txt
│ ├── startuplessons.txt
│ ├── submarine.txt
│ ├── sun.txt
│ ├── superangels.txt
│ ├── todo.txt
│ ├── unions.txt
│ ├── useful.txt
│ ├── vb.txt
│ ├── vcsqueeze.txt
│ ├── vw.txt
│ ├── want.txt
│ ├── web20.txt
│ ├── weird.txt
│ ├── wisdom.txt
│ └── worked.txt
├── eval.sh
├── eval_text_niah.py
├── eval_text_ppl.py
├── extend_qwen2.sh
├── niah_output
│ ├── distractor_0
│ │ └── Qwen2-7B-Instruct-extend-step_1000
│ │ │ ├── all_accuracies.json
│ │ │ ├── avg_accuracy.txt
│ │ │ └── heatmap.png
│ ├── distractor_3
│ │ ├── Meta-Llama-3-8B-Instruct-extend-step1000
│ │ │ ├── all_accuracies.json
│ │ │ ├── avg_accuracy.txt
│ │ │ └── heatmap.png
│ │ └── Qwen2-7B-Instruct-extend-step_1000
│ │ │ ├── all_accuracies.json
│ │ │ ├── avg_accuracy.txt
│ │ │ └── heatmap.png
│ ├── distractor_5
│ │ └── Qwen2-7B-Instruct-extend-step_1000
│ │ │ ├── all_accuracies.json
│ │ │ ├── avg_accuracy.txt
│ │ │ └── heatmap.png
│ └── git_placeholder
├── plot_ppl.py
└── text_extend_train.py
└── vision_niah
├── data
├── haystack_embeddings
│ └── git_placeholder
├── haystack_videos
│ └── git_placeholder
└── needle_embeddings
│ └── git_placeholder
├── eval.sh
├── eval_vision_niah.py
├── eval_vision_niah_sampling.py
├── model_weights
└── git_placeholder
├── needle_datasets
├── dataset.json
├── generate_hf_dataset.py
├── git_placeholder
└── images
│ ├── astronaut.png
│ ├── construction_site.png
│ ├── dolphin.png
│ ├── panda_scientist.png
│ ├── selenium_green.jpg
│ ├── sora_balloon.png
│ ├── teddy_bear_times_square.png
│ └── ucsd.jpeg
├── niah_output
├── LLaVA-NeXT-Video-7B-32K
│ ├── rope_theta_1000000
│ │ ├── all_accuracies.json
│ │ ├── avg_accuracy.txt
│ │ └── heatmap.png
│ └── rope_theta_100000000
│ │ ├── all_accuracies.json
│ │ ├── avg_accuracy.txt
│ │ └── heatmap.png
└── LongVA-7B
│ ├── all_accuracies.json
│ ├── avg_accuracy.txt
│ └── heatmap.png
├── produce_haystack_embedding.py
└── produce_needle_embedding.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | **__pycache__/
3 | wandb/
4 | *.pyc
5 | *.egg-info
6 | *.log
7 | *.log.*
8 | .deepspeed_env
9 | text_extend/training_output
10 | vision_niah/model_weights
11 | **/**.pt
12 | # **/**.mp4
13 | longva/build
14 | vision_niah/data/haystack_videos
15 | local_demo/cache/
16 | i18n
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2022 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and
6 | binary forms, with or without modification, are permitted provided
7 | that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in
14 | the documentation and/or other materials provided with the
15 | distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived
19 | from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | In the event that redistribution and/or use for commercial purpose in
34 | source or binary forms, with or without modification is required,
35 | please contact the contributor(s) of the work.
36 |
--------------------------------------------------------------------------------
/easy_context/__init__.py:
--------------------------------------------------------------------------------
1 | from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs
2 | from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama
3 | from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs
4 | from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama
5 | from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral
6 | from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
7 | from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs
8 | from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama
9 | from .modeling_qwen2 import Qwen2ForCausalLM_RingAttn
10 | def prepare_seq_parallel_inputs(
11 | seq_algo, input_ids, position_ids, target_ids, rank, world_size, device
12 | ):
13 | if seq_algo == "zigzag_ring_attn":
14 | return prepare_zigzag_ring_attn_inputs(
15 | input_ids, position_ids, target_ids, rank, world_size, device
16 | )
17 | elif seq_algo == "dist_flash_attn":
18 | return prepare_dist_flash_attn_inputs(
19 | input_ids, position_ids, target_ids, rank, world_size, device
20 | )
21 | elif seq_algo == "ulysses_attn":
22 | return prepare_ulysses_attn_inputs(
23 | input_ids, position_ids, target_ids, rank, world_size, device
24 | )
25 | elif seq_algo == "data_parallel":
26 | return {
27 | "local_input_ids": input_ids.to(device),
28 | "local_position_ids": position_ids.to(device),
29 | "local_target_ids": target_ids.to(device),
30 | }
31 | else:
32 | raise ValueError(f"Invalid seq_algo: {seq_algo}")
33 |
34 | def apply_seq_parallel_monkey_patch(
35 | seq_algo, model
36 | ):
37 | assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}"
38 | assert model in ["llama", "mistral"], f"Invalid model: {model}"
39 | if seq_algo == "data_parallel":
40 | return
41 | elif seq_algo == "zigzag_ring_attn" and model == "llama":
42 | apply_zigzag_ring_attn_monkey_patch_llama()
43 | elif seq_algo == "zigzag_ring_attn" and model == "mistral":
44 | apply_zigzag_ring_attn_monkey_patch_mistral()
45 | elif seq_algo == "dist_flash_attn" and model == "llama":
46 | apply_dist_flash_attn_monkey_patch_llama()
47 | elif seq_algo == "ulysses_attn" and model == "llama":
48 | apply_ulysses_attn_monkey_patch_llama()
49 | else:
50 | raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}")
51 |
52 | def prepare_dataloader(seq_algo, dataloader, acclerator):
53 | if seq_algo == "data_parallel":
54 | return acclerator.prepare(dataloader)
55 | else:
56 | return dataloader
--------------------------------------------------------------------------------
/easy_context/accelerate_configs/deepspeed_inference.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: easy_context/accelerate_configs/zero3_offload_inference.json
5 | zero3_init_flag: false
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | machine_rank: 0
9 | main_training_function: main
10 | num_machines: 1
11 | num_processes: 8
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/easy_context/accelerate_configs/single_node.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: easy_context/accelerate_configs/zero3_offload.json
5 | zero3_init_flag: false
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | machine_rank: 0
9 | main_training_function: main
10 | num_machines: 1
11 | num_processes: 8
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/easy_context/accelerate_configs/two_node.yaml:
--------------------------------------------------------------------------------
1 | debug: false
2 | deepspeed_config:
3 | deepspeed_config_file: easy_context/accelerate_configs/zero3_offload.json
4 | deepspeed_multinode_launcher: standard
5 | zero3_init_flag: false
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | num_machines: 2
9 | num_processes: 16
10 | main_training_function: main
11 | rdzv_backend: c10d
12 | same_network: false
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/easy_context/accelerate_configs/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "fp16": {
6 | "enabled": "auto"
7 | },
8 | "scheduler": {
9 | "type": "WarmupLR",
10 | "params": {
11 | "warmup_min_lr": 1e-5,
12 | "warmup_max_lr": 1e-5,
13 | "warmup_num_steps": 0,
14 | "warmup_type": "linear"
15 | }
16 | },
17 | "optimizer": {
18 | "type": "AdamW",
19 | "params": {
20 | "lr": "auto",
21 | "betas": [0.9, 0.95],
22 | "eps": 1e-8,
23 | "weight_decay": 0.1
24 | }
25 | },
26 | "zero_optimization": {
27 | "stage": 3,
28 | "offload_optimizer": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "offload_param": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "overlap_comm": true,
37 | "contiguous_gradients": true,
38 | "sub_group_size": 1e9,
39 | "reduce_bucket_size": "auto",
40 | "stage3_prefetch_bucket_size": "auto",
41 | "stage3_param_persistence_threshold": "auto",
42 | "stage3_max_live_parameters": 1e9,
43 | "stage3_max_reuse_distance": 1e9,
44 | "stage3_gather_16bit_weights_on_model_save": true
45 | },
46 | "gradient_accumulation_steps": "auto",
47 | "gradient_clipping": "auto",
48 | "steps_per_print": 2000,
49 | "train_batch_size": "auto",
50 | "train_micro_batch_size_per_gpu": 1,
51 | "wall_clock_breakdown": false
52 | }
53 |
--------------------------------------------------------------------------------
/easy_context/accelerate_configs/zero3_offload_inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "fp16": {
6 | "enabled": "auto"
7 | },
8 | "zero_optimization": {
9 | "stage": 3,
10 | "stage3_prefetch_bucket_size": 33554432,
11 | "stage3_param_persistence_threshold": 4096,
12 | "stage3_max_live_parameters":33554432,
13 | "offload_param": {
14 | "device": "cpu",
15 | "pin_memory": true
16 | }
17 | },
18 | "train_batch_size": 8,
19 | "train_micro_batch_size_per_gpu": 1,
20 | "wall_clock_breakdown": false
21 | }
--------------------------------------------------------------------------------
/easy_context/dist_flash_attn/README.md:
--------------------------------------------------------------------------------
1 | # LightSeq
2 | Taken from https://github.com/RulinShao/LightSeq. All credits to the authors.
3 |
4 | ```
5 | @article{li2023lightseq,
6 | title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS},
7 | author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao},
8 | journal={arXiv preprint arXiv:2310.03294},
9 | year={2023}
10 | }
11 | ```
--------------------------------------------------------------------------------
/easy_context/dist_flash_attn/prepare_input.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def extract_local(value, rank, world_size, device, dim=1):
4 | value_local = value.chunk(world_size, dim=dim)[rank]
5 | return value_local.to(device)
6 |
7 |
8 | def prepare_dist_flash_attn_inputs(
9 | input_ids, position_ids, target_ids, rank, world_size, device
10 | ):
11 | local_input_ids = extract_local(
12 | input_ids,
13 | rank,
14 | world_size,
15 | device,
16 | )
17 | local_position_ids = extract_local(
18 | position_ids,
19 | rank,
20 | world_size,
21 | device,
22 | )
23 | if target_ids is not None:
24 | local_target_ids = extract_local(
25 | target_ids,
26 | rank,
27 | world_size,
28 | device,
29 | )
30 | else:
31 | local_target_ids = None
32 | return {
33 | "local_input_ids": local_input_ids,
34 | "local_position_ids": local_position_ids,
35 | "local_target_ids": local_target_ids,
36 | }
--------------------------------------------------------------------------------
/easy_context/low_mem_cross_ent.py:
--------------------------------------------------------------------------------
1 | """Low memory cross entropy without materilizing the logits
2 |
3 | This module enables long-context training of large vocab models, e.g., Gemma has 250K vocab and Llama 3 has 150K
4 |
5 | Yao Fu, University of Edinburgh
6 | yao.fu@ed.ac.uk
7 | """
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 |
13 | def cross_ent_normal(x, weight, labels):
14 | logits = torch.einsum("bsh, vh -> bsv", x, weight)
15 | vocab = weight.size(0)
16 | loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1))
17 | return loss
18 |
19 | class LowMemLogitProjCrossEnt(torch.autograd.Function):
20 | """Low memory implementation of logits projection plus cross entropy loss.
21 | Useful for reducing the peak memory when dealing with vocabulary larger than 100000
22 |
23 | TODO: integrate this function into easy context
24 |
25 | Two tricks used here
26 | 1. Shard the data to reduce peak memory
27 | 2. Do not save the logits
28 | """
29 |
30 | @staticmethod
31 | # @torch.compile() # Currently we do not use torch.compile because it uses additional memory
32 | def forward(ctx, x: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, sp: int=4):
33 | """
34 | Args:
35 | x: size = [batch, seqlen, hidden]
36 | weight: size = [vocab, hidden]
37 | labels: size = [batch, seqlen]
38 | """
39 | bsz, seqlen, hidden = x.size()
40 | vocab = weight.size(0)
41 | micro_seqlen = seqlen // sp
42 |
43 | loss = 0
44 | for i in range(sp): # shard data along the sequence dimension
45 | logits_i_slice = torch.einsum("bsh, vh -> bsv", x[:, micro_seqlen * i: micro_seqlen * (i + 1)], weight)
46 | loss_i = F.cross_entropy(logits_i_slice.reshape(-1, vocab), labels[:, micro_seqlen * i: micro_seqlen * (i + 1)].reshape(-1))
47 | loss = loss + loss_i
48 |
49 | loss = loss / sp
50 | ctx.save_for_backward(x, weight, labels) # because we do no save logits, we save memory
51 | ctx.sp = sp
52 | return loss
53 |
54 | # @torch.compile()
55 | @staticmethod
56 | def backward(ctx, grad_output):
57 | """Manually calculate the gradient in a memory-efficient way
58 | Ref: https://indii.org/blog/gradients-of-softmax-and-logsumexp/
59 | """
60 | x, weight, labels = ctx.saved_tensors
61 | sp = ctx.sp
62 | device = x.device
63 | dtype = x.dtype
64 | bsz, seqlen, hidden = x.size()
65 | vocab, hidden = weight.size()
66 | micro_seqlen = seqlen // sp
67 |
68 | d_weight = torch.zeros_like(weight, device=weight.device)
69 | d_x = []
70 | for i in range(sp): # shard data along sequence dimension, reduce peak memory
71 | x_ = x[:, micro_seqlen * i: micro_seqlen * (i + 1)]
72 | p = F.softmax(
73 | torch.einsum("blh, vh -> blv", x_, weight),
74 | dim=-1
75 | )
76 |
77 | # memory efficient in-place backprop
78 | # loss -> d_logits
79 | d_logits = -p.view(-1) # [b * l * v]
80 | labels_ = labels[:, micro_seqlen * i: micro_seqlen * (i + 1)].reshape(-1) # [b * l]
81 | index = torch.arange(bsz * micro_seqlen, device=device) * vocab + labels_
82 | source = torch.tensor([1] * bsz * micro_seqlen, dtype=dtype, device=device)
83 | d_logits.index_add_(0, index, source)
84 | d_logits = -d_logits.view(bsz, micro_seqlen, vocab) / (bsz * seqlen)
85 |
86 | # d_logits -> d_x and d_weight
87 | d_x.append(torch.einsum("blv, vh -> blh", d_logits, weight))
88 | d_weight += torch.einsum("blv, blh -> vh", d_logits, x_)
89 |
90 | d_weight = grad_output * d_weight
91 | d_x = grad_output * torch.concat(d_x, 1)
92 | return d_x, d_weight, None, None
93 |
94 | low_mem_cross_ent = LowMemLogitProjCrossEnt.apply
95 |
--------------------------------------------------------------------------------
/easy_context/low_mem_cross_ent_tests/test_correctness.py:
--------------------------------------------------------------------------------
1 | """Test the correctness (up to certain tolerance of numerical error) of low-memory cross-ent
2 |
3 | Yao Fu, University of Edinburgh
4 | yao.fu@ed.ac.uk
5 | """
6 |
7 | import sys
8 | sys.path.append("..")
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from low_mem_cross_ent import low_mem_cross_ent, cross_ent_normal
13 |
14 | bsz = 1
15 | seqlen = 50000
16 | hidden = 4096
17 | vocab = 15000
18 | dtype = torch.bfloat16
19 | rtol=1e-05 # relative tolerance when comparing the gradients from two implementations
20 | atol=1e-07 # absolute tolerance when comparing the gradients from two implementations
21 | # in Pytorch its default is 1e-8 but our implementation cannot pass this threshold
22 | # 1e-7 seems to be the smallest rolerance we can pass
23 |
24 | x = torch.normal(mean=0, std=0.01, size=(bsz, seqlen, hidden),
25 | device="cuda", dtype=dtype, requires_grad=True)
26 | weight = torch.normal(mean=0, std=0.01, size=(vocab, hidden),
27 | device="cuda", dtype=dtype, requires_grad=True)
28 | labels = torch.randint(low=0, high=vocab - 1, size=(bsz, seqlen), device="cuda")
29 |
30 | loss_normal = cross_ent_normal(x, weight, labels)
31 | print("loss normal: %.4f" % loss_normal.cpu().item())
32 | loss_normal.backward()
33 | x_grad = x.grad.clone()
34 | weight_grad = weight.grad.clone()
35 | # print(x.grad)
36 | # print(weight.grad)
37 |
38 |
39 | # TODO: this one almost reduce memory to half. Maybe further increase sp
40 | x.grad = None
41 | weight.grad = None
42 | loss_low_mem = low_mem_cross_ent(x, weight, labels)
43 | print("loss low mem: %.4f" % loss_low_mem.cpu().item())
44 | loss_low_mem.backward()
45 | # print(x.grad)
46 | # print(weight.grad)
47 |
48 | ## Test implementation by asserting close
49 | assert(torch.allclose(x_grad, x.grad, rtol=rtol, atol=atol))
50 | assert(torch.allclose(weight_grad, weight.grad, rtol=rtol, atol=atol))
51 | print("PASS: gradients from normal computation and low memory computation are close.")
52 |
53 |
54 | # #### Test gradient of logits
55 | # x.grad = None
56 | # weight.grad = None
57 | # logits = torch.einsum("bsh, vh -> bsv", x, weight)
58 | # loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1))
59 | # d_logits = torch.autograd.grad(loss, logits)
60 | # p = F.softmax(torch.einsum("blh, vh -> blv", x, weight), dim=-1)
61 | # p_ = p / (bsz * seqlen)
62 |
63 | # #### test index add
64 | # x = torch.tensor([1, 2, 3, 4, 5, 6, 7])
65 | # index = torch.tensor([1, 3, 4])
66 | # source = torch.tensor([1, 1, 1])
67 | # x.index_add_(dim=0, index=index, source=source)
68 |
69 | # #### test index add 2
70 | # sp = 4
71 | # micro_seqlen = seqlen // sp
72 | # p = torch.normal(mean=0, std=0.01, size=(bsz, micro_seqlen, vocab),
73 | # device="cuda", dtype=torch.bfloat16)
74 | # labels_ = labels[:, :micro_seqlen].view(-1)
75 | # index = torch.arange(bsz * micro_seqlen, device="cuda") * vocab
76 | # index += labels_
77 | # d_logits = -p.view(-1)
78 | # source = torch.tensor([1] * bsz * micro_seqlen, dtype=torch.bfloat16, device="cuda")
79 | # d_logits.index_add_(0, index, source)
80 | # d_logits = d_logits.view(bsz, micro_seqlen, vocab)
81 |
82 |
--------------------------------------------------------------------------------
/easy_context/low_mem_cross_ent_tests/test_mem_and_speed.py:
--------------------------------------------------------------------------------
1 | """Test the memory and speed and MFU of low memory cross entropy
2 |
3 | Yao Fu, University of Edinburgh
4 | yao.fu@ed.ac.uk
5 |
6 | bf16, seqlen=50000, vocab=150000, without torch.compile
7 | | | normal | low_mem | - |
8 | | sp | - | 4 | 16 |
9 | | peak mem | 43.4G | 18.5G | 8.1G |
10 | | forward | 0.307 | 0.310 | 0.315 |
11 | | backward | 0.631 | 0.896 | 0.914 |
12 | | MFU | 0.57 | 0.45 | 0.44 |
13 |
14 | NOTE: tried torch.compile and it takes significantly larger memory, so do not use
15 | TODO: profile and check why backward is slower
16 | """
17 | import sys
18 | sys.path.append("..")
19 |
20 | import torch
21 | import numpy as np
22 | import torch.nn.functional as F
23 | from low_mem_cross_ent import low_mem_cross_ent, cross_ent_normal
24 |
25 | implementation = "low_mem" # "normal", "low_mem"
26 | device_type = "A100"
27 | bsz = 1
28 | seqlen = 50000
29 | hidden = 4096
30 | vocab = 150000
31 | sp=16
32 | dtype = torch.bfloat16
33 | # dtype = torch.float
34 | G = 1024 ** 3
35 | T = 1024 ** 4
36 |
37 | x = torch.normal(mean=0, std=0.01, size=(bsz, seqlen, hidden),
38 | device="cuda", dtype=dtype, requires_grad=True)
39 | weight = torch.normal(mean=0, std=0.01, size=(vocab, hidden),
40 | device="cuda", dtype=dtype, requires_grad=True)
41 | labels = torch.randint(low=0, high=vocab - 1, size=(bsz, seqlen), device="cuda")
42 |
43 | def timed(fn):
44 | start = torch.cuda.Event(enable_timing=True)
45 | end = torch.cuda.Event(enable_timing=True)
46 | start.record()
47 | result = fn()
48 | end.record()
49 | torch.cuda.synchronize()
50 | return result, start.elapsed_time(end) / 1000
51 |
52 | n_runs = 50
53 | flop = 6 * bsz * seqlen * hidden * vocab
54 | if(implementation == "normal"):
55 | forward_times, backward_times = [], []
56 | for _ in range(n_runs):
57 | loss_normal, time_elapse = timed(lambda: cross_ent_normal(x, weight, labels))
58 | forward_times.append(time_elapse)
59 | _, time_elapse = timed(lambda: loss_normal.backward())
60 | backward_times.append(time_elapse)
61 | mem = torch.cuda.max_memory_allocated()
62 | elif(implementation == "low_mem"):
63 | forward_times, backward_times = [], []
64 | for _ in range(n_runs):
65 | loss_low_mem, time_elapse = timed(lambda: low_mem_cross_ent(x, weight, labels, sp))
66 | forward_times.append(time_elapse)
67 | _, time_elapse = timed(lambda: loss_low_mem.backward())
68 | backward_times.append(time_elapse)
69 | mem = torch.cuda.max_memory_allocated()
70 | else: raise NameError("Implementation %s not recognized" % implementation)
71 |
72 | forward_time = np.median(forward_times)
73 | backward_time = np.median(backward_times)
74 | flops = (flop / T) / (forward_time + backward_time)
75 | if(device_type == "A100"):
76 | device_flop = 312
77 | else: raise NameError("device %s not recognized" % device_type)
78 |
79 | print("%s, peak memory %.1fG, forward time %.4f, backward time %.4f, flops %.2fT, util %.2f" %
80 | (implementation, mem / G, forward_time, backward_time, flops, flops / device_flop))
81 |
--------------------------------------------------------------------------------
/easy_context/ulysses_attn/monkey_patch.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | from typing import List, Optional, Tuple, Union
3 | import warnings
4 | import torch
5 | import torch.utils.checkpoint
6 | try:
7 | from yunchang.ulysses import UlyssesAttention
8 | ulysses_attn = UlyssesAttention()
9 | except:
10 | ulysses_attn = None
11 |
12 |
13 | def new_flash_attn_forward(
14 | self,
15 | query_states,
16 | key_states,
17 | value_states,
18 | attention_mask,
19 | query_length,
20 | dropout=0.0,
21 | softmax_scale=None,
22 | use_sliding_windows=False,
23 | ):
24 | if not self._flash_attn_uses_top_left_mask:
25 | causal = self.is_causal
26 | else:
27 | causal = self.is_causal and query_length != 1
28 |
29 | # Contains at least one padding token in the sequence
30 | assert attention_mask is None
31 | assert causal is True
32 | assert use_sliding_windows is False
33 | attn_output = ulysses_attn(
34 | query_states,
35 | key_states,
36 | value_states,
37 | dropout,
38 | softmax_scale,
39 | causal=causal,
40 | )
41 |
42 | return attn_output
43 |
44 |
45 | def new_decoder_forward(
46 | self,
47 | hidden_states: torch.Tensor,
48 | attention_mask: Optional[torch.Tensor] = None,
49 | position_ids: Optional[torch.LongTensor] = None,
50 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
51 | output_attentions: Optional[bool] = False,
52 | use_cache: Optional[bool] = False,
53 | cache_position: Optional[torch.LongTensor] = None,
54 | **kwargs,
55 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
56 | assert isinstance(
57 | self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
58 | ) or isinstance(
59 | self.self_attn,
60 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
61 | ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
62 |
63 | if "padding_mask" in kwargs:
64 | warnings.warn(
65 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
66 | )
67 |
68 | residual = hidden_states
69 |
70 | hidden_states = self.input_layernorm(hidden_states)
71 |
72 | # Self Attention
73 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
74 | hidden_states=hidden_states,
75 | attention_mask=attention_mask,
76 | position_ids=position_ids,
77 | past_key_value=past_key_value,
78 | output_attentions=output_attentions,
79 | use_cache=use_cache,
80 | cache_position=cache_position,
81 | **kwargs,
82 | )
83 | hidden_states = residual + hidden_states
84 |
85 | # Fully Connected
86 | residual = hidden_states
87 | hidden_states = self.post_attention_layernorm(hidden_states)
88 | hidden_states = self.mlp(hidden_states)
89 | hidden_states = residual + hidden_states
90 |
91 | outputs = (hidden_states,)
92 |
93 | if output_attentions:
94 | outputs += (self_attn_weights,)
95 |
96 | if use_cache:
97 | outputs += (present_key_value,)
98 |
99 | return outputs
100 |
101 |
102 | def apply_ulysses_attn_monkey_patch_llama():
103 | transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
104 | new_flash_attn_forward
105 | )
106 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
107 | new_decoder_forward
108 | )
109 |
110 |
111 |
--------------------------------------------------------------------------------
/easy_context/ulysses_attn/prepare_inputs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def extract_local(value, rank, world_size, device, dim=1):
5 | dimension_size = value.shape[dim]
6 | sub_seq_length = dimension_size // world_size
7 |
8 | sub_seq_start = rank * sub_seq_length
9 | sub_seq_end = (rank + 1) * sub_seq_length
10 | local_value = value[:, sub_seq_start:sub_seq_end]
11 |
12 | return local_value.to(device)
13 |
14 |
15 | def prepare_ulysses_attn_inputs(
16 | input_ids, position_ids, target_ids, rank, world_size, device
17 | ):
18 |
19 | local_input_ids = extract_local(
20 | input_ids,
21 | rank,
22 | world_size,
23 | device,
24 | )
25 | local_position_ids = extract_local(
26 | position_ids,
27 | rank,
28 | world_size,
29 | device,
30 | )
31 |
32 | if target_ids is not None:
33 | local_target_ids = extract_local(
34 | target_ids,
35 | rank,
36 | world_size,
37 | device,
38 | )
39 | else:
40 | local_target_ids = None
41 | return {
42 | "local_input_ids": local_input_ids,
43 | "local_position_ids": local_position_ids,
44 | "local_target_ids": local_target_ids,
45 | }
46 |
--------------------------------------------------------------------------------
/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import transformers
17 | import inspect
18 |
19 |
20 | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
21 | """
22 | Saves VRAM by smartly offloading to RAM.
23 | Tiny hit to performance, since we mask the movement via non blocking calls.
24 | """
25 |
26 | @staticmethod
27 | @torch.cuda.amp.custom_fwd
28 | def forward(ctx, forward_function, hidden_states, *args):
29 | saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
30 | with torch.no_grad():
31 | output = forward_function(hidden_states, *args)
32 | ctx.save_for_backward(saved_hidden_states)
33 | ctx.forward_function = forward_function
34 | ctx.args = args
35 |
36 | return output
37 |
38 | pass
39 |
40 | @staticmethod
41 | @torch.cuda.amp.custom_bwd
42 | def backward(ctx, dY):
43 | (hidden_states,) = ctx.saved_tensors
44 | hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
45 | hidden_states.requires_grad = True
46 | with torch.enable_grad():
47 | (output,) = ctx.forward_function(hidden_states, *ctx.args)
48 | torch.autograd.backward(output, dY)
49 | return (
50 | None,
51 | hidden_states.grad,
52 | ) + (
53 | None,
54 | ) * len(ctx.args)
55 |
56 | pass
57 |
58 |
59 | pass
60 |
61 |
62 | def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
63 | assert gradient_checkpointing_kwargs == None
64 | if not self.supports_gradient_checkpointing:
65 | raise ValueError(
66 | f"{self.__class__.__name__} does not support gradient checkpointing."
67 | )
68 |
69 | gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
70 | # For old GC format (transformers < 4.35.0) for models that live on the Hub
71 | # we will fall back to the overwritten `_set_gradient_checkpointing` method
72 | _is_using_old_format = (
73 | "value" in inspect.signature(self._set_gradient_checkpointing).parameters
74 | )
75 |
76 | if not _is_using_old_format:
77 | self._set_gradient_checkpointing(
78 | enable=True, gradient_checkpointing_func=gradient_checkpointing_func
79 | )
80 | else:
81 | raise NotImplementedError()
82 |
83 | if getattr(self, "_hf_peft_config_loaded", False):
84 | # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
85 | # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
86 | # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
87 | # the gradients to make sure the gradient flows.
88 | self.enable_input_require_grads()
89 |
90 |
91 | def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
92 | transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
93 | new_gradient_checkpointing_enable
94 | )
95 |
--------------------------------------------------------------------------------
/easy_context/zigzag_ring_attn/monkey_patch.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | from typing import List, Optional, Tuple, Union
3 | import warnings
4 | import torch
5 | import torch.utils.checkpoint
6 | from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
7 |
8 |
9 | def new_flash_attn_forward(
10 | self,
11 | query_states,
12 | key_states,
13 | value_states,
14 | attention_mask,
15 | query_length,
16 | dropout=0.0,
17 | softmax_scale=None,
18 | use_sliding_windows=False,
19 | ):
20 | if not self._flash_attn_uses_top_left_mask:
21 | causal = self.is_causal
22 | else:
23 | causal = self.is_causal and query_length != 1
24 |
25 | # Contains at least one padding token in the sequence
26 | assert attention_mask is None
27 | assert causal is True
28 | assert use_sliding_windows is False
29 | attn_output = zigzag_ring_flash_attn_func(
30 | query_states,
31 | key_states,
32 | value_states,
33 | dropout,
34 | softmax_scale,
35 | causal=causal,
36 | )
37 |
38 | return attn_output
39 |
40 |
41 | def new_decoder_forward(
42 | self,
43 | hidden_states: torch.Tensor,
44 | attention_mask: Optional[torch.Tensor] = None,
45 | position_ids: Optional[torch.LongTensor] = None,
46 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
47 | output_attentions: Optional[bool] = False,
48 | use_cache: Optional[bool] = False,
49 | cache_position: Optional[torch.LongTensor] = None,
50 | **kwargs,
51 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
52 | assert isinstance(
53 | self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
54 | ) or isinstance(
55 | self.self_attn,
56 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
57 | ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
58 |
59 | if "padding_mask" in kwargs:
60 | warnings.warn(
61 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
62 | )
63 |
64 | residual = hidden_states
65 |
66 | hidden_states = self.input_layernorm(hidden_states)
67 |
68 | # Self Attention
69 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
70 | hidden_states=hidden_states,
71 | attention_mask=attention_mask,
72 | position_ids=position_ids,
73 | past_key_value=past_key_value,
74 | output_attentions=output_attentions,
75 | use_cache=use_cache,
76 | cache_position=cache_position,
77 | **kwargs,
78 | )
79 | hidden_states = residual + hidden_states
80 |
81 | # Fully Connected
82 | residual = hidden_states
83 | hidden_states = self.post_attention_layernorm(hidden_states)
84 | hidden_states = self.mlp(hidden_states)
85 | hidden_states = residual + hidden_states
86 |
87 | outputs = (hidden_states,)
88 |
89 | if output_attentions:
90 | outputs += (self_attn_weights,)
91 |
92 | if use_cache:
93 | outputs += (present_key_value,)
94 |
95 | return outputs
96 |
97 |
98 | def apply_zigzag_ring_attn_monkey_patch_llama():
99 | transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
100 | new_flash_attn_forward
101 | )
102 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
103 | new_decoder_forward
104 | )
105 |
106 |
107 | def apply_zigzag_ring_attn_monkey_patch_mistral():
108 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = (
109 | new_flash_attn_forward
110 | )
111 | transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
112 | new_decoder_forward
113 | )
114 |
--------------------------------------------------------------------------------
/easy_context/zigzag_ring_attn/prepare_inputs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def extract_local(value, rank, world_size, device, dim=1):
5 | value_chunks = value.chunk(2 * world_size, dim=dim)
6 | local_value = torch.cat(
7 | [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim
8 | )
9 | return local_value.to(device)
10 |
11 |
12 | def prepare_zigzag_ring_attn_inputs(
13 | input_ids, position_ids, target_ids, rank, world_size, device
14 | ):
15 | local_input_ids = extract_local(
16 | input_ids,
17 | rank,
18 | world_size,
19 | device,
20 | )
21 | local_position_ids = extract_local(
22 | position_ids,
23 | rank,
24 | world_size,
25 | device,
26 | )
27 | if target_ids is not None:
28 | local_target_ids = extract_local(
29 | target_ids,
30 | rank,
31 | world_size,
32 | device,
33 | )
34 | else:
35 | local_target_ids = None
36 | return {
37 | "local_input_ids": local_input_ids,
38 | "local_position_ids": local_position_ids,
39 | "local_target_ids": local_target_ids,
40 | }
41 |
--------------------------------------------------------------------------------
/local_demo/assets/assistant_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/assistant_logo.png
--------------------------------------------------------------------------------
/local_demo/assets/dc_demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/dc_demo.mp4
--------------------------------------------------------------------------------
/local_demo/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/favicon.ico
--------------------------------------------------------------------------------
/local_demo/assets/jobs.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/jobs.mp4
--------------------------------------------------------------------------------
/local_demo/assets/llava_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/llava_logo.png
--------------------------------------------------------------------------------
/local_demo/assets/llava_next.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/llava_next.jpg
--------------------------------------------------------------------------------
/local_demo/assets/lmms-eval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/lmms-eval.png
--------------------------------------------------------------------------------
/local_demo/assets/otter_books.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/otter_books.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_01.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_02.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_03.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_03.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_04.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_04.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_05.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_05.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_06.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_06.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_07.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_07.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_08.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_08.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_09.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_09.jpg
--------------------------------------------------------------------------------
/local_demo/assets/user_example_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_10.png
--------------------------------------------------------------------------------
/local_demo/assets/user_example_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_example_11.png
--------------------------------------------------------------------------------
/local_demo/assets/user_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/user_logo.png
--------------------------------------------------------------------------------
/local_demo/assets/water.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/water.mp4
--------------------------------------------------------------------------------
/local_demo/assets/white_cat_smile.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/assets/white_cat_smile.jpg
--------------------------------------------------------------------------------
/local_demo/cache/0266b3f7b9311a93f0f38bc8cf0698b019c829b1/user_example_09.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/0266b3f7b9311a93f0f38bc8cf0698b019c829b1/user_example_09.jpg
--------------------------------------------------------------------------------
/local_demo/cache/0333f1f38db26d76752fee6cd2d938f06617dd2f/dc_demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/0333f1f38db26d76752fee6cd2d938f06617dd2f/dc_demo.mp4
--------------------------------------------------------------------------------
/local_demo/cache/0c299f0f979725494d926c3e76cfeb75c19ef564/assistant_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/0c299f0f979725494d926c3e76cfeb75c19ef564/assistant_logo.png
--------------------------------------------------------------------------------
/local_demo/cache/3c94cf4ccdbb877666285b42dd9263073d1fa19c/user_example_07.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/3c94cf4ccdbb877666285b42dd9263073d1fa19c/user_example_07.jpg
--------------------------------------------------------------------------------
/local_demo/cache/447d9ea0a12a7be4cf7f279dc4e667b1b938a0c2/otter_books.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/447d9ea0a12a7be4cf7f279dc4e667b1b938a0c2/otter_books.jpg
--------------------------------------------------------------------------------
/local_demo/cache/7abebd5007998873c3d36612f984d241c3b574d8/user_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/7abebd5007998873c3d36612f984d241c3b574d8/user_logo.png
--------------------------------------------------------------------------------
/local_demo/cache/7dcdac5ff3df641ae7eb47d8b7aaeab740e02de9/user_example_06.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/7dcdac5ff3df641ae7eb47d8b7aaeab740e02de9/user_example_06.jpg
--------------------------------------------------------------------------------
/local_demo/cache/b7b8f04e74cda37d9a574c1a2098d7e4ee97b212/user_example_05.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/b7b8f04e74cda37d9a574c1a2098d7e4ee97b212/user_example_05.jpg
--------------------------------------------------------------------------------
/local_demo/cache/dbb5054b8579d1943ff74045f254e0e759efec99/white_cat_smile.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/local_demo/cache/dbb5054b8579d1943ff74045f254e0e759efec99/white_cat_smile.jpg
--------------------------------------------------------------------------------
/local_demo/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hashlib
3 |
4 | #############
5 | # LongVA Demo Utils
6 | #############
7 |
8 | title_markdown = """
9 | # [LongVA Multimodal Chat](https://lmm-lab.github.io/LongVA/)
10 | """
11 |
12 | subtitle_markdown = """
13 | ### This is our research preview of LongVA, a multimodal model that is capable of accurately retrieving visual information from 2000 frames or more than 200K visual tokens.
14 | """
15 |
16 | html_header = """
17 |
78 |
79 |
89 | """
90 |
91 | block_css = """
92 | #buttons button {
93 | min-width: min(120px,100%);
94 | }
95 | """
96 |
97 | tos_markdown = """
98 | ## Terms of use
99 | By using this service, users are required to agree to the following terms:
100 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
101 | Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
102 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
103 | """
104 |
105 |
106 | learn_more_markdown = """
107 | ## License
108 | The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
109 | """
110 |
111 | bibtext = """
112 | ## Citation
113 | ```
114 | @misc{zhang2024longva,
115 | title={LongVA: Long Context Transfer from Language to Vision},
116 | url={https://lmms-lab.github.io/posts/longva/},
117 | author={Zhang, Peiyuan and Zhang, Kaichen and Li, Bo and Zeng, Guangtao and Yang, Jingkang and Zhang, Yuanhan and Li, Chunyuan and Liu, Ziwei},
118 | month={June},
119 | year={2024}
120 | }
121 | ```
122 | """
123 |
124 | PARENT_FOLDER = os.path.dirname(os.path.abspath(__file__))
125 | ################## BACKEND ##################
126 | os.environ["GRADIO_EXAMPLES_CACHE"] = (
127 | f"{PARENT_FOLDER}/cache"
128 | )
129 | os.environ["GRADIO_TEMP_DIR"] = (
130 | f"{PARENT_FOLDER}/cache"
131 | )
132 |
133 | def generate_file_hash(file_path):
134 | sha256_hash = hashlib.sha256()
135 | with open(file_path, "rb") as f:
136 | for byte_block in iter(lambda: f.read(4096), b""):
137 | sha256_hash.update(byte_block)
138 | return sha256_hash.hexdigest()[:6]
--------------------------------------------------------------------------------
/local_demo/debug_json_api.py:
--------------------------------------------------------------------------------
1 | from gradio_client import Client, handle_file
2 |
3 | client = Client("http://127.0.0.1:8000/")
4 |
5 | # state = [
6 | # (('/mnt/sfs-common/peiyuan/peiyuan/LongVa/local_demo/assets/otter_books.jpg',), None),
7 | # ("What does this image show?", None)
8 | # ]
9 | state = [ ("What's the video about?", None)]
10 |
11 |
12 | # result = client.predict(
13 | # history=[],
14 | # message={'text': "What's the video about?", 'files': []},
15 | # video_input="/mnt/sfs-common/peiyuan/peiyuan/LongVa/local_demo/cache/0333f1f38db26d76752fee6cd2d938f06617dd2f/dc_demo.mp4",
16 | # api_name="/add_message"
17 | # )
18 |
19 | # result = client.predict(
20 | # video_input=None,
21 | # state=state,
22 | # sample_frames=16,
23 | # temperature=0.7,
24 | # max_new_tokens=1024,
25 | # top_p=1,
26 | # api_name="/bot_response_1"
27 | # )
28 | # print(result)
29 |
30 | image_path = "local_demo/assets/assistant_logo.png"
31 | # image to base64
32 | import base64
33 | with open(image_path, "rb") as image_file:
34 | base64_image = base64.b64encode(image_file.read()).decode()
35 |
36 | input_json_url = [
37 | {
38 | "role": "user",
39 | "content": [
40 | {"type": "text", "text": "What’s in this image?"},
41 | {
42 | "type": "image_url",
43 | "image_url": {
44 | "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
45 | },
46 | },
47 | ],
48 | },
49 | {
50 | "role": "assistant",
51 | "content": [
52 | {"type": "text", "text": "The image shows a serene, open field with lush green grass and a few scattered trees or shrubs in the distance. The sky above is mostly clear with some wispy clouds, suggesting it's a bright and likely pleasant day. There's a light-colored wooden boardwalk or path that meanders through the field, inviting a peaceful walk along its length. The overall scene conveys a sense of tranquility and natural beauty."},
53 | ]
54 | },
55 | {
56 | "role": "user",
57 | "content": [
58 | {"type": "text", "text": "Where is this place?"},
59 | ]
60 | }
61 | ]
62 |
63 | input_json_base64 = [
64 | {
65 | "role": "user",
66 | "content": [
67 | {"type": "text", "text": "What’s in this image?"},
68 | {
69 | "type": "image_url",
70 | "image_url": {
71 | "url": f"data:image/jpeg;base64,{base64_image}",
72 | },
73 | },
74 | ],
75 | }
76 | ]
77 | result = client.predict(
78 | input_json=input_json_base64,
79 | api_name="/base64_api"
80 | )
81 | print(result)
82 |
--------------------------------------------------------------------------------
/local_demo/theme_dropdown.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 |
4 | from gradio.themes.utils import ThemeAsset
5 |
6 |
7 | def create_theme_dropdown():
8 | import gradio as gr
9 |
10 | asset_path = pathlib.Path(__file__).parent / "themes"
11 | themes = []
12 | for theme_asset in os.listdir(str(asset_path)):
13 | themes.append(
14 | (ThemeAsset(theme_asset), gr.Theme.load(str(asset_path / theme_asset)))
15 | )
16 |
17 | def make_else_if(theme_asset):
18 | return f"""
19 | else if (theme == '{str(theme_asset[0].version)}') {{
20 | var theme_css = `{theme_asset[1]._get_theme_css()}`
21 | }}"""
22 |
23 | head, tail = themes[0], themes[1:]
24 | if_statement = f"""
25 | if (theme == "{str(head[0].version)}") {{
26 | var theme_css = `{head[1]._get_theme_css()}`
27 | }} {" ".join(make_else_if(t) for t in tail)}
28 | """
29 |
30 | latest_to_oldest = sorted([t[0] for t in themes], key=lambda asset: asset.version)[
31 | ::-1
32 | ]
33 | latest_to_oldest = [str(t.version) for t in latest_to_oldest]
34 |
35 | component = gr.Dropdown(
36 | choices=latest_to_oldest,
37 | value=latest_to_oldest[0],
38 | render=False,
39 | label="Select Version",
40 | )
41 |
42 | return (
43 | component,
44 | f"""
45 | (theme) => {{
46 | if (!document.querySelector('.theme-css')) {{
47 | var theme_elem = document.createElement('style');
48 | theme_elem.classList.add('theme-css');
49 | document.head.appendChild(theme_elem);
50 | }} else {{
51 | var theme_elem = document.querySelector('.theme-css');
52 | }}
53 | {if_statement}
54 | theme_elem.innerHTML = theme_css;
55 | }}
56 | """,
57 | )
--------------------------------------------------------------------------------
/longva/longva/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/longva/longva/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/longva/longva/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | AVAILABLE_MODELS = {
4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
6 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
7 | # Add other models as needed
8 | }
9 |
10 | for model_name, model_classes in AVAILABLE_MODELS.items():
11 | try:
12 | exec(f"from .language_model.{model_name} import {model_classes}")
13 | except Exception as e:
14 | raise e
15 | print(f"Failed to import {model_name} from longva.language_model.{model_name}")
16 |
--------------------------------------------------------------------------------
/longva/longva/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 |
6 | import argparse
7 |
8 | import torch
9 | from tqdm import tqdm
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | from longva import LlavaLlamaForCausalLM
12 |
13 |
14 | def apply_delta(base_model_path, target_model_path, delta_path):
15 | print("Loading base model")
16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31 | bparam = base.state_dict()[name]
32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
33 |
34 | print("Saving target model")
35 | delta.save_pretrained(target_model_path)
36 | delta_tokenizer.save_pretrained(target_model_path)
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument("--base-model-path", type=str, required=True)
42 | parser.add_argument("--target-model-path", type=str, required=True)
43 | parser.add_argument("--delta-path", type=str, required=True)
44 |
45 | args = parser.parse_args()
46 |
47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
48 |
--------------------------------------------------------------------------------
/longva/longva/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 |
6 | import argparse
7 |
8 | import torch
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from longva.model import *
11 | from longva.model.utils import auto_upgrade
12 |
13 |
14 | def consolidate_ckpt(src_path, dst_path):
15 | print("Loading model")
16 | auto_upgrade(src_path)
17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
19 | src_model.save_pretrained(dst_path)
20 | src_tokenizer.save_pretrained(dst_path)
21 |
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--src", type=str, required=True)
26 | parser.add_argument("--dst", type=str, required=True)
27 |
28 | args = parser.parse_args()
29 |
30 | consolidate_ckpt(args.src, args.dst)
31 |
--------------------------------------------------------------------------------
/longva/longva/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 |
20 | from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
21 | from longva.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
22 |
23 |
24 | class LlavaMptConfig(MptConfig):
25 | model_type = "llava_mpt"
26 |
27 |
28 | class LlavaMptModel(LlavaMetaModel, MptModel):
29 | config_class = LlavaMptConfig
30 |
31 | def __init__(self, config: MptConfig):
32 | config.hidden_size = config.d_model
33 | super(LlavaMptModel, self).__init__(config)
34 |
35 | def embed_tokens(self, x):
36 | return self.wte(x)
37 |
38 |
39 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
40 | config_class = LlavaMptConfig
41 | supports_gradient_checkpointing = True
42 |
43 | def __init__(self, config):
44 | super(MptForCausalLM, self).__init__(config)
45 |
46 | config.model_type = "llava_mpt"
47 | config.rope_scaling = None
48 | self.generation_config = GenerationConfig(
49 | temperature=0.0,
50 | max_new_tokens=1024,
51 | do_sample=False,
52 | top_p=None,
53 | )
54 |
55 | self.transformer = LlavaMptModel(config)
56 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57 |
58 | # Initialize weights and apply final processing
59 | self.post_init()
60 |
61 | def get_model(self):
62 | return self.transformer
63 |
64 | def _set_gradient_checkpointing(self, module, value=False):
65 | if isinstance(module, LlavaMptModel):
66 | module.gradient_checkpointing = value
67 |
68 | def forward(
69 | self,
70 | input_ids: Optional[torch.LongTensor] = None,
71 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
72 | attention_mask: Optional[torch.Tensor] = None,
73 | inputs_embeds: Optional[torch.Tensor] = None,
74 | labels: Optional[torch.Tensor] = None,
75 | use_cache: Optional[bool] = None,
76 | output_attentions: Optional[bool] = None,
77 | output_hidden_states: Optional[bool] = None,
78 | return_dict: Optional[bool] = None,
79 | cache_position=None,
80 | images=None,
81 | ):
82 |
83 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
84 |
85 | return super().forward(
86 | input_ids,
87 | past_key_values=past_key_values,
88 | attention_mask=attention_mask,
89 | inputs_embeds=inputs_embeds,
90 | labels=labels,
91 | use_cache=use_cache,
92 | output_attentions=output_attentions,
93 | output_hidden_states=output_hidden_states,
94 | return_dict=return_dict,
95 | )
96 |
97 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
98 | images = kwargs.pop("images", None)
99 | _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
100 | _inputs["images"] = images
101 | return _inputs
102 |
103 |
104 | AutoConfig.register("llava_mpt", LlavaMptConfig)
105 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
106 |
--------------------------------------------------------------------------------
/longva/longva/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 |
6 | import argparse
7 |
8 | import torch
9 | from tqdm import tqdm
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | from longva.model.utils import auto_upgrade
12 |
13 |
14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
15 | print("Loading base model")
16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31 | bparam = base.state_dict()[name]
32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
4 |
5 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
6 | # from .dev_eva_clip.eva_vit import EvaViTWrapper
7 |
8 |
9 | def build_vision_tower(vision_tower_cfg, **kwargs):
10 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
11 | is_absolute_path_exists = os.path.exists(vision_tower)
12 | use_s2 = getattr(vision_tower_cfg, "s2", False)
13 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
14 | if use_s2:
15 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
16 | else:
17 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
18 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
19 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
20 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
21 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
22 |
23 | raise ValueError(f"Unknown vision tower: {vision_tower}")
24 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 | from .pooler_projector import PoolerProjector
6 |
7 |
8 | class IdentityMap(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, *args, **kwargs):
13 | return x
14 |
15 | @property
16 | def config(self):
17 | return {"mm_projector_type": "identity"}
18 |
19 |
20 | class SimpleResBlock(nn.Module):
21 | def __init__(self, channels):
22 | super().__init__()
23 | self.pre_norm = nn.LayerNorm(channels)
24 |
25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
26 |
27 | def forward(self, x):
28 | x = self.pre_norm(x)
29 | return x + self.proj(x)
30 |
31 |
32 | def build_vision_projector(config, delay_load=False, **kwargs):
33 | projector_type = getattr(config, "mm_projector_type", "linear")
34 |
35 | if projector_type == "linear":
36 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
37 |
38 | if projector_type == "pooler":
39 | return PoolerProjector(config, kwargs["vision_cfg"])
40 |
41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
42 | if mlp_gelu_match:
43 | mlp_depth = int(mlp_gelu_match.group(1))
44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
45 | for _ in range(1, mlp_depth):
46 | modules.append(nn.GELU())
47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
48 | return nn.Sequential(*modules)
49 |
50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
51 | if mlp_gelu_resnet_match:
52 | mlp_depth = int(mlp_gelu_resnet_match.group(1))
53 | res_depth = int(mlp_gelu_resnet_match.group(2))
54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
55 | for _ in range(1, mlp_depth):
56 | modules.append(nn.GELU())
57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
58 | for _ in range(res_depth):
59 | modules.append(SimpleResBlock(config.hidden_size))
60 | return nn.Sequential(*modules)
61 |
62 | if projector_type == "identity":
63 | return IdentityMap()
64 |
65 | raise ValueError(f"Unknown projector type: {projector_type}")
66 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_projector/pooler_projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import math
5 |
6 | from transformers.models.clip.modeling_clip import CLIPVisionModel
7 |
8 |
9 | class PoolerProjector(nn.Module):
10 | def __init__(self, config, vision_cfg):
11 | super().__init__()
12 | self._config = config
13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size
14 |
15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
16 |
17 | self.proj = nn.Sequential(
18 | nn.GELU(),
19 | nn.Linear(config.hidden_size, config.hidden_size),
20 | )
21 |
22 | def forward(self, x, *args, **kwargs):
23 | height = width = self.hw
24 | assert height * width == x.shape[1]
25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
26 | x = self.conv_pool(x)
27 | x = x.flatten(2).transpose(1, 2)
28 | x = self.proj(x)
29 | return x
30 |
31 | @property
32 | def config(self):
33 | return {"mm_projector_type": "pooler"}
34 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_resampler/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .masked_drop import MaskedDrop
4 | from .spatial_pool import SpatialPool
5 | from .perceiver import PerceiverResampler
6 | from .qformer import Qformer
7 |
8 |
9 | class IdentityMap(torch.nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, x, *args, **kwargs):
14 | return x
15 |
16 | @property
17 | def config(self):
18 | return {"mm_resampler_type": None}
19 |
20 |
21 | def build_vision_resampler(model_args, delay_load=False, **kwargs):
22 | resampler_type = getattr(model_args, "mm_resampler_type", None)
23 | if resampler_type == "masked_drop":
24 | return MaskedDrop(model_args)
25 | elif resampler_type == "spatial_pool":
26 | return SpatialPool(model_args, **kwargs)
27 | elif resampler_type == "perceiver":
28 | return PerceiverResampler(model_args, **kwargs)
29 | elif resampler_type == "qformer":
30 | return Qformer(model_args, **kwargs)
31 | elif resampler_type is None:
32 | return IdentityMap()
33 |
34 | raise ValueError(f"Unknown resampler type: {resampler_type}")
35 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_resampler/masked_drop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import random
5 |
6 |
7 | class MaskedDrop(nn.Module):
8 | def __init__(self, model_args):
9 | super().__init__()
10 |
11 | self.mode = model_args.mm_mask_drop_mode
12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13 | self.ratio = model_args.mm_mask_drop_ratio
14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16 |
17 | def forward(self, image_features, *args, **kwargs):
18 |
19 | if not self.training:
20 | return image_features
21 |
22 | if self.skip_percentage > random.random():
23 | return image_features
24 |
25 | masked_features = []
26 |
27 | for image_feature in image_features:
28 | num_tokens = image_feature.shape[0]
29 | if self.mode == "fixed":
30 | num_keep = int(num_tokens * self.ratio)
31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32 | elif self.mode == "range":
33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35 | elif self.mode == "cls_only":
36 | masked_features.append(image_feature[0:1])
37 | else:
38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}")
39 |
40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
41 | masked_features = torch.stack(masked_features, dim=0)
42 |
43 | return masked_features
44 |
45 | @property
46 | def config(self):
47 | return {
48 | "mm_resampler_type": "masked_drop",
49 | "mm_mask_drop_mode": self.mode,
50 | "mm_mask_drop_skip_percentage": self.skip_percentage,
51 | "mm_mask_drop_ratio": self.ratio,
52 | "mm_mask_drop_ratio_upper": self.ratio_upper,
53 | "mm_mask_drop_ratio_lower": self.ratio_lower,
54 | }
55 |
56 | def random_masking(self, x, len_keep):
57 | """
58 | Perform per-sample random masking by per-sample shuffling.
59 | Per-sample shuffling is done by argsort random noise.
60 | x: [N, L, D], sequence
61 | """
62 | N, L, D = x.shape # batch, length, dim
63 |
64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
65 |
66 | # sort noise for each sample
67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
68 | ids_restore = torch.argsort(ids_shuffle, dim=1)
69 |
70 | # keep the first subset
71 | ids_keep = ids_shuffle[:, :len_keep]
72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
73 |
74 | # generate the binary mask: 0 is keep, 1 is remove
75 | mask = torch.ones([N, L], device=x.device)
76 | mask[:, :len_keep] = 0
77 | # unshuffle to get the binary mask
78 | mask = torch.gather(mask, dim=1, index=ids_restore)
79 |
80 | return x_masked, mask, ids_restore
81 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_resampler/perceiver.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/lucidrains/flamingo-pytorch
3 | """
4 |
5 | import torch
6 | from einops import rearrange, repeat
7 |
8 | try:
9 | from einops_exts import rearrange_many
10 | except:
11 | pass
12 |
13 | from torch import einsum, nn
14 |
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 |
20 | def FeedForward(dim, mult=4):
21 | inner_dim = int(dim * mult)
22 | return nn.Sequential(
23 | nn.LayerNorm(dim),
24 | nn.Linear(dim, inner_dim, bias=False),
25 | nn.GELU(),
26 | nn.Linear(inner_dim, dim, bias=False),
27 | )
28 |
29 |
30 | class PerceiverAttention(nn.Module):
31 | def __init__(self, *, dim, dim_head=64, heads=8):
32 | super().__init__()
33 | self.scale = dim_head**-0.5
34 | self.heads = heads
35 | inner_dim = dim_head * heads
36 |
37 | self.norm_media = nn.LayerNorm(dim)
38 | self.norm_latents = nn.LayerNorm(dim)
39 |
40 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
43 |
44 | def forward(self, x, latents):
45 | """
46 | Args:
47 | x (torch.Tensor): image features
48 | shape (b, T, n1, D)
49 | latent (torch.Tensor): latent features
50 | shape (b, T, n2, D)
51 | """
52 | x = self.norm_media(x)
53 | latents = self.norm_latents(latents)
54 |
55 | h = self.heads
56 |
57 | q = self.to_q(latents)
58 | kv_input = torch.cat((x, latents), dim=-2)
59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61 | q = q * self.scale
62 |
63 | # attention
64 | sim = einsum("... i d, ... j d -> ... i j", q, k)
65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
66 | attn = sim.softmax(dim=-1)
67 |
68 | out = einsum("... i j, ... j d -> ... i d", attn, v)
69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
70 | return self.to_out(out)
71 |
72 |
73 | class PerceiverResamplerModule(nn.Module):
74 | def __init__(
75 | self,
76 | *,
77 | dim,
78 | depth=6,
79 | dim_head=64,
80 | heads=8,
81 | num_latents=64,
82 | max_num_media=None,
83 | max_num_frames=None,
84 | ff_mult=4,
85 | ):
86 | super().__init__()
87 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
90 |
91 | self.layers = nn.ModuleList([])
92 | for _ in range(depth):
93 | self.layers.append(
94 | nn.ModuleList(
95 | [
96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
98 | ]
99 | )
100 | )
101 |
102 | self.norm = nn.LayerNorm(dim)
103 |
104 | def forward(self, x):
105 | """
106 | Args:
107 | x (torch.Tensor): image features
108 | shape (b, T, F, v, D)
109 | Returns:
110 | shape (b, T, n, D) where n is self.num_latents
111 | """
112 | b, T, F, v = x.shape[:4]
113 |
114 | # frame and media time embeddings
115 | if exists(self.frame_embs):
116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
117 | x = x + frame_embs
118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
119 | if exists(self.media_time_embs):
120 | x = x + self.media_time_embs[:T]
121 |
122 | # blocks
123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
124 | for attn, ff in self.layers:
125 | latents = attn(x, latents) + latents
126 | latents = ff(latents) + latents
127 | return self.norm(latents)
128 |
129 |
130 | class PerceiverResampler(nn.Module):
131 | def __init__(self, model_args, vision_tower):
132 | super().__init__()
133 |
134 | self.depth = model_args.mm_perceiver_depth
135 | self.num_latents = model_args.mm_perceiver_latents
136 | self.ff_mult = model_args.mm_perceiver_ff_mult
137 | self.pretrained = model_args.mm_perceiver_pretrained
138 |
139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
140 |
141 | if self.pretrained is not None:
142 | self.load_state_dict(torch.load(self.pretrained))
143 |
144 | def forward(self, image_features, *args, **kwargs):
145 | return self.perceiver(image_features[:, None, None]).squeeze(1)
146 |
147 | @property
148 | def config(self):
149 | return {
150 | "mm_resampler_type": "perceiver",
151 | "mm_perceiver_depth": self.depth,
152 | "mm_perceiver_latents": self.num_latents,
153 | "mm_perceiver_ff_mult": self.ff_mult,
154 | "mm_perceiver_pretrained": self.pretrained,
155 | }
156 |
--------------------------------------------------------------------------------
/longva/longva/model/multimodal_resampler/spatial_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class SpatialPool(nn.Module):
7 | def __init__(self, model_args, vision_tower):
8 | super().__init__()
9 |
10 | self.mode = model_args.mm_spatial_pool_mode
11 | self.stride = model_args.mm_spatial_pool_stride
12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
13 |
14 | if self.mode == "average":
15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
16 | elif self.mode == "max":
17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
18 | elif self.mode == "conv":
19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
20 | else:
21 | raise ValueError(f"Unknown pooling mode: {self.pool}.")
22 |
23 | def forward(self, image_features, images, *args, **kwargs):
24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
25 | ori_H = int(ori_W * images.shape[2] // images.shape[3])
26 |
27 | B, _, F = image_features.shape
28 |
29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
30 | image_features_spatial_pool = self.pool(image_features_spatial)
31 |
32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
33 |
34 | @property
35 | def config(self):
36 | return {
37 | "mm_resampler_type": "spatial_pool",
38 | "mm_spatial_pool_stride": self.stride,
39 | "mm_spatial_pool_mode": self.mode,
40 | "mm_spatial_pool_out_channels": self.out_channels,
41 | }
42 |
43 | @property
44 | def hidden_size(self):
45 | return self.out_channels
46 |
--------------------------------------------------------------------------------
/longva/longva/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if "llava" in config and "llava" not in cfg.model_type:
7 | assert cfg.model_type == "llama"
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = "LlavaLlamaForCausalLM"
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/longva/longva/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | padding_mask: Optional[torch.Tensor] = None,
25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26 | if output_attentions:
27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
28 |
29 | bsz, q_len, _ = hidden_states.size()
30 |
31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim)
34 |
35 | kv_seq_len = key_states.shape[-2]
36 | if past_key_value is not None:
37 | kv_seq_len += past_key_value[0].shape[-2]
38 |
39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
41 |
42 | if past_key_value is not None:
43 | # reuse k, v
44 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
45 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
46 |
47 | past_key_value = (key_states, value_states) if use_cache else None
48 |
49 | # repeat k/v heads if n_kv_heads < n_heads
50 | key_states = repeat_kv(key_states, self.num_key_value_groups)
51 | value_states = repeat_kv(value_states, self.num_key_value_groups)
52 |
53 | # Transform the data into the format required by flash attention
54 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
56 | key_padding_mask = attention_mask
57 |
58 | if key_padding_mask is None:
59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
61 | max_s = q_len
62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
63 | output = output.view(bsz, q_len, -1)
64 | else:
65 | qkv = qkv.reshape(bsz, q_len, -1)
66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
70 | output = pad_input(output_unpad, indices, bsz, q_len)
71 |
72 | return self.o_proj(output), None, past_key_value
73 |
74 |
75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
76 | # requires the attention mask to be the same as the key_padding_mask
77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
78 | # [bsz, seq_len]
79 | return attention_mask
80 |
81 |
82 | def replace_llama_attn_with_flash_attn():
83 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
84 | if cuda_major < 8:
85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593")
86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
88 |
--------------------------------------------------------------------------------
/longva/longva/train/train_mem.py:
--------------------------------------------------------------------------------
1 | from longva.train.train import train
2 |
3 | if __name__ == "__main__":
4 | train()
5 |
--------------------------------------------------------------------------------
/longva/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 240
3 |
4 | [build-system]
5 | requires = ["setuptools>=61.0"]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [project]
9 | name = "longva"
10 | version = "1.7.0.dev0"
11 | description = "Towards GPT-4 like large language and visual assistant."
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: Apache Software License",
17 | ]
18 |
19 | [project.optional-dependencies]
20 | standalone = [
21 | "shortuuid",
22 | "httpx==0.24.0",
23 | "einops",
24 | "ftfy",
25 | ]
26 |
27 | train = [
28 | "longva[standalone]",
29 | "open_clip_torch",
30 | "fastapi",
31 | "gradio",
32 | "markdown2[all]",
33 | "numpy",
34 | "requests",
35 | "sentencepiece",
36 | "uvicorn",
37 | "wandb==0.16.5",
38 | "deepspeed==0.14.2",
39 | "torch==2.1.2",
40 | "torchvision==0.16.2",
41 | "peft==0.4.0",
42 | "accelerate>=0.29.1",
43 | "tokenizers~=0.15.2",
44 | "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4",
45 | "bitsandbytes==0.41.0",
46 | "scikit-learn==1.2.2",
47 | "sentencepiece~=0.1.99",
48 | "einops==0.6.1",
49 | "einops-exts==0.0.4",
50 | "gradio_client",
51 | "urllib3<=2.0.0",
52 | "datasets==2.16.1",
53 | "pydantic==1.10.8",
54 | "timm",
55 | "hf_transfer",
56 | "opencv-python",
57 | "av",
58 | "decord",
59 | "tyro",
60 | "scipy",
61 | ]
62 |
63 |
64 | [tool.setuptools.packages.find]
65 | include = ["longva*", "trl*"]
66 | exclude = [
67 | "assets*",
68 | "benchmark*",
69 | "docs",
70 | "dist*",
71 | "playground*",
72 | "scripts*",
73 | "tests*",
74 | "checkpoints*",
75 | "project_checkpoints*",
76 | "debug_checkpoints*",
77 | "mlx_configs*",
78 | "wandb*",
79 | "notebooks*",
80 | ]
81 |
82 | [tool.wheel]
83 | exclude = [
84 | "assets*",
85 | "benchmark*",
86 | "docs",
87 | "dist*",
88 | "playground*",
89 | "scripts*",
90 | "tests*",
91 | "checkpoints*",
92 | "project_checkpoints*",
93 | "debug_checkpoints*",
94 | "mlx_configs*",
95 | "wandb*",
96 | "notebooks*",
97 | ]
--------------------------------------------------------------------------------
/longva/scripts/dpo.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=8
2 | export NCCL_IB_DISABLE=0
3 | export NCCL_IB_GID_INDEX=3
4 | # export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE}
5 | export NCCL_SOCKET_IFNAME=eth0
6 | export NCCL_DEBUG=INFO
7 |
8 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
9 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
10 |
11 | ############### Pretrain ################
12 |
13 | # Stage 2
14 | PROMPT_VERSION="qwen_1_5"
15 |
16 | #torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \
17 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \
18 | longva/train/train_dpo.py \
19 | --deepspeed scripts/zero3.json \
20 | --model_name_or_path lmms-lab/LongVA-7B \
21 | --version $PROMPT_VERSION \
22 | --dpo_alpha 1.0 --beta 0.1 --gamma 0 \
23 | --data_path="/data/llava_video/shareVideoGPTV/dpo/sft_dpo_17k.jsonl" \
24 | --image_folder /data/llava_data \
25 | --video_folder /llava_video/shareVideoGPTV/frames/all_frames/ \
26 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
27 | --vision_tower ${VISION_MODEL_VERSION} \
28 | --mm_projector_type mlp2x_gelu \
29 | --mm_vision_select_layer -2 \
30 | --mm_use_im_start_end False \
31 | --mm_use_im_patch_token False \
32 | --mm_spatial_pool_stride 2 \
33 | --mm_resampler_type "spatial_pool" \
34 | --mm_spatial_pool_out_channels 1024 \
35 | --group_by_modality_length True \
36 | --image_aspect_ratio anyres \
37 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \
38 | --mm_patch_merge_type unires \
39 | --bf16 True \
40 | --run_name $MID_RUN_NAME \
41 | --output_dir "/checkpoints/${MID_RUN_NAME}" \
42 | --num_train_epochs 3 \
43 | --per_device_train_batch_size 1 \
44 | --per_device_eval_batch_size 4 \
45 | --gradient_accumulation_steps 16 \
46 | --evaluation_strategy "no" \
47 | --save_strategy "steps" \
48 | --save_steps 3000 \
49 | --save_total_limit 1 \
50 | --learning_rate 5e-7 \
51 | --weight_decay 0. \
52 | --warmup_ratio 0.1 \
53 | --lr_scheduler_type "linear" \
54 | --logging_steps 1 \
55 | --tf32 True \
56 | --model_max_length 32768 \
57 | --gradient_checkpointing True \
58 | --dataloader_num_workers 16 \
59 | --lazy_preprocess True \
60 | --report_to wandb \
61 | --torch_compile True \
62 | --torch_compile_backend "inductor" \
63 | --dataloader_drop_last True \
64 | --attn_implementation sdpa
--------------------------------------------------------------------------------
/longva/scripts/finetune.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=8
2 | export NCCL_IB_DISABLE=0
3 | export NCCL_IB_GID_INDEX=3
4 | export NCCL_SOCKET_IFNAME=eth0
5 | export NCCL_DEBUG=INFO
6 |
7 | LLM_VERSION="lmms-lab/Qwen2-7B-Instruct-224K"
8 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
9 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
10 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
11 |
12 | ############### Pretrain ################
13 |
14 | PROMPT_VERSION=qwen_1_5
15 |
16 | BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain"
17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
18 |
19 | CKPT_PATH=$LLM_VERSION # this could also be the previous stage checkpoint
20 |
21 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \
22 | longva/train/train_mem.py \
23 | --deepspeed scripts/zero3.json \
24 | --model_name_or_path ${CKPT_PATH} \
25 | --version ${PROMPT_VERSION} \
26 | --data_path=llava_1_6.json \
27 | --image_folder your_image_folder \
28 | --pretrain_mm_mlp_adapter="/checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \
29 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
30 | --mm_vision_tower_lr=2e-6 \
31 | --vision_tower ${VISION_MODEL_VERSION} \
32 | --mm_projector_type mlp2x_gelu \
33 | --mm_vision_select_layer -2 \
34 | --mm_use_im_start_end False \
35 | --mm_use_im_patch_token False \
36 | --group_by_modality_length True \
37 | --image_aspect_ratio anyres \
38 | --image_grid_pinpoints "[(336, 672), (336, 1008), (336, 1344), (336, 1680), (336, 2016), (336, 2352), (336, 2688), (336, 3024), (336, 3360), (336, 3696), (336, 4032), (336, 4368), (336, 4704), (336, 5040), (336, 5376), (336, 5712), (336, 6048), (336, 6384), (336, 6720), (336, 7056), (336, 7392), (336, 7728), (336, 8064), (336, 8400), (336, 8736), (336, 9072), (336, 9408), (336, 9744), (336, 10080), (336, 10416), (336, 10752), (336, 11088), (336, 11424), (336, 11760), (336, 12096), (336, 12432), (336, 12768), (336, 13104), (336, 13440), (336, 13776), (336, 14112), (336, 14448), (336, 14784), (336, 15120), (336, 15456), (336, 15792), (336, 16128), (336, 16464), (672, 336), (672, 672), (672, 1008), (672, 1344), (672, 1680), (672, 2016), (672, 2352), (672, 2688), (672, 3024), (672, 3360), (672, 3696), (672, 4032), (672, 4368), (672, 4704), (672, 5040), (672, 5376), (672, 5712), (672, 6048), (672, 6384), (672, 6720), (672, 7056), (672, 7392), (672, 7728), (672, 8064), (1008, 336), (1008, 672), (1008, 1008), (1008, 1344), (1008, 1680), (1008, 2016), (1008, 2352), (1008, 2688), (1008, 3024), (1008, 3360), (1008, 3696), (1008, 4032), (1008, 4368), (1008, 4704), (1008, 5040), (1008, 5376), (1344, 336), (1344, 672), (1344, 1008), (1344, 1344), (1344, 1680), (1344, 2016), (1344, 2352), (1344, 2688), (1344, 3024), (1344, 3360), (1344, 3696), (1344, 4032), (1680, 336), (1680, 672), (1680, 1008), (1680, 1344), (1680, 1680), (1680, 2016), (1680, 2352), (1680, 2688), (1680, 3024), (2016, 336), (2016, 672), (2016, 1008), (2016, 1344), (2016, 1680), (2016, 2016), (2016, 2352), (2016, 2688), (2352, 336), (2352, 672), (2352, 1008), (2352, 1344), (2352, 1680), (2352, 2016), (2352, 2352), (2688, 336), (2688, 672), (2688, 1008), (2688, 1344), (2688, 1680), (2688, 2016), (3024, 336), (3024, 672), (3024, 1008), (3024, 1344), (3024, 1680), (3360, 336), (3360, 672), (3360, 1008), (3360, 1344), (3696, 336), (3696, 672), (3696, 1008), (3696, 1344), (4032, 336), (4032, 672), (4032, 1008), (4032, 1344), (4368, 336), (4368, 672), (4368, 1008), (4704, 336), (4704, 672), (4704, 1008), (5040, 336), (5040, 672), (5040, 1008), (5376, 336), (5376, 672), (5376, 1008), (5712, 336), (5712, 672), (6048, 336), (6048, 672), (6384, 336), (6384, 672), (6720, 336), (6720, 672), (7056, 336), (7056, 672), (7392, 336), (7392, 672), (7728, 336), (7728, 672), (8064, 336), (8064, 672), (8400, 336), (8736, 336), (9072, 336), (9408, 336), (9744, 336), (10080, 336), (10416, 336), (10752, 336), (11088, 336), (11424, 336), (11760, 336), (12096, 336), (12432, 336), (12768, 336), (13104, 336), (13440, 336), (13776, 336), (14112, 336), (14448, 336), (14784, 336), (15120, 336), (15456, 336), (15792, 336), (16128, 336), (16464, 336)]" \
39 | --mm_patch_merge_type unires \
40 | --bf16 True \
41 | --run_name $MID_RUN_NAME \
42 | --output_dir "/checkpoints/${MID_RUN_NAME}" \
43 | --num_train_epochs 1 \
44 | --per_device_train_batch_size 4 \
45 | --per_device_eval_batch_size 4 \
46 | --gradient_accumulation_steps 1 \
47 | --evaluation_strategy "no" \
48 | --save_strategy "steps" \
49 | --save_steps 3000 \
50 | --save_total_limit 1 \
51 | --learning_rate 1e-5 \
52 | --weight_decay 0. \
53 | --warmup_ratio 0.03 \
54 | --lr_scheduler_type "cosine" \
55 | --logging_steps 1 \
56 | --tf32 True \
57 | --model_max_length 32768 \
58 | --gradient_checkpointing True \
59 | --dataloader_num_workers 16 \
60 | --lazy_preprocess True \
61 | --report_to wandb \
62 | --torch_compile True \
63 | --torch_compile_backend "inductor" \
64 | --dataloader_drop_last True \
65 | --attn_implementation sdpa
66 |
67 | # You can delete the sdpa attn_implementation if you want to use flash attn
68 |
--------------------------------------------------------------------------------
/longva/scripts/pretrain.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=8
2 | export NCCL_IB_DISABLE=0
3 | export NCCL_IB_GID_INDEX=3
4 | export NCCL_SOCKET_IFNAME=eth0
5 | export NCCL_DEBUG=INFO
6 |
7 | LLM_VERSION="Qwen/Qwen2-7B-Instruct"
8 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
9 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
10 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
11 |
12 | ############### Pretrain ################
13 |
14 | PROMPT_VERSION=plain
15 |
16 | BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain"
17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
18 |
19 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \
20 | longva/train/train_mem.py \
21 | --deepspeed scripts/zero3.json \
22 | --model_name_or_path ${LLM_VERSION} \
23 | --version ${PROMPT_VERSION} \
24 | --data_path /blip_558k/blip_558k_plain.json \
25 | --image_folder /blip_558k/images \
26 | --vision_tower ${VISION_MODEL_VERSION} \
27 | --mm_tunable_parts="mm_mlp_adapter" \
28 | --mm_vision_select_layer -2 \
29 | --mm_projector_type mlp2x_gelu \
30 | --mm_use_im_start_end False \
31 | --mm_use_im_patch_token False \
32 | --bf16 True \
33 | --output_dir /checkpoints/projectors/${BASE_RUN_NAME} \
34 | --num_train_epochs 1 \
35 | --per_device_train_batch_size 16 \
36 | --per_device_eval_batch_size 4 \
37 | --gradient_accumulation_steps 1 \
38 | --evaluation_strategy "no" \
39 | --save_strategy "no" \
40 | --save_steps 50000 \
41 | --learning_rate 1e-3 \
42 | --weight_decay 0. \
43 | --warmup_ratio 0.03 \
44 | --lr_scheduler_type "cosine" \
45 | --logging_steps 1 \
46 | --tf32 True \
47 | --model_max_length 8192 \
48 | --gradient_checkpointing True \
49 | --dataloader_num_workers 16 \
50 | --lazy_preprocess True \
51 | --report_to wandb \
52 | --run_name $BASE_RUN_NAME \
53 | --attn_implementation sdpa
54 |
55 | # You can delete the sdpa attn_implementation if you want to use flash attn
--------------------------------------------------------------------------------
/longva/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/longva/trl/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | __version__ = "0.7.11.dev0"
4 |
5 | from .core import set_seed
6 | from .environment import TextEnvironment, TextHistory
7 | from .extras import BestOfNSampler
8 | from .import_utils import (
9 | is_bitsandbytes_available,
10 | is_diffusers_available,
11 | is_npu_available,
12 | is_peft_available,
13 | is_wandb_available,
14 | is_xpu_available,
15 | )
16 | from .models import (
17 | AutoModelForCausalLMWithValueHead,
18 | AutoModelForSeq2SeqLMWithValueHead,
19 | PreTrainedModelWrapper,
20 | create_reference_model,
21 | setup_chat_format,
22 | )
23 | from .trainer import (
24 | DataCollatorForCompletionOnlyLM,
25 | DPOTrainer,
26 | IterativeSFTTrainer,
27 | ModelConfig,
28 | PPOConfig,
29 | PPOTrainer,
30 | RewardConfig,
31 | RewardTrainer,
32 | SFTTrainer,
33 | )
34 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
35 |
36 |
37 | if is_diffusers_available():
38 | from .models import (
39 | DDPOPipelineOutput,
40 | DDPOSchedulerOutput,
41 | DDPOStableDiffusionPipeline,
42 | DefaultDDPOStableDiffusionPipeline,
43 | )
44 | from .trainer import DDPOConfig, DDPOTrainer
45 |
--------------------------------------------------------------------------------
/longva/trl/environment/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | from .base_environment import TextEnvironment, TextHistory
4 |
--------------------------------------------------------------------------------
/longva/trl/extras/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | # Copyright 2022 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | from .best_of_n_sampler import BestOfNSampler
17 |
--------------------------------------------------------------------------------
/longva/trl/extras/best_of_n_sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Optional, Union
2 |
3 | import torch
4 | from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
5 |
6 | from ..core import set_seed
7 | from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
8 |
9 |
10 | class BestOfNSampler(object):
11 | def __init__(
12 | self,
13 | model: PreTrainedModelWrapper,
14 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
15 | queries_to_scores: Callable[[List[str]], List[float]],
16 | length_sampler: Any,
17 | sample_size: int = 4,
18 | seed: Optional[int] = None,
19 | n_candidates: int = 1,
20 | generation_config: Optional[GenerationConfig] = None,
21 | ) -> None:
22 | r"""
23 | Initialize the sampler for best-of-n generation
24 |
25 | Args:
26 | model (`PreTrainedModelWrapper`):
27 | The pretrained model to use for generation
28 | tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`):
29 | Tokenizer associated with the pretrained model
30 | queries_to_scores (`Callable[[List[str]], List[float]]`):
31 | Callable that takes a list of generated texts and returns the associated reward scores
32 | length_sampler (`Any`):
33 | Sampler used to sample the length of the generated text
34 | sample_size (`int`):
35 | Number of samples to generate for each query
36 | seed (`int`, *optional*):
37 | Random seed used to control generation
38 | n_candidates (`int`):
39 | Number of candidates to return for each query
40 | generation_config (`GenerationConfig`, *optional*):
41 | Generation config passed to the underlying model's `generate` method.
42 | See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details
43 | """
44 | if seed is not None:
45 | set_seed(seed)
46 |
47 | if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
48 | raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}")
49 | if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
50 | raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}")
51 |
52 | self.model = model
53 | self.tokenizer = tokenizer
54 |
55 | self.queries_to_scores = queries_to_scores
56 | self.length_sampler = length_sampler
57 | self.gen_config = generation_config
58 | self.sample_size = sample_size
59 | self.n_candidates = n_candidates
60 |
61 | def generate(
62 | self,
63 | tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]],
64 | skip_special_tokens: bool = True,
65 | device: Optional[Union[str, torch.device]] = None,
66 | **generation_kwargs,
67 | ) -> List[List[str]]:
68 | r"""
69 | Generate the best of n samples for input queries
70 |
71 | Args:
72 | tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`):
73 | represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
74 | skip_special_tokens (`bool`):
75 | Whether to remove the special tokens from the output
76 | device (`str` or `torch.device`, *optional*):
77 | The device on which the model will be loaded
78 | **generation_kwargs (`dict`, *optional*):
79 | Additional keyword arguments passed along to the underlying model's `generate` method.
80 | This is used to override generation config
81 |
82 | Returns:
83 | List[List[str]]: A list of lists of generated texts
84 | """
85 | queries = None
86 |
87 | if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
88 | queries = tokenized_query.unsqueeze(0)
89 | elif isinstance(tokenized_query, List):
90 | element_type = type(tokenized_query[0])
91 | if element_type == int:
92 | queries = torch.tensor(tokenized_query).unsqueeze(0)
93 | elif element_type == torch.Tensor:
94 | queries = [tensor.reshape((1, -1)) for tensor in tokenized_query]
95 | else:
96 | queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query]
97 |
98 | result = []
99 |
100 | for query in queries:
101 | queries = query.repeat((self.sample_size, 1))
102 | output = self.model.generate(
103 | queries.to(device),
104 | max_new_tokens=self.length_sampler(),
105 | generation_config=self.gen_config,
106 | **generation_kwargs,
107 | ).squeeze()
108 | output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens)
109 | scores = torch.tensor(self.queries_to_scores(output))
110 | output = [output[i] for i in scores.topk(self.n_candidates).indices]
111 | result.append(output)
112 |
113 | return result
114 |
--------------------------------------------------------------------------------
/longva/trl/extras/dataset_formatting.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Callable, Literal, Optional, Union
3 |
4 | from datasets import Dataset, Value
5 | from transformers import AutoTokenizer
6 |
7 | from ..trainer.utils import ConstantLengthDataset
8 |
9 |
10 | FORMAT_MAPPING = {
11 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
12 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
13 | }
14 |
15 |
16 | def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
17 | r"""
18 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
19 | apply chat template to the dataset
20 | """
21 |
22 | def format_dataset(examples):
23 | if isinstance(examples[messages_field][0], list):
24 | output_texts = []
25 | for i in range(len(examples[messages_field])):
26 | output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
27 | return output_texts
28 | else:
29 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
30 |
31 | return format_dataset
32 |
33 |
34 | def instructions_formatting_function(tokenizer: AutoTokenizer):
35 | r"""
36 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
37 | apply chat template to the dataset
38 | """
39 |
40 | def format_dataset(examples):
41 | if isinstance(examples["prompt"], list):
42 | output_texts = []
43 | for i in range(len(examples["prompt"])):
44 | converted_sample = [
45 | {"role": "user", "content": examples["prompt"][i]},
46 | {"role": "assistant", "content": examples["completion"][i]},
47 | ]
48 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
49 | return output_texts
50 | else:
51 | converted_sample = [
52 | {"role": "user", "content": examples["prompt"]},
53 | {"role": "assistant", "content": examples["completion"]},
54 | ]
55 | return tokenizer.apply_chat_template(converted_sample, tokenize=False)
56 |
57 | return format_dataset
58 |
59 |
60 | def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
61 | r"""
62 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
63 | - `ChatML` with [{"role": str, "content": str}]
64 | - `instruction` with [{"prompt": str, "completion": str}]
65 |
66 | Args:
67 | dataset (Dataset): User dataset
68 | tokenizer (AutoTokenizer): Tokenizer used for formatting
69 |
70 | Returns:
71 | Callable: Formatting function if the dataset format is supported else None
72 | """
73 | if isinstance(dataset, Dataset):
74 | if "messages" in dataset.features:
75 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
76 | logging.info("Formatting dataset with chatml format")
77 | return conversations_formatting_function(tokenizer, "messages")
78 | if "conversations" in dataset.features:
79 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
80 | logging.info("Formatting dataset with chatml format")
81 | return conversations_formatting_function(tokenizer, "conversations")
82 | elif dataset.features == FORMAT_MAPPING["instruction"]:
83 | logging.info("Formatting dataset with instruction format")
84 | return instructions_formatting_function(tokenizer)
85 |
86 | return None
87 |
--------------------------------------------------------------------------------
/longva/trl/import_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import importlib
15 | import sys
16 |
17 |
18 | if sys.version_info < (3, 8):
19 | _is_python_greater_3_8 = False
20 | else:
21 | _is_python_greater_3_8 = True
22 |
23 |
24 | def is_peft_available() -> bool:
25 | return importlib.util.find_spec("peft") is not None
26 |
27 |
28 | def is_unsloth_available() -> bool:
29 | return importlib.util.find_spec("unsloth") is not None
30 |
31 |
32 | def is_accelerate_greater_20_0() -> bool:
33 | if _is_python_greater_3_8:
34 | from importlib.metadata import version
35 |
36 | accelerate_version = version("accelerate")
37 | else:
38 | import pkg_resources
39 |
40 | accelerate_version = pkg_resources.get_distribution("accelerate").version
41 | return accelerate_version >= "0.20.0"
42 |
43 |
44 | def is_transformers_greater_than(version: str) -> bool:
45 | _transformers_version = importlib.metadata.version("transformers")
46 | return _transformers_version > version
47 |
48 |
49 | def is_torch_greater_2_0() -> bool:
50 | if _is_python_greater_3_8:
51 | from importlib.metadata import version
52 |
53 | torch_version = version("torch")
54 | else:
55 | import pkg_resources
56 |
57 | torch_version = pkg_resources.get_distribution("torch").version
58 | return torch_version >= "2.0"
59 |
60 |
61 | def is_diffusers_available() -> bool:
62 | return importlib.util.find_spec("diffusers") is not None
63 |
64 |
65 | def is_bitsandbytes_available() -> bool:
66 | import torch
67 |
68 | # bnb can be imported without GPU but is not usable.
69 | return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()
70 |
71 |
72 | def is_torchvision_available() -> bool:
73 | return importlib.util.find_spec("torchvision") is not None
74 |
75 |
76 | def is_rich_available() -> bool:
77 | return importlib.util.find_spec("rich") is not None
78 |
79 |
80 | def is_wandb_available() -> bool:
81 | return importlib.util.find_spec("wandb") is not None
82 |
83 |
84 | def is_xpu_available() -> bool:
85 | if is_accelerate_greater_20_0():
86 | import accelerate
87 |
88 | return accelerate.utils.is_xpu_available()
89 | else:
90 | if importlib.util.find_spec("intel_extension_for_pytorch") is None:
91 | return False
92 | try:
93 | import torch
94 |
95 | return hasattr(torch, "xpu") and torch.xpu.is_available()
96 | except RuntimeError:
97 | return False
98 |
99 |
100 | def is_npu_available() -> bool:
101 | """Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
102 | if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
103 | return False
104 |
105 | import torch
106 | import torch_npu # noqa: F401
107 |
108 | return hasattr(torch, "npu") and torch.npu.is_available()
109 |
--------------------------------------------------------------------------------
/longva/trl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | # Copyright 2022 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | from .modeling_base import PreTrainedModelWrapper, create_reference_model
17 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
18 | from .utils import setup_chat_format
19 |
20 |
21 | SUPPORTED_ARCHITECTURES = (
22 | AutoModelForCausalLMWithValueHead,
23 | AutoModelForSeq2SeqLMWithValueHead,
24 | )
25 |
26 | from ..import_utils import is_diffusers_available
27 |
28 |
29 | if is_diffusers_available():
30 | from .modeling_sd_base import (
31 | DDPOPipelineOutput,
32 | DDPOSchedulerOutput,
33 | DDPOStableDiffusionPipeline,
34 | DefaultDDPOStableDiffusionPipeline,
35 | )
36 |
--------------------------------------------------------------------------------
/longva/trl/models/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal, Optional, Tuple
3 |
4 | from transformers import PreTrainedModel, PreTrainedTokenizer
5 |
6 |
7 | # TODO: Add Abstract Base Class if more formats are added
8 | @dataclass
9 | class ChatMlSpecialTokens:
10 | """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
11 |
12 | bos_token: str = "<|im_start|>"
13 | eos_token: str = "<|im_end|>"
14 | pad_token: str = "<|im_end|>"
15 |
16 | @property
17 | def system(self):
18 | return f"{self.bos_token}system"
19 |
20 | @property
21 | def user(self):
22 | return f"{self.bos_token}user"
23 |
24 | @property
25 | def assistant(self):
26 | return f"{self.bos_token}assistant"
27 |
28 | @property
29 | def chat_template(self):
30 | return (
31 | "{% for message in messages %}"
32 | f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
33 | "{% endfor %}"
34 | "{% if add_generation_prompt %}"
35 | f"{{{{ '{self.assistant}\n' }}}}"
36 | "{% endif %}"
37 | )
38 |
39 |
40 | FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
41 |
42 |
43 | def setup_chat_format(
44 | model: PreTrainedModel,
45 | tokenizer: PreTrainedTokenizer,
46 | format: Optional[Literal["chatml"]] = "chatml",
47 | resize_to_multiple_of: Optional[int] = None,
48 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
49 | """
50 | Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
51 |
52 | Args:
53 | model (`~transformers.PreTrainedModel`): The model to be modified.
54 | tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
55 | format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
56 | resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
57 | Returns:
58 | model (`~transformers.PreTrainedModel`): The modified model.
59 | tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
60 | """
61 | # check if format available and retrieve
62 | if format not in FORMAT_MAPPING:
63 | raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
64 |
65 | chat_format = FORMAT_MAPPING[format]()
66 |
67 | # set special tokens and them
68 | tokenizer.eos_token = chat_format.eos_token
69 | tokenizer.pad_token = chat_format.pad_token
70 | tokenizer.bos_token = chat_format.bos_token
71 | tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
72 | # set chat format for tokenizer
73 | tokenizer.chat_template = chat_format.chat_template
74 |
75 | # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
76 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None)
77 | # Make sure to update the generation config to use the new eos & bos token
78 | if getattr(model, "generation_config", None) is not None:
79 | model.generation_config.bos_token_id = tokenizer.bos_token_id
80 | model.generation_config.eos_token_id = tokenizer.eos_token_id
81 | model.generation_config.pad_token_id = tokenizer.pad_token_id
82 |
83 | return model, tokenizer
84 |
--------------------------------------------------------------------------------
/longva/trl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | # Copyright 2022 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # There is a circular import in the PPOTrainer if we let isort sort these
18 | # isort: off
19 | from .utils import (
20 | AdaptiveKLController,
21 | FixedKLController,
22 | ConstantLengthDataset,
23 | DataCollatorForCompletionOnlyLM,
24 | RunningMoments,
25 | disable_dropout_in_model,
26 | peft_module_casting_to_bf16,
27 | )
28 |
29 | # isort: on
30 |
31 | from ..import_utils import is_diffusers_available
32 | from .base import BaseTrainer
33 | from .ddpo_config import DDPOConfig
34 |
35 |
36 | if is_diffusers_available():
37 | from .ddpo_trainer import DDPOTrainer
38 |
39 | from .dpo_trainer import DPOTrainer
40 | from .iterative_sft_trainer import IterativeSFTTrainer
41 | from .model_config import ModelConfig
42 | from .ppo_config import PPOConfig
43 | from .ppo_trainer import PPOTrainer
44 | from .reward_config import RewardConfig
45 | from .reward_trainer import RewardTrainer, compute_accuracy
46 | from .sft_trainer import SFTTrainer
47 |
--------------------------------------------------------------------------------
/longva/trl/trainer/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from huggingface_hub import PyTorchModelHubMixin
16 |
17 |
18 | class BaseTrainer(PyTorchModelHubMixin):
19 | r"""
20 | Base class for all trainers - this base class implements the basic functions that we
21 | need for a trainer.
22 |
23 | The trainer needs to have the following functions:
24 | - step: takes in a batch of data and performs a step of training
25 | - loss: takes in a batch of data and returns the loss
26 | - compute_rewards: takes in a batch of data and returns the rewards
27 | - _build_models_and_tokenizer: builds the models and tokenizer
28 | - _build_dataset: builds the dataset
29 | Each user is expected to implement their own trainer class that inherits from this base
30 | if they want to use a new training algorithm.
31 | """
32 |
33 | def __init__(self, config):
34 | self.config = config
35 |
36 | def step(self, *args):
37 | raise NotImplementedError("Not implemented")
38 |
39 | def loss(self, *args):
40 | raise NotImplementedError("Not implemented")
41 |
42 | def compute_rewards(self, *args):
43 | raise NotImplementedError("Not implemented")
44 |
45 | def _save_pretrained(self, save_directory):
46 | raise NotImplementedError("Not implemented")
47 |
--------------------------------------------------------------------------------
/longva/trl/trainer/ddpo_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import warnings
4 | from dataclasses import dataclass, field
5 | from typing import Literal, Optional
6 |
7 | from ..core import flatten_dict
8 | from ..import_utils import is_bitsandbytes_available, is_torchvision_available
9 |
10 |
11 | @dataclass
12 | class DDPOConfig:
13 | """
14 | Configuration class for DDPOTrainer
15 | """
16 |
17 | # common parameters
18 | exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
19 | """the name of this experiment (by default is the file name without the extension name)"""
20 | run_name: Optional[str] = ""
21 | """Run name for wandb logging and checkpoint saving."""
22 | seed: int = 0
23 | """Seed value for random generations"""
24 | log_with: Optional[Literal["wandb", "tensorboard"]] = None
25 | """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
26 | tracker_kwargs: dict = field(default_factory=dict)
27 | """Keyword arguments for the tracker (e.g. wandb_project)"""
28 | accelerator_kwargs: dict = field(default_factory=dict)
29 | """Keyword arguments for the accelerator"""
30 | project_kwargs: dict = field(default_factory=dict)
31 | """Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
32 | tracker_project_name: str = "trl"
33 | """Name of project to use for tracking"""
34 | logdir: str = "logs"
35 | """Top-level logging directory for checkpoint saving."""
36 |
37 | # hyperparameters
38 | num_epochs: int = 100
39 | """Number of epochs to train."""
40 | save_freq: int = 1
41 | """Number of epochs between saving model checkpoints."""
42 | num_checkpoint_limit: int = 5
43 | """Number of checkpoints to keep before overwriting old ones."""
44 | mixed_precision: str = "fp16"
45 | """Mixed precision training."""
46 | allow_tf32: bool = True
47 | """Allow tf32 on Ampere GPUs."""
48 | resume_from: Optional[str] = ""
49 | """Resume training from a checkpoint."""
50 | sample_num_steps: int = 50
51 | """Number of sampler inference steps."""
52 | sample_eta: float = 1.0
53 | """Eta parameter for the DDIM sampler."""
54 | sample_guidance_scale: float = 5.0
55 | """Classifier-free guidance weight."""
56 | sample_batch_size: int = 1
57 | """Batch size (per GPU!) to use for sampling."""
58 | sample_num_batches_per_epoch: int = 2
59 | """Number of batches to sample per epoch."""
60 | train_batch_size: int = 1
61 | """Batch size (per GPU!) to use for training."""
62 | train_use_8bit_adam: bool = False
63 | """Whether to use the 8bit Adam optimizer from bitsandbytes."""
64 | train_learning_rate: float = 3e-4
65 | """Learning rate."""
66 | train_adam_beta1: float = 0.9
67 | """Adam beta1."""
68 | train_adam_beta2: float = 0.999
69 | """Adam beta2."""
70 | train_adam_weight_decay: float = 1e-4
71 | """Adam weight decay."""
72 | train_adam_epsilon: float = 1e-8
73 | """Adam epsilon."""
74 | train_gradient_accumulation_steps: int = 1
75 | """Number of gradient accumulation steps."""
76 | train_max_grad_norm: float = 1.0
77 | """Maximum gradient norm for gradient clipping."""
78 | train_num_inner_epochs: int = 1
79 | """Number of inner epochs per outer epoch."""
80 | train_cfg: bool = True
81 | """Whether or not to use classifier-free guidance during training."""
82 | train_adv_clip_max: float = 5
83 | """Clip advantages to the range."""
84 | train_clip_range: float = 1e-4
85 | """The PPO clip range."""
86 | train_timestep_fraction: float = 1.0
87 | """The fraction of timesteps to train on."""
88 | per_prompt_stat_tracking: bool = False
89 | """Whether to track statistics for each prompt separately."""
90 | per_prompt_stat_tracking_buffer_size: int = 16
91 | """Number of reward values to store in the buffer for each prompt."""
92 | per_prompt_stat_tracking_min_count: int = 16
93 | """The minimum number of reward values to store in the buffer."""
94 | async_reward_computation: bool = False
95 | """Whether to compute rewards asynchronously."""
96 | max_workers: int = 2
97 | """The maximum number of workers to use for async reward computation."""
98 | negative_prompts: Optional[str] = ""
99 | """Comma-separated list of prompts to use as negative examples."""
100 |
101 | def to_dict(self):
102 | output_dict = {}
103 | for key, value in self.__dict__.items():
104 | output_dict[key] = value
105 | return flatten_dict(output_dict)
106 |
107 | def __post_init__(self):
108 | if self.log_with not in ["wandb", "tensorboard"]:
109 | warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."))
110 |
111 | if self.log_with == "wandb" and not is_torchvision_available():
112 | warnings.warn("Wandb image logging requires torchvision to be installed")
113 |
114 | if self.train_use_8bit_adam and not is_bitsandbytes_available():
115 | raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.")
116 |
--------------------------------------------------------------------------------
/longva/trl/trainer/model_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Optional
3 |
4 | from ..core import flatten_dict
5 |
6 |
7 | @dataclass
8 | class ModelConfig:
9 | """
10 | Arguments which define the model and tokenizer to load.
11 | """
12 |
13 | model_name_or_path: Optional[str] = field(
14 | default=None,
15 | metadata={"help": ("The model checkpoint for weights initialization.")},
16 | )
17 | model_revision: str = field(
18 | default="main",
19 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
20 | )
21 | torch_dtype: Optional[str] = field(
22 | default=None,
23 | metadata={
24 | "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."),
25 | "choices": ["auto", "bfloat16", "float16", "float32"],
26 | },
27 | )
28 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
29 | attn_implementation: Optional[str] = field(
30 | default=None,
31 | metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")},
32 | )
33 | use_peft: bool = field(
34 | default=False,
35 | metadata={"help": ("Whether to use PEFT or not for training.")},
36 | )
37 | lora_r: Optional[int] = field(
38 | default=16,
39 | metadata={"help": ("LoRA R value.")},
40 | )
41 | lora_alpha: Optional[int] = field(
42 | default=32,
43 | metadata={"help": ("LoRA alpha.")},
44 | )
45 | lora_dropout: Optional[float] = field(
46 | default=0.05,
47 | metadata={"help": ("LoRA dropout.")},
48 | )
49 | lora_target_modules: Optional[List[str]] = field(
50 | default=None,
51 | metadata={"help": ("LoRA target modules.")},
52 | )
53 | lora_modules_to_save: Optional[List[str]] = field(
54 | default=None,
55 | metadata={"help": ("Model layers to unfreeze & train")},
56 | )
57 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"})
58 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"})
59 |
60 | bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
61 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
62 |
63 | def to_dict(self):
64 | output_dict = {}
65 | for key, value in self.__dict__.items():
66 | output_dict[key] = value
67 | return flatten_dict(output_dict)
68 |
69 | def __post_init__(self):
70 | if self.load_in_8bit and self.load_in_4bit:
71 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
72 |
--------------------------------------------------------------------------------
/longva/trl/trainer/reward_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class RewardConfig(TrainingArguments):
23 | """
24 | RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
25 |
26 | Using [`HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | max_length (`int`, *optional*, defaults to `None`):
32 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
33 | gradient_checkpointing (`bool`, *optional*, defaults to `True`):
34 | If True, use gradient checkpointing to save memory at the expense of slower backward pass.
35 | """
36 |
37 | max_length: Optional[int] = None
38 | """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
39 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.28.0
2 | decord>=0.6.0
3 | datasets
4 | evaluate>=0.4.1
5 | tqdm
6 | einops
7 | sentencepiece
8 | scikit-learn
9 | matplotlib
10 | numpy==1.26.4
11 | ring_flash_attn@git+https://github.com/zhuzilin/ring-flash-attention.git@b39e427e737ed9a515c88122f88f9183b7eb32de
12 | deepspeed==0.14.0
13 | wandb
14 | seaborn
15 | pandas
16 | pytest
17 | loguru
18 | gradio==4.29.0
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/bias.txt:
--------------------------------------------------------------------------------
1 | October 2015This will come as a surprise to a lot of people, but in some cases
2 | it's possible to detect bias in a selection process without knowing
3 | anything about the applicant pool. Which is exciting because among
4 | other things it means third parties can use this technique to detect
5 | bias whether those doing the selecting want them to or not.You can use this technique whenever (a) you have at least
6 | a random sample of the applicants that were selected, (b) their
7 | subsequent performance is measured, and (c) the groups of
8 | applicants you're comparing have roughly equal distribution of ability.How does it work? Think about what it means to be biased. What
9 | it means for a selection process to be biased against applicants
10 | of type x is that it's harder for them to make it through. Which
11 | means applicants of type x have to be better to get selected than
12 | applicants not of type x.
13 | [1]
14 | Which means applicants of type x
15 | who do make it through the selection process will outperform other
16 | successful applicants. And if the performance of all the successful
17 | applicants is measured, you'll know if they do.Of course, the test you use to measure performance must be a valid
18 | one. And in particular it must not be invalidated by the bias you're
19 | trying to measure.
20 | But there are some domains where performance can be measured, and
21 | in those detecting bias is straightforward. Want to know if the
22 | selection process was biased against some type of applicant? Check
23 | whether they outperform the others. This is not just a heuristic
24 | for detecting bias. It's what bias means.For example, many suspect that venture capital firms are biased
25 | against female founders. This would be easy to detect: among their
26 | portfolio companies, do startups with female founders outperform
27 | those without? A couple months ago, one VC firm (almost certainly
28 | unintentionally) published a study showing bias of this type. First
29 | Round Capital found that among its portfolio companies, startups
30 | with female founders outperformed
31 | those without by 63%.
32 | [2]The reason I began by saying that this technique would come as a
33 | surprise to many people is that we so rarely see analyses of this
34 | type. I'm sure it will come as a surprise to First Round that they
35 | performed one. I doubt anyone there realized that by limiting their
36 | sample to their own portfolio, they were producing a study not of
37 | startup trends but of their own biases when selecting companies.I predict we'll see this technique used more in the future. The
38 | information needed to conduct such studies is increasingly available.
39 | Data about who applies for things is usually closely guarded by the
40 | organizations selecting them, but nowadays data about who gets
41 | selected is often publicly available to anyone who takes the trouble
42 | to aggregate it.
43 | Notes[1]
44 | This technique wouldn't work if the selection process looked
45 | for different things from different types of applicants—for
46 | example, if an employer hired men based on their ability but women
47 | based on their appearance.[2]
48 | As Paul Buchheit points out, First Round excluded their most
49 | successful investment, Uber, from the study. And while it
50 | makes sense to exclude outliers from some types of studies,
51 | studies of returns from startup investing, which is all about
52 | hitting outliers, are not one of them.
53 | Thanks to Sam Altman, Jessica Livingston, and Geoff Ralston for reading
54 | drafts of this.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/copy.txt:
--------------------------------------------------------------------------------
1 | July 2006
2 | When I was in high school I spent a lot of time imitating bad
3 | writers. What we studied in English classes was mostly fiction,
4 | so I assumed that was the highest form of writing. Mistake number
5 | one. The stories that seemed to be most admired were ones in which
6 | people suffered in complicated ways. Anything funny or
7 | gripping was ipso facto suspect, unless it was old enough to be hard to
8 | understand, like Shakespeare or Chaucer. Mistake number two. The
9 | ideal medium seemed the short story, which I've since learned had
10 | quite a brief life, roughly coincident with the peak of magazine
11 | publishing. But since their size made them perfect for use in
12 | high school classes, we read a lot of them, which gave us the
13 | impression the short story was flourishing. Mistake number three.
14 | And because they were so short, nothing really had to happen; you
15 | could just show a randomly truncated slice of life, and that was
16 | considered advanced. Mistake number four. The result was that I
17 | wrote a lot of stories in which nothing happened except that someone
18 | was unhappy in a way that seemed deep.For most of college I was a philosophy major. I was very impressed
19 | by the papers published in philosophy journals. They were so
20 | beautifully typeset, and their tone was just captivating—alternately
21 | casual and buffer-overflowingly technical. A fellow would be walking
22 | along a street and suddenly modality qua modality would spring upon
23 | him. I didn't ever quite understand these papers, but I figured
24 | I'd get around to that later, when I had time to reread them more
25 | closely. In the meantime I tried my best to imitate them. This
26 | was, I can now see, a doomed undertaking, because they weren't
27 | really saying anything. No philosopher ever refuted another, for
28 | example, because no one said anything definite enough to refute.
29 | Needless to say, my imitations didn't say anything either.In grad school I was still wasting time imitating the wrong things.
30 | There was then a fashionable type of program called an expert system,
31 | at the core of which was something called an inference engine. I
32 | looked at what these things did and thought "I could write that in
33 | a thousand lines of code." And yet eminent professors were writing
34 | books about them, and startups were selling them for a year's salary
35 | a copy. What an opportunity, I thought; these impressive things
36 | seem easy to me; I must be pretty sharp. Wrong. It was simply a
37 | fad. The books the professors wrote about expert systems are now
38 | ignored. They were not even on a path to anything interesting.
39 | And the customers paying so much for them were largely the same
40 | government agencies that paid thousands for screwdrivers and toilet
41 | seats.How do you avoid copying the wrong things? Copy only what you
42 | genuinely like. That would have saved me in all three cases. I
43 | didn't enjoy the short stories we had to read in English classes;
44 | I didn't learn anything from philosophy papers; I didn't use expert
45 | systems myself. I believed these things were good because they
46 | were admired.It can be hard to separate the things you like from the things
47 | you're impressed with. One trick is to ignore presentation. Whenever
48 | I see a painting impressively hung in a museum, I ask myself: how
49 | much would I pay for this if I found it at a garage sale, dirty and
50 | frameless, and with no idea who painted it? If you walk around a
51 | museum trying this experiment, you'll find you get some truly
52 | startling results. Don't ignore this data point just because it's
53 | an outlier.Another way to figure out what you like is to look at what you enjoy
54 | as guilty pleasures. Many things people like, especially if they're
55 | young and ambitious, they like largely for the feeling of virtue
56 | in liking them. 99% of people reading Ulysses are thinking
57 | "I'm reading Ulysses" as they do it. A guilty pleasure is
58 | at least a pure one. What do you read when you don't feel up to being
59 | virtuous? What kind of book do you read and feel sad that there's
60 | only half of it left, instead of being impressed that you're half
61 | way through? That's what you really like.Even when you find genuinely good things to copy, there's another
62 | pitfall to be avoided. Be careful to copy what makes them good,
63 | rather than their flaws. It's easy to be drawn into imitating
64 | flaws, because they're easier to see, and of course easier to copy
65 | too. For example, most painters in the eighteenth and nineteenth
66 | centuries used brownish colors. They were imitating the great
67 | painters of the Renaissance, whose paintings by that time were brown
68 | with dirt. Those paintings have since been cleaned, revealing
69 | brilliant colors; their imitators are of course still brown.It was painting, incidentally, that cured me of copying the wrong
70 | things. Halfway through grad school I decided I wanted to try being
71 | a painter, and the art world was so manifestly corrupt that it
72 | snapped the leash of credulity. These people made philosophy
73 | professors seem as scrupulous as mathematicians. It was so clearly
74 | a choice of doing good work xor being an insider that I was forced
75 | to see the distinction. It's there to some degree in almost every
76 | field, but I had till then managed to avoid facing it.That was one of the most valuable things I learned from painting:
77 | you have to figure out for yourself what's
78 | good. You can't trust
79 | authorities. They'll lie to you on this one.
80 |
81 | Comment on this essay.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/diff.txt:
--------------------------------------------------------------------------------
1 | December 2001 (rev. May 2002)
2 |
3 | (This article came about in response to some questions on
4 | the LL1 mailing list. It is now
5 | incorporated in Revenge of the Nerds.)When McCarthy designed Lisp in the late 1950s, it was
6 | a radical departure from existing languages,
7 | the most important of which was Fortran.Lisp embodied nine new ideas:
8 | 1. Conditionals. A conditional is an if-then-else
9 | construct. We take these for granted now. They were
10 | invented
11 | by McCarthy in the course of developing Lisp.
12 | (Fortran at that time only had a conditional
13 | goto, closely based on the branch instruction in the
14 | underlying hardware.) McCarthy, who was on the Algol committee, got
15 | conditionals into Algol, whence they spread to most other
16 | languages.2. A function type. In Lisp, functions are first class
17 | objects-- they're a data type just like integers, strings,
18 | etc, and have a literal representation, can be stored in variables,
19 | can be passed as arguments, and so on.3. Recursion. Recursion existed as a mathematical concept
20 | before Lisp of course, but Lisp was the first programming language to support
21 | it. (It's arguably implicit in making functions first class
22 | objects.)4. A new concept of variables. In Lisp, all variables
23 | are effectively pointers. Values are what
24 | have types, not variables, and assigning or binding
25 | variables means copying pointers, not what they point to.5. Garbage-collection.6. Programs composed of expressions. Lisp programs are
26 | trees of expressions, each of which returns a value.
27 | (In some Lisps expressions
28 | can return multiple values.) This is in contrast to Fortran
29 | and most succeeding languages, which distinguish between
30 | expressions and statements.It was natural to have this
31 | distinction in Fortran because (not surprisingly in a language
32 | where the input format was punched cards) the language was
33 | line-oriented. You could not nest statements. And
34 | so while you needed expressions for math to work, there was
35 | no point in making anything else return a value, because
36 | there could not be anything waiting for it.This limitation
37 | went away with the arrival of block-structured languages,
38 | but by then it was too late. The distinction between
39 | expressions and statements was entrenched. It spread from
40 | Fortran into Algol and thence to both their descendants.When a language is made entirely of expressions, you can
41 | compose expressions however you want. You can say either
42 | (using Arc syntax)(if foo (= x 1) (= x 2))or(= x (if foo 1 2))7. A symbol type. Symbols differ from strings in that
43 | you can test equality by comparing a pointer.8. A notation for code using trees of symbols.9. The whole language always available.
44 | There is
45 | no real distinction between read-time, compile-time, and runtime.
46 | You can compile or run code while reading, read or run code
47 | while compiling, and read or compile code at runtime.Running code at read-time lets users reprogram Lisp's syntax;
48 | running code at compile-time is the basis of macros; compiling
49 | at runtime is the basis of Lisp's use as an extension
50 | language in programs like Emacs; and reading at runtime
51 | enables programs to communicate using s-expressions, an
52 | idea recently reinvented as XML.
53 | When Lisp was first invented, all these ideas were far
54 | removed from ordinary programming practice, which was
55 | dictated largely by the hardware available in the late 1950s.Over time, the default language, embodied
56 | in a succession of popular languages, has
57 | gradually evolved toward Lisp. 1-5 are now widespread.
58 | 6 is starting to appear in the mainstream.
59 | Python has a form of 7, though there doesn't seem to be
60 | any syntax for it.
61 | 8, which (with 9) is what makes Lisp macros
62 | possible, is so far still unique to Lisp,
63 | perhaps because (a) it requires those parens, or something
64 | just as bad, and (b) if you add that final increment of power,
65 | you can no
66 | longer claim to have invented a new language, but only
67 | to have designed a new dialect of Lisp ; -)Though useful to present-day programmers, it's
68 | strange to describe Lisp in terms of its
69 | variation from the random expedients other languages
70 | adopted. That was not, probably, how McCarthy
71 | thought of it. Lisp wasn't designed to fix the mistakes
72 | in Fortran; it came about more as the byproduct of an
73 | attempt to axiomatize computation.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/founders.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | Want to start a startup? Get funded by
4 | Y Combinator.
5 |
6 |
7 |
8 |
9 | October 2010
10 |
11 | (I wrote this for Forbes, who asked me to write something
12 | about the qualities we look for in founders. In print they had to cut
13 | the last item because they didn't have room.)1. DeterminationThis has turned out to be the most important quality in startup
14 | founders. We thought when we started Y Combinator that the most
15 | important quality would be intelligence. That's the myth in the
16 | Valley. And certainly you don't want founders to be stupid. But
17 | as long as you're over a certain threshold of intelligence, what
18 | matters most is determination. You're going to hit a lot of
19 | obstacles. You can't be the sort of person who gets demoralized
20 | easily.Bill Clerico and Rich Aberman of WePay
21 | are a good example. They're
22 | doing a finance startup, which means endless negotiations with big,
23 | bureaucratic companies. When you're starting a startup that depends
24 | on deals with big companies to exist, it often feels like they're
25 | trying to ignore you out of existence. But when Bill Clerico starts
26 | calling you, you may as well do what he asks, because he is not
27 | going away.
28 | 2. FlexibilityYou do not however want the sort of determination implied by phrases
29 | like "don't give up on your dreams." The world of startups is so
30 | unpredictable that you need to be able to modify your dreams on the
31 | fly. The best metaphor I've found for the combination of determination
32 | and flexibility you need is a running back.
33 | He's determined to get
34 | downfield, but at any given moment he may need to go sideways or
35 | even backwards to get there.The current record holder for flexibility may be Daniel Gross of
36 | Greplin. He applied to YC with
37 | some bad ecommerce idea. We told
38 | him we'd fund him if he did something else. He thought for a second,
39 | and said ok. He then went through two more ideas before settling
40 | on Greplin. He'd only been working on it for a couple days when
41 | he presented to investors at Demo Day, but he got a lot of interest.
42 | He always seems to land on his feet.
43 | 3. ImaginationIntelligence does matter a lot of course. It seems like the type
44 | that matters most is imagination. It's not so important to be able
45 | to solve predefined problems quickly as to be able to come up with
46 | surprising new ideas. In the startup world, most good ideas
47 | seem
48 | bad initially. If they were obviously good, someone would already
49 | be doing them. So you need the kind of intelligence that produces
50 | ideas with just the right level of craziness.Airbnb is that kind of idea.
51 | In fact, when we funded Airbnb, we
52 | thought it was too crazy. We couldn't believe large numbers of
53 | people would want to stay in other people's places. We funded them
54 | because we liked the founders so much. As soon as we heard they'd
55 | been supporting themselves by selling Obama and McCain branded
56 | breakfast cereal, they were in. And it turned out the idea was on
57 | the right side of crazy after all.
58 | 4. NaughtinessThough the most successful founders are usually good people, they
59 | tend to have a piratical gleam in their eye. They're not Goody
60 | Two-Shoes type good. Morally, they care about getting the big
61 | questions right, but not about observing proprieties. That's why
62 | I'd use the word naughty rather than evil. They delight in
63 | breaking
64 | rules, but not rules that matter. This quality may be redundant
65 | though; it may be implied by imagination.Sam Altman of Loopt
66 | is one of the most successful alumni, so we
67 | asked him what question we could put on the Y Combinator application
68 | that would help us discover more people like him. He said to ask
69 | about a time when they'd hacked something to their advantage—hacked in the sense of beating the system, not breaking into
70 | computers. It has become one of the questions we pay most attention
71 | to when judging applications.
72 | 5. FriendshipEmpirically it seems to be hard to start a startup with just
73 | one
74 | founder. Most of the big successes have two or three. And the
75 | relationship between the founders has to be strong. They must
76 | genuinely like one another, and work well together. Startups do
77 | to the relationship between the founders what a dog does to a sock:
78 | if it can be pulled apart, it will be.Emmett Shear and Justin Kan of Justin.tv
79 | are a good example of close
80 | friends who work well together. They've known each other since
81 | second grade. They can practically read one another's minds. I'm
82 | sure they argue, like all founders, but I have never once sensed
83 | any unresolved tension between them.Thanks to Jessica Livingston and Chris Steiner for reading drafts of this.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/foundervisa.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | April 2009I usually avoid politics, but since we now seem to have an administration that's open to suggestions, I'm going to risk making one. The single biggest thing the government could do to increase the number of startups in this country is a policy that would cost nothing: establish a new class of visa for startup founders.The biggest constraint on the number of new startups that get created in the US is not tax policy or employment law or even Sarbanes-Oxley. It's that we won't let the people who want to start them into the country.Letting just 10,000 startup founders into the country each year could have a visible effect on the economy. If we assume 4 people per startup, which is probably an overestimate, that's 2500 new companies. Each year. They wouldn't all grow as big as Google, but out of 2500 some would come close.By definition these 10,000 founders wouldn't be taking jobs from Americans: it could be part of the terms of the visa that they couldn't work for existing companies, only new ones they'd founded. In fact they'd cause there to be
4 | more jobs for Americans, because the companies they started would hire more employees as they grew.The tricky part might seem to be how one defined a startup. But that could be solved quite easily: let the market decide. Startup investors work hard to find the best startups. The government could not do better than to piggyback on their expertise, and use investment by recognized startup investors as the test of whether a company was a real startup.How would the government decide who's a startup investor? The same way they decide what counts as a university for student visas. We'll establish our own accreditation procedure. We know who one another are.10,000 people is a drop in the bucket by immigration standards, but would represent a huge increase in the pool of startup founders. I think this would have such a visible effect on the economy that it would make the legislator who introduced the bill famous. The only way to know for sure would be to try it, and that would cost practically nothing.
5 | Thanks to Trevor Blackwell, Paul Buchheit, Jeff Clavier, David Hornik, Jessica Livingston, Greg Mcadoo, Aydin Senkut, and Fred Wilson for reading drafts of this.Related:
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/iflisp.txt:
--------------------------------------------------------------------------------
1 | May 2003If Lisp is so great, why don't more people use it? I was
2 | asked this question by a student in the audience at a
3 | talk I gave recently. Not for the first time, either.In languages, as in so many things, there's not much
4 | correlation between popularity and quality. Why does
5 | John Grisham (King of Torts sales rank, 44) outsell
6 | Jane Austen (Pride and Prejudice sales rank, 6191)?
7 | Would even Grisham claim that it's because he's a better
8 | writer?Here's the first sentence of Pride and Prejudice:
9 |
10 | It is a truth universally acknowledged, that a single man
11 | in possession of a good fortune must be in want of a
12 | wife.
13 |
14 | "It is a truth universally acknowledged?" Long words for
15 | the first sentence of a love story.Like Jane Austen, Lisp looks hard. Its syntax, or lack
16 | of syntax, makes it look completely unlike
17 | the languages
18 | most people are used to. Before I learned Lisp, I was afraid
19 | of it too. I recently came across a notebook from 1983
20 | in which I'd written:
21 |
22 | I suppose I should learn Lisp, but it seems so foreign.
23 |
24 | Fortunately, I was 19 at the time and not too resistant to learning
25 | new things. I was so ignorant that learning
26 | almost anything meant learning new things.People frightened by Lisp make up other reasons for not
27 | using it. The standard
28 | excuse, back when C was the default language, was that Lisp
29 | was too slow. Now that Lisp dialects are among
30 | the faster
31 | languages available, that excuse has gone away.
32 | Now the standard excuse is openly circular: that other languages
33 | are more popular.(Beware of such reasoning. It gets you Windows.)Popularity is always self-perpetuating, but it's especially
34 | so in programming languages. More libraries
35 | get written for popular languages, which makes them still
36 | more popular. Programs often have to work with existing programs,
37 | and this is easier if they're written in the same language,
38 | so languages spread from program to program like a virus.
39 | And managers prefer popular languages, because they give them
40 | more leverage over developers, who can more easily be replaced.Indeed, if programming languages were all more or less equivalent,
41 | there would be little justification for using any but the most
42 | popular. But they aren't all equivalent, not by a long
43 | shot. And that's why less popular languages, like Jane Austen's
44 | novels, continue to survive at all. When everyone else is reading
45 | the latest John Grisham novel, there will always be a few people
46 | reading Jane Austen instead.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/island.txt:
--------------------------------------------------------------------------------
1 | July 2006I've discovered a handy test for figuring out what you're addicted
2 | to. Imagine you were going to spend the weekend at a friend's house
3 | on a little island off the coast of Maine. There are no shops on
4 | the island and you won't be able to leave while you're there. Also,
5 | you've never been to this house before, so you can't assume it will
6 | have more than any house might.What, besides clothes and toiletries, do you make a point of packing?
7 | That's what you're addicted to. For example, if you find yourself
8 | packing a bottle of vodka (just in case), you may want to stop and
9 | think about that.For me the list is four things: books, earplugs, a notebook, and a
10 | pen.There are other things I might bring if I thought of it, like music,
11 | or tea, but I can live without them. I'm not so addicted to caffeine
12 | that I wouldn't risk the house not having any tea, just for a
13 | weekend.Quiet is another matter. I realize it seems a bit eccentric to
14 | take earplugs on a trip to an island off the coast of Maine. If
15 | anywhere should be quiet, that should. But what if the person in
16 | the next room snored? What if there was a kid playing basketball?
17 | (Thump, thump, thump... thump.) Why risk it? Earplugs are small.Sometimes I can think with noise. If I already have momentum on
18 | some project, I can work in noisy places. I can edit an essay or
19 | debug code in an airport. But airports are not so bad: most of the
20 | noise is whitish. I couldn't work with the sound of a sitcom coming
21 | through the wall, or a car in the street playing thump-thump music.And of course there's another kind of thinking, when you're starting
22 | something new, that requires complete quiet. You never
23 | know when this will strike. It's just as well to carry plugs.The notebook and pen are professional equipment, as it were. Though
24 | actually there is something druglike about them, in the sense that
25 | their main purpose is to make me feel better. I hardly ever go
26 | back and read stuff I write down in notebooks. It's just that if
27 | I can't write things down, worrying about remembering one idea gets
28 | in the way of having the next. Pen and paper wick ideas.The best notebooks I've found are made by a company called Miquelrius.
29 | I use their smallest size, which is about 2.5 x 4 in.
30 | The secret to writing on such
31 | narrow pages is to break words only when you run out of space, like
32 | a Latin inscription. I use the cheapest plastic Bic ballpoints,
33 | partly because their gluey ink doesn't seep through pages, and
34 | partly so I don't worry about losing them.I only started carrying a notebook about three years ago. Before
35 | that I used whatever scraps of paper I could find. But the problem
36 | with scraps of paper is that they're not ordered. In a notebook
37 | you can guess what a scribble means by looking at the pages
38 | around it. In the scrap era I was constantly finding notes I'd
39 | written years before that might say something I needed to remember,
40 | if I could only figure out what.As for books, I know the house would probably have something to
41 | read. On the average trip I bring four books and only read one of
42 | them, because I find new books to read en route. Really bringing
43 | books is insurance.I realize this dependence on books is not entirely good—that what
44 | I need them for is distraction. The books I bring on trips are
45 | often quite virtuous, the sort of stuff that might be assigned
46 | reading in a college class. But I know my motives aren't virtuous.
47 | I bring books because if the world gets boring I need to be able
48 | to slip into another distilled by some writer. It's like eating
49 | jam when you know you should be eating fruit.There is a point where I'll do without books. I was walking in
50 | some steep mountains once, and decided I'd rather just think, if I
51 | was bored, rather than carry a single unnecessary ounce. It wasn't
52 | so bad. I found I could entertain myself by having ideas instead
53 | of reading other people's. If you stop eating jam, fruit starts
54 | to taste better.So maybe I'll try not bringing books on some future trip. They're
55 | going to have to pry the plugs out of my cold, dead ears, however.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/know.txt:
--------------------------------------------------------------------------------
1 | December 2014I've read Villehardouin's chronicle of the Fourth Crusade at least
2 | two times, maybe three. And yet if I had to write down everything
3 | I remember from it, I doubt it would amount to much more than a
4 | page. Multiply this times several hundred, and I get an uneasy
5 | feeling when I look at my bookshelves. What use is it to read all
6 | these books if I remember so little from them?A few months ago, as I was reading Constance Reid's excellent
7 | biography of Hilbert, I figured out if not the answer to this
8 | question, at least something that made me feel better about it.
9 | She writes:
10 |
11 | Hilbert had no patience with mathematical lectures which filled
12 | the students with facts but did not teach them how to frame a
13 | problem and solve it. He often used to tell them that "a perfect
14 | formulation of a problem is already half its solution."
15 |
16 | That has always seemed to me an important point, and I was even
17 | more convinced of it after hearing it confirmed by Hilbert.But how had I come to believe in this idea in the first place? A
18 | combination of my own experience and other things I'd read. None
19 | of which I could at that moment remember! And eventually I'd forget
20 | that Hilbert had confirmed it too. But my increased belief in the
21 | importance of this idea would remain something I'd learned from
22 | this book, even after I'd forgotten I'd learned it.Reading and experience train your model of the world. And even if
23 | you forget the experience or what you read, its effect on your model
24 | of the world persists. Your mind is like a compiled program you've
25 | lost the source of. It works, but you don't know why.The place to look for what I learned from Villehardouin's chronicle
26 | is not what I remember from it, but my mental models of the crusades,
27 | Venice, medieval culture, siege warfare, and so on. Which doesn't
28 | mean I couldn't have read more attentively, but at least the harvest
29 | of reading is not so miserably small as it might seem.This is one of those things that seem obvious in retrospect. But
30 | it was a surprise to me and presumably would be to anyone else who
31 | felt uneasy about (apparently) forgetting so much they'd read.Realizing it does more than make you feel a little better about
32 | forgetting, though. There are specific implications.For example, reading and experience are usually "compiled" at the
33 | time they happen, using the state of your brain at that time. The
34 | same book would get compiled differently at different points in
35 | your life. Which means it is very much worth reading important
36 | books multiple times. I always used to feel some misgivings about
37 | rereading books. I unconsciously lumped reading together with work
38 | like carpentry, where having to do something again is a sign you
39 | did it wrong the first time. Whereas now the phrase "already read"
40 | seems almost ill-formed.Intriguingly, this implication isn't limited to books. Technology
41 | will increasingly make it possible to relive our experiences. When
42 | people do that today it's usually to enjoy them again (e.g. when
43 | looking at pictures of a trip) or to find the origin of some bug in
44 | their compiled code (e.g. when Stephen Fry succeeded in remembering
45 | the childhood trauma that prevented him from singing). But as
46 | technologies for recording and playing back your life improve, it
47 | may become common for people to relive experiences without any goal
48 | in mind, simply to learn from them again as one might when rereading
49 | a book.Eventually we may be able not just to play back experiences but
50 | also to index and even edit them. So although not knowing how you
51 | know things may seem part of being human, it may not be.
52 | Thanks to Sam Altman, Jessica Livingston, and Robert Morris for reading
53 | drafts of this.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/mod.txt:
--------------------------------------------------------------------------------
1 | December 2019There are two distinct ways to be politically moderate: on purpose
2 | and by accident. Intentional moderates are trimmers, deliberately
3 | choosing a position mid-way between the extremes of right and left.
4 | Accidental moderates end up in the middle, on average, because they
5 | make up their own minds about each question, and the far right and
6 | far left are roughly equally wrong.You can distinguish intentional from accidental moderates by the
7 | distribution of their opinions. If the far left opinion on some
8 | matter is 0 and the far right opinion 100, an intentional moderate's
9 | opinion on every question will be near 50. Whereas an accidental
10 | moderate's opinions will be scattered over a broad range, but will,
11 | like those of the intentional moderate, average to about 50.Intentional moderates are similar to those on the far left and the
12 | far right in that their opinions are, in a sense, not their own.
13 | The defining quality of an ideologue, whether on the left or the
14 | right, is to acquire one's opinions in bulk. You don't get to pick
15 | and choose. Your opinions about taxation can be predicted from your
16 | opinions about sex. And although intentional moderates
17 | might seem to be the opposite of ideologues, their beliefs (though
18 | in their case the word "positions" might be more accurate) are also
19 | acquired in bulk. If the median opinion shifts to the right or left,
20 | the intentional moderate must shift with it. Otherwise they stop
21 | being moderate.Accidental moderates, on the other hand, not only choose their own
22 | answers, but choose their own questions. They may not care at all
23 | about questions that the left and right both think are terribly
24 | important. So you can only even measure the politics of an accidental
25 | moderate from the intersection of the questions they care about and
26 | those the left and right care about, and this can
27 | sometimes be vanishingly small.It is not merely a manipulative rhetorical trick to say "if you're
28 | not with us, you're against us," but often simply false.Moderates are sometimes derided as cowards, particularly by
29 | the extreme left. But while it may be accurate to call intentional
30 | moderates cowards, openly being an accidental moderate requires the
31 | most courage of all, because you get attacked from both right and
32 | left, and you don't have the comfort of being an orthodox member
33 | of a large group to sustain you.Nearly all the most impressive people I know are accidental moderates.
34 | If I knew a lot of professional athletes, or people in the entertainment
35 | business, that might be different. Being on the far left or far
36 | right doesn't affect how fast you run or how well you sing. But
37 | someone who works with ideas has to be independent-minded to do it
38 | well.Or more precisely, you have to be independent-minded about the ideas
39 | you work with. You could be mindlessly doctrinaire in your politics
40 | and still be a good mathematician. In the 20th century, a lot of
41 | very smart people were Marxists just no one who was smart about
42 | the subjects Marxism involves. But if the ideas you use in your
43 | work intersect with the politics of your time, you have two choices:
44 | be an accidental moderate, or be mediocre.Notes[1] It's possible in theory for one side to be entirely right and
45 | the other to be entirely wrong. Indeed, ideologues must always
46 | believe this is the case. But historically it rarely has been.[2] For some reason the far right tend to ignore moderates rather
47 | than despise them as backsliders. I'm not sure why. Perhaps it
48 | means that the far right is less ideological than the far left. Or
49 | perhaps that they are more confident, or more resigned, or simply
50 | more disorganized. I just don't know.[3] Having heretical opinions doesn't mean you have to express
51 | them openly. It may be
52 | easier to have them if you don't.
53 | Thanks to Austen Allred, Trevor Blackwell, Patrick Collison, Jessica Livingston,
54 | Amjad Masad, Ryan Petersen, and Harj Taggar for reading drafts of this.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/nft.txt:
--------------------------------------------------------------------------------
1 | May 2021Noora Health, a nonprofit I've
2 | supported for years, just launched
3 | a new NFT. It has a dramatic name, Save Thousands of Lives,
4 | because that's what the proceeds will do.Noora has been saving lives for 7 years. They run programs in
5 | hospitals in South Asia to teach new mothers how to take care of
6 | their babies once they get home. They're in 165 hospitals now. And
7 | because they know the numbers before and after they start at a new
8 | hospital, they can measure the impact they have. It is massive.
9 | For every 1000 live births, they save 9 babies.This number comes from a study
10 | of 133,733 families at 28 different
11 | hospitals that Noora conducted in collaboration with the Better
12 | Birth team at Ariadne Labs, a joint center for health systems
13 | innovation at Brigham and Womens Hospital and Harvard T.H. Chan
14 | School of Public Health.Noora is so effective that even if you measure their costs in the
15 | most conservative way, by dividing their entire budget by the number
16 | of lives saved, the cost of saving a life is the lowest I've seen.
17 | $1,235.For this NFT, they're going to issue a public report tracking how
18 | this specific tranche of money is spent, and estimating the number
19 | of lives saved as a result.NFTs are a new territory, and this way of using them is especially
20 | new, but I'm excited about its potential. And I'm excited to see
21 | what happens with this particular auction, because unlike an NFT
22 | representing something that has already happened,
23 | this NFT gets better as the price gets higher.The reserve price was about $2.5 million, because that's what it
24 | takes for the name to be accurate: that's what it costs to save
25 | 2000 lives. But the higher the price of this NFT goes, the more
26 | lives will be saved. What a sentence to be able to write.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/pow.txt:
--------------------------------------------------------------------------------
1 | January 2017People who are powerful but uncharismatic will tend to be disliked.
2 | Their power makes them a target for criticism that they don't have
3 | the charisma to disarm. That was Hillary Clinton's problem. It also
4 | tends to be a problem for any CEO who is more of a builder than a
5 | schmoozer. And yet the builder-type CEO is (like Hillary) probably
6 | the best person for the job.I don't think there is any solution to this problem. It's human
7 | nature. The best we can do is to recognize that it's happening, and
8 | to understand that being a magnet for criticism is sometimes a sign
9 | not that someone is the wrong person for a job, but that they're
10 | the right one.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/rootsoflisp.txt:
--------------------------------------------------------------------------------
1 | May 2001
2 |
3 | (I wrote this article to help myself understand exactly
4 | what McCarthy discovered. You don't need to know this stuff
5 | to program in Lisp, but it should be helpful to
6 | anyone who wants to
7 | understand the essence of Lisp both in the sense of its
8 | origins and its semantic core. The fact that it has such a core
9 | is one of Lisp's distinguishing features, and the reason why,
10 | unlike other languages, Lisp has dialects.)In 1960, John
11 | McCarthy published a remarkable paper in
12 | which he did for programming something like what Euclid did for
13 | geometry. He showed how, given a handful of simple
14 | operators and a notation for functions, you can
15 | build a whole programming language.
16 | He called this language Lisp, for "List Processing,"
17 | because one of his key ideas was to use a simple
18 | data structure called a list for both
19 | code and data.It's worth understanding what McCarthy discovered, not
20 | just as a landmark in the history of computers, but as
21 | a model for what programming is tending to become in
22 | our own time. It seems to me that there have been
23 | two really clean, consistent models of programming so
24 | far: the C model and the Lisp model.
25 | These two seem points of high ground, with swampy lowlands
26 | between them. As computers have grown more powerful,
27 | the new languages being developed have been moving
28 | steadily toward the Lisp model. A popular recipe
29 | for new programming languages in the past 20 years
30 | has been to take the C model of computing and add to
31 | it, piecemeal, parts taken from the Lisp model,
32 | like runtime typing and garbage collection.In this article I'm going to try to explain in the
33 | simplest possible terms what McCarthy discovered.
34 | The point is not just to learn about an interesting
35 | theoretical result someone figured out forty years ago,
36 | but to show where languages are heading.
37 | The unusual thing about Lisp in fact, the defining
38 | quality of Lisp is that it can be written in
39 | itself. To understand what McCarthy meant by this,
40 | we're going to retrace his steps, with his mathematical
41 | notation translated into running Common Lisp code.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/rss.txt:
--------------------------------------------------------------------------------
1 | Aaron Swartz created a scraped
2 | feed
3 | of the essays page.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/sun.txt:
--------------------------------------------------------------------------------
1 | September 2017The most valuable insights are both general and surprising.
2 | F = ma for example. But general and surprising is a hard
3 | combination to achieve. That territory tends to be picked
4 | clean, precisely because those insights are so valuable.Ordinarily, the best that people can do is one without the
5 | other: either surprising without being general (e.g.
6 | gossip), or general without being surprising (e.g.
7 | platitudes).Where things get interesting is the moderately valuable
8 | insights. You get those from small additions of whichever
9 | quality was missing. The more common case is a small
10 | addition of generality: a piece of gossip that's more than
11 | just gossip, because it teaches something interesting about
12 | the world. But another less common approach is to focus on
13 | the most general ideas and see if you can find something new
14 | to say about them. Because these start out so general, you
15 | only need a small delta of novelty to produce a useful
16 | insight.A small delta of novelty is all you'll be able to get most
17 | of the time. Which means if you take this route, your ideas
18 | will seem a lot like ones that already exist. Sometimes
19 | you'll find you've merely rediscovered an idea that did
20 | already exist. But don't be discouraged. Remember the huge
21 | multiplier that kicks in when you do manage to think of
22 | something even a little new.Corollary: the more general the ideas you're talking about,
23 | the less you should worry about repeating yourself. If you
24 | write enough, it's inevitable you will. Your brain is much
25 | the same from year to year and so are the stimuli that hit
26 | it. I feel slightly bad when I find I've said something
27 | close to what I've said before, as if I were plagiarizing
28 | myself. But rationally one shouldn't. You won't say
29 | something exactly the same way the second time, and that
30 | variation increases the chance you'll get that tiny but
31 | critical delta of novelty.And of course, ideas beget ideas. (That sounds
32 | familiar.)
33 | An idea with a small amount of novelty could lead to one
34 | with more. But only if you keep going. So it's doubly
35 | important not to let yourself be discouraged by people who
36 | say there's not much new about something you've discovered.
37 | "Not much new" is a real achievement when you're talking
38 | about the most general ideas. It's not true that there's nothing new under the sun. There
39 | are some domains where there's almost nothing new. But
40 | there's a big difference between nothing and almost nothing,
41 | when it's multiplied by the area under the sun.
42 | Thanks to Sam Altman, Patrick Collison, and Jessica
43 | Livingston for reading drafts of this.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/todo.txt:
--------------------------------------------------------------------------------
1 | April 2012A palliative care nurse called Bronnie Ware made a list of the
2 | biggest regrets
3 | of the dying. Her list seems plausible. I could see
4 | myself — can see myself — making at least 4 of these
5 | 5 mistakes.If you had to compress them into a single piece of advice, it might
6 | be: don't be a cog. The 5 regrets paint a portrait of post-industrial
7 | man, who shrinks himself into a shape that fits his circumstances,
8 | then turns dutifully till he stops.The alarming thing is, the mistakes that produce these regrets are
9 | all errors of omission. You forget your dreams, ignore your family,
10 | suppress your feelings, neglect your friends, and forget to be
11 | happy. Errors of omission are a particularly dangerous type of
12 | mistake, because you make them by default.I would like to avoid making these mistakes. But how do you avoid
13 | mistakes you make by default? Ideally you transform your life so
14 | it has other defaults. But it may not be possible to do that
15 | completely. As long as these mistakes happen by default, you probably
16 | have to be reminded not to make them. So I inverted the 5 regrets,
17 | yielding a list of 5 commands
18 |
19 | Don't ignore your dreams; don't work too much; say what you
20 | think; cultivate friendships; be happy.
21 |
22 | which I then put at the top of the file I use as a todo list.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/unions.txt:
--------------------------------------------------------------------------------
1 | May 2007People who worry about the increasing gap between rich and poor
2 | generally look back on the mid twentieth century as a golden age.
3 | In those days we had a large number of high-paying union manufacturing
4 | jobs that boosted the median income. I wouldn't quite call the
5 | high-paying union job a myth, but I think people who dwell on it
6 | are reading too much into it.Oddly enough, it was working with startups that made me realize
7 | where the high-paying union job came from. In a rapidly growing
8 | market, you don't worry too much about efficiency. It's more
9 | important to grow fast. If there's some mundane problem getting
10 | in your way, and there's a simple solution that's somewhat expensive,
11 | just take it and get on with more important things. EBay didn't
12 | win by paying less for servers than their competitors.Difficult though it may be to imagine now, manufacturing was a
13 | growth industry in the mid twentieth century. This was an era when
14 | small firms making everything from cars to candy were getting
15 | consolidated into a new kind of corporation with national reach and
16 | huge economies of scale. You had to grow fast or die. Workers
17 | were for these companies what servers are for an Internet startup.
18 | A reliable supply was more important than low cost.If you looked in the head of a 1950s auto executive, the attitude
19 | must have been: sure, give 'em whatever they ask for, so long as
20 | the new model isn't delayed.In other words, those workers were not paid what their work was
21 | worth. Circumstances being what they were, companies would have
22 | been stupid to insist on paying them so little.If you want a less controversial example of this phenomenon, ask
23 | anyone who worked as a consultant building web sites during the
24 | Internet Bubble. In the late nineties you could get paid huge sums
25 | of money for building the most trivial things. And yet does anyone
26 | who was there have any expectation those days will ever return? I
27 | doubt it. Surely everyone realizes that was just a temporary
28 | aberration.The era of labor unions seems to have been the same kind of aberration,
29 | just spread
30 | over a longer period, and mixed together with a lot of ideology
31 | that prevents people from viewing it with as cold an eye as they
32 | would something like consulting during the Bubble.Basically, unions were just Razorfish.People who think the labor movement was the creation of heroic union
33 | organizers have a problem to explain: why are unions shrinking now?
34 | The best they can do is fall back on the default explanation of
35 | people living in fallen civilizations. Our ancestors were giants.
36 | The workers of the early twentieth century must have had a moral
37 | courage that's lacking today.In fact there's a simpler explanation. The early twentieth century
38 | was just a fast-growing startup overpaying for infrastructure. And
39 | we in the present are not a fallen people, who have abandoned
40 | whatever mysterious high-minded principles produced the high-paying
41 | union job. We simply live in a time when the fast-growing companies
42 | overspend on different things.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/vw.txt:
--------------------------------------------------------------------------------
1 | January 2012A few hours before the Yahoo acquisition was announced in June 1998
2 | I took a snapshot of Viaweb's
3 | site. I thought it might be interesting to look at one day.The first thing one notices is is how tiny the pages are. Screens
4 | were a lot smaller in 1998. If I remember correctly, our frontpage
5 | used to just fit in the size window people typically used then.Browsers then (IE 6 was still 3 years in the future) had few fonts
6 | and they weren't antialiased. If you wanted to make pages that
7 | looked good, you had to render display text as images.You may notice a certain similarity between the Viaweb and Y Combinator logos. We did that
8 | as an inside joke when we started YC. Considering how basic a red
9 | circle is, it seemed surprising to me when we started Viaweb how
10 | few other companies used one as their logo. A bit later I realized
11 | why.On the Company
12 | page you'll notice a mysterious individual called John McArtyem.
13 | Robert Morris (aka Rtm) was so publicity averse after the
14 | Worm that he
15 | didn't want his name on the site. I managed to get him to agree
16 | to a compromise: we could use his bio but not his name. He has
17 | since relaxed a bit
18 | on that point.Trevor graduated at about the same time the acquisition closed, so in the
19 | course of 4 days he went from impecunious grad student to millionaire
20 | PhD. The culmination of my career as a writer of press releases
21 | was one celebrating
22 | his graduation, illustrated with a drawing I did of him during
23 | a meeting.(Trevor also appears as Trevino
24 | Bagwell in our directory of web designers merchants could hire
25 | to build stores for them. We inserted him as a ringer in case some
26 | competitor tried to spam our web designers. We assumed his logo
27 | would deter any actual customers, but it did not.)Back in the 90s, to get users you had to get mentioned in magazines
28 | and newspapers. There were not the same ways to get found online
29 | that there are today. So we used to pay a PR
30 | firm $16,000 a month to get us mentioned in the press. Fortunately
31 | reporters liked
32 | us.In our advice about
33 | getting traffic from search engines (I don't think the term SEO
34 | had been coined yet), we say there are only 7 that matter: Yahoo,
35 | AltaVista, Excite, WebCrawler, InfoSeek, Lycos, and HotBot. Notice
36 | anything missing? Google was incorporated that September.We supported online transactions via a company called
37 | Cybercash,
38 | since if we lacked that feature we'd have gotten beaten up in product
39 | comparisons. But Cybercash was so bad and most stores' order volumes
40 | were so low that it was better if merchants processed orders like phone orders. We had a page in our site trying to talk merchants
41 | out of doing real time authorizations.The whole site was organized like a funnel, directing people to the
42 | test drive.
43 | It was a novel thing to be able to try out software online. We put
44 | cgi-bin in our dynamic urls to fool competitors about how our
45 | software worked.We had some well
46 | known users. Needless to say, Frederick's of Hollywood got the
47 | most traffic. We charged a flat fee of $300/month for big stores,
48 | so it was a little alarming to have users who got lots of traffic.
49 | I once calculated how much Frederick's was costing us in bandwidth,
50 | and it was about $300/month.Since we hosted all the stores, which together were getting just
51 | over 10 million page views per month in June 1998, we consumed what
52 | at the time seemed a lot of bandwidth. We had 2 T1s (3 Mb/sec)
53 | coming into our offices. In those days there was no AWS. Even
54 | colocating servers seemed too risky, considering how often things
55 | went wrong with them. So we had our servers in our offices. Or
56 | more precisely, in Trevor's office. In return for the unique
57 | privilege of sharing his office with no other humans, he had to
58 | share it with 6 shrieking tower servers. His office was nicknamed
59 | the Hot Tub on account of the heat they generated. Most days his
60 | stack of window air conditioners could keep up.For describing pages, we had a template language called RTML, which
61 | supposedly stood for something, but which in fact I named after
62 | Rtm. RTML was Common Lisp augmented by some macros and libraries,
63 | and concealed under a structure editor that made it look like it
64 | had syntax.Since we did continuous releases, our software didn't actually have
65 | versions. But in those days the trade press expected versions, so
66 | we made them up. If we wanted to get lots of attention, we made
67 | the version number an
68 | integer. That "version 4.0" icon was generated by our own
69 | button generator, incidentally. The whole Viaweb site was made
70 | with our software, even though it wasn't an online store, because
71 | we wanted to experience what our users did.At the end of 1997, we released a general purpose shopping search
72 | engine called Shopfind. It
73 | was pretty advanced for the time. It had a programmable crawler
74 | that could crawl most of the different stores online and pick out
75 | the products.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/want.txt:
--------------------------------------------------------------------------------
1 | November 2022Since I was about 9 I've been puzzled by the apparent contradiction
2 | between being made of matter that behaves in a predictable way, and
3 | the feeling that I could choose to do whatever I wanted. At the
4 | time I had a self-interested motive for exploring the question. At
5 | that age (like most succeeding ages) I was always in trouble with
6 | the authorities, and it seemed to me that there might possibly be
7 | some way to get out of trouble by arguing that I wasn't responsible
8 | for my actions. I gradually lost hope of that, but the puzzle
9 | remained: How do you reconcile being a machine made of matter with
10 | the feeling that you're free to choose what you do?
11 | [1]The best way to explain the answer may be to start with a slightly
12 | wrong version, and then fix it. The wrong version is: You can do
13 | what you want, but you can't want what you want. Yes, you can control
14 | what you do, but you'll do what you want, and you can't control
15 | that.The reason this is mistaken is that people do sometimes change what
16 | they want. People who don't want to want something — drug addicts,
17 | for example — can sometimes make themselves stop wanting it. And
18 | people who want to want something — who want to like classical
19 | music, or broccoli — sometimes succeed.So we modify our initial statement: You can do what you want, but
20 | you can't want to want what you want.That's still not quite true. It's possible to change what you want
21 | to want. I can imagine someone saying "I decided to stop wanting
22 | to like classical music." But we're getting closer to the truth.
23 | It's rare for people to change what they want to want, and the more
24 | "want to"s we add, the rarer it gets.We can get arbitrarily close to a true statement by adding more "want
25 | to"s in much the same way we can get arbitrarily close to 1 by adding
26 | more 9s to a string of 9s following a decimal point. In practice
27 | three or four "want to"s must surely be enough. It's hard even to
28 | envision what it would mean to change what you want to want to want
29 | to want, let alone actually do it.So one way to express the correct answer is to use a regular
30 | expression. You can do what you want, but there's some statement
31 | of the form "you can't (want to)* want what you want" that's true.
32 | Ultimately you get back to a want that you don't control.
33 | [2]
34 | Notes[1]
35 | I didn't know when I was 9 that matter might behave randomly,
36 | but I don't think it affects the problem much. Randomness destroys
37 | the ghost in the machine as effectively as determinism.[2]
38 | If you don't like using an expression, you can make the same
39 | point using higher-order desires: There is some n such that you
40 | don't control your nth-order desires.
41 | Thanks to Trevor Blackwell,
42 | Jessica Livingston, Robert Morris, and
43 | Michael Nielsen for reading drafts of this.
--------------------------------------------------------------------------------
/text_extend/PaulGrahamEssays/weird.txt:
--------------------------------------------------------------------------------
1 | August 2021When people say that in their experience all programming languages
2 | are basically equivalent, they're making a statement not about
3 | languages but about the kind of programming they've done.99.5% of programming consists of gluing together calls to library
4 | functions. All popular languages are equally good at this. So one
5 | can easily spend one's whole career operating in the intersection
6 | of popular programming languages.But the other .5% of programming is disproportionately interesting.
7 | If you want to learn what it consists of, the weirdness of weird
8 | languages is a good clue to follow.Weird languages aren't weird by accident. Not the good ones, at
9 | least. The weirdness of the good ones usually implies the existence
10 | of some form of programming that's not just the usual gluing together
11 | of library calls.A concrete example: Lisp macros. Lisp macros seem weird even to
12 | many Lisp programmers. They're not only not in the intersection of
13 | popular languages, but by their nature would be hard to implement
14 | properly in a language without turning it into a dialect of
15 | Lisp. And macros are definitely evidence of techniques that go
16 | beyond glue programming. For example, solving problems by first
17 | writing a language for problems of that type, and then writing
18 | your specific application in it. Nor is this all you can do with
19 | macros; it's just one region in a space of program-manipulating
20 | techniques that even now is far from fully explored.So if you want to expand your concept of what programming can be,
21 | one way to do it is by learning weird languages. Pick a language
22 | that most programmers consider weird but whose median user is smart,
23 | and then focus on the differences between this language and the
24 | intersection of popular languages. What can you say in this language
25 | that would be impossibly inconvenient to say in others? In the
26 | process of learning how to say things you couldn't previously say,
27 | you'll probably be learning how to think things you couldn't
28 | previously think.
29 | Thanks to Trevor Blackwell, Patrick Collison, Daniel Gackle, Amjad
30 | Masad, and Robert Morris for reading drafts of this.
31 |
--------------------------------------------------------------------------------
/text_extend/eval.sh:
--------------------------------------------------------------------------------
1 | for model in Qwen2-7B-Instrcuct-224K
2 | do
3 | for num_distractor in 0 3 5
4 | do
5 | accelerate launch --num_processes 8 --config_file easy_context/accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 text_extend/eval_text_niah.py \
6 | --model text_extend/training_output/$model \
7 | --max_context_length 512000 \
8 | --min_context_length 32000 \
9 | --context_interval 32000 \
10 | --depth_interval 0.1 \
11 | --num_samples 5 \
12 | --rnd_number_digits 7 \
13 | --haystack_dir text_extend/PaulGrahamEssays \
14 | --num_distractor $num_distractor
15 | done
16 | done
17 |
18 |
--------------------------------------------------------------------------------
/text_extend/extend_qwen2.sh:
--------------------------------------------------------------------------------
1 | accelerate launch \
2 | --config_file easy_context/accelerate_configs/single_node.yaml \
3 | text_extend/text_extend_train.py \
4 | --batch-size 1 \
5 | --gradient-accumulate-every 4 \
6 | --seed 2024 \
7 | --output-dir text_extend/training_output/Qwen2-7B-Instrcuct-224K \
8 | --wandb LMExtend \
9 | --max-train-steps 1000 \
10 | --learning-rate 1e-5 \
11 | --dataset PY007/slimpajama_Qwen2_tokenized_upsample_4096_chunk_256K \
12 | --model Qwen/Qwen2-7B-Instruct \
13 | --seq-length 224000 \
14 | --rope-theta 1000000000 \
15 | --parallel_mode zigzag_ring_attn \
16 | --checkpointing-steps 200
17 |
18 | rm text_extend/training_output/Qwen2-7B-Instruct-extend/model.safetensors
19 |
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_0/Qwen2-7B-Instruct-extend-step_1000/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.9977272727272727
2 |
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_0/Qwen2-7B-Instruct-extend-step_1000/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_0/Qwen2-7B-Instruct-extend-step_1000/heatmap.png
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_3/Meta-Llama-3-8B-Instruct-extend-step1000/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.905681818181818
2 |
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_3/Meta-Llama-3-8B-Instruct-extend-step1000/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_3/Meta-Llama-3-8B-Instruct-extend-step1000/heatmap.png
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_3/Qwen2-7B-Instruct-extend-step_1000/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.9352272727272729
2 |
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_3/Qwen2-7B-Instruct-extend-step_1000/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_3/Qwen2-7B-Instruct-extend-step_1000/heatmap.png
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_5/Qwen2-7B-Instruct-extend-step_1000/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.9204545454545454
2 |
--------------------------------------------------------------------------------
/text_extend/niah_output/distractor_5/Qwen2-7B-Instruct-extend-step_1000/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/distractor_5/Qwen2-7B-Instruct-extend-step_1000/heatmap.png
--------------------------------------------------------------------------------
/text_extend/niah_output/git_placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/text_extend/niah_output/git_placeholder
--------------------------------------------------------------------------------
/text_extend/plot_ppl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 |
5 |
6 | def main(args):
7 | data = pd.read_csv(args.csv)
8 | fig, ax = plt.subplots(figsize=(10, 5))
9 |
10 | x_data = [float(x) for x in data.columns[1:]]
11 | for row in data.values:
12 | label = row[0].replace("NousResearch/", "")
13 | ax.plot(x_data, [float(x) for x in row[1:]], label=label)
14 |
15 | ax.set_xlabel("Context Window")
16 | ax.set_ylabel("Perplexity (lower is better)")
17 |
18 | ax.set_xlim(args.xmin, args.xmax)
19 | ax.set_ylim(args.ymin, args.ymax)
20 |
21 | ax.legend(loc="upper right")
22 |
23 | fig.savefig(args.csv + ".png")
24 | fig.savefig(args.csv + ".pdf", transparent=True)
25 |
26 |
27 | if __name__ == "__main__":
28 | args = argparse.ArgumentParser()
29 | args.add_argument("csv", type=str)
30 | args.add_argument("--xmin", type=int, default=0)
31 | args.add_argument("--xmax", type=int, default=32768)
32 | args.add_argument("--ymin", type=float, default=3)
33 | args.add_argument("--ymax", type=float, default=10)
34 | main(args.parse_args())
35 |
--------------------------------------------------------------------------------
/vision_niah/data/haystack_embeddings/git_placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/data/haystack_embeddings/git_placeholder
--------------------------------------------------------------------------------
/vision_niah/data/haystack_videos/git_placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/data/haystack_videos/git_placeholder
--------------------------------------------------------------------------------
/vision_niah/data/needle_embeddings/git_placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/data/needle_embeddings/git_placeholder
--------------------------------------------------------------------------------
/vision_niah/eval.sh:
--------------------------------------------------------------------------------
1 | for MODEL_NAME in LongVA-7B
2 | do
3 | mkdir vision_niah/data/haystack_embeddings/$MODEL_NAME
4 | mkdir vision_niah/data/needle_embeddings/$MODEL_NAME
5 | python vision_niah/produce_haystack_embedding.py --model vision_niah/model_weights/$MODEL_NAME --output_dir vision_niah/data/haystack_embeddings/$MODEL_NAME --sampled_frames_num 3000 --pooling_size 2
6 | python vision_niah/produce_needle_embedding.py --model vision_niah/model_weights/$MODEL_NAME --output_dir vision_niah/data/needle_embeddings/$MODEL_NAME --pooling_size 2 --needle_dataset LongVa/v_niah_needles
7 |
8 | accelerate launch --num_processes 8 --config_file easy_context/accelerate_configs/deepspeed_inference.yaml --main_process_port 6000 vision_niah/eval_vision_niah.py \
9 | --model vision_niah/model_weights/$MODEL_NAME \
10 | --needle_embedding_dir vision_niah/data/needle_embeddings/$MODEL_NAME \
11 | --haystack_dir vision_niah/data/haystack_embeddings/$MODEL_NAME \
12 | --needle_dataset lmms-lab/v_niah_needles \
13 | --prompt_template qwen2 \
14 | --max_frame_num 3000 \
15 | --min_frame_num 200\
16 | --frame_interval 200 \
17 | --depth_interval 0.2
18 | done
19 |
20 |
21 |
--------------------------------------------------------------------------------
/vision_niah/eval_vision_niah_sampling.py:
--------------------------------------------------------------------------------
1 | from longva.model.builder import load_pretrained_model
2 | from longva.mm_utils import tokenizer_image_token, process_images
3 | from longva.constants import IMAGE_TOKEN_INDEX
4 | from PIL import Image
5 | from datasets import load_dataset
6 | from decord import VideoReader, cpu
7 | import torch
8 | import numpy as np
9 | # fix seed
10 | torch.manual_seed(0)
11 |
12 | model_path = "lmms-lab/LongVA-7B"
13 | video_path = "vision_niah/data/haystack_videos/movie.mp4"
14 | haystack_frames = 256 # you can change this to several thousands so long you GPU memory can handle it :)
15 | gen_kwargs = {"do_sample": False, "use_cache": True, "max_new_tokens": 1024}
16 | tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, "llava_qwen", device_map="auto")
17 |
18 | preprompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
19 | postprompt = "<|im_end|>\n<|im_start|>assistant\n"
20 | #video input
21 |
22 |
23 | vniah_needle_dataset = load_dataset("lmms-lab/v_niah_needles")["test"]
24 | question = vniah_needle_dataset[1]["question"]
25 | image = vniah_needle_dataset[1]["image"].convert("RGB")
26 | answer = vniah_needle_dataset[1]["answer"]
27 | images_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.float16)
28 |
29 |
30 |
31 | prompt = preprompt + "" + question + postprompt
32 | print(prompt)
33 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
34 | vr = VideoReader(video_path, ctx=cpu(0))
35 | total_frame_num = len(vr)
36 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, haystack_frames, dtype=int)
37 | frame_idx = uniform_sampled_frames.tolist()
38 | frames = vr.get_batch(frame_idx).asnumpy()
39 | video_tensor = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.float16)
40 | # insert image_tensor in the middle
41 | video_tensor = torch.cat([video_tensor[:len(video_tensor)//2], images_tensor, video_tensor[len(video_tensor)//2:]], dim=0)
42 | # insert at the very end
43 | # video_tensor = torch.cat([video_tensor, images_tensor], dim=0)
44 | with torch.inference_mode():
45 | output_ids = model.generate(input_ids, images=[video_tensor], modalities=["video"], **gen_kwargs)
46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
47 | print("Output:", outputs)
48 | print("Answer:", answer)
--------------------------------------------------------------------------------
/vision_niah/model_weights/git_placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/model_weights/git_placeholder
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/dataset.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "path": "ucsd.jpeg",
4 | "prompt": "\nFind the frame of the 'While You Were Out' note. What is the name of the university on that note?\nA. University of california, los angeles\nB. University of california, san diego\nC. University of california, berkeley\nD. University of california, santa barbara\nAnswer with the option's letter from the given choices directly.",
5 | "answer": "B"
6 | },
7 | {
8 | "path": "sora_balloon.png",
9 | "prompt": "\nFind the frame of a couple in a wedding. In side the frame, there is a balloon on the bridegroom's head. What is the color of that ballon?\nAnswer the question using a single word or phrase.",
10 | "answer": "Yellow"
11 | },
12 | {
13 | "path": "selenium_green.jpg",
14 | "prompt": "\nFind the frame with the image of Selenium tablets. How many mg does each tablet contain?\nAnswer the question using a single word or phrase.",
15 | "answer": "200"
16 | },
17 | {
18 | "path": "panda_scientist.png",
19 | "prompt": "\nFind the frame of a scientist. The scientist is a...\nA. Bird\nB. Elephant\nC. Panda\nD. Dog\nAnswer with the option's letter from the given choices directly.",
20 | "answer": "C"
21 | },
22 | {
23 | "path": "teddy_bear_times_square.png",
24 | "prompt": "\nFind the frame of a teddy bear. Where is this teddy bear?\nA. Times Square\nB. Eiffel Tower\nC. Taj Mahal\nD. Sydney Opera House\nAnswer with the option's letter from the given choices directly.",
25 | "answer": "A"
26 | }
27 | ]
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/generate_hf_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import json
17 | import os
18 | from PIL import Image
19 | import datasets
20 |
21 |
22 | # Find for instance the citation on arxiv or on the dataset repo/website
23 | _CITATION = """\
24 | @inproceedings{masry-etal-2022-chartqa,
25 | title = "{C}hart{QA}: A Benchmark for Question Answering about Charts with Visual and Logical Reasoning",
26 | author = "Masry, Ahmed and
27 | Long, Do and
28 | Tan, Jia Qing and
29 | Joty, Shafiq and
30 | Hoque, Enamul",
31 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022",
32 | month = may,
33 | year = "2022",
34 | address = "Dublin, Ireland",
35 | publisher = "Association for Computational Linguistics",
36 | url = "https://aclanthology.org/2022.findings-acl.177",
37 | doi = "10.18653/v1/2022.findings-acl.177",
38 | pages = "2263--2279",
39 | }
40 | """
41 | _DESCRIPTION = "A largescale benchmark covering 9.6K human-written questions as well as 23.1K questions generated from human-written chart summaries."
42 |
43 |
44 | def get_builder_config(VERSION):
45 | builder_config = [
46 | datasets.BuilderConfig(
47 | name=f"V-NIAH",
48 | version=VERSION,
49 | description=f"V-NIAH",
50 | )
51 | ]
52 | return builder_config
53 |
54 |
55 | dataset_features = {
56 | "image": datasets.Image(),
57 | "question": datasets.Value("string"),
58 | "answer": datasets.Value("string"),
59 | }
60 |
61 |
62 | class ChartQA(datasets.GeneratorBasedBuilder):
63 | VERSION = datasets.Version("1.0.0")
64 |
65 | BUILDER_CONFIGS = get_builder_config(VERSION)
66 |
67 | def _info(self):
68 | features = datasets.Features(dataset_features)
69 | return datasets.DatasetInfo(
70 | # This is the description that will appear on the datasets page.
71 | description=_DESCRIPTION,
72 | # This defines the different columns of the dataset and their types
73 | features=features, # Here we define them above because they are different between the two configurations
74 | # If there's a common (input, target) tuple from the features, uncomment supervised_keys line below and
75 | # specify them. They'll be used if as_supervised=True in builder.as_dataset.
76 | # supervised_keys=("sentence", "label"),
77 | # Homepage of the dataset for documentation
78 | # Citation for the dataset
79 | citation=_CITATION,
80 | )
81 |
82 | def _split_generators(self, dl_manager):
83 | # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
84 |
85 | # dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLS
86 | # It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files.
87 | # By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive
88 | image_path = "needle_datasets/images/"
89 | augmented_annotation_path = "needle_datasets/dataset.json"
90 | return [
91 | datasets.SplitGenerator(
92 | name=datasets.Split.TEST,
93 | gen_kwargs={
94 | "annotation": augmented_annotation_path,
95 | "image_path": image_path,
96 | },
97 | ),
98 | ]
99 |
100 | # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
101 | def _generate_examples(self, annotation, image_path):
102 | # The `key` is for legacy reasons (tfds) and is not important in itself, but must be unique for each example.
103 | with open(annotation, encoding="utf-8") as f:
104 | annotation = json.load(f)
105 | for index, data in enumerate(annotation):
106 | question = data["prompt"]
107 | answer = data["answer"]
108 | print(data["path"])
109 | now_data = {}
110 | now_data["image"] = Image.open(image_path + data["path"])
111 | now_data["question"] = question
112 | now_data["answer"] = answer
113 | yield index, now_data
114 |
115 |
116 | if __name__ == "__main__":
117 | from datasets import load_dataset
118 |
119 | data = load_dataset(
120 | "needle_datasets/generate_hf_dataset.py",
121 | )
122 | data.push_to_hub("LongVa/longva_needles", private=False)
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/git_placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/git_placeholder
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/astronaut.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/astronaut.png
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/construction_site.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/construction_site.png
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/dolphin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/dolphin.png
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/panda_scientist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/panda_scientist.png
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/selenium_green.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/selenium_green.jpg
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/sora_balloon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/sora_balloon.png
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/teddy_bear_times_square.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/teddy_bear_times_square.png
--------------------------------------------------------------------------------
/vision_niah/needle_datasets/images/ucsd.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/needle_datasets/images/ucsd.jpeg
--------------------------------------------------------------------------------
/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_1000000/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.14666666666666667
2 |
--------------------------------------------------------------------------------
/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_1000000/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_1000000/heatmap.png
--------------------------------------------------------------------------------
/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_100000000/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.28444444444444444
2 |
--------------------------------------------------------------------------------
/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_100000000/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/niah_output/LLaVA-NeXT-Video-7B-32K/rope_theta_100000000/heatmap.png
--------------------------------------------------------------------------------
/vision_niah/niah_output/LongVA-7B/avg_accuracy.txt:
--------------------------------------------------------------------------------
1 | Average Accuracy: 0.908888888888889
2 |
--------------------------------------------------------------------------------
/vision_niah/niah_output/LongVA-7B/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EvolvingLMMs-Lab/LongVA/f4c8781d22d45e3a6ffd017f04aa33dfcc432fc9/vision_niah/niah_output/LongVA-7B/heatmap.png
--------------------------------------------------------------------------------
/vision_niah/produce_haystack_embedding.py:
--------------------------------------------------------------------------------
1 |
2 | from longva.model.builder import load_pretrained_model
3 | from longva.mm_utils import tokenizer_image_token, get_model_name_from_path
4 | from decord import VideoReader, cpu
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 | import torch
9 | from longva.mm_utils import process_images
10 | import math
11 | from PIL import Image
12 | def load_video_batches(video_path, batch_size):
13 | vr = VideoReader(video_path, ctx=cpu(0))
14 | total_frame_num = len(vr)
15 | fps = round(vr.get_avg_fps())
16 | frame_idx = [i for i in range(0, len(vr), fps)]
17 | for start_idx in range(0, len(frame_idx), batch_size):
18 | end_idx = min(start_idx + batch_size, total_frame_num)
19 | frame_indices = frame_idx[start_idx:end_idx]
20 | batch_frames = vr.get_batch(frame_indices).asnumpy()
21 | yield batch_frames
22 |
23 | def main(args):
24 | video_path = args.video_path
25 | model_path = args.model
26 | model_name = "llava_qwen"
27 |
28 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False,device_map="cuda:0")
29 | del model.model.layers
30 | model.config.image_aspect_ratio = "pad"
31 | model.config.mm_patch_merge_type="flat"
32 | # Process video in batches
33 | batch_size = 32
34 | total_batches = (args.sampled_frames_num + batch_size - 1) // batch_size
35 | image_feature_list = []
36 | if args.add_newline_token:
37 | newline_token_embeddong = model.model.image_newline
38 | with torch.inference_mode():
39 | for i, video_batch in tqdm(enumerate(load_video_batches(video_path, batch_size)), total=total_batches, desc="Processing Video Batches"):
40 | images = [Image.fromarray(frame).convert("RGB") for frame in video_batch]
41 | processed_images = process_images(images, image_processor,model.config).half()
42 | image_features = model.encode_images(processed_images)
43 | print(image_features.shape)
44 | if args.pooling_size != 0:
45 | B, _, F = image_features.shape
46 |
47 | image_features_spatial = image_features.view(B, int(math.sqrt(_)), int(math.sqrt(_)), F).permute(0, 3, 1, 2) # B, F, 24, 24
48 | image_features_spatial_pool = torch.nn.functional.avg_pool2d(image_features_spatial, args.pooling_size, args.pooling_size) # B, F, 12, 12
49 | image_features = image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() # B, 144, F
50 | if args.add_newline_token:
51 | image_features = torch.cat([image_features, newline_token_embeddong.unsqueeze(0).expand(image_features.shape[0], 1, -1)], dim=1)
52 | image_feature_list.append(image_features.to(torch.bfloat16).to("cpu"))
53 | if i > total_batches:
54 | break
55 | image_feature_list = torch.cat(image_feature_list, dim=0)
56 | torch.save(image_feature_list, f"{args.output_dir}/video_embeddings.pt")
57 |
58 | if __name__ == "__main__":
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-Vicuna")
61 | parser.add_argument("--video_path", type=str, default="vision_niah/data/haystack_videos/movie.mp4")
62 | parser.add_argument("--sampled_frames_num", type=int, default=7200)
63 | parser.add_argument("--output_dir", type=str, default="video_needle_haystack/data/haystack_vicuna_embeddings")
64 | parser.add_argument("--pooling_size", type=int, default=0)
65 | parser.add_argument("--add_newline_token", action="store_true")
66 | args = parser.parse_args()
67 | main(args)
68 |
--------------------------------------------------------------------------------
/vision_niah/produce_needle_embedding.py:
--------------------------------------------------------------------------------
1 |
2 | from longva.model.builder import load_pretrained_model
3 | from longva.mm_utils import tokenizer_image_token, get_model_name_from_path
4 | import json
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 | import torch
9 | from pathlib import Path
10 | from PIL import Image
11 | from datasets import load_dataset
12 | from longva.mm_utils import process_images
13 | import math
14 | def main(args):
15 | model_name = "llava_qwen"
16 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model, None, model_name, load_8bit=False,device_map="cuda:0")
17 | model.config.image_aspect_ratio = "pad"
18 | model.config.mm_patch_merge_type="flat"
19 | dataset = load_dataset(args.needle_dataset)["test"]
20 | for index, instance in enumerate(dataset):
21 | image = instance["image"].convert("RGB")
22 | image = process_images([image], image_processor, model.config).half()
23 | image_features = model.encode_images(image)
24 | if args.pooling_size != 0:
25 | B, _, F = image_features.shape
26 | image_features_spatial = image_features.view(B, int(math.sqrt(_)), int(math.sqrt(_)), F).permute(0, 3, 1, 2) # B, F, 24, 24
27 | image_features_spatial_pool = torch.nn.functional.avg_pool2d(image_features_spatial, args.pooling_size, args.pooling_size) # B, F, 12, 12
28 | image_features = image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() # B, 144, F
29 | image_features = image_features.squeeze(0)
30 | torch.save(image_features, f"{args.output_dir}/{index}.pt")
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-Vicuna")
36 | parser.add_argument("--needle_dataset", type=str, default="lmms-lab/v_niah_needles")
37 | parser.add_argument("--output_dir", type=str, default="video_needle_haystack/data/needle_vicuna_embeddings")
38 | parser.add_argument("--pooling_size", type=int, default=0)
39 | args = parser.parse_args()
40 | main(args)
41 |
--------------------------------------------------------------------------------